1use std::{borrow::Cow, io, mem::MaybeUninit};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut};
4use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
5
6#[cfg(feature = "rustls")]
7mod rtls;
8
9#[derive(Debug)]
10#[allow(clippy::large_enum_variant)]
11enum TlsStreamInner<S> {
12 #[cfg(feature = "native-tls")]
13 NativeTls(native_tls::TlsStream<SyncStream<S>>),
14 #[cfg(feature = "rustls")]
15 Rustls(rtls::TlsStream<SyncStream<S>>),
16}
17
18impl<S> TlsStreamInner<S> {
19 fn get_mut(&mut self) -> &mut SyncStream<S> {
20 match self {
21 #[cfg(feature = "native-tls")]
22 Self::NativeTls(s) => s.get_mut(),
23 #[cfg(feature = "rustls")]
24 Self::Rustls(s) => s.get_mut(),
25 }
26 }
27
28 pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
29 match self {
30 #[cfg(feature = "native-tls")]
31 Self::NativeTls(s) => s.negotiated_alpn().ok().flatten().map(Cow::from),
32 #[cfg(feature = "rustls")]
33 Self::Rustls(s) => s.negotiated_alpn().map(Cow::from),
34 }
35 }
36}
37
38impl<S> io::Read for TlsStreamInner<S> {
39 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
40 match self {
41 #[cfg(feature = "native-tls")]
42 Self::NativeTls(s) => io::Read::read(s, buf),
43 #[cfg(feature = "rustls")]
44 Self::Rustls(s) => io::Read::read(s, buf),
45 }
46 }
47
48 #[cfg(feature = "read_buf")]
49 fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
50 match self {
51 #[cfg(feature = "native-tls")]
52 Self::NativeTls(s) => io::Read::read_buf(s, buf),
53 #[cfg(feature = "rustls")]
54 Self::Rustls(s) => io::Read::read_buf(s, buf),
55 }
56 }
57}
58
59impl<S> io::Write for TlsStreamInner<S> {
60 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
61 match self {
62 #[cfg(feature = "native-tls")]
63 Self::NativeTls(s) => io::Write::write(s, buf),
64 #[cfg(feature = "rustls")]
65 Self::Rustls(s) => io::Write::write(s, buf),
66 }
67 }
68
69 fn flush(&mut self) -> io::Result<()> {
70 match self {
71 #[cfg(feature = "native-tls")]
72 Self::NativeTls(s) => io::Write::flush(s),
73 #[cfg(feature = "rustls")]
74 Self::Rustls(s) => io::Write::flush(s),
75 }
76 }
77}
78
79#[derive(Debug)]
87pub struct TlsStream<S>(TlsStreamInner<S>);
88
89impl<S> TlsStream<S> {
90 #[cfg(feature = "rustls")]
91 pub(crate) fn new_rustls_client(s: SyncStream<S>, conn: rustls::ClientConnection) -> Self {
92 Self(TlsStreamInner::Rustls(rtls::TlsStream::new_client(s, conn)))
93 }
94
95 #[cfg(feature = "rustls")]
96 pub(crate) fn new_rustls_server(s: SyncStream<S>, conn: rustls::ServerConnection) -> Self {
97 Self(TlsStreamInner::Rustls(rtls::TlsStream::new_server(s, conn)))
98 }
99
100 pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
102 self.0.negotiated_alpn()
103 }
104}
105
106#[cfg(feature = "native-tls")]
107#[doc(hidden)]
108impl<S> From<native_tls::TlsStream<SyncStream<S>>> for TlsStream<S> {
109 fn from(value: native_tls::TlsStream<SyncStream<S>>) -> Self {
110 Self(TlsStreamInner::NativeTls(value))
111 }
112}
113
114impl<S: AsyncRead> AsyncRead for TlsStream<S> {
115 async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
116 let slice: &mut [MaybeUninit<u8>] = buf.as_mut_slice();
117
118 #[cfg(feature = "read_buf")]
119 let mut f = {
120 let mut borrowed_buf = io::BorrowedBuf::from(slice);
121 move |s: &mut _| {
122 let mut cursor = borrowed_buf.unfilled();
123 std::io::Read::read_buf(s, cursor.reborrow())?;
124 Ok::<usize, io::Error>(cursor.written())
125 }
126 };
127
128 #[cfg(not(feature = "read_buf"))]
129 let mut f = {
130 slice.fill(MaybeUninit::new(0));
131 let slice =
133 unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), slice.len()) };
134 |s: &mut _| std::io::Read::read(s, slice)
135 };
136
137 loop {
138 match f(&mut self.0) {
139 Ok(res) => {
140 unsafe { buf.set_buf_init(res) };
141 return BufResult(Ok(res), buf);
142 }
143 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
144 match self.0.get_mut().fill_read_buf().await {
145 Ok(_) => continue,
146 Err(e) => return BufResult(Err(e), buf),
147 }
148 }
149 res => return BufResult(res, buf),
150 }
151 }
152 }
153}
154
155impl<S: AsyncWrite> AsyncWrite for TlsStream<S> {
156 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
157 let slice = buf.as_slice();
158 loop {
159 let res = io::Write::write(&mut self.0, slice);
160 match res {
161 Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.flush().await {
162 Ok(_) => continue,
163 Err(e) => return BufResult(Err(e), buf),
164 },
165 _ => return BufResult(res, buf),
166 }
167 }
168 }
169
170 async fn flush(&mut self) -> io::Result<()> {
171 loop {
172 match io::Write::flush(&mut self.0) {
173 Ok(()) => break,
174 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
175 self.0.get_mut().flush_write_buf().await?;
176 }
177 Err(e) => return Err(e),
178 }
179 }
180 self.0.get_mut().flush_write_buf().await?;
181 Ok(())
182 }
183
184 async fn shutdown(&mut self) -> io::Result<()> {
185 self.flush().await?;
186 self.0.get_mut().get_mut().shutdown().await
187 }
188}