monoio_rustls_fork_shadow_tls/
client.rs

1use std::sync::Arc;
2
3use monoio::io::{AsyncReadRent, AsyncWriteRent, OwnedReadHalf, OwnedWriteHalf};
4use rustls_fork_shadow_tls::{ClientConfig, ClientConnection};
5
6use crate::{stream::Stream, TlsError};
7
8/// A wrapper around an underlying raw stream which implements the TLS protocol.
9pub type TlsStream<IO> = Stream<IO, ClientConnection>;
10/// TlsStream for read only.
11pub type TlsStreamReadHalf<IO> = OwnedReadHalf<TlsStream<IO>>;
12/// TlsStream for write only.
13pub type TlsStreamWriteHalf<IO> = OwnedWriteHalf<TlsStream<IO>>;
14
15/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
16#[derive(Clone)]
17pub struct TlsConnector {
18    inner: Arc<ClientConfig>,
19    #[cfg(feature = "unsafe_io")]
20    unsafe_io: bool,
21}
22
23impl From<Arc<ClientConfig>> for TlsConnector {
24    fn from(inner: Arc<ClientConfig>) -> TlsConnector {
25        TlsConnector {
26            inner,
27            #[cfg(feature = "unsafe_io")]
28            unsafe_io: false,
29        }
30    }
31}
32
33impl From<ClientConfig> for TlsConnector {
34    fn from(inner: ClientConfig) -> TlsConnector {
35        TlsConnector {
36            inner: Arc::new(inner),
37            #[cfg(feature = "unsafe_io")]
38            unsafe_io: false,
39        }
40    }
41}
42
43impl TlsConnector {
44    /// Enable unsafe-io.
45    /// # Safety
46    /// Users must make sure the buffer ptr and len is valid until io finished.
47    /// So the Future cannot be dropped directly. Consider using CancellableIO.
48    #[cfg(feature = "unsafe_io")]
49    pub unsafe fn unsafe_io(self, enabled: bool) -> Self {
50        Self {
51            unsafe_io: enabled,
52            ..self
53        }
54    }
55
56    pub async fn connect<IO>(
57        &self,
58        domain: rustls_fork_shadow_tls::ServerName,
59        stream: IO,
60    ) -> Result<TlsStream<IO>, TlsError>
61    where
62        IO: AsyncReadRent + AsyncWriteRent,
63    {
64        let session = ClientConnection::new(self.inner.clone(), domain)?;
65        #[cfg(feature = "unsafe_io")]
66        let mut stream = if self.unsafe_io {
67            // # Safety
68            // Users already maked unsafe io.
69            unsafe { Stream::new_unsafe(stream, session) }
70        } else {
71            Stream::new(stream, session)
72        };
73        #[cfg(not(feature = "unsafe_io"))]
74        let mut stream = Stream::new(stream, session);
75        stream.handshake().await?;
76        Ok(stream)
77    }
78
79    pub async fn connect_with_session_id_generator<IO>(
80        &self,
81        domain: rustls_fork_shadow_tls::ServerName,
82        stream: IO,
83        generator: impl Fn(&[u8]) -> [u8; 32],
84    ) -> Result<TlsStream<IO>, TlsError>
85    where
86        IO: AsyncReadRent + AsyncWriteRent,
87    {
88        let session =
89            ClientConnection::new_with_session_id_generator(self.inner.clone(), domain, generator)?;
90        let mut stream = Stream::new(stream, session);
91        stream.handshake().await?;
92        Ok(stream)
93    }
94}