1use std::{fmt::Debug, io};
2
3use compio_io::{
4 AsyncRead, AsyncWrite,
5 compat::{AsyncStream, SyncStream},
6};
7
8use crate::TlsStream;
9
10#[derive(Clone)]
11enum TlsConnectorInner {
12 #[cfg(feature = "native-tls")]
13 NativeTls(native_tls::TlsConnector),
14 #[cfg(feature = "rustls")]
15 Rustls(futures_rustls::TlsConnector),
16 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
17 None(std::convert::Infallible),
18}
19
20impl Debug for TlsConnectorInner {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 match self {
23 #[cfg(feature = "native-tls")]
24 Self::NativeTls(_) => f.debug_tuple("NativeTls").finish(),
25 #[cfg(feature = "rustls")]
26 Self::Rustls(_) => f.debug_tuple("Rustls").finish(),
27 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
28 Self::None(f) => match *f {},
29 }
30 }
31}
32
33#[derive(Debug, Clone)]
36pub struct TlsConnector(TlsConnectorInner);
37
38#[cfg(feature = "native-tls")]
39impl From<native_tls::TlsConnector> for TlsConnector {
40 fn from(value: native_tls::TlsConnector) -> Self {
41 Self(TlsConnectorInner::NativeTls(value))
42 }
43}
44
45#[cfg(feature = "rustls")]
46impl From<std::sync::Arc<rustls::ClientConfig>> for TlsConnector {
47 fn from(value: std::sync::Arc<rustls::ClientConfig>) -> Self {
48 Self(TlsConnectorInner::Rustls(value.into()))
49 }
50}
51
52impl TlsConnector {
53 pub async fn connect<S: AsyncRead + AsyncWrite + 'static>(
66 &self,
67 domain: &str,
68 stream: S,
69 ) -> io::Result<TlsStream<S>> {
70 match &self.0 {
71 #[cfg(feature = "native-tls")]
72 TlsConnectorInner::NativeTls(c) => {
73 handshake_native_tls(c.connect(domain, SyncStream::new(stream))).await
74 }
75 #[cfg(feature = "rustls")]
76 TlsConnectorInner::Rustls(c) => {
77 let client = c
78 .connect(
79 domain.to_string().try_into().map_err(io::Error::other)?,
80 AsyncStream::new(stream),
81 )
82 .await?;
83 Ok(TlsStream::from(client))
84 }
85 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
86 TlsConnectorInner::None(f) => match *f {},
87 }
88 }
89}
90
91#[derive(Clone)]
92enum TlsAcceptorInner {
93 #[cfg(feature = "native-tls")]
94 NativeTls(native_tls::TlsAcceptor),
95 #[cfg(feature = "rustls")]
96 Rustls(futures_rustls::TlsAcceptor),
97 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
98 None(std::convert::Infallible),
99}
100
101#[derive(Clone)]
104pub struct TlsAcceptor(TlsAcceptorInner);
105
106#[cfg(feature = "native-tls")]
107impl From<native_tls::TlsAcceptor> for TlsAcceptor {
108 fn from(value: native_tls::TlsAcceptor) -> Self {
109 Self(TlsAcceptorInner::NativeTls(value))
110 }
111}
112
113#[cfg(feature = "rustls")]
114impl From<std::sync::Arc<rustls::ServerConfig>> for TlsAcceptor {
115 fn from(value: std::sync::Arc<rustls::ServerConfig>) -> Self {
116 Self(TlsAcceptorInner::Rustls(value.into()))
117 }
118}
119
120impl TlsAcceptor {
121 pub async fn accept<S: AsyncRead + AsyncWrite + 'static>(
132 &self,
133 stream: S,
134 ) -> io::Result<TlsStream<S>> {
135 match &self.0 {
136 #[cfg(feature = "native-tls")]
137 TlsAcceptorInner::NativeTls(c) => {
138 handshake_native_tls(c.accept(SyncStream::new(stream))).await
139 }
140 #[cfg(feature = "rustls")]
141 TlsAcceptorInner::Rustls(c) => {
142 let server = c.accept(AsyncStream::new(stream)).await?;
143 Ok(TlsStream::from(server))
144 }
145 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
146 TlsAcceptorInner::None(f) => match *f {},
147 }
148 }
149}
150
151#[cfg(feature = "native-tls")]
152async fn handshake_native_tls<S: AsyncRead + AsyncWrite>(
153 mut res: Result<
154 native_tls::TlsStream<SyncStream<S>>,
155 native_tls::HandshakeError<SyncStream<S>>,
156 >,
157) -> io::Result<TlsStream<S>> {
158 use native_tls::HandshakeError;
159
160 loop {
161 match res {
162 Ok(mut s) => {
163 s.get_mut().flush_write_buf().await?;
164 return Ok(TlsStream::from(s));
165 }
166 Err(e) => match e {
167 HandshakeError::Failure(e) => return Err(io::Error::other(e)),
168 HandshakeError::WouldBlock(mut mid_stream) => {
169 if mid_stream.get_mut().flush_write_buf().await? == 0 {
170 mid_stream.get_mut().fill_read_buf().await?;
171 }
172 res = mid_stream.handshake();
173 }
174 },
175 }
176 }
177}