compio_tls/adapter/
mod.rs1use std::io;
2
3use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
4
5use crate::TlsStream;
6
7#[cfg(feature = "rustls")]
8mod rtls;
9
10#[derive(Debug, Clone)]
11enum TlsConnectorInner {
12 #[cfg(feature = "native-tls")]
13 NativeTls(native_tls::TlsConnector),
14 #[cfg(feature = "rustls")]
15 Rustls(rtls::TlsConnector),
16}
17
18#[derive(Debug, Clone)]
21pub struct TlsConnector(TlsConnectorInner);
22
23#[cfg(feature = "native-tls")]
24impl From<native_tls::TlsConnector> for TlsConnector {
25 fn from(value: native_tls::TlsConnector) -> Self {
26 Self(TlsConnectorInner::NativeTls(value))
27 }
28}
29
30#[cfg(feature = "rustls")]
31impl From<std::sync::Arc<rustls::ClientConfig>> for TlsConnector {
32 fn from(value: std::sync::Arc<rustls::ClientConfig>) -> Self {
33 Self(TlsConnectorInner::Rustls(rtls::TlsConnector(value)))
34 }
35}
36
37impl TlsConnector {
38 pub async fn connect<S: AsyncRead + AsyncWrite>(
51 &self,
52 domain: &str,
53 stream: S,
54 ) -> io::Result<TlsStream<S>> {
55 match &self.0 {
56 #[cfg(feature = "native-tls")]
57 TlsConnectorInner::NativeTls(c) => {
58 handshake_native_tls(c.connect(domain, SyncStream::new(stream))).await
59 }
60 #[cfg(feature = "rustls")]
61 TlsConnectorInner::Rustls(c) => handshake_rustls(c.connect(domain, stream)).await,
62 }
63 }
64}
65
66#[derive(Clone)]
67enum TlsAcceptorInner {
68 #[cfg(feature = "native-tls")]
69 NativeTls(native_tls::TlsAcceptor),
70 #[cfg(feature = "rustls")]
71 Rustls(rtls::TlsAcceptor),
72}
73
74#[derive(Clone)]
77pub struct TlsAcceptor(TlsAcceptorInner);
78
79#[cfg(feature = "native-tls")]
80impl From<native_tls::TlsAcceptor> for TlsAcceptor {
81 fn from(value: native_tls::TlsAcceptor) -> Self {
82 Self(TlsAcceptorInner::NativeTls(value))
83 }
84}
85
86#[cfg(feature = "rustls")]
87impl From<std::sync::Arc<rustls::ServerConfig>> for TlsAcceptor {
88 fn from(value: std::sync::Arc<rustls::ServerConfig>) -> Self {
89 Self(TlsAcceptorInner::Rustls(rtls::TlsAcceptor(value)))
90 }
91}
92
93impl TlsAcceptor {
94 pub async fn accept<S: AsyncRead + AsyncWrite>(&self, stream: S) -> io::Result<TlsStream<S>> {
105 match &self.0 {
106 #[cfg(feature = "native-tls")]
107 TlsAcceptorInner::NativeTls(c) => {
108 handshake_native_tls(c.accept(SyncStream::new(stream))).await
109 }
110 #[cfg(feature = "rustls")]
111 TlsAcceptorInner::Rustls(c) => handshake_rustls(c.accept(stream)).await,
112 }
113 }
114}
115
116#[cfg(feature = "native-tls")]
117async fn handshake_native_tls<S: AsyncRead + AsyncWrite>(
118 mut res: Result<
119 native_tls::TlsStream<SyncStream<S>>,
120 native_tls::HandshakeError<SyncStream<S>>,
121 >,
122) -> io::Result<TlsStream<S>> {
123 use native_tls::HandshakeError;
124
125 loop {
126 match res {
127 Ok(mut s) => {
128 s.get_mut().flush_write_buf().await?;
129 return Ok(TlsStream::from(s));
130 }
131 Err(e) => match e {
132 HandshakeError::Failure(e) => return Err(io::Error::other(e)),
133 HandshakeError::WouldBlock(mut mid_stream) => {
134 if mid_stream.get_mut().flush_write_buf().await? == 0 {
135 mid_stream.get_mut().fill_read_buf().await?;
136 }
137 res = mid_stream.handshake();
138 }
139 },
140 }
141 }
142}
143
144#[cfg(feature = "rustls")]
145async fn handshake_rustls<S: AsyncRead + AsyncWrite, C, D>(
146 mut res: Result<TlsStream<S>, rtls::HandshakeError<S, C>>,
147) -> io::Result<TlsStream<S>>
148where
149 C: std::ops::DerefMut<Target = rustls::ConnectionCommon<D>>,
150{
151 use rtls::HandshakeError;
152
153 loop {
154 match res {
155 Ok(mut s) => {
156 s.flush().await?;
157 return Ok(s);
158 }
159 Err(e) => match e {
160 HandshakeError::Rustls(e) => return Err(io::Error::other(e)),
161 HandshakeError::System(e) => return Err(e),
162 HandshakeError::WouldBlock(mut mid_stream) => {
163 if mid_stream.get_mut().flush_write_buf().await? == 0 {
164 mid_stream.get_mut().fill_read_buf().await?;
165 }
166 res = mid_stream.handshake::<D>();
167 }
168 },
169 }
170 }
171}