use crate::{HyperHttpSnafu, HyperSnafu, IoSnafu, Result, RustlsSnafu};
use http_body_util::{Empty, Full};
use hyper::body::Bytes;
use hyper::body::Incoming;
use hyper::header::HeaderValue;
use hyper::http::uri::Scheme;
use hyper::{Request, Response, Uri};
use hyper_util::rt::TokioIo;
use snafu::ResultExt;
use std::io::Error as IoError;
use std::sync::Arc;
use std::sync::Once;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsConnector;
trait AsyncReadWrite: AsyncRead + AsyncWrite {}
impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite {}
static CRYPTO_PROVIDER_INIT: Once = Once::new();
pub struct HttpClient {
pub tls_config: TlsConnector,
}
pub struct HttpClientBuilder {
tls_config: Option<TlsConnector>,
}
impl HttpClientBuilder {
pub fn new() -> Self {
HttpClientBuilder { tls_config: None }
}
pub fn tls_config(mut self, tls_config: TlsConnector) -> Self {
self.tls_config = Some(tls_config);
self
}
pub fn build(self) -> Result<HttpClient> {
CRYPTO_PROVIDER_INIT.call_once(|| {
let _ = rustls::crypto::CryptoProvider::install_default(
rustls::crypto::aws_lc_rs::default_provider(),
);
});
let tls_config = match self.tls_config {
Some(x) => x,
None => {
let mut root_cert_store = rustls::RootCertStore::empty();
for cert in
rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
root_cert_store.add(cert).context(RustlsSnafu)?;
}
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
TlsConnector::from(Arc::new(config))
}
};
Ok(HttpClient { tls_config })
}
}
impl Default for HttpClientBuilder {
fn default() -> Self {
Self::new()
}
}
impl HttpClient {
pub async fn new() -> Result<Self> {
HttpClientBuilder::new().build()
}
pub fn builder() -> HttpClientBuilder {
HttpClientBuilder::new()
}
pub async fn head<T>(&self, tor: &crate::Tor, uri: T) -> Result<Response<Incoming>>
where
Uri: TryFrom<T>,
<Uri as TryFrom<T>>::Error: Into<hyper::http::Error>,
{
let req = Request::head(uri)
.body(Empty::<Bytes>::new())
.context(HyperHttpSnafu)?;
let resp = self.send_request(tor, req).await?;
Ok(resp)
}
pub async fn get<T>(&self, tor: &crate::Tor, uri: T) -> Result<Response<Incoming>>
where
Uri: TryFrom<T>,
<Uri as TryFrom<T>>::Error: Into<hyper::http::Error>,
{
let req = Request::get(uri)
.body(Empty::<Bytes>::new())
.context(HyperHttpSnafu)?;
let resp = self.send_request(tor, req).await?;
Ok(resp)
}
pub async fn post<T>(
&self,
tor: &crate::Tor,
uri: T,
content_type: &str,
body: Bytes,
) -> Result<Response<Incoming>>
where
Uri: TryFrom<T>,
<Uri as TryFrom<T>>::Error: Into<hyper::http::Error>,
{
let req = Request::post(uri)
.header(hyper::header::CONTENT_TYPE, content_type)
.body(Full::<Bytes>::from(body))
.context(HyperHttpSnafu)?;
let resp = self.send_request(tor, req).await?;
Ok(resp)
}
async fn send_request<B>(&self, tor: &crate::Tor, req: Request<B>) -> Result<Response<Incoming>>
where
B: hyper::body::Body + Send + 'static, B::Data: Send, B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, {
let stream = self.create_stream(tor, req.uri()).await.context(IoSnafu)?;
let (mut request_sender, connection) =
hyper::client::conn::http1::handshake(TokioIo::new(stream))
.await
.context(HyperSnafu)?;
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("Error: {e:?}");
}
});
let mut final_req_builder = Request::builder().uri(req.uri()).method(req.method());
for (key, value) in req.headers() {
final_req_builder = final_req_builder.header(key, value);
}
if !req.headers().contains_key(hyper::header::HOST)
&& let Some(authority) = req.uri().authority()
&& let Ok(host_header_value) = HeaderValue::from_str(authority.as_str())
{
final_req_builder = final_req_builder.header(hyper::header::HOST, host_header_value);
}
let final_req = final_req_builder
.body(req.into_body())
.context(HyperHttpSnafu)?;
let resp = request_sender
.send_request(final_req)
.await
.context(HyperSnafu)?;
Ok(resp)
}
async fn create_stream(
&self,
tor: &crate::Tor,
url: &Uri,
) -> Result<Box<dyn AsyncReadWrite + Unpin + Send>, IoError> {
let host = url
.host()
.ok_or_else(|| IoError::new(std::io::ErrorKind::InvalidInput, "Missing host"))?;
let https = url.scheme() == Some(&Scheme::HTTPS);
let port = match url.port_u16() {
Some(port) => port,
None if https => 443,
None => 80,
};
let stream = tor.connect((host, port)).await.map_err(IoError::other)?;
if https {
let server_name = url
.host()
.unwrap_or_default()
.to_string()
.try_into()
.unwrap();
let wrapped_stream = self
.tls_config
.connect(server_name, stream)
.await
.map_err(IoError::other)?;
Ok(Box::new(wrapped_stream) as Box<dyn AsyncReadWrite + Unpin + Send>)
} else {
Ok(Box::new(stream) as Box<dyn AsyncReadWrite + Unpin + Send>)
}
}
}