Skip to main content

compio_rustls/
connector.rs

1use std::{
2    io,
3    sync::Arc,
4};
5
6use compio_io::{
7    AsyncRead,
8    AsyncWrite,
9};
10use rustls::{
11    ClientConfig,
12    ClientConnection,
13    pki_types::ServerName,
14};
15
16use crate::stream::TlsStream;
17
18/// A wrapper around a [`rustls::ClientConfig`].
19///
20/// **Note:** Clones are cheap.
21#[derive(Clone)]
22pub struct TlsConnector {
23    rustls_client_config: Arc<ClientConfig>,
24}
25
26impl TlsConnector {
27    pub fn new(rustls_client_config: Arc<ClientConfig>) -> Self {
28        Self {
29            rustls_client_config,
30        }
31    }
32
33    pub async fn connect<S>(&self, domain: ServerName<'static>, stream: S) -> io::Result<TlsStream<S, ClientConnection>>
34    where
35        S: AsyncRead + AsyncWrite,
36    {
37        self.connect_impl(domain, stream, None, |_| ()).await
38    }
39
40    pub async fn connect_with<S, F>(
41        &self,
42        domain: ServerName<'static>,
43        stream: S,
44        f: F,
45    ) -> io::Result<TlsStream<S, ClientConnection>>
46    where
47        S: AsyncRead + AsyncWrite,
48        F: FnOnce(&mut ClientConnection),
49    {
50        self.connect_impl(domain, stream, None, f).await
51    }
52
53    async fn connect_impl<S, F>(
54        &self,
55        domain: ServerName<'static>,
56        stream: S,
57        alpn_protocols: Option<Vec<Vec<u8>>>,
58        f: F,
59    ) -> io::Result<TlsStream<S, ClientConnection>>
60    where
61        S: AsyncRead + AsyncWrite,
62        F: FnOnce(&mut ClientConnection),
63    {
64        let alpn = alpn_protocols.unwrap_or_else(|| self.rustls_client_config.alpn_protocols.clone());
65        let mut session = ClientConnection::new_with_alpn(self.rustls_client_config.clone(), domain, alpn)
66            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
67
68        f(&mut session);
69
70        let mut tls_stream = TlsStream::new(stream, session);
71        tls_stream.handshake().await?;
72
73        Ok(tls_stream)
74    }
75
76    pub fn with_alpn(&self, alpn_protocols: Vec<Vec<u8>>) -> TlsConnectorWithAlpn<'_> {
77        TlsConnectorWithAlpn {
78            inner: self,
79            alpn_protocols,
80        }
81    }
82
83    /// Get a read-only reference to underlying config
84    pub fn config(&self) -> &Arc<ClientConfig> {
85        &self.rustls_client_config
86    }
87}
88
89pub struct TlsConnectorWithAlpn<'c> {
90    inner:          &'c TlsConnector,
91    alpn_protocols: Vec<Vec<u8>>,
92}
93
94impl<'c> TlsConnectorWithAlpn<'c> {
95    pub async fn connect<S>(self, domain: ServerName<'static>, stream: S) -> io::Result<TlsStream<S, ClientConnection>>
96    where
97        S: AsyncRead + AsyncWrite,
98    {
99        self.inner
100            .connect_impl(domain, stream, Some(self.alpn_protocols), |_| ())
101            .await
102    }
103
104    pub async fn connect_with<S, F>(
105        self,
106        domain: ServerName<'static>,
107        stream: S,
108        f: F,
109    ) -> io::Result<TlsStream<S, ClientConnection>>
110    where
111        S: AsyncRead + AsyncWrite,
112        F: FnOnce(&mut ClientConnection),
113    {
114        self.inner
115            .connect_impl(domain, stream, Some(self.alpn_protocols), f)
116            .await
117    }
118}