cebolla 0.3.0

A convenience layer over Arti for building and connecting to Tor hidden services
Documentation
use crate::{HyperHttpSnafu, HyperSnafu, IoSnafu, Result, RustlsSnafu};

use http_body_util::{Empty, Full};
use hyper::body::Bytes;
use hyper::body::Incoming;
use hyper::header::HeaderValue;
use hyper::http::uri::Scheme;
use hyper::{Request, Response, Uri};
use hyper_util::rt::TokioIo;
use snafu::ResultExt;
use std::io::Error as IoError;
use std::sync::Arc;
use std::sync::Once;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsConnector;

/// A trait for types that implement both `AsyncRead` and `AsyncWrite`.
trait AsyncReadWrite: AsyncRead + AsyncWrite {}

impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite {}

static CRYPTO_PROVIDER_INIT: Once = Once::new();

///  Tor HTTP client
pub struct HttpClient {
    /// TLS configuration for HTTPS connections.
    pub tls_config: TlsConnector,
}

/// Builder for creating an `HttpClient`.
pub struct HttpClientBuilder {
    tls_config: Option<TlsConnector>,
}

impl HttpClientBuilder {
    /// Creates a new `ClientBuilder`.
    pub fn new() -> Self {
        HttpClientBuilder { tls_config: None }
    }

    /// Sets the TLS configuration for the `ClientBuilder`.
    pub fn tls_config(mut self, tls_config: TlsConnector) -> Self {
        self.tls_config = Some(tls_config);
        self
    }

    /// Builds the `Client` from the `Clientuilder`.
    pub fn build(self) -> Result<HttpClient> {
        CRYPTO_PROVIDER_INIT.call_once(|| {
            // It's ok if this fails, it likely means a provider was already installed.
            let _ = rustls::crypto::CryptoProvider::install_default(
                rustls::crypto::aws_lc_rs::default_provider(),
            );
        });

        let tls_config = match self.tls_config {
            Some(x) => x,
            None => {
                let mut root_cert_store = rustls::RootCertStore::empty();
                for cert in
                    rustls_native_certs::load_native_certs().expect("could not load platform certs")
                {
                    root_cert_store.add(cert).context(RustlsSnafu)?;
                }
                let config = rustls::ClientConfig::builder()
                    .with_root_certificates(root_cert_store)
                    .with_no_client_auth();
                TlsConnector::from(Arc::new(config))
            }
        };
        Ok(HttpClient { tls_config })
    }
}

impl Default for HttpClientBuilder {
    fn default() -> Self {
        Self::new()
    }
}

impl HttpClient {
    /// Creates a new `Client` with default configuration.
    pub async fn new() -> Result<Self> {
        HttpClientBuilder::new().build()
    }

    pub fn builder() -> HttpClientBuilder {
        HttpClientBuilder::new()
    }

    /// Sends an HTTP HEAD request to the specified URI.
    pub async fn head<T>(&self, tor: &crate::Tor, uri: T) -> Result<Response<Incoming>>
    where
        Uri: TryFrom<T>,
        <Uri as TryFrom<T>>::Error: Into<hyper::http::Error>,
    {
        let req = Request::head(uri)
            .body(Empty::<Bytes>::new())
            .context(HyperHttpSnafu)?;

        let resp = self.send_request(tor, req).await?;
        Ok(resp)
    }

    /// Sends an HTTP GET request to the specified URI.
    pub async fn get<T>(&self, tor: &crate::Tor, uri: T) -> Result<Response<Incoming>>
    where
        Uri: TryFrom<T>,
        <Uri as TryFrom<T>>::Error: Into<hyper::http::Error>,
    {
        let req = Request::get(uri)
            .body(Empty::<Bytes>::new())
            .context(HyperHttpSnafu)?;

        let resp = self.send_request(tor, req).await?;
        Ok(resp)
    }

    /// Sends an HTTP POST request to the specified URI with the given content type and body.
    pub async fn post<T>(
        &self,
        tor: &crate::Tor,
        uri: T,
        content_type: &str,
        body: Bytes,
    ) -> Result<Response<Incoming>>
    where
        Uri: TryFrom<T>,
        <Uri as TryFrom<T>>::Error: Into<hyper::http::Error>,
    {
        let req = Request::post(uri)
            .header(hyper::header::CONTENT_TYPE, content_type)
            .body(Full::<Bytes>::from(body))
            .context(HyperHttpSnafu)?;

        let resp = self.send_request(tor, req).await?;
        Ok(resp)
    }

    /// Sends an HTTP request and returns the response.
    async fn send_request<B>(&self, tor: &crate::Tor, req: Request<B>) -> Result<Response<Incoming>>
    where
        B: hyper::body::Body + Send + 'static, // B must implement Body and be sendable
        B::Data: Send,                         // B::Data must be sendable
        B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, // B::Error must be convertible to a boxed error
    {
        let stream = self.create_stream(tor, req.uri()).await.context(IoSnafu)?;

        let (mut request_sender, connection) =
            hyper::client::conn::http1::handshake(TokioIo::new(stream))
                .await
                .context(HyperSnafu)?;

        // Spawn a task to poll the connection and drive the HTTP state
        tokio::spawn(async move {
            if let Err(e) = connection.await {
                eprintln!("Error: {e:?}");
            }
        });

        let mut final_req_builder = Request::builder().uri(req.uri()).method(req.method());

        for (key, value) in req.headers() {
            final_req_builder = final_req_builder.header(key, value);
        }

        if !req.headers().contains_key(hyper::header::HOST)
            && let Some(authority) = req.uri().authority()
            && let Ok(host_header_value) = HeaderValue::from_str(authority.as_str())
        {
            final_req_builder = final_req_builder.header(hyper::header::HOST, host_header_value);
        }

        let final_req = final_req_builder
            .body(req.into_body())
            .context(HyperHttpSnafu)?;

        let resp = request_sender
            .send_request(final_req)
            .await
            .context(HyperSnafu)?;

        Ok(resp)
    }

    /// Creates a stream for the specified URI, optionally wrapping it with TLS.
    async fn create_stream(
        &self,
        tor: &crate::Tor,
        url: &Uri,
    ) -> Result<Box<dyn AsyncReadWrite + Unpin + Send>, IoError> {
        let host = url
            .host()
            .ok_or_else(|| IoError::new(std::io::ErrorKind::InvalidInput, "Missing host"))?;
        let https = url.scheme() == Some(&Scheme::HTTPS);

        let port = match url.port_u16() {
            Some(port) => port,
            None if https => 443,
            None => 80,
        };

        let stream = tor.connect((host, port)).await.map_err(IoError::other)?;

        if https {
            let server_name = url
                .host()
                .unwrap_or_default()
                .to_string()
                .try_into()
                .unwrap();
            let wrapped_stream = self
                .tls_config
                .connect(server_name, stream)
                .await
                .map_err(IoError::other)?;
            Ok(Box::new(wrapped_stream) as Box<dyn AsyncReadWrite + Unpin + Send>)
        } else {
            Ok(Box::new(stream) as Box<dyn AsyncReadWrite + Unpin + Send>)
        }
    }
}