areq 0.1.0-alpha5

Async runtime-agnostic HTTP requests
Documentation
//! The http client over TLS.

use {
    crate::proto::{Error, Handshake, Session},
    futures_lite::prelude::*,
    futures_rustls::{
        TlsConnector,
        client::TlsStream,
        pki_types::ServerName,
        rustls::{ClientConfig, RootCertStore},
    },
    std::{io, sync::Arc},
    url::Host,
};

pub use crate::negotiate::{Negotiate, Select};

pub struct Tls<N> {
    inner: N,
    connector: TlsConnector,
}

impl<N> Tls<N> {
    #[cfg(feature = "webpki-roots")]
    #[cfg_attr(docsrs, doc(cfg(feature = "webpki-roots")))]
    pub fn with_webpki_roots(inner: N) -> Self
    where
        N: Negotiate + 'static,
    {
        use std::{
            any::{Any, TypeId},
            collections::BTreeMap,
            sync::Mutex,
        };

        struct Configs(Mutex<BTreeMap<TypeId, Arc<ClientConfig>>>);

        impl Configs {
            const fn new() -> Self {
                Self(Mutex::new(BTreeMap::new()))
            }

            fn get<N>(&self, inner: &N) -> Arc<ClientConfig>
            where
                N: Negotiate + 'static,
            {
                let init = || {
                    let root = RootCertStore {
                        roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
                    };

                    let mut conf = ClientConfig::builder()
                        .with_root_certificates(root)
                        .with_no_client_auth();

                    conf.alpn_protocols.extend(inner.support().map(Vec::from));
                    Arc::new(conf)
                };

                self.0
                    .lock()
                    .expect("lock configs")
                    .entry(inner.type_id())
                    .or_insert_with(init)
                    .clone()
            }
        }

        static CONFS: Configs = Configs::new();

        let connector = TlsConnector::from(CONFS.get(&inner));
        Self::with_connector(inner, connector)
    }

    pub fn with_cert(inner: N, cert: &[u8]) -> Result<Self, Error>
    where
        N: Negotiate,
    {
        let conf = read_tls_config(cert, &inner)?;
        let connector = TlsConnector::from(Arc::new(conf));
        Ok(Self::with_connector(inner, connector))
    }

    pub fn with_connector(inner: N, connector: TlsConnector) -> Self {
        Self { inner, connector }
    }
}

impl<I, B, N> Handshake<I, B> for Tls<N>
where
    I: AsyncRead + AsyncWrite + Unpin,
    N: Negotiate<Handshake: Handshake<TlsStream<I>, B>>,
{
    type Client = <N::Handshake as Handshake<TlsStream<I>, B>>::Client;

    async fn handshake(
        self,
        se: Session<I>,
    ) -> Result<(Self::Client, impl Future<Output = ()>), Error> {
        let Session { addr, io } = se;

        let name = as_server_name(&addr.host)?.to_owned();
        let tls = self.connector.connect(name, io).await?;

        let (_, conn) = tls.get_ref();
        let proto = conn
            .alpn_protocol()
            // if the remote server doesn't specify a protocol,
            // fall back to the first supported one by default
            .unwrap_or_else(|| self.inner.support().next().unwrap_or_default());

        let handshake = self
            .inner
            .negotiate(proto)
            .ok_or_else(|| Error::UnsupportedProtocol(Box::from(proto)))?;

        let se = Session { addr, io: tls };
        let (client, conn) = handshake.handshake(se).await?;
        Ok((client, conn))
    }
}

fn as_server_name(host: &Host) -> Result<ServerName<'_>, Error> {
    match host {
        Host::Domain(domain) => {
            ServerName::try_from(domain.as_str()).map_err(|_| Error::InvalidHost)
        }
        Host::Ipv4(ip) => Ok(ServerName::from(*ip)),
        Host::Ipv6(ip) => Ok(ServerName::from(*ip)),
    }
}

fn read_tls_config<N>(mut cert: &[u8], inner: &N) -> Result<ClientConfig, io::Error>
where
    N: Negotiate,
{
    let mut root = RootCertStore::empty();
    for cert in rustls_pemfile::certs(&mut cert) {
        root.add(cert?).map_err(io::Error::other)?;
    }

    let mut conf = ClientConfig::builder()
        .with_root_certificates(root)
        .with_no_client_auth();

    conf.alpn_protocols.extend(inner.support().map(Vec::from));
    Ok(conf)
}