monoio_rustls_fork_shadow_tls/
client.rs1use std::sync::Arc;
2
3use monoio::io::{AsyncReadRent, AsyncWriteRent, OwnedReadHalf, OwnedWriteHalf};
4use rustls_fork_shadow_tls::{ClientConfig, ClientConnection};
5
6use crate::{stream::Stream, TlsError};
7
8pub type TlsStream<IO> = Stream<IO, ClientConnection>;
10pub type TlsStreamReadHalf<IO> = OwnedReadHalf<TlsStream<IO>>;
12pub type TlsStreamWriteHalf<IO> = OwnedWriteHalf<TlsStream<IO>>;
14
15#[derive(Clone)]
17pub struct TlsConnector {
18 inner: Arc<ClientConfig>,
19 #[cfg(feature = "unsafe_io")]
20 unsafe_io: bool,
21}
22
23impl From<Arc<ClientConfig>> for TlsConnector {
24 fn from(inner: Arc<ClientConfig>) -> TlsConnector {
25 TlsConnector {
26 inner,
27 #[cfg(feature = "unsafe_io")]
28 unsafe_io: false,
29 }
30 }
31}
32
33impl From<ClientConfig> for TlsConnector {
34 fn from(inner: ClientConfig) -> TlsConnector {
35 TlsConnector {
36 inner: Arc::new(inner),
37 #[cfg(feature = "unsafe_io")]
38 unsafe_io: false,
39 }
40 }
41}
42
43impl TlsConnector {
44 #[cfg(feature = "unsafe_io")]
49 pub unsafe fn unsafe_io(self, enabled: bool) -> Self {
50 Self {
51 unsafe_io: enabled,
52 ..self
53 }
54 }
55
56 pub async fn connect<IO>(
57 &self,
58 domain: rustls_fork_shadow_tls::ServerName,
59 stream: IO,
60 ) -> Result<TlsStream<IO>, TlsError>
61 where
62 IO: AsyncReadRent + AsyncWriteRent,
63 {
64 let session = ClientConnection::new(self.inner.clone(), domain)?;
65 #[cfg(feature = "unsafe_io")]
66 let mut stream = if self.unsafe_io {
67 unsafe { Stream::new_unsafe(stream, session) }
70 } else {
71 Stream::new(stream, session)
72 };
73 #[cfg(not(feature = "unsafe_io"))]
74 let mut stream = Stream::new(stream, session);
75 stream.handshake().await?;
76 Ok(stream)
77 }
78
79 pub async fn connect_with_session_id_generator<IO>(
80 &self,
81 domain: rustls_fork_shadow_tls::ServerName,
82 stream: IO,
83 generator: impl Fn(&[u8]) -> [u8; 32],
84 ) -> Result<TlsStream<IO>, TlsError>
85 where
86 IO: AsyncReadRent + AsyncWriteRent,
87 {
88 let session =
89 ClientConnection::new_with_session_id_generator(self.inner.clone(), domain, generator)?;
90 let mut stream = Stream::new(stream, session);
91 stream.handshake().await?;
92 Ok(stream)
93 }
94}