gel_stream/common/
stream.rs

1#[cfg(feature = "tokio")]
2use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
3
4use std::future::Future;
5#[cfg(feature = "tokio")]
6use std::{
7    any::Any,
8    io::IoSlice,
9    pin::Pin,
10    task::{Context, Poll},
11};
12
13use crate::{Ssl, SslError, TlsDriver, TlsHandshake, TlsServerParameterProvider};
14
15/// A convenience trait for streams from this crate.
16#[cfg(feature = "tokio")]
17pub trait Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {
18    fn downcast<S: Stream + 'static>(self) -> Result<S, Self>
19    where
20        Self: Sized + 'static,
21    {
22        // Note that we only support Tokio TcpStream for rustls.
23        let mut holder = Some(self);
24        let stream = &mut holder as &mut dyn Any;
25        let Some(stream) = stream.downcast_mut::<Option<S>>() else {
26            return Err(holder.take().unwrap());
27        };
28        let stream = stream.take().unwrap();
29        Ok(stream)
30    }
31}
32
33#[cfg(feature = "tokio")]
34impl<T> Stream for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {}
35
36#[cfg(not(feature = "tokio"))]
37pub trait Stream: 'static {}
38#[cfg(not(feature = "tokio"))]
39impl<S: Stream, D: TlsDriver> Stream for UpgradableStream<S, D> {}
40#[cfg(not(feature = "tokio"))]
41impl Stream for () {}
42
43/// A trait for streams that can be upgraded to a TLS stream.
44pub trait StreamUpgrade: Stream {
45    fn secure_upgrade(&mut self) -> impl Future<Output = Result<(), SslError>> + Send;
46    fn handshake(&self) -> Option<&TlsHandshake>;
47}
48
49#[derive(Default, Debug)]
50struct UpgradableStreamOptions {
51    ignore_missing_close_notify: bool,
52}
53
54#[allow(private_bounds)]
55#[derive(derive_more::Debug)]
56pub struct UpgradableStream<S: Stream, D: TlsDriver = Ssl> {
57    inner: UpgradableStreamInner<S, D>,
58    options: UpgradableStreamOptions,
59}
60
61#[allow(private_bounds)]
62impl<S: Stream, D: TlsDriver> UpgradableStream<S, D> {
63    #[inline(always)]
64    pub(crate) fn new_client(base: S, config: Option<D::ClientParams>) -> Self {
65        UpgradableStream {
66            inner: UpgradableStreamInner::BaseClient(base, config),
67            options: Default::default(),
68        }
69    }
70
71    #[inline(always)]
72    pub(crate) fn new_server(base: S, config: Option<TlsServerParameterProvider>) -> Self {
73        UpgradableStream {
74            inner: UpgradableStreamInner::BaseServer(base, config),
75            options: Default::default(),
76        }
77    }
78
79    /// Consume the `UpgradableStream` and return the underlying stream as a [`Box<dyn Stream>`].
80    pub fn into_boxed(self) -> Result<Box<dyn Stream>, Self> {
81        match self.inner {
82            UpgradableStreamInner::BaseClient(base, _) => Ok(Box::new(base)),
83            UpgradableStreamInner::BaseServer(base, _) => Ok(Box::new(base)),
84            UpgradableStreamInner::Upgraded(upgraded, _) => Ok(Box::new(upgraded)),
85            UpgradableStreamInner::Upgrading => Err(self),
86        }
87    }
88
89    pub fn handshake(&self) -> Option<&TlsHandshake> {
90        match &self.inner {
91            UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
92            _ => None,
93        }
94    }
95
96    pub fn ignore_missing_close_notify(&mut self) {
97        self.options.ignore_missing_close_notify = true;
98    }
99
100    /// Uncleanly shut down the stream. This may cause errors on the peer side
101    /// when using TLS.
102    pub fn unclean_shutdown(self) -> Result<(), Self> {
103        match self.inner {
104            UpgradableStreamInner::BaseClient(..) => Ok(()),
105            UpgradableStreamInner::BaseServer(..) => Ok(()),
106            UpgradableStreamInner::Upgraded(upgraded, cfg) => {
107                if let Err(e) = D::unclean_shutdown(upgraded) {
108                    Err(Self {
109                        inner: UpgradableStreamInner::Upgraded(e, cfg),
110                        options: self.options,
111                    })
112                } else {
113                    Ok(())
114                }
115            }
116            UpgradableStreamInner::Upgrading => Err(self),
117        }
118    }
119}
120
121impl<S: Stream, D: TlsDriver> StreamUpgrade for UpgradableStream<S, D> {
122    async fn secure_upgrade(&mut self) -> Result<(), SslError> {
123        match std::mem::replace(&mut self.inner, UpgradableStreamInner::Upgrading) {
124            UpgradableStreamInner::BaseClient(base, config) => {
125                let Some(config) = config else {
126                    return Err(SslError::SslUnsupportedByClient);
127                };
128                let (upgraded, handshake) = D::upgrade_client(config, base).await?;
129                self.inner = UpgradableStreamInner::Upgraded(upgraded, handshake);
130                Ok(())
131            }
132            UpgradableStreamInner::BaseServer(base, config) => {
133                let Some(config) = config else {
134                    return Err(SslError::SslUnsupportedByClient);
135                };
136                let (upgraded, handshake) = D::upgrade_server(config, base).await?;
137                self.inner = UpgradableStreamInner::Upgraded(upgraded, handshake);
138                Ok(())
139            }
140            UpgradableStreamInner::Upgraded(..) => Err(SslError::SslAlreadyUpgraded),
141            UpgradableStreamInner::Upgrading => Err(SslError::SslAlreadyUpgraded),
142        }
143    }
144
145    fn handshake(&self) -> Option<&TlsHandshake> {
146        match &self.inner {
147            UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
148            _ => None,
149        }
150    }
151}
152
153#[cfg(feature = "tokio")]
154impl<S: Stream, D: TlsDriver> tokio::io::AsyncRead for UpgradableStream<S, D> {
155    #[inline(always)]
156    fn poll_read(
157        self: Pin<&mut Self>,
158        cx: &mut std::task::Context<'_>,
159        buf: &mut tokio::io::ReadBuf<'_>,
160    ) -> std::task::Poll<std::io::Result<()>> {
161        let ignore_missing_close_notify = self.options.ignore_missing_close_notify;
162        let inner = &mut self.get_mut().inner;
163        let res = match inner {
164            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_read(cx, buf),
165            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_read(cx, buf),
166            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_read(cx, buf),
167            UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
168                std::io::ErrorKind::InvalidInput,
169                "Cannot read while upgrading",
170            ))),
171        };
172        if ignore_missing_close_notify {
173            if matches!(res, std::task::Poll::Ready(Err(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof)
174            {
175                return std::task::Poll::Ready(Ok(()));
176            }
177        }
178        res
179    }
180}
181
182#[cfg(feature = "tokio")]
183impl<S: Stream, D: TlsDriver> tokio::io::AsyncWrite for UpgradableStream<S, D> {
184    #[inline(always)]
185    fn poll_write(
186        self: Pin<&mut Self>,
187        cx: &mut std::task::Context<'_>,
188        buf: &[u8],
189    ) -> std::task::Poll<Result<usize, std::io::Error>> {
190        let inner = &mut self.get_mut().inner;
191        match inner {
192            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_write(cx, buf),
193            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_write(cx, buf),
194            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_write(cx, buf),
195            UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
196                std::io::ErrorKind::InvalidInput,
197                "Cannot write while upgrading",
198            ))),
199        }
200    }
201
202    #[inline(always)]
203    fn poll_flush(
204        self: Pin<&mut Self>,
205        cx: &mut std::task::Context<'_>,
206    ) -> std::task::Poll<Result<(), std::io::Error>> {
207        let inner = &mut self.get_mut().inner;
208        match inner {
209            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_flush(cx),
210            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_flush(cx),
211            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_flush(cx),
212            UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
213                std::io::ErrorKind::InvalidInput,
214                "Cannot flush while upgrading",
215            ))),
216        }
217    }
218
219    #[inline(always)]
220    fn poll_shutdown(
221        self: Pin<&mut Self>,
222        cx: &mut std::task::Context<'_>,
223    ) -> std::task::Poll<Result<(), std::io::Error>> {
224        let inner = &mut self.get_mut().inner;
225        match inner {
226            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_shutdown(cx),
227            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_shutdown(cx),
228            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_shutdown(cx),
229            UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
230                std::io::ErrorKind::InvalidInput,
231                "Cannot shutdown while upgrading",
232            ))),
233        }
234    }
235
236    #[inline(always)]
237    fn is_write_vectored(&self) -> bool {
238        match &self.inner {
239            UpgradableStreamInner::BaseClient(base, _) => base.is_write_vectored(),
240            UpgradableStreamInner::BaseServer(base, _) => base.is_write_vectored(),
241            UpgradableStreamInner::Upgraded(upgraded, _) => upgraded.is_write_vectored(),
242            UpgradableStreamInner::Upgrading => false,
243        }
244    }
245
246    #[inline(always)]
247    fn poll_write_vectored(
248        self: Pin<&mut Self>,
249        cx: &mut std::task::Context<'_>,
250        bufs: &[std::io::IoSlice<'_>],
251    ) -> std::task::Poll<Result<usize, std::io::Error>> {
252        let inner = &mut self.get_mut().inner;
253        match inner {
254            UpgradableStreamInner::BaseClient(base, _) => {
255                Pin::new(base).poll_write_vectored(cx, bufs)
256            }
257            UpgradableStreamInner::BaseServer(base, _) => {
258                Pin::new(base).poll_write_vectored(cx, bufs)
259            }
260            UpgradableStreamInner::Upgraded(upgraded, _) => {
261                Pin::new(upgraded).poll_write_vectored(cx, bufs)
262            }
263            UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
264                std::io::ErrorKind::InvalidInput,
265                "Cannot write vectored while upgrading",
266            ))),
267        }
268    }
269}
270
271#[derive(derive_more::Debug)]
272enum UpgradableStreamInner<S: Stream, D: TlsDriver> {
273    #[debug("BaseClient(..)")]
274    BaseClient(S, Option<D::ClientParams>),
275    #[debug("BaseServer(..)")]
276    BaseServer(S, Option<TlsServerParameterProvider>),
277    #[debug("Upgraded(..)")]
278    Upgraded(D::Stream, TlsHandshake),
279    #[debug("Upgrading")]
280    Upgrading,
281}
282
283pub trait Rewindable {
284    fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()>;
285}
286
287pub struct RewindStream<S> {
288    buffer: Vec<u8>,
289    inner: S,
290}
291
292impl<S> RewindStream<S> {
293    pub fn new(inner: S) -> Self {
294        RewindStream {
295            buffer: Vec::new(),
296            inner,
297        }
298    }
299
300    pub fn rewind(&mut self, data: &[u8]) {
301        self.buffer.extend_from_slice(data);
302    }
303
304    pub fn into_inner(self) -> (S, Vec<u8>) {
305        (self.inner, self.buffer)
306    }
307}
308
309#[cfg(feature = "tokio")]
310impl<S: AsyncRead + Unpin> AsyncRead for RewindStream<S> {
311    #[inline(always)]
312    fn poll_read(
313        mut self: Pin<&mut Self>,
314        cx: &mut Context<'_>,
315        buf: &mut ReadBuf<'_>,
316    ) -> Poll<std::io::Result<()>> {
317        if !self.buffer.is_empty() {
318            let to_read = std::cmp::min(buf.remaining(), self.buffer.len());
319            let data = self.buffer.drain(..to_read).collect::<Vec<_>>();
320            buf.put_slice(&data);
321            Poll::Ready(Ok(()))
322        } else {
323            Pin::new(&mut self.inner).poll_read(cx, buf)
324        }
325    }
326}
327
328#[cfg(feature = "tokio")]
329impl<S: AsyncWrite + Unpin> AsyncWrite for RewindStream<S> {
330    #[inline(always)]
331    fn poll_write(
332        mut self: Pin<&mut Self>,
333        cx: &mut Context<'_>,
334        buf: &[u8],
335    ) -> Poll<Result<usize, std::io::Error>> {
336        Pin::new(&mut self.inner).poll_write(cx, buf)
337    }
338
339    #[inline(always)]
340    fn poll_flush(
341        mut self: Pin<&mut Self>,
342        cx: &mut Context<'_>,
343    ) -> Poll<Result<(), std::io::Error>> {
344        Pin::new(&mut self.inner).poll_flush(cx)
345    }
346
347    #[inline(always)]
348    fn poll_shutdown(
349        mut self: Pin<&mut Self>,
350        cx: &mut Context<'_>,
351    ) -> Poll<Result<(), std::io::Error>> {
352        Pin::new(&mut self.inner).poll_shutdown(cx)
353    }
354
355    #[inline(always)]
356    fn is_write_vectored(&self) -> bool {
357        self.inner.is_write_vectored()
358    }
359
360    #[inline(always)]
361    fn poll_write_vectored(
362        mut self: Pin<&mut Self>,
363        cx: &mut Context<'_>,
364        bufs: &[IoSlice<'_>],
365    ) -> Poll<Result<usize, std::io::Error>> {
366        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
367    }
368}
369
370impl<S: Stream> Rewindable for RewindStream<S> {
371    fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
372        self.rewind(bytes);
373        Ok(())
374    }
375}
376
377impl<S: Stream + Rewindable, D: TlsDriver> Rewindable for UpgradableStream<S, D>
378where
379    D::Stream: Rewindable,
380{
381    fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
382        match &mut self.inner {
383            UpgradableStreamInner::BaseClient(stm, _) => stm.rewind(bytes),
384            UpgradableStreamInner::BaseServer(stm, _) => stm.rewind(bytes),
385            UpgradableStreamInner::Upgraded(stm, _) => stm.rewind(bytes),
386            UpgradableStreamInner::Upgrading => Err(std::io::Error::new(
387                std::io::ErrorKind::Unsupported,
388                "Cannot rewind a stream that is upgrading",
389            )),
390        }
391    }
392}