gel_stream/common/
stream.rs

1#[cfg(feature = "tokio")]
2use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
3
4use std::future::Future;
5#[cfg(feature = "tokio")]
6use std::{
7    any::Any,
8    io::IoSlice,
9    pin::Pin,
10    task::{Context, Poll},
11};
12
13use crate::{Ssl, SslError, TlsDriver, TlsHandshake, TlsServerParameterProvider};
14
15#[cfg(feature = "tokio")]
16pub trait Stream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {
17    fn downcast<S: Stream + 'static>(self) -> Result<S, Self>
18    where
19        Self: Sized + 'static,
20    {
21        // Note that we only support Tokio TcpStream for rustls.
22        let mut holder = Some(self);
23        let stream = &mut holder as &mut dyn Any;
24        let Some(stream) = stream.downcast_mut::<Option<S>>() else {
25            return Err(holder.take().unwrap());
26        };
27        let stream = stream.take().unwrap();
28        Ok(stream)
29    }
30}
31
32#[cfg(feature = "tokio")]
33impl<T> Stream for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static {}
34
35#[cfg(not(feature = "tokio"))]
36pub trait Stream: 'static {}
37#[cfg(not(feature = "tokio"))]
38impl<S: Stream, D: TlsDriver> Stream for UpgradableStream<S, D> {}
39#[cfg(not(feature = "tokio"))]
40impl Stream for () {}
41
42pub trait StreamUpgrade: Stream {
43    fn secure_upgrade(&mut self) -> impl Future<Output = Result<(), SslError>> + Send;
44    fn handshake(&self) -> Option<&TlsHandshake>;
45}
46
47#[allow(private_bounds)]
48#[derive(derive_more::Debug)]
49pub struct UpgradableStream<S: Stream, D: TlsDriver = Ssl> {
50    inner: UpgradableStreamInner<S, D>,
51}
52
53#[allow(private_bounds)]
54impl<S: Stream, D: TlsDriver> UpgradableStream<S, D> {
55    #[inline(always)]
56    pub(crate) fn new_client(base: S, config: Option<D::ClientParams>) -> Self {
57        UpgradableStream {
58            inner: UpgradableStreamInner::BaseClient(base, config),
59        }
60    }
61
62    #[inline(always)]
63    pub(crate) fn new_server(base: S, config: Option<TlsServerParameterProvider>) -> Self {
64        UpgradableStream {
65            inner: UpgradableStreamInner::BaseServer(base, config),
66        }
67    }
68
69    /// Consume the `UpgradableStream` and return the underlying stream as a [`Box<dyn Stream>`].
70    pub fn into_boxed(self) -> Result<Box<dyn Stream>, Self> {
71        match self.inner {
72            UpgradableStreamInner::BaseClient(base, _) => Ok(Box::new(base)),
73            UpgradableStreamInner::BaseServer(base, _) => Ok(Box::new(base)),
74            UpgradableStreamInner::Upgraded(upgraded, _) => Ok(Box::new(upgraded)),
75            UpgradableStreamInner::Upgrading => Err(self),
76        }
77    }
78
79    pub fn handshake(&self) -> Option<&TlsHandshake> {
80        match &self.inner {
81            UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
82            _ => None,
83        }
84    }
85}
86
87impl<S: Stream, D: TlsDriver> StreamUpgrade for UpgradableStream<S, D> {
88    async fn secure_upgrade(&mut self) -> Result<(), SslError> {
89        match std::mem::replace(&mut self.inner, UpgradableStreamInner::Upgrading) {
90            UpgradableStreamInner::BaseClient(base, config) => {
91                let Some(config) = config else {
92                    return Err(SslError::SslUnsupportedByClient);
93                };
94                let (upgraded, handshake) = D::upgrade_client(config, base).await?;
95                self.inner = UpgradableStreamInner::Upgraded(upgraded, handshake);
96                Ok(())
97            }
98            UpgradableStreamInner::BaseServer(base, config) => {
99                let Some(config) = config else {
100                    return Err(SslError::SslUnsupportedByClient);
101                };
102                let (upgraded, handshake) = D::upgrade_server(config, base).await?;
103                self.inner = UpgradableStreamInner::Upgraded(upgraded, handshake);
104                Ok(())
105            }
106            UpgradableStreamInner::Upgraded(..) => Err(SslError::SslAlreadyUpgraded),
107            UpgradableStreamInner::Upgrading => Err(SslError::SslAlreadyUpgraded),
108        }
109    }
110
111    fn handshake(&self) -> Option<&TlsHandshake> {
112        match &self.inner {
113            UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
114            _ => None,
115        }
116    }
117}
118
119#[cfg(feature = "tokio")]
120impl<S: Stream, D: TlsDriver> tokio::io::AsyncRead for UpgradableStream<S, D> {
121    #[inline(always)]
122    fn poll_read(
123        self: Pin<&mut Self>,
124        cx: &mut std::task::Context<'_>,
125        buf: &mut tokio::io::ReadBuf<'_>,
126    ) -> std::task::Poll<std::io::Result<()>> {
127        let inner = &mut self.get_mut().inner;
128        match inner {
129            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_read(cx, buf),
130            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_read(cx, buf),
131            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_read(cx, buf),
132            UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
133                std::io::ErrorKind::InvalidInput,
134                "Cannot read while upgrading",
135            ))),
136        }
137    }
138}
139
140#[cfg(feature = "tokio")]
141impl<S: Stream, D: TlsDriver> tokio::io::AsyncWrite for UpgradableStream<S, D> {
142    #[inline(always)]
143    fn poll_write(
144        self: Pin<&mut Self>,
145        cx: &mut std::task::Context<'_>,
146        buf: &[u8],
147    ) -> std::task::Poll<Result<usize, std::io::Error>> {
148        let inner = &mut self.get_mut().inner;
149        match inner {
150            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_write(cx, buf),
151            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_write(cx, buf),
152            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_write(cx, buf),
153            UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
154                std::io::ErrorKind::InvalidInput,
155                "Cannot write while upgrading",
156            ))),
157        }
158    }
159
160    #[inline(always)]
161    fn poll_flush(
162        self: Pin<&mut Self>,
163        cx: &mut std::task::Context<'_>,
164    ) -> std::task::Poll<Result<(), std::io::Error>> {
165        let inner = &mut self.get_mut().inner;
166        match inner {
167            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_flush(cx),
168            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_flush(cx),
169            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_flush(cx),
170            UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
171                std::io::ErrorKind::InvalidInput,
172                "Cannot flush while upgrading",
173            ))),
174        }
175    }
176
177    #[inline(always)]
178    fn poll_shutdown(
179        self: Pin<&mut Self>,
180        cx: &mut std::task::Context<'_>,
181    ) -> std::task::Poll<Result<(), std::io::Error>> {
182        let inner = &mut self.get_mut().inner;
183        match inner {
184            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_shutdown(cx),
185            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_shutdown(cx),
186            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_shutdown(cx),
187            UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
188                std::io::ErrorKind::InvalidInput,
189                "Cannot shutdown while upgrading",
190            ))),
191        }
192    }
193
194    #[inline(always)]
195    fn is_write_vectored(&self) -> bool {
196        match &self.inner {
197            UpgradableStreamInner::BaseClient(base, _) => base.is_write_vectored(),
198            UpgradableStreamInner::BaseServer(base, _) => base.is_write_vectored(),
199            UpgradableStreamInner::Upgraded(upgraded, _) => upgraded.is_write_vectored(),
200            UpgradableStreamInner::Upgrading => false,
201        }
202    }
203
204    #[inline(always)]
205    fn poll_write_vectored(
206        self: Pin<&mut Self>,
207        cx: &mut std::task::Context<'_>,
208        bufs: &[std::io::IoSlice<'_>],
209    ) -> std::task::Poll<Result<usize, std::io::Error>> {
210        let inner = &mut self.get_mut().inner;
211        match inner {
212            UpgradableStreamInner::BaseClient(base, _) => {
213                Pin::new(base).poll_write_vectored(cx, bufs)
214            }
215            UpgradableStreamInner::BaseServer(base, _) => {
216                Pin::new(base).poll_write_vectored(cx, bufs)
217            }
218            UpgradableStreamInner::Upgraded(upgraded, _) => {
219                Pin::new(upgraded).poll_write_vectored(cx, bufs)
220            }
221            UpgradableStreamInner::Upgrading => std::task::Poll::Ready(Err(std::io::Error::new(
222                std::io::ErrorKind::InvalidInput,
223                "Cannot write vectored while upgrading",
224            ))),
225        }
226    }
227}
228
229#[derive(derive_more::Debug)]
230enum UpgradableStreamInner<S: Stream, D: TlsDriver> {
231    #[debug("BaseClient(..)")]
232    BaseClient(S, Option<D::ClientParams>),
233    #[debug("BaseServer(..)")]
234    BaseServer(S, Option<TlsServerParameterProvider>),
235    #[debug("Upgraded(..)")]
236    Upgraded(D::Stream, TlsHandshake),
237    #[debug("Upgrading")]
238    Upgrading,
239}
240
241pub trait Rewindable {
242    fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()>;
243}
244
245pub struct RewindStream<S> {
246    buffer: Vec<u8>,
247    inner: S,
248}
249
250impl<S> RewindStream<S> {
251    pub fn new(inner: S) -> Self {
252        RewindStream {
253            buffer: Vec::new(),
254            inner,
255        }
256    }
257
258    pub fn rewind(&mut self, data: &[u8]) {
259        self.buffer.extend_from_slice(data);
260    }
261
262    pub fn into_inner(self) -> (S, Vec<u8>) {
263        (self.inner, self.buffer)
264    }
265}
266
267#[cfg(feature = "tokio")]
268impl<S: AsyncRead + Unpin> AsyncRead for RewindStream<S> {
269    #[inline(always)]
270    fn poll_read(
271        mut self: Pin<&mut Self>,
272        cx: &mut Context<'_>,
273        buf: &mut ReadBuf<'_>,
274    ) -> Poll<std::io::Result<()>> {
275        if !self.buffer.is_empty() {
276            let to_read = std::cmp::min(buf.remaining(), self.buffer.len());
277            let data = self.buffer.drain(..to_read).collect::<Vec<_>>();
278            buf.put_slice(&data);
279            Poll::Ready(Ok(()))
280        } else {
281            Pin::new(&mut self.inner).poll_read(cx, buf)
282        }
283    }
284}
285
286#[cfg(feature = "tokio")]
287impl<S: AsyncWrite + Unpin> AsyncWrite for RewindStream<S> {
288    #[inline(always)]
289    fn poll_write(
290        mut self: Pin<&mut Self>,
291        cx: &mut Context<'_>,
292        buf: &[u8],
293    ) -> Poll<Result<usize, std::io::Error>> {
294        Pin::new(&mut self.inner).poll_write(cx, buf)
295    }
296
297    #[inline(always)]
298    fn poll_flush(
299        mut self: Pin<&mut Self>,
300        cx: &mut Context<'_>,
301    ) -> Poll<Result<(), std::io::Error>> {
302        Pin::new(&mut self.inner).poll_flush(cx)
303    }
304
305    #[inline(always)]
306    fn poll_shutdown(
307        mut self: Pin<&mut Self>,
308        cx: &mut Context<'_>,
309    ) -> Poll<Result<(), std::io::Error>> {
310        Pin::new(&mut self.inner).poll_shutdown(cx)
311    }
312
313    #[inline(always)]
314    fn is_write_vectored(&self) -> bool {
315        self.inner.is_write_vectored()
316    }
317
318    #[inline(always)]
319    fn poll_write_vectored(
320        mut self: Pin<&mut Self>,
321        cx: &mut Context<'_>,
322        bufs: &[IoSlice<'_>],
323    ) -> Poll<Result<usize, std::io::Error>> {
324        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
325    }
326}
327
328impl<S: Stream> Rewindable for RewindStream<S> {
329    fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
330        self.rewind(bytes);
331        Ok(())
332    }
333}
334
335impl<S: Stream + Rewindable, D: TlsDriver> Rewindable for UpgradableStream<S, D>
336where
337    D::Stream: Rewindable,
338{
339    fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
340        match &mut self.inner {
341            UpgradableStreamInner::BaseClient(stm, _) => stm.rewind(bytes),
342            UpgradableStreamInner::BaseServer(stm, _) => stm.rewind(bytes),
343            UpgradableStreamInner::Upgraded(stm, _) => stm.rewind(bytes),
344            UpgradableStreamInner::Upgrading => Err(std::io::Error::new(
345                std::io::ErrorKind::Unsupported,
346                "Cannot rewind a stream that is upgrading",
347            )),
348        }
349    }
350}