compio_tls/
stream.rs

1use std::{borrow::Cow, io, mem::MaybeUninit};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut};
4use compio_io::{
5    AsyncRead, AsyncWrite,
6    compat::{AsyncStream, SyncStream},
7};
8
9#[derive(Debug)]
10#[allow(clippy::large_enum_variant)]
11enum TlsStreamInner<S> {
12    #[cfg(feature = "native-tls")]
13    NativeTls(native_tls::TlsStream<SyncStream<S>>),
14    #[cfg(feature = "rustls")]
15    Rustls(futures_rustls::TlsStream<AsyncStream<S>>),
16    #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
17    None(std::convert::Infallible, std::marker::PhantomData<S>),
18}
19
20impl<S> TlsStreamInner<S> {
21    pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
22        match self {
23            #[cfg(feature = "native-tls")]
24            Self::NativeTls(s) => s.negotiated_alpn().ok().flatten().map(Cow::from),
25            #[cfg(feature = "rustls")]
26            Self::Rustls(s) => s.get_ref().1.alpn_protocol().map(Cow::from),
27            #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
28            Self::None(f, ..) => match *f {},
29        }
30    }
31}
32
33/// A wrapper around an underlying raw stream which implements the TLS or SSL
34/// protocol.
35///
36/// A `TlsStream<S>` represents a handshake that has been completed successfully
37/// and both the server and the client are ready for receiving and sending
38/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written
39/// to a `TlsStream` are encrypted when passing through to `S`.
40#[derive(Debug)]
41pub struct TlsStream<S>(TlsStreamInner<S>);
42
43impl<S> TlsStream<S> {
44    /// Returns the negotiated ALPN protocol.
45    pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
46        self.0.negotiated_alpn()
47    }
48}
49
50#[cfg(feature = "native-tls")]
51#[doc(hidden)]
52impl<S> From<native_tls::TlsStream<SyncStream<S>>> for TlsStream<S> {
53    fn from(value: native_tls::TlsStream<SyncStream<S>>) -> Self {
54        Self(TlsStreamInner::NativeTls(value))
55    }
56}
57
58#[cfg(feature = "rustls")]
59#[doc(hidden)]
60impl<S> From<futures_rustls::client::TlsStream<AsyncStream<S>>> for TlsStream<S> {
61    fn from(value: futures_rustls::client::TlsStream<AsyncStream<S>>) -> Self {
62        Self(TlsStreamInner::Rustls(futures_rustls::TlsStream::Client(
63            value,
64        )))
65    }
66}
67
68#[cfg(feature = "rustls")]
69#[doc(hidden)]
70impl<S> From<futures_rustls::server::TlsStream<AsyncStream<S>>> for TlsStream<S> {
71    fn from(value: futures_rustls::server::TlsStream<AsyncStream<S>>) -> Self {
72        Self(TlsStreamInner::Rustls(futures_rustls::TlsStream::Server(
73            value,
74        )))
75    }
76}
77
78impl<S: AsyncRead + AsyncWrite + 'static> AsyncRead for TlsStream<S> {
79    async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
80        let slice = buf.as_mut_slice();
81        slice.fill(MaybeUninit::new(0));
82        // SAFETY: The memory has been initialized
83        let slice =
84            unsafe { std::slice::from_raw_parts_mut::<u8>(slice.as_mut_ptr().cast(), slice.len()) };
85        match &mut self.0 {
86            #[cfg(feature = "native-tls")]
87            TlsStreamInner::NativeTls(s) => loop {
88                match io::Read::read(s, slice) {
89                    Ok(res) => {
90                        unsafe { buf.set_buf_init(res) };
91                        return BufResult(Ok(res), buf);
92                    }
93                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
94                        match s.get_mut().fill_read_buf().await {
95                            Ok(_) => continue,
96                            Err(e) => return BufResult(Err(e), buf),
97                        }
98                    }
99                    res => return BufResult(res, buf),
100                }
101            },
102            #[cfg(feature = "rustls")]
103            TlsStreamInner::Rustls(s) => {
104                let res = futures_util::AsyncReadExt::read(s, slice).await;
105                let res = match res {
106                    Ok(len) => {
107                        unsafe { buf.set_buf_init(len) };
108                        Ok(len)
109                    }
110                    // TLS streams may return UnexpectedEof when the connection is closed.
111                    // https://docs.rs/rustls/latest/rustls/manual/_03_howto/index.html#unexpected-eof
112                    Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(0),
113                    _ => res,
114                };
115                BufResult(res, buf)
116            }
117            #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
118            TlsStreamInner::None(f, ..) => match *f {},
119        }
120    }
121}
122
123#[cfg(feature = "native-tls")]
124async fn flush_impl(s: &mut native_tls::TlsStream<SyncStream<impl AsyncWrite>>) -> io::Result<()> {
125    loop {
126        match io::Write::flush(s) {
127            Ok(()) => break,
128            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
129                s.get_mut().flush_write_buf().await?;
130            }
131            Err(e) => return Err(e),
132        }
133    }
134    s.get_mut().flush_write_buf().await?;
135    Ok(())
136}
137
138impl<S: AsyncRead + AsyncWrite + 'static> AsyncWrite for TlsStream<S> {
139    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
140        let slice = buf.as_slice();
141        match &mut self.0 {
142            #[cfg(feature = "native-tls")]
143            TlsStreamInner::NativeTls(s) => loop {
144                let res = io::Write::write(s, slice);
145                match res {
146                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => match flush_impl(s).await {
147                        Ok(_) => continue,
148                        Err(e) => return BufResult(Err(e), buf),
149                    },
150                    _ => return BufResult(res, buf),
151                }
152            },
153            #[cfg(feature = "rustls")]
154            TlsStreamInner::Rustls(s) => {
155                let res = futures_util::AsyncWriteExt::write(s, slice).await;
156                BufResult(res, buf)
157            }
158            #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
159            TlsStreamInner::None(f, ..) => match *f {},
160        }
161    }
162
163    async fn flush(&mut self) -> io::Result<()> {
164        match &mut self.0 {
165            #[cfg(feature = "native-tls")]
166            TlsStreamInner::NativeTls(s) => flush_impl(s).await,
167            #[cfg(feature = "rustls")]
168            TlsStreamInner::Rustls(s) => futures_util::AsyncWriteExt::flush(s).await,
169            #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
170            TlsStreamInner::None(f, ..) => match *f {},
171        }
172    }
173
174    async fn shutdown(&mut self) -> io::Result<()> {
175        self.flush().await?;
176        match &mut self.0 {
177            #[cfg(feature = "native-tls")]
178            TlsStreamInner::NativeTls(s) => s.get_mut().get_mut().shutdown().await,
179            #[cfg(feature = "rustls")]
180            TlsStreamInner::Rustls(s) => futures_util::AsyncWriteExt::close(s).await,
181            #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
182            TlsStreamInner::None(f, ..) => match *f {},
183        }
184    }
185}