use crate::http::address::Address;
use crate::http::cookie::Cookie;
use crate::http::headers::{Header, HeaderLike, HeaderType, Headers};
use crate::http::method::Method;
use crate::http::{Request, Response, StatusCode};
use std::error::Error;
use std::io::Write;
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
#[cfg(feature = "tls")]
use rustls::{Certificate, ClientConfig, ClientConnection, RootCertStore, StreamOwned};
#[cfg(feature = "tls")]
use rustls_native_certs::load_native_certs;
#[cfg(feature = "tls")]
use std::convert::TryInto;
#[cfg(feature = "tls")]
use std::sync::Arc;
#[derive(Default)]
pub struct Client {
#[cfg(feature = "tls")]
tls_config: Option<Arc<ClientConfig>>,
}
impl Client {
pub fn new() -> Self {
Self::default()
}
pub fn get(&mut self, url: impl AsRef<str>) -> Result<ClientRequest, Box<dyn Error>> {
let url = Self::parse_url(url).ok_or("Invalid URL")?;
let request = Request {
method: Method::Get,
uri: url.path,
headers: url.host_headers,
query: url.query,
version: "HTTP/1.1".to_string(),
content: None,
address: Address::new(url.host).unwrap(),
};
Ok(ClientRequest {
address: url.host,
client: self,
protocol: url.protocol,
request,
follow_redirects: false,
cookies: Vec::new(),
})
}
pub fn post(
&mut self,
url: impl AsRef<str>,
data: Vec<u8>,
) -> Result<ClientRequest, Box<dyn Error>> {
let url = Self::parse_url(url).ok_or("Invalid URL")?;
let content_length = Header::new("Content-Length", data.len().to_string());
let mut request = Request {
method: Method::Post,
uri: url.path,
headers: url.host_headers,
query: url.query,
version: "HTTP/1.1".to_string(),
content: Some(data),
address: Address::new(url.host).unwrap(),
};
request.headers.push(content_length);
Ok(ClientRequest {
address: url.host,
client: self,
protocol: url.protocol,
request,
follow_redirects: false,
cookies: Vec::new(),
})
}
pub fn put(
&mut self,
url: impl AsRef<str>,
data: Vec<u8>,
) -> Result<ClientRequest, Box<dyn Error>> {
let url = Self::parse_url(url).ok_or("Invalid URL")?;
let content_length = Header::new("Content-Length", data.len().to_string());
let mut request = Request {
method: Method::Put,
uri: url.path,
headers: url.host_headers,
query: url.query,
version: "HTTP/1.1".to_string(),
content: Some(data),
address: Address::new(url.host).unwrap(),
};
request.headers.push(content_length);
Ok(ClientRequest {
address: url.host,
client: self,
protocol: url.protocol,
request,
follow_redirects: false,
cookies: Vec::new(),
})
}
pub fn delete(&mut self, url: impl AsRef<str>) -> Result<ClientRequest, Box<dyn Error>> {
let url = Self::parse_url(url).ok_or("Invalid URL")?;
let request = Request {
method: Method::Delete,
uri: url.path,
headers: url.host_headers,
query: url.query,
version: "HTTP/1.1".to_string(),
content: None,
address: Address::new(url.host).unwrap(),
};
Ok(ClientRequest {
address: url.host,
client: self,
protocol: url.protocol,
request,
follow_redirects: false,
cookies: Vec::new(),
})
}
pub fn request(
&self,
address: impl ToSocketAddrs,
request: Request,
) -> Result<Response, Box<dyn Error>> {
let mut stream = TcpStream::connect(address)?;
let request_bytes: Vec<u8> = request.into();
stream.write_all(&request_bytes)?;
let response = Response::from_stream(&mut stream)?;
Ok(response)
}
#[cfg(not(feature = "tls"))]
pub fn request_tls(
&mut self,
_: impl ToSocketAddrs,
_: Request,
) -> Result<Response, Box<dyn Error>> {
Err("TLS feature is not enabled".into())
}
#[cfg(feature = "tls")]
pub fn request_tls(
&mut self,
address: impl ToSocketAddrs,
request: Request,
) -> Result<Response, Box<dyn Error>> {
if self.tls_config.is_none() {
let mut roots = RootCertStore::empty();
for cert in load_native_certs()? {
roots.add(&Certificate(cert.0))?;
}
let conf = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
self.tls_config = Some(Arc::new(conf))
}
let conn = ClientConnection::new(
self.tls_config.as_ref().unwrap().clone(),
request
.headers
.get(&HeaderType::Host)
.unwrap()
.try_into()
.unwrap(),
)?;
let sock = TcpStream::connect(address)?;
let mut tls = StreamOwned::new(conn, sock);
let request_bytes: Vec<u8> = request.into();
tls.write_all(&request_bytes)?;
let response = Response::from_stream(&mut tls)?;
Ok(response)
}
pub(crate) fn parse_url(url: impl AsRef<str>) -> Option<ParsedUrl> {
let url = url.as_ref();
if let Some(stripped) = url.strip_prefix("http://") {
let protocol = Protocol::Http;
let (host, path) = stripped.split_once('/').unwrap_or((stripped, ""));
let mut headers = Headers::new();
headers.add(HeaderType::Host, host);
let host = format!("{}:80", host);
let host = host.to_socket_addrs().ok()?.next()?;
let (path, query) = path.split_once('?').unwrap_or((path, ""));
Some(ParsedUrl {
protocol,
host,
host_headers: headers,
path: format!("/{}", path),
query: query.to_string(),
})
} else if let Some(stripped) = url.strip_prefix("https://") {
let protocol = Protocol::Https;
let (host, path) = stripped.split_once('/').unwrap_or((stripped, ""));
let mut headers = Headers::new();
headers.add(HeaderType::Host, host);
let host = format!("{}:443", host);
let host = host.to_socket_addrs().ok()?.next()?;
let (path, query) = path.split_once('?').unwrap_or((path, ""));
Some(ParsedUrl {
protocol,
host,
host_headers: headers,
path: format!("/{}", path),
query: query.to_string(),
})
} else {
None
}
}
}
pub struct ClientRequest<'a> {
client: &'a mut Client,
protocol: Protocol,
address: SocketAddr,
request: Request,
follow_redirects: bool,
cookies: Vec<Cookie>,
}
impl<'a> ClientRequest<'a> {
pub fn with_header(mut self, header: impl HeaderLike, value: impl AsRef<str>) -> Self {
self.request.headers.add(header, value);
self
}
pub fn with_cookie(mut self, cookie: Cookie) -> Self {
self.cookies.push(cookie);
self
}
pub fn with_redirects(mut self, follow_redirects: bool) -> Self {
self.follow_redirects = follow_redirects;
self
}
pub fn send(mut self) -> Result<Response, Box<dyn Error>> {
if let Some(header) = Cookie::to_header(&self.cookies) {
self.request.headers.push(header);
}
let response = match self.protocol {
Protocol::Http => self.client.request(self.address, self.request.clone()),
Protocol::Https => self.client.request_tls(self.address, self.request.clone()),
};
if self.follow_redirects
&& response
.as_ref()
.map(|r| {
r.status_code == StatusCode::MovedPermanently
|| r.status_code == StatusCode::TemporaryRedirect
|| r.status_code == StatusCode::Found
})
.unwrap_or(false)
{
response
.and_then(|r| {
r.headers
.get(&HeaderType::Location)
.map_or(Err("No location header".into()), |s| Ok(s.to_string()))
})
.and_then(|l| {
if l.starts_with('/') {
self.request.uri = l;
} else {
let new_url = Client::parse_url(l).ok_or("Invalid URL")?;
let request = Request {
method: self.request.method,
uri: new_url.path,
headers: new_url.host_headers,
query: new_url.query,
version: "HTTP/1.1".to_string(),
content: self.request.content,
address: Address::new(new_url.host).unwrap(),
};
self.protocol = new_url.protocol;
self.address = new_url.host;
self.request = request;
}
self.send()
})
} else {
response
}
}
pub fn into_inner(self) -> Request {
self.request
}
}
#[derive(Debug, PartialEq, Eq)]
pub(crate) struct ParsedUrl {
pub(crate) protocol: Protocol,
pub(crate) host: SocketAddr,
pub(crate) host_headers: Headers,
pub(crate) path: String,
pub(crate) query: String,
}
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum Protocol {
Http,
Https,
}