compio_tls/
adapter.rs

1use std::{fmt::Debug, io};
2
3use compio_io::{
4    AsyncRead, AsyncWrite,
5    compat::{AsyncStream, SyncStream},
6};
7
8use crate::TlsStream;
9
10#[derive(Clone)]
11enum TlsConnectorInner {
12    #[cfg(feature = "native-tls")]
13    NativeTls(native_tls::TlsConnector),
14    #[cfg(feature = "rustls")]
15    Rustls(futures_rustls::TlsConnector),
16    #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
17    None(std::convert::Infallible),
18}
19
20impl Debug for TlsConnectorInner {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            #[cfg(feature = "native-tls")]
24            Self::NativeTls(_) => f.debug_tuple("NativeTls").finish(),
25            #[cfg(feature = "rustls")]
26            Self::Rustls(_) => f.debug_tuple("Rustls").finish(),
27            #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
28            Self::None(f) => match *f {},
29        }
30    }
31}
32
33/// A wrapper around a [`native_tls::TlsConnector`] or [`rustls::ClientConfig`],
34/// providing an async `connect` method.
35#[derive(Debug, Clone)]
36pub struct TlsConnector(TlsConnectorInner);
37
38#[cfg(feature = "native-tls")]
39impl From<native_tls::TlsConnector> for TlsConnector {
40    fn from(value: native_tls::TlsConnector) -> Self {
41        Self(TlsConnectorInner::NativeTls(value))
42    }
43}
44
45#[cfg(feature = "rustls")]
46impl From<std::sync::Arc<rustls::ClientConfig>> for TlsConnector {
47    fn from(value: std::sync::Arc<rustls::ClientConfig>) -> Self {
48        Self(TlsConnectorInner::Rustls(value.into()))
49    }
50}
51
52impl TlsConnector {
53    /// Connects the provided stream with this connector, assuming the provided
54    /// domain.
55    ///
56    /// This function will internally call `TlsConnector::connect` to connect
57    /// the stream and returns a future representing the resolution of the
58    /// connection operation. The returned future will resolve to either
59    /// `TlsStream<S>` or `Error` depending if it's successful or not.
60    ///
61    /// This is typically used for clients who have already established, for
62    /// example, a TCP connection to a remote server. That stream is then
63    /// provided here to perform the client half of a connection to a
64    /// TLS-powered server.
65    pub async fn connect<S: AsyncRead + AsyncWrite + 'static>(
66        &self,
67        domain: &str,
68        stream: S,
69    ) -> io::Result<TlsStream<S>> {
70        match &self.0 {
71            #[cfg(feature = "native-tls")]
72            TlsConnectorInner::NativeTls(c) => {
73                handshake_native_tls(c.connect(domain, SyncStream::new(stream))).await
74            }
75            #[cfg(feature = "rustls")]
76            TlsConnectorInner::Rustls(c) => {
77                let client = c
78                    .connect(
79                        domain.to_string().try_into().map_err(io::Error::other)?,
80                        AsyncStream::new(stream),
81                    )
82                    .await?;
83                Ok(TlsStream::from(client))
84            }
85            #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
86            TlsConnectorInner::None(f) => match *f {},
87        }
88    }
89}
90
91#[derive(Clone)]
92enum TlsAcceptorInner {
93    #[cfg(feature = "native-tls")]
94    NativeTls(native_tls::TlsAcceptor),
95    #[cfg(feature = "rustls")]
96    Rustls(futures_rustls::TlsAcceptor),
97    #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
98    None(std::convert::Infallible),
99}
100
101/// A wrapper around a [`native_tls::TlsAcceptor`] or [`rustls::ServerConfig`],
102/// providing an async `accept` method.
103#[derive(Clone)]
104pub struct TlsAcceptor(TlsAcceptorInner);
105
106#[cfg(feature = "native-tls")]
107impl From<native_tls::TlsAcceptor> for TlsAcceptor {
108    fn from(value: native_tls::TlsAcceptor) -> Self {
109        Self(TlsAcceptorInner::NativeTls(value))
110    }
111}
112
113#[cfg(feature = "rustls")]
114impl From<std::sync::Arc<rustls::ServerConfig>> for TlsAcceptor {
115    fn from(value: std::sync::Arc<rustls::ServerConfig>) -> Self {
116        Self(TlsAcceptorInner::Rustls(value.into()))
117    }
118}
119
120impl TlsAcceptor {
121    /// Accepts a new client connection with the provided stream.
122    ///
123    /// This function will internally call `TlsAcceptor::accept` to connect
124    /// the stream and returns a future representing the resolution of the
125    /// connection operation. The returned future will resolve to either
126    /// `TlsStream<S>` or `Error` depending if it's successful or not.
127    ///
128    /// This is typically used after a new socket has been accepted from a
129    /// `TcpListener`. That socket is then passed to this function to perform
130    /// the server half of accepting a client connection.
131    pub async fn accept<S: AsyncRead + AsyncWrite + 'static>(
132        &self,
133        stream: S,
134    ) -> io::Result<TlsStream<S>> {
135        match &self.0 {
136            #[cfg(feature = "native-tls")]
137            TlsAcceptorInner::NativeTls(c) => {
138                handshake_native_tls(c.accept(SyncStream::new(stream))).await
139            }
140            #[cfg(feature = "rustls")]
141            TlsAcceptorInner::Rustls(c) => {
142                let server = c.accept(AsyncStream::new(stream)).await?;
143                Ok(TlsStream::from(server))
144            }
145            #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
146            TlsAcceptorInner::None(f) => match *f {},
147        }
148    }
149}
150
151#[cfg(feature = "native-tls")]
152async fn handshake_native_tls<S: AsyncRead + AsyncWrite>(
153    mut res: Result<
154        native_tls::TlsStream<SyncStream<S>>,
155        native_tls::HandshakeError<SyncStream<S>>,
156    >,
157) -> io::Result<TlsStream<S>> {
158    use native_tls::HandshakeError;
159
160    loop {
161        match res {
162            Ok(mut s) => {
163                s.get_mut().flush_write_buf().await?;
164                return Ok(TlsStream::from(s));
165            }
166            Err(e) => match e {
167                HandshakeError::Failure(e) => return Err(io::Error::other(e)),
168                HandshakeError::WouldBlock(mut mid_stream) => {
169                    if mid_stream.get_mut().flush_write_buf().await? == 0 {
170                        mid_stream.get_mut().fill_read_buf().await?;
171                    }
172                    res = mid_stream.handshake();
173                }
174            },
175        }
176    }
177}