use {
crate::proto::{Error, Handshake, Session},
futures_lite::prelude::*,
futures_rustls::{
TlsConnector,
client::TlsStream,
pki_types::ServerName,
rustls::{ClientConfig, RootCertStore},
},
std::{io, sync::Arc},
url::Host,
};
pub use crate::negotiate::{Negotiate, Select};
pub struct Tls<N> {
inner: N,
connector: TlsConnector,
}
impl<N> Tls<N> {
#[cfg(feature = "webpki-roots")]
#[cfg_attr(docsrs, doc(cfg(feature = "webpki-roots")))]
pub fn with_webpki_roots(inner: N) -> Self
where
N: Negotiate + 'static,
{
use std::{
any::{Any, TypeId},
collections::BTreeMap,
sync::Mutex,
};
struct Configs(Mutex<BTreeMap<TypeId, Arc<ClientConfig>>>);
impl Configs {
const fn new() -> Self {
Self(Mutex::new(BTreeMap::new()))
}
fn get<N>(&self, inner: &N) -> Arc<ClientConfig>
where
N: Negotiate + 'static,
{
let init = || {
let root = RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
};
let mut conf = ClientConfig::builder()
.with_root_certificates(root)
.with_no_client_auth();
conf.alpn_protocols.extend(inner.support().map(Vec::from));
Arc::new(conf)
};
self.0
.lock()
.expect("lock configs")
.entry(inner.type_id())
.or_insert_with(init)
.clone()
}
}
static CONFS: Configs = Configs::new();
let connector = TlsConnector::from(CONFS.get(&inner));
Self::with_connector(inner, connector)
}
pub fn with_cert(inner: N, cert: &[u8]) -> Result<Self, Error>
where
N: Negotiate,
{
let conf = read_tls_config(cert, &inner)?;
let connector = TlsConnector::from(Arc::new(conf));
Ok(Self::with_connector(inner, connector))
}
pub fn with_connector(inner: N, connector: TlsConnector) -> Self {
Self { inner, connector }
}
}
impl<I, B, N> Handshake<I, B> for Tls<N>
where
I: AsyncRead + AsyncWrite + Unpin,
N: Negotiate<Handshake: Handshake<TlsStream<I>, B>>,
{
type Client = <N::Handshake as Handshake<TlsStream<I>, B>>::Client;
async fn handshake(
self,
se: Session<I>,
) -> Result<(Self::Client, impl Future<Output = ()>), Error> {
let Session { addr, io } = se;
let name = as_server_name(&addr.host)?.to_owned();
let tls = self.connector.connect(name, io).await?;
let (_, conn) = tls.get_ref();
let proto = conn
.alpn_protocol()
.unwrap_or_else(|| self.inner.support().next().unwrap_or_default());
let handshake = self
.inner
.negotiate(proto)
.ok_or_else(|| Error::UnsupportedProtocol(Box::from(proto)))?;
let se = Session { addr, io: tls };
let (client, conn) = handshake.handshake(se).await?;
Ok((client, conn))
}
}
fn as_server_name(host: &Host) -> Result<ServerName<'_>, Error> {
match host {
Host::Domain(domain) => {
ServerName::try_from(domain.as_str()).map_err(|_| Error::InvalidHost)
}
Host::Ipv4(ip) => Ok(ServerName::from(*ip)),
Host::Ipv6(ip) => Ok(ServerName::from(*ip)),
}
}
fn read_tls_config<N>(mut cert: &[u8], inner: &N) -> Result<ClientConfig, io::Error>
where
N: Negotiate,
{
let mut root = RootCertStore::empty();
for cert in rustls_pemfile::certs(&mut cert) {
root.add(cert?).map_err(io::Error::other)?;
}
let mut conf = ClientConfig::builder()
.with_root_certificates(root)
.with_no_client_auth();
conf.alpn_protocols.extend(inner.support().map(Vec::from));
Ok(conf)
}