gel_stream/common/
stream.rs

1#[cfg(feature = "tokio")]
2use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
3
4use std::future::Future;
5#[cfg(feature = "tokio")]
6use std::{
7    any::Any,
8    io::IoSlice,
9    pin::Pin,
10    task::{Context, Poll},
11};
12
13use crate::{Ssl, SslError, TlsDriver, TlsHandshake, TlsServerParameterProvider};
14
15/// A convenience trait for streams from this crate.
16#[cfg(feature = "tokio")]
17pub trait Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {
18    fn downcast<S: Stream + 'static>(self) -> Result<S, Self>
19    where
20        Self: Sized + 'static,
21    {
22        // Note that we only support Tokio TcpStream for rustls.
23        let mut holder = Some(self);
24        let stream = &mut holder as &mut dyn Any;
25        let Some(stream) = stream.downcast_mut::<Option<S>>() else {
26            return Err(holder.take().unwrap());
27        };
28        let stream = stream.take().unwrap();
29        Ok(stream)
30    }
31}
32
33#[cfg(feature = "tokio")]
34impl<T> Stream for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {}
35
36#[cfg(not(feature = "tokio"))]
37pub trait Stream: 'static {}
38#[cfg(not(feature = "tokio"))]
39impl<S: Stream, D: TlsDriver> Stream for UpgradableStream<S, D> {}
40#[cfg(not(feature = "tokio"))]
41impl Stream for () {}
42
43/// A trait for streams that can be upgraded to a TLS stream.
44pub trait StreamUpgrade: Stream {
45    fn secure_upgrade(&mut self) -> impl Future<Output = Result<(), SslError>> + Send;
46    fn handshake(&self) -> Option<&TlsHandshake>;
47}
48
49#[derive(Default, Debug)]
50struct UpgradableStreamOptions {
51    ignore_missing_close_notify: bool,
52}
53
54#[allow(private_bounds)]
55#[derive(derive_more::Debug)]
56pub struct UpgradableStream<S: Stream, D: TlsDriver = Ssl> {
57    inner: UpgradableStreamInner<S, D>,
58    options: UpgradableStreamOptions,
59}
60
61#[allow(private_bounds)]
62impl<S: Stream, D: TlsDriver> UpgradableStream<S, D> {
63    #[inline(always)]
64    pub(crate) fn new_client(base: S, config: Option<D::ClientParams>) -> Self {
65        UpgradableStream {
66            inner: UpgradableStreamInner::BaseClient(base, config),
67            options: Default::default(),
68        }
69    }
70
71    #[inline(always)]
72    pub(crate) fn new_server(base: S, config: Option<TlsServerParameterProvider>) -> Self {
73        UpgradableStream {
74            inner: UpgradableStreamInner::BaseServer(base, config),
75            options: Default::default(),
76        }
77    }
78
79    /// Consume the `UpgradableStream` and return the underlying stream as a [`Box<dyn Stream>`].
80    pub fn into_boxed(self) -> Result<Box<dyn Stream>, Self> {
81        match self.inner {
82            UpgradableStreamInner::BaseClient(base, _) => Ok(Box::new(base)),
83            UpgradableStreamInner::BaseServer(base, _) => Ok(Box::new(base)),
84            UpgradableStreamInner::Upgraded(upgraded, _) => Ok(Box::new(upgraded)),
85            UpgradableStreamInner::Upgrading => Err(self),
86        }
87    }
88
89    pub fn handshake(&self) -> Option<&TlsHandshake> {
90        match &self.inner {
91            UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
92            _ => None,
93        }
94    }
95
96    pub fn ignore_missing_close_notify(&mut self) {
97        self.options.ignore_missing_close_notify = true;
98    }
99
100    /// Uncleanly shut down the stream. This may cause errors on the peer side
101    /// when using TLS.
102    pub fn unclean_shutdown(self) -> Result<(), Self> {
103        match self.inner {
104            UpgradableStreamInner::BaseClient(..) => Ok(()),
105            UpgradableStreamInner::BaseServer(..) => Ok(()),
106            UpgradableStreamInner::Upgraded(upgraded, cfg) => {
107                if let Err(e) = D::unclean_shutdown(upgraded) {
108                    Err(Self {
109                        inner: UpgradableStreamInner::Upgraded(e, cfg),
110                        options: self.options,
111                    })
112                } else {
113                    Ok(())
114                }
115            }
116            UpgradableStreamInner::Upgrading => Err(self),
117        }
118    }
119}
120
121impl<S: Stream, D: TlsDriver> StreamUpgrade for UpgradableStream<S, D> {
122    async fn secure_upgrade(&mut self) -> Result<(), SslError> {
123        match std::mem::replace(&mut self.inner, UpgradableStreamInner::Upgrading) {
124            UpgradableStreamInner::BaseClient(base, config) => {
125                let Some(config) = config else {
126                    return Err(SslError::SslUnsupportedByClient);
127                };
128                let (upgraded, handshake) = D::upgrade_client(config, base).await?;
129                self.inner = UpgradableStreamInner::Upgraded(upgraded, handshake);
130                Ok(())
131            }
132            UpgradableStreamInner::BaseServer(base, config) => {
133                let Some(config) = config else {
134                    return Err(SslError::SslUnsupportedByClient);
135                };
136                let (upgraded, handshake) = D::upgrade_server(config, base).await?;
137                self.inner = UpgradableStreamInner::Upgraded(upgraded, handshake);
138                Ok(())
139            }
140            UpgradableStreamInner::Upgraded(..) => Err(SslError::SslAlreadyUpgraded),
141            UpgradableStreamInner::Upgrading => Err(SslError::SslAlreadyUpgraded),
142        }
143    }
144
145    fn handshake(&self) -> Option<&TlsHandshake> {
146        match &self.inner {
147            UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
148            _ => None,
149        }
150    }
151}
152
153#[cfg(feature = "tokio")]
154impl<S: Stream, D: TlsDriver> tokio::io::AsyncRead for UpgradableStream<S, D> {
155    #[inline(always)]
156    fn poll_read(
157        self: Pin<&mut Self>,
158        cx: &mut std::task::Context<'_>,
159        buf: &mut tokio::io::ReadBuf<'_>,
160    ) -> std::task::Poll<std::io::Result<()>> {
161        let ignore_missing_close_notify = self.options.ignore_missing_close_notify;
162        let inner = &mut self.get_mut().inner;
163        let res = match inner {
164            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_read(cx, buf),
165            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_read(cx, buf),
166            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_read(cx, buf),
167            UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
168                std::io::ErrorKind::InvalidInput,
169                "Cannot read while upgrading",
170            ))),
171        };
172        if ignore_missing_close_notify {
173            if matches!(res, std::task::Poll::Ready(Err(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof)
174            {
175                eprintln!("Unexpected EOF");
176                return std::task::Poll::Ready(Ok(()));
177            }
178        }
179        res
180    }
181}
182
183#[cfg(feature = "tokio")]
184impl<S: Stream, D: TlsDriver> tokio::io::AsyncWrite for UpgradableStream<S, D> {
185    #[inline(always)]
186    fn poll_write(
187        self: Pin<&mut Self>,
188        cx: &mut std::task::Context<'_>,
189        buf: &[u8],
190    ) -> std::task::Poll<Result<usize, std::io::Error>> {
191        let inner = &mut self.get_mut().inner;
192        match inner {
193            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_write(cx, buf),
194            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_write(cx, buf),
195            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_write(cx, buf),
196            UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
197                std::io::ErrorKind::InvalidInput,
198                "Cannot write while upgrading",
199            ))),
200        }
201    }
202
203    #[inline(always)]
204    fn poll_flush(
205        self: Pin<&mut Self>,
206        cx: &mut std::task::Context<'_>,
207    ) -> std::task::Poll<Result<(), std::io::Error>> {
208        let inner = &mut self.get_mut().inner;
209        match inner {
210            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_flush(cx),
211            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_flush(cx),
212            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_flush(cx),
213            UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
214                std::io::ErrorKind::InvalidInput,
215                "Cannot flush while upgrading",
216            ))),
217        }
218    }
219
220    #[inline(always)]
221    fn poll_shutdown(
222        self: Pin<&mut Self>,
223        cx: &mut std::task::Context<'_>,
224    ) -> std::task::Poll<Result<(), std::io::Error>> {
225        let inner = &mut self.get_mut().inner;
226        match inner {
227            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_shutdown(cx),
228            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_shutdown(cx),
229            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_shutdown(cx),
230            UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
231                std::io::ErrorKind::InvalidInput,
232                "Cannot shutdown while upgrading",
233            ))),
234        }
235    }
236
237    #[inline(always)]
238    fn is_write_vectored(&self) -> bool {
239        match &self.inner {
240            UpgradableStreamInner::BaseClient(base, _) => base.is_write_vectored(),
241            UpgradableStreamInner::BaseServer(base, _) => base.is_write_vectored(),
242            UpgradableStreamInner::Upgraded(upgraded, _) => upgraded.is_write_vectored(),
243            UpgradableStreamInner::Upgrading => false,
244        }
245    }
246
247    #[inline(always)]
248    fn poll_write_vectored(
249        self: Pin<&mut Self>,
250        cx: &mut std::task::Context<'_>,
251        bufs: &[std::io::IoSlice<'_>],
252    ) -> std::task::Poll<Result<usize, std::io::Error>> {
253        let inner = &mut self.get_mut().inner;
254        match inner {
255            UpgradableStreamInner::BaseClient(base, _) => {
256                Pin::new(base).poll_write_vectored(cx, bufs)
257            }
258            UpgradableStreamInner::BaseServer(base, _) => {
259                Pin::new(base).poll_write_vectored(cx, bufs)
260            }
261            UpgradableStreamInner::Upgraded(upgraded, _) => {
262                Pin::new(upgraded).poll_write_vectored(cx, bufs)
263            }
264            UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
265                std::io::ErrorKind::InvalidInput,
266                "Cannot write vectored while upgrading",
267            ))),
268        }
269    }
270}
271
272#[derive(derive_more::Debug)]
273enum UpgradableStreamInner<S: Stream, D: TlsDriver> {
274    #[debug("BaseClient(..)")]
275    BaseClient(S, Option<D::ClientParams>),
276    #[debug("BaseServer(..)")]
277    BaseServer(S, Option<TlsServerParameterProvider>),
278    #[debug("Upgraded(..)")]
279    Upgraded(D::Stream, TlsHandshake),
280    #[debug("Upgrading")]
281    Upgrading,
282}
283
284pub trait Rewindable {
285    fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()>;
286}
287
288pub struct RewindStream<S> {
289    buffer: Vec<u8>,
290    inner: S,
291}
292
293impl<S> RewindStream<S> {
294    pub fn new(inner: S) -> Self {
295        RewindStream {
296            buffer: Vec::new(),
297            inner,
298        }
299    }
300
301    pub fn rewind(&mut self, data: &[u8]) {
302        self.buffer.extend_from_slice(data);
303    }
304
305    pub fn into_inner(self) -> (S, Vec<u8>) {
306        (self.inner, self.buffer)
307    }
308}
309
310#[cfg(feature = "tokio")]
311impl<S: AsyncRead + Unpin> AsyncRead for RewindStream<S> {
312    #[inline(always)]
313    fn poll_read(
314        mut self: Pin<&mut Self>,
315        cx: &mut Context<'_>,
316        buf: &mut ReadBuf<'_>,
317    ) -> Poll<std::io::Result<()>> {
318        if !self.buffer.is_empty() {
319            let to_read = std::cmp::min(buf.remaining(), self.buffer.len());
320            let data = self.buffer.drain(..to_read).collect::<Vec<_>>();
321            buf.put_slice(&data);
322            Poll::Ready(Ok(()))
323        } else {
324            Pin::new(&mut self.inner).poll_read(cx, buf)
325        }
326    }
327}
328
329#[cfg(feature = "tokio")]
330impl<S: AsyncWrite + Unpin> AsyncWrite for RewindStream<S> {
331    #[inline(always)]
332    fn poll_write(
333        mut self: Pin<&mut Self>,
334        cx: &mut Context<'_>,
335        buf: &[u8],
336    ) -> Poll<Result<usize, std::io::Error>> {
337        Pin::new(&mut self.inner).poll_write(cx, buf)
338    }
339
340    #[inline(always)]
341    fn poll_flush(
342        mut self: Pin<&mut Self>,
343        cx: &mut Context<'_>,
344    ) -> Poll<Result<(), std::io::Error>> {
345        Pin::new(&mut self.inner).poll_flush(cx)
346    }
347
348    #[inline(always)]
349    fn poll_shutdown(
350        mut self: Pin<&mut Self>,
351        cx: &mut Context<'_>,
352    ) -> Poll<Result<(), std::io::Error>> {
353        Pin::new(&mut self.inner).poll_shutdown(cx)
354    }
355
356    #[inline(always)]
357    fn is_write_vectored(&self) -> bool {
358        self.inner.is_write_vectored()
359    }
360
361    #[inline(always)]
362    fn poll_write_vectored(
363        mut self: Pin<&mut Self>,
364        cx: &mut Context<'_>,
365        bufs: &[IoSlice<'_>],
366    ) -> Poll<Result<usize, std::io::Error>> {
367        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
368    }
369}
370
371impl<S: Stream> Rewindable for RewindStream<S> {
372    fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
373        self.rewind(bytes);
374        Ok(())
375    }
376}
377
378impl<S: Stream + Rewindable, D: TlsDriver> Rewindable for UpgradableStream<S, D>
379where
380    D::Stream: Rewindable,
381{
382    fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
383        match &mut self.inner {
384            UpgradableStreamInner::BaseClient(stm, _) => stm.rewind(bytes),
385            UpgradableStreamInner::BaseServer(stm, _) => stm.rewind(bytes),
386            UpgradableStreamInner::Upgraded(stm, _) => stm.rewind(bytes),
387            UpgradableStreamInner::Upgrading => Err(std::io::Error::new(
388                std::io::ErrorKind::Unsupported,
389                "Cannot rewind a stream that is upgrading",
390            )),
391        }
392    }
393}