compio_tls/adapter/
mod.rs

1use std::io;
2
3use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
4
5use crate::TlsStream;
6
7#[cfg(feature = "rustls")]
8mod rtls;
9
10#[derive(Debug, Clone)]
11enum TlsConnectorInner {
12    #[cfg(feature = "native-tls")]
13    NativeTls(native_tls::TlsConnector),
14    #[cfg(feature = "rustls")]
15    Rustls(rtls::TlsConnector),
16}
17
18/// A wrapper around a [`native_tls::TlsConnector`] or [`rustls::ClientConfig`],
19/// providing an async `connect` method.
20#[derive(Debug, Clone)]
21pub struct TlsConnector(TlsConnectorInner);
22
23#[cfg(feature = "native-tls")]
24impl From<native_tls::TlsConnector> for TlsConnector {
25    fn from(value: native_tls::TlsConnector) -> Self {
26        Self(TlsConnectorInner::NativeTls(value))
27    }
28}
29
30#[cfg(feature = "rustls")]
31impl From<std::sync::Arc<rustls::ClientConfig>> for TlsConnector {
32    fn from(value: std::sync::Arc<rustls::ClientConfig>) -> Self {
33        Self(TlsConnectorInner::Rustls(rtls::TlsConnector(value)))
34    }
35}
36
37impl TlsConnector {
38    /// Connects the provided stream with this connector, assuming the provided
39    /// domain.
40    ///
41    /// This function will internally call `TlsConnector::connect` to connect
42    /// the stream and returns a future representing the resolution of the
43    /// connection operation. The returned future will resolve to either
44    /// `TlsStream<S>` or `Error` depending if it's successful or not.
45    ///
46    /// This is typically used for clients who have already established, for
47    /// example, a TCP connection to a remote server. That stream is then
48    /// provided here to perform the client half of a connection to a
49    /// TLS-powered server.
50    pub async fn connect<S: AsyncRead + AsyncWrite>(
51        &self,
52        domain: &str,
53        stream: S,
54    ) -> io::Result<TlsStream<S>> {
55        match &self.0 {
56            #[cfg(feature = "native-tls")]
57            TlsConnectorInner::NativeTls(c) => {
58                handshake_native_tls(c.connect(domain, SyncStream::new(stream))).await
59            }
60            #[cfg(feature = "rustls")]
61            TlsConnectorInner::Rustls(c) => handshake_rustls(c.connect(domain, stream)).await,
62        }
63    }
64}
65
66#[derive(Clone)]
67enum TlsAcceptorInner {
68    #[cfg(feature = "native-tls")]
69    NativeTls(native_tls::TlsAcceptor),
70    #[cfg(feature = "rustls")]
71    Rustls(rtls::TlsAcceptor),
72}
73
74/// A wrapper around a [`native_tls::TlsAcceptor`] or [`rustls::ServerConfig`],
75/// providing an async `accept` method.
76#[derive(Clone)]
77pub struct TlsAcceptor(TlsAcceptorInner);
78
79#[cfg(feature = "native-tls")]
80impl From<native_tls::TlsAcceptor> for TlsAcceptor {
81    fn from(value: native_tls::TlsAcceptor) -> Self {
82        Self(TlsAcceptorInner::NativeTls(value))
83    }
84}
85
86#[cfg(feature = "rustls")]
87impl From<std::sync::Arc<rustls::ServerConfig>> for TlsAcceptor {
88    fn from(value: std::sync::Arc<rustls::ServerConfig>) -> Self {
89        Self(TlsAcceptorInner::Rustls(rtls::TlsAcceptor(value)))
90    }
91}
92
93impl TlsAcceptor {
94    /// Accepts a new client connection with the provided stream.
95    ///
96    /// This function will internally call `TlsAcceptor::accept` to connect
97    /// the stream and returns a future representing the resolution of the
98    /// connection operation. The returned future will resolve to either
99    /// `TlsStream<S>` or `Error` depending if it's successful or not.
100    ///
101    /// This is typically used after a new socket has been accepted from a
102    /// `TcpListener`. That socket is then passed to this function to perform
103    /// the server half of accepting a client connection.
104    pub async fn accept<S: AsyncRead + AsyncWrite>(&self, stream: S) -> io::Result<TlsStream<S>> {
105        match &self.0 {
106            #[cfg(feature = "native-tls")]
107            TlsAcceptorInner::NativeTls(c) => {
108                handshake_native_tls(c.accept(SyncStream::new(stream))).await
109            }
110            #[cfg(feature = "rustls")]
111            TlsAcceptorInner::Rustls(c) => handshake_rustls(c.accept(stream)).await,
112        }
113    }
114}
115
116#[cfg(feature = "native-tls")]
117async fn handshake_native_tls<S: AsyncRead + AsyncWrite>(
118    mut res: Result<
119        native_tls::TlsStream<SyncStream<S>>,
120        native_tls::HandshakeError<SyncStream<S>>,
121    >,
122) -> io::Result<TlsStream<S>> {
123    use native_tls::HandshakeError;
124
125    loop {
126        match res {
127            Ok(mut s) => {
128                s.get_mut().flush_write_buf().await?;
129                return Ok(TlsStream::from(s));
130            }
131            Err(e) => match e {
132                HandshakeError::Failure(e) => return Err(io::Error::other(e)),
133                HandshakeError::WouldBlock(mut mid_stream) => {
134                    if mid_stream.get_mut().flush_write_buf().await? == 0 {
135                        mid_stream.get_mut().fill_read_buf().await?;
136                    }
137                    res = mid_stream.handshake();
138                }
139            },
140        }
141    }
142}
143
144#[cfg(feature = "rustls")]
145async fn handshake_rustls<S: AsyncRead + AsyncWrite, C, D>(
146    mut res: Result<TlsStream<S>, rtls::HandshakeError<S, C>>,
147) -> io::Result<TlsStream<S>>
148where
149    C: std::ops::DerefMut<Target = rustls::ConnectionCommon<D>>,
150{
151    use rtls::HandshakeError;
152
153    loop {
154        match res {
155            Ok(mut s) => {
156                s.flush().await?;
157                return Ok(s);
158            }
159            Err(e) => match e {
160                HandshakeError::Rustls(e) => return Err(io::Error::other(e)),
161                HandshakeError::System(e) => return Err(e),
162                HandshakeError::WouldBlock(mut mid_stream) => {
163                    if mid_stream.get_mut().flush_write_buf().await? == 0 {
164                        mid_stream.get_mut().fill_read_buf().await?;
165                    }
166                    res = mid_stream.handshake::<D>();
167                }
168            },
169        }
170    }
171}