1use std::{borrow::Cow, io, mem::MaybeUninit};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut};
4use compio_io::{
5 AsyncRead, AsyncWrite,
6 compat::{AsyncStream, SyncStream},
7};
8
9#[derive(Debug)]
10#[allow(clippy::large_enum_variant)]
11enum TlsStreamInner<S> {
12 #[cfg(feature = "native-tls")]
13 NativeTls(native_tls::TlsStream<SyncStream<S>>),
14 #[cfg(feature = "rustls")]
15 Rustls(futures_rustls::TlsStream<AsyncStream<S>>),
16 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
17 None(std::convert::Infallible, std::marker::PhantomData<S>),
18}
19
20impl<S> TlsStreamInner<S> {
21 pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
22 match self {
23 #[cfg(feature = "native-tls")]
24 Self::NativeTls(s) => s.negotiated_alpn().ok().flatten().map(Cow::from),
25 #[cfg(feature = "rustls")]
26 Self::Rustls(s) => s.get_ref().1.alpn_protocol().map(Cow::from),
27 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
28 Self::None(f, ..) => match *f {},
29 }
30 }
31}
32
33#[derive(Debug)]
41pub struct TlsStream<S>(TlsStreamInner<S>);
42
43impl<S> TlsStream<S> {
44 pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
46 self.0.negotiated_alpn()
47 }
48}
49
50#[cfg(feature = "native-tls")]
51#[doc(hidden)]
52impl<S> From<native_tls::TlsStream<SyncStream<S>>> for TlsStream<S> {
53 fn from(value: native_tls::TlsStream<SyncStream<S>>) -> Self {
54 Self(TlsStreamInner::NativeTls(value))
55 }
56}
57
58#[cfg(feature = "rustls")]
59#[doc(hidden)]
60impl<S> From<futures_rustls::client::TlsStream<AsyncStream<S>>> for TlsStream<S> {
61 fn from(value: futures_rustls::client::TlsStream<AsyncStream<S>>) -> Self {
62 Self(TlsStreamInner::Rustls(futures_rustls::TlsStream::Client(
63 value,
64 )))
65 }
66}
67
68#[cfg(feature = "rustls")]
69#[doc(hidden)]
70impl<S> From<futures_rustls::server::TlsStream<AsyncStream<S>>> for TlsStream<S> {
71 fn from(value: futures_rustls::server::TlsStream<AsyncStream<S>>) -> Self {
72 Self(TlsStreamInner::Rustls(futures_rustls::TlsStream::Server(
73 value,
74 )))
75 }
76}
77
78impl<S: AsyncRead + AsyncWrite + 'static> AsyncRead for TlsStream<S> {
79 async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
80 let slice = buf.as_mut_slice();
81 slice.fill(MaybeUninit::new(0));
82 let slice =
84 unsafe { std::slice::from_raw_parts_mut::<u8>(slice.as_mut_ptr().cast(), slice.len()) };
85 match &mut self.0 {
86 #[cfg(feature = "native-tls")]
87 TlsStreamInner::NativeTls(s) => loop {
88 match io::Read::read(s, slice) {
89 Ok(res) => {
90 unsafe { buf.set_buf_init(res) };
91 return BufResult(Ok(res), buf);
92 }
93 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
94 match s.get_mut().fill_read_buf().await {
95 Ok(_) => continue,
96 Err(e) => return BufResult(Err(e), buf),
97 }
98 }
99 res => return BufResult(res, buf),
100 }
101 },
102 #[cfg(feature = "rustls")]
103 TlsStreamInner::Rustls(s) => {
104 let res = futures_util::AsyncReadExt::read(s, slice).await;
105 let res = match res {
106 Ok(len) => {
107 unsafe { buf.set_buf_init(len) };
108 Ok(len)
109 }
110 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(0),
113 _ => res,
114 };
115 BufResult(res, buf)
116 }
117 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
118 TlsStreamInner::None(f, ..) => match *f {},
119 }
120 }
121}
122
123#[cfg(feature = "native-tls")]
124async fn flush_impl(s: &mut native_tls::TlsStream<SyncStream<impl AsyncWrite>>) -> io::Result<()> {
125 loop {
126 match io::Write::flush(s) {
127 Ok(()) => break,
128 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
129 s.get_mut().flush_write_buf().await?;
130 }
131 Err(e) => return Err(e),
132 }
133 }
134 s.get_mut().flush_write_buf().await?;
135 Ok(())
136}
137
138impl<S: AsyncRead + AsyncWrite + 'static> AsyncWrite for TlsStream<S> {
139 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
140 let slice = buf.as_slice();
141 match &mut self.0 {
142 #[cfg(feature = "native-tls")]
143 TlsStreamInner::NativeTls(s) => loop {
144 let res = io::Write::write(s, slice);
145 match res {
146 Err(e) if e.kind() == io::ErrorKind::WouldBlock => match flush_impl(s).await {
147 Ok(_) => continue,
148 Err(e) => return BufResult(Err(e), buf),
149 },
150 _ => return BufResult(res, buf),
151 }
152 },
153 #[cfg(feature = "rustls")]
154 TlsStreamInner::Rustls(s) => {
155 let res = futures_util::AsyncWriteExt::write(s, slice).await;
156 BufResult(res, buf)
157 }
158 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
159 TlsStreamInner::None(f, ..) => match *f {},
160 }
161 }
162
163 async fn flush(&mut self) -> io::Result<()> {
164 match &mut self.0 {
165 #[cfg(feature = "native-tls")]
166 TlsStreamInner::NativeTls(s) => flush_impl(s).await,
167 #[cfg(feature = "rustls")]
168 TlsStreamInner::Rustls(s) => futures_util::AsyncWriteExt::flush(s).await,
169 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
170 TlsStreamInner::None(f, ..) => match *f {},
171 }
172 }
173
174 async fn shutdown(&mut self) -> io::Result<()> {
175 self.flush().await?;
176 match &mut self.0 {
177 #[cfg(feature = "native-tls")]
178 TlsStreamInner::NativeTls(s) => s.get_mut().get_mut().shutdown().await,
179 #[cfg(feature = "rustls")]
180 TlsStreamInner::Rustls(s) => futures_util::AsyncWriteExt::close(s).await,
181 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
182 TlsStreamInner::None(f, ..) => match *f {},
183 }
184 }
185}