compio_tls/stream/
mod.rs

1use std::{borrow::Cow, io, mem::MaybeUninit};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut};
4use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
5
6#[cfg(feature = "rustls")]
7mod rtls;
8
9#[derive(Debug)]
10#[allow(clippy::large_enum_variant)]
11enum TlsStreamInner<S> {
12    #[cfg(feature = "native-tls")]
13    NativeTls(native_tls::TlsStream<SyncStream<S>>),
14    #[cfg(feature = "rustls")]
15    Rustls(rtls::TlsStream<SyncStream<S>>),
16}
17
18impl<S> TlsStreamInner<S> {
19    fn get_mut(&mut self) -> &mut SyncStream<S> {
20        match self {
21            #[cfg(feature = "native-tls")]
22            Self::NativeTls(s) => s.get_mut(),
23            #[cfg(feature = "rustls")]
24            Self::Rustls(s) => s.get_mut(),
25        }
26    }
27
28    pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
29        match self {
30            #[cfg(feature = "native-tls")]
31            Self::NativeTls(s) => s.negotiated_alpn().ok().flatten().map(Cow::from),
32            #[cfg(feature = "rustls")]
33            Self::Rustls(s) => s.negotiated_alpn().map(Cow::from),
34        }
35    }
36}
37
38impl<S> io::Read for TlsStreamInner<S> {
39    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
40        match self {
41            #[cfg(feature = "native-tls")]
42            Self::NativeTls(s) => io::Read::read(s, buf),
43            #[cfg(feature = "rustls")]
44            Self::Rustls(s) => io::Read::read(s, buf),
45        }
46    }
47
48    #[cfg(feature = "read_buf")]
49    fn read_buf(&mut self, buf: io::BorrowedCursor<'_>) -> io::Result<()> {
50        match self {
51            #[cfg(feature = "native-tls")]
52            Self::NativeTls(s) => io::Read::read_buf(s, buf),
53            #[cfg(feature = "rustls")]
54            Self::Rustls(s) => io::Read::read_buf(s, buf),
55        }
56    }
57}
58
59impl<S> io::Write for TlsStreamInner<S> {
60    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
61        match self {
62            #[cfg(feature = "native-tls")]
63            Self::NativeTls(s) => io::Write::write(s, buf),
64            #[cfg(feature = "rustls")]
65            Self::Rustls(s) => io::Write::write(s, buf),
66        }
67    }
68
69    fn flush(&mut self) -> io::Result<()> {
70        match self {
71            #[cfg(feature = "native-tls")]
72            Self::NativeTls(s) => io::Write::flush(s),
73            #[cfg(feature = "rustls")]
74            Self::Rustls(s) => io::Write::flush(s),
75        }
76    }
77}
78
79/// A wrapper around an underlying raw stream which implements the TLS or SSL
80/// protocol.
81///
82/// A `TlsStream<S>` represents a handshake that has been completed successfully
83/// and both the server and the client are ready for receiving and sending
84/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written
85/// to a `TlsStream` are encrypted when passing through to `S`.
86#[derive(Debug)]
87pub struct TlsStream<S>(TlsStreamInner<S>);
88
89impl<S> TlsStream<S> {
90    #[cfg(feature = "rustls")]
91    pub(crate) fn new_rustls_client(s: SyncStream<S>, conn: rustls::ClientConnection) -> Self {
92        Self(TlsStreamInner::Rustls(rtls::TlsStream::new_client(s, conn)))
93    }
94
95    #[cfg(feature = "rustls")]
96    pub(crate) fn new_rustls_server(s: SyncStream<S>, conn: rustls::ServerConnection) -> Self {
97        Self(TlsStreamInner::Rustls(rtls::TlsStream::new_server(s, conn)))
98    }
99
100    /// Returns the negotiated ALPN protocol.
101    pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
102        self.0.negotiated_alpn()
103    }
104}
105
106#[cfg(feature = "native-tls")]
107#[doc(hidden)]
108impl<S> From<native_tls::TlsStream<SyncStream<S>>> for TlsStream<S> {
109    fn from(value: native_tls::TlsStream<SyncStream<S>>) -> Self {
110        Self(TlsStreamInner::NativeTls(value))
111    }
112}
113
114impl<S: AsyncRead> AsyncRead for TlsStream<S> {
115    async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
116        let slice: &mut [MaybeUninit<u8>] = buf.as_mut_slice();
117
118        #[cfg(feature = "read_buf")]
119        let mut f = {
120            let mut borrowed_buf = io::BorrowedBuf::from(slice);
121            move |s: &mut _| {
122                let mut cursor = borrowed_buf.unfilled();
123                std::io::Read::read_buf(s, cursor.reborrow())?;
124                Ok::<usize, io::Error>(cursor.written())
125            }
126        };
127
128        #[cfg(not(feature = "read_buf"))]
129        let mut f = {
130            slice.fill(MaybeUninit::new(0));
131            // SAFETY: The memory has been initialized
132            let slice =
133                unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), slice.len()) };
134            |s: &mut _| std::io::Read::read(s, slice)
135        };
136
137        loop {
138            match f(&mut self.0) {
139                Ok(res) => {
140                    unsafe { buf.set_buf_init(res) };
141                    return BufResult(Ok(res), buf);
142                }
143                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
144                    match self.0.get_mut().fill_read_buf().await {
145                        Ok(_) => continue,
146                        Err(e) => return BufResult(Err(e), buf),
147                    }
148                }
149                res => return BufResult(res, buf),
150            }
151        }
152    }
153}
154
155impl<S: AsyncWrite> AsyncWrite for TlsStream<S> {
156    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
157        let slice = buf.as_slice();
158        loop {
159            let res = io::Write::write(&mut self.0, slice);
160            match res {
161                Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.flush().await {
162                    Ok(_) => continue,
163                    Err(e) => return BufResult(Err(e), buf),
164                },
165                _ => return BufResult(res, buf),
166            }
167        }
168    }
169
170    async fn flush(&mut self) -> io::Result<()> {
171        loop {
172            match io::Write::flush(&mut self.0) {
173                Ok(()) => break,
174                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
175                    self.0.get_mut().flush_write_buf().await?;
176                }
177                Err(e) => return Err(e),
178            }
179        }
180        self.0.get_mut().flush_write_buf().await?;
181        Ok(())
182    }
183
184    async fn shutdown(&mut self) -> io::Result<()> {
185        self.flush().await?;
186        self.0.get_mut().get_mut().shutdown().await
187    }
188}