async_tls_lite/
connector.rs

1// ref https://github.com/async-rs/async-tls/blob/v0.7.1/src/connector.rs
2
3use std::io;
4use std::sync::Arc;
5
6use futures_util::io::{AsyncRead, AsyncWrite};
7use rustls::{ClientConfig, ClientSession};
8use webpki::DNSNameRef;
9
10use crate::{handshake, TlsStream};
11
12#[derive(Clone)]
13pub struct TlsConnector {
14    inner: Arc<ClientConfig>,
15}
16
17impl From<Arc<ClientConfig>> for TlsConnector {
18    fn from(inner: Arc<ClientConfig>) -> TlsConnector {
19        TlsConnector { inner }
20    }
21}
22
23impl Default for TlsConnector {
24    fn default() -> Self {
25        let mut config = ClientConfig::new();
26        config
27            .root_store
28            .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
29        Arc::new(config).into()
30    }
31}
32
33impl TlsConnector {
34    pub fn new() -> Self {
35        Default::default()
36    }
37
38    pub async fn connect<IO>(
39        &self,
40        domain: impl AsRef<str>,
41        stream: IO,
42    ) -> io::Result<TlsStream<ClientSession, IO>>
43    where
44        IO: AsyncRead + AsyncWrite + Unpin,
45    {
46        let domain = match DNSNameRef::try_from_ascii_str(domain.as_ref()) {
47            Ok(domain) => domain,
48            Err(_) => {
49                return Err(io::Error::new(
50                    io::ErrorKind::InvalidInput,
51                    "invalid domain",
52                ));
53            }
54        };
55
56        let session = ClientSession::new(&self.inner, domain);
57
58        handshake(session, stream).await
59    }
60}