use crate::{Error, Request, ResponseLazy};
#[cfg(feature = "https")]
use native_tls::{TlsConnector, TlsStream};
use std::io::{self, BufReader, BufWriter, Read, Write};
use std::net::TcpStream;
#[cfg(feature = "timeout")]
use std::net::ToSocketAddrs;
#[cfg(feature = "timeout")]
use std::time::Duration;
use std::time::Instant;
type UnsecuredStream = BufReader<TcpStream>;
#[cfg(feature = "https")]
type SecuredStream = TlsStream<TcpStream>;
pub(crate) enum HttpStream {
Unsecured(UnsecuredStream, Option<Instant>),
#[cfg(feature = "https")]
Secured(Box<SecuredStream>, Option<Instant>),
}
impl HttpStream {
fn create_unsecured(reader: UnsecuredStream, timeout_at: Option<Instant>) -> HttpStream {
HttpStream::Unsecured(reader, timeout_at)
}
#[cfg(feature = "https")]
fn create_secured(reader: SecuredStream, timeout_at: Option<Instant>) -> HttpStream {
HttpStream::Secured(Box::new(reader), timeout_at)
}
}
impl Read for HttpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let timeout = |tcp: &TcpStream, timeout_at: Option<Instant>| {
if let Some(timeout_at) = timeout_at {
let now = Instant::now();
if timeout_at <= now {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"The request's timeout was reached.",
));
} else {
tcp.set_read_timeout(Some(timeout_at - now)).ok();
}
}
Ok(())
};
match self {
HttpStream::Unsecured(inner, timeout_at) => {
timeout(inner.get_ref(), *timeout_at)?;
inner.read(buf)
}
#[cfg(feature = "https")]
HttpStream::Secured(inner, timeout_at) => {
timeout(inner.get_ref(), *timeout_at)?;
inner.read(buf)
}
}
}
}
pub struct Connection {
request: Request,
#[cfg(feature = "timeout")]
timeout: Option<u64>,
}
impl Connection {
pub(crate) fn new(request: Request) -> Connection {
#[cfg(feature = "timeout")]
{
let timeout = request
.timeout
.or_else(|| match std::env::var("TINYGET_TIMEOUT") {
Ok(t) => t.parse::<u64>().ok(),
Err(_) => None,
});
Connection { request, timeout }
}
#[cfg(not(feature = "timeout"))]
{
Connection { request }
}
}
#[cfg(feature = "https")]
pub(crate) fn send_https(self) -> Result<ResponseLazy, Error> {
let bytes = self.request.as_bytes();
let dns_name = &self.request.host;
let dns_name = dns_name.split(':').next().unwrap();
let sess = match TlsConnector::new() {
Ok(sess) => sess,
Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))),
};
let tcp = self.connect()?;
let mut tls = match sess.connect(dns_name, tcp) {
Ok(tls) => tls,
Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))),
};
tls.write(&bytes)?;
let response = ResponseLazy::from_stream(HttpStream::create_secured(tls, None))?;
handle_redirects(self, response)
}
#[cfg(all(feature = "https", feature = "timeout"))]
pub(crate) fn send_https_timeout(self, timeout: Duration) -> Result<ResponseLazy, Error> {
let bytes = self.request.as_bytes();
let timeout_duration = self.timeout.map(Duration::from_secs);
let timeout_at = timeout_duration.map(|d| Instant::now() + d);
let dns_name = &self.request.host;
let dns_name = dns_name.split(':').next().unwrap();
let sess = match TlsConnector::new() {
Ok(sess) => sess,
Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))),
};
let tcp = self.connect_timeout(timeout)?;
let mut tls = match sess.connect(dns_name, tcp) {
Ok(tls) => tls,
Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))),
};
tls.write(&bytes)?;
let response = ResponseLazy::from_stream(HttpStream::create_secured(tls, timeout_at))?;
handle_redirects(self, response)
}
pub(crate) fn send(self) -> Result<ResponseLazy, Error> {
let bytes = self.request.as_bytes();
let tcp = self.connect()?;
let mut stream = BufWriter::new(tcp);
stream.write_all(&bytes)?;
let tcp = match stream.into_inner() {
Ok(tcp) => tcp,
Err(_) => {
return Err(Error::Other(
"IntoInnerError after writing the request into the TcpStream.",
));
}
};
let stream = HttpStream::create_unsecured(BufReader::new(tcp), None);
let response = ResponseLazy::from_stream(stream)?;
handle_redirects(self, response)
}
fn connect(&self) -> Result<TcpStream, Error> {
TcpStream::connect(&self.request.host).map_err(Error::from)
}
#[cfg(feature = "timeout")]
#[allow(dead_code)]
pub(crate) fn send_timeout(self, timeout: Duration) -> Result<ResponseLazy, Error> {
let bytes = self.request.as_bytes();
let timeout_duration = self.timeout.map(Duration::from_secs);
let timeout_at = timeout_duration.map(|d| Instant::now() + d);
let tcp = self.connect_timeout(timeout)?;
let mut stream = BufWriter::new(tcp);
stream.get_ref().set_write_timeout(timeout_duration).ok();
stream.write_all(&bytes)?;
let tcp = match stream.into_inner() {
Ok(tcp) => tcp,
Err(_) => {
return Err(Error::Other(
"IntoInnerError after writing the request into the TcpStream.",
));
}
};
let stream = HttpStream::create_unsecured(BufReader::new(tcp), timeout_at);
let response = ResponseLazy::from_stream(stream)?;
handle_redirects(self, response)
}
#[cfg(feature = "timeout")]
fn connect_timeout(&self, timeout: Duration) -> Result<TcpStream, Error> {
let addr = self
.request
.host
.to_socket_addrs()?
.next()
.ok_or(Error::Other("Failed to resolve host to SocketAddr"))?;
TcpStream::connect_timeout(&addr, timeout).map_err(Error::from)
}
}
fn handle_redirects(connection: Connection, response: ResponseLazy) -> Result<ResponseLazy, Error> {
let status_code = response.status_code;
let url = response.headers.get("location");
if let Some(request) = get_redirect(connection, status_code, url) {
request?.send_lazy()
} else {
Ok(response)
}
}
fn get_redirect(
connection: Connection,
status_code: i32,
url: Option<&String>,
) -> Option<Result<Request, Error>> {
match status_code {
301 | 302 | 303 | 307 => match url {
Some(url) => Some(connection.request.redirect_to(url.clone())),
None => Some(Err(Error::RedirectLocationMissing)),
},
_ => None,
}
}