gel_stream/common/
stream.rs

1#[cfg(feature = "tokio")]
2use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
3
4#[cfg(feature = "tokio")]
5use std::{
6    any::Any,
7    pin::Pin,
8    task::{Context, Poll},
9};
10use std::{future::Future, num::NonZeroUsize, ops::Deref};
11
12use crate::{
13    LocalAddress, PeerCred, RemoteAddress, ResolvedTarget, Ssl, SslError, StreamMetadata,
14    TlsDriver, TlsHandshake, TlsServerParameterProvider, Transport, DEFAULT_PREVIEW_BUFFER_SIZE,
15};
16
17/// A trait for streams that can be converted to a handle or file descriptor.
18#[cfg(unix)]
19pub trait AsHandle {
20    fn as_fd(&self) -> std::os::fd::BorrowedFd;
21}
22
23/// A trait for streams that can be converted to a handle or file descriptor.
24#[cfg(windows)]
25pub trait AsHandle {
26    fn as_handle(&self) -> std::os::windows::io::BorrowedSocket;
27}
28
29/// A convenience trait for streams from this crate.
30#[cfg(feature = "tokio")]
31pub trait Stream:
32    tokio::io::AsyncRead + tokio::io::AsyncWrite + StreamMetadata + Send + Unpin + AsHandle + 'static
33{
34    /// Attempt to downcast a generic stream to a specific stream type.
35    fn downcast<S: Stream + 'static>(self) -> Result<S, Self>
36    where
37        Self: Sized + 'static,
38    {
39        let mut holder = Some(self);
40        let stream = &mut holder as &mut dyn Any;
41        if let Some(stream) = stream.downcast_mut::<Option<S>>() {
42            return Ok(stream.take().unwrap());
43        }
44        if let Some(stream) = stream.downcast_mut::<Option<Box<S>>>() {
45            return Ok(*stream.take().unwrap());
46        }
47        Err(holder.take().unwrap())
48    }
49
50    /// Box the stream as a `Box<dyn Stream + Send>`.
51    fn boxed(self) -> Box<dyn Stream + Send>
52    where
53        Self: Sized + 'static,
54    {
55        let mut holder = Some(self);
56        let stream = &mut holder as &mut dyn Any;
57        if let Some(stream) = stream.downcast_mut::<Option<Box<dyn Stream>>>() {
58            stream.take().unwrap()
59        } else {
60            Box::new(holder.take().unwrap())
61        }
62    }
63}
64
65#[cfg(feature = "tokio")]
66impl<T> Stream for T where
67    T: tokio::io::AsyncRead
68        + tokio::io::AsyncWrite
69        + StreamMetadata
70        + AsHandle
71        + Unpin
72        + Send
73        + 'static
74{
75}
76
77// NOTE: Once we're on Rust 1.87, we can use trait upcasting and get rid of this impl.
78impl PeerCred for Box<dyn Stream + Send> {
79    #[cfg(all(unix, feature = "tokio"))]
80    fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
81        self.as_ref().peer_cred()
82    }
83}
84
85impl LocalAddress for Box<dyn Stream + Send> {
86    fn local_address(&self) -> std::io::Result<ResolvedTarget> {
87        self.as_ref().local_address()
88    }
89}
90
91impl RemoteAddress for Box<dyn Stream + Send> {
92    fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
93        self.as_ref().remote_address()
94    }
95}
96
97impl StreamMetadata for Box<dyn Stream + Send> {
98    fn transport(&self) -> Transport {
99        self.as_ref().transport()
100    }
101}
102
103#[cfg(not(feature = "tokio"))]
104pub trait Stream: StreamMetadata + Unpin + AsHandle + 'static {}
105#[cfg(not(feature = "tokio"))]
106impl<S: Stream, D: TlsDriver> Stream for UpgradableStream<S, D> {}
107#[cfg(not(feature = "tokio"))]
108impl Stream for () {}
109
110/// A trait for streams that can be upgraded to a TLS stream.
111pub trait StreamUpgrade: Stream + Sized {
112    /// Upgrade the stream to a TLS stream.
113    fn secure_upgrade(self) -> impl Future<Output = Result<Self, SslError>> + Send;
114    /// Upgrade the stream to a TLS stream, and preview the initial bytes.
115    fn secure_upgrade_preview(
116        self,
117        options: PreviewConfiguration,
118    ) -> impl Future<Output = Result<(Preview, Self), SslError>> + Send;
119    /// Get the TLS handshake information, if the stream is upgraded.
120    fn handshake(&self) -> Option<&TlsHandshake>;
121}
122
123#[cfg(feature = "optimization")]
124#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
125pub enum StreamOptimization {
126    #[default]
127    /// Optimize for general use.
128    General,
129    /// Optimize for interactive use with low latency.
130    Interactive,
131    /// Optimize for bulk streaming.
132    BulkStreaming(BulkStreamDirection),
133}
134
135#[cfg(feature = "optimization")]
136#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
137pub enum BulkStreamDirection {
138    Send,
139    Receive,
140    #[default]
141    Both,
142}
143
144/// A trait for streams that can provide a `socket2::SockRef`.
145#[cfg(any(feature = "optimization", feature = "keepalive"))]
146fn with_socket2<S: AsHandle + Sized>(
147    stream: &S,
148    f: &mut dyn for<'a> FnMut(socket2::SockRef<'a>) -> Result<(), std::io::Error>,
149) -> Result<(), std::io::Error> {
150    #[cfg(unix)]
151    let res = f(socket2::SockRef::from(&stream.as_fd()));
152    #[cfg(windows)]
153    let res = f(socket2::SockRef::from(&stream.as_handle()));
154    res
155}
156
157#[cfg(feature = "optimization")]
158pub trait StreamOptimizationExt: Stream + Sized {
159    /// Optimize the stream for the given optimization.
160    #[cfg(feature = "optimization")]
161    fn optimize_for(&mut self, optimization: StreamOptimization) -> Result<(), std::io::Error> {
162        macro_rules! try_optimize(
163            ( $s:ident . $method:ident ( $($args:tt)* ) ) => {
164                $s.$method($($args)*).map_err(|e: std::io::Error| std::io::Error::new(e.kind(), format!("{}: {}", stringify!($method), e)))
165            };
166        );
167
168        #[cfg(unix)]
169        if self.transport() == Transport::Unix {
170            return Ok(());
171        }
172
173        let mut with_socket2_fn = move |s: socket2::SockRef<'_>| {
174            #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
175            try_optimize!(s.set_cork(false))?;
176
177            #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
178            try_optimize!(s.set_quickack(false))?;
179
180            #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
181            try_optimize!(s.set_thin_linear_timeouts(false))?;
182
183            try_optimize!(s.set_send_buffer_size(256 * 1024))?;
184            try_optimize!(s.set_recv_buffer_size(256 * 1024))?;
185
186            match optimization {
187                StreamOptimization::General => {
188                    try_optimize!(s.set_nodelay(false))?;
189                    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
190                    try_optimize!(s.set_thin_linear_timeouts(true))?;
191                }
192                StreamOptimization::Interactive => {
193                    try_optimize!(s.set_nodelay(true))?;
194                    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
195                    try_optimize!(s.set_thin_linear_timeouts(true))?;
196                }
197                StreamOptimization::BulkStreaming(direction) => {
198                    try_optimize!(s.set_nodelay(false))?;
199                    // Handle send buffer size
200                    match direction {
201                        BulkStreamDirection::Send | BulkStreamDirection::Both => {
202                            try_optimize!(s.set_send_buffer_size(16 * 1024 * 1024))?;
203                            #[cfg(any(
204                                target_os = "android",
205                                target_os = "fuchsia",
206                                target_os = "linux"
207                            ))]
208                            try_optimize!(s.set_cork(true))?;
209                        }
210                        BulkStreamDirection::Receive => {}
211                    }
212
213                    // Handle receive buffer size
214                    match direction {
215                        BulkStreamDirection::Receive | BulkStreamDirection::Both => {
216                            try_optimize!(s.set_recv_buffer_size(16 * 1024 * 1024))?;
217                            #[cfg(any(
218                                target_os = "android",
219                                target_os = "fuchsia",
220                                target_os = "linux"
221                            ))]
222                            try_optimize!(s.set_quickack(true))?;
223                        }
224                        BulkStreamDirection::Send => {}
225                    }
226                }
227            }
228            Ok(())
229        };
230
231        with_socket2(self, &mut with_socket2_fn)
232    }
233}
234
235#[cfg(feature = "optimization")]
236impl<S: Stream + Sized> StreamOptimizationExt for S {}
237
238/// A trait for streams that can be peeked asynchronously.
239pub trait PeekableStream: Stream {
240    #[cfg(feature = "tokio")]
241    fn poll_peek(
242        self: Pin<&mut Self>,
243        cx: &mut Context<'_>,
244        buf: &mut tokio::io::ReadBuf,
245    ) -> Poll<std::io::Result<usize>>;
246    #[cfg(feature = "tokio")]
247    fn peek(self: Pin<&mut Self>, buf: &mut [u8]) -> impl Future<Output = std::io::Result<usize>> {
248        async {
249            let mut this = self;
250            std::future::poll_fn(move |cx| this.as_mut().poll_peek(cx, &mut ReadBuf::new(buf)))
251                .await
252        }
253    }
254}
255
256/// A preview of the initial bytes of the stream.
257#[derive(Debug, Clone, PartialEq, Eq)]
258#[must_use]
259pub struct Preview {
260    buffer: smallvec::SmallVec<[u8; DEFAULT_PREVIEW_BUFFER_SIZE as usize]>,
261}
262
263impl Preview {
264    pub(crate) fn new(
265        buffer: smallvec::SmallVec<[u8; DEFAULT_PREVIEW_BUFFER_SIZE as usize]>,
266    ) -> Self {
267        Self { buffer }
268    }
269}
270
271impl Deref for Preview {
272    type Target = [u8];
273    fn deref(&self) -> &Self::Target {
274        &self.buffer
275    }
276}
277
278impl AsRef<[u8]> for Preview {
279    fn as_ref(&self) -> &[u8] {
280        &self.buffer
281    }
282}
283
284impl<const N: usize> PartialEq<[u8; N]> for Preview {
285    fn eq(&self, other: &[u8; N]) -> bool {
286        self.buffer.as_slice() == other
287    }
288}
289
290impl<const N: usize> PartialEq<&[u8; N]> for Preview {
291    fn eq(&self, other: &&[u8; N]) -> bool {
292        self.buffer.as_slice() == *other
293    }
294}
295
296impl PartialEq<[u8]> for Preview {
297    fn eq(&self, other: &[u8]) -> bool {
298        self.buffer.as_slice() == other
299    }
300}
301
302/// Configuration for the initial preview of the client connection.
303#[derive(Debug, Clone, Copy)]
304pub struct PreviewConfiguration {
305    /// The maximum number of bytes to preview. Recommended value is 8 bytes.
306    pub max_preview_bytes: NonZeroUsize,
307    /// The maximum duration to preview for. Recommended value is 10 seconds.
308    pub max_preview_duration: std::time::Duration,
309}
310
311impl Default for PreviewConfiguration {
312    fn default() -> Self {
313        Self {
314            max_preview_bytes: NonZeroUsize::new(DEFAULT_PREVIEW_BUFFER_SIZE as usize).unwrap(),
315            max_preview_duration: std::time::Duration::from_secs(10),
316        }
317    }
318}
319
320#[derive(Default, Debug)]
321struct UpgradableStreamOptions {
322    ignore_missing_close_notify: bool,
323}
324
325#[allow(private_bounds)]
326#[derive(derive_more::Debug, derive_io::AsyncWrite, derive_io::AsSocketDescriptor)]
327pub struct UpgradableStream<S: Stream, D: TlsDriver = Ssl> {
328    #[write]
329    #[descriptor]
330    inner: UpgradableStreamInner<S, D>,
331    options: UpgradableStreamOptions,
332}
333
334#[allow(private_bounds)]
335impl<S: Stream, D: TlsDriver> UpgradableStream<S, D> {
336    #[inline(always)]
337    pub(crate) fn new_client(base: S, config: Option<D::ClientParams>) -> Self {
338        UpgradableStream {
339            inner: UpgradableStreamInner::BaseClient(base, config),
340            options: Default::default(),
341        }
342    }
343
344    #[inline(always)]
345    pub(crate) fn new_server(base: S, config: Option<TlsServerParameterProvider>) -> Self {
346        UpgradableStream {
347            inner: UpgradableStreamInner::BaseServer(base, config),
348            options: Default::default(),
349        }
350    }
351
352    #[inline(always)]
353    pub(crate) fn new_server_preview(
354        base: RewindStream<S>,
355        config: Option<TlsServerParameterProvider>,
356    ) -> Self {
357        UpgradableStream {
358            inner: UpgradableStreamInner::BaseServerPreview(base, config),
359            options: Default::default(),
360        }
361    }
362
363    /// Consume the `UpgradableStream` and return the underlying stream as a [`Box<dyn Stream>`].
364    pub fn into_boxed(self) -> Result<Box<dyn Stream>, Self> {
365        match self.inner {
366            UpgradableStreamInner::BaseClient(base, _) => Ok(Box::new(base)),
367            UpgradableStreamInner::BaseServer(base, _) => Ok(Box::new(base)),
368            UpgradableStreamInner::BaseServerPreview(base, _) => Ok(Box::new(base)),
369            UpgradableStreamInner::Upgraded(upgraded, _) => Ok(Box::new(upgraded)),
370            UpgradableStreamInner::UpgradedPreview(upgraded, _) => Ok(Box::new(upgraded)),
371        }
372    }
373
374    pub fn handshake(&self) -> Option<&TlsHandshake> {
375        match &self.inner {
376            UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
377            _ => None,
378        }
379    }
380
381    pub fn ignore_missing_close_notify(&mut self) {
382        self.options.ignore_missing_close_notify = true;
383    }
384
385    /// Uncleanly shut down the stream. This may cause errors on the peer side
386    /// when using TLS.
387    pub fn unclean_shutdown(self) -> Result<(), Self> {
388        match self.inner {
389            UpgradableStreamInner::BaseClient(..) => Ok(()),
390            UpgradableStreamInner::BaseServer(..) => Ok(()),
391            UpgradableStreamInner::BaseServerPreview(..) => Ok(()),
392            UpgradableStreamInner::Upgraded(upgraded, cfg) => {
393                if let Err(e) = D::unclean_shutdown(upgraded) {
394                    Err(Self {
395                        inner: UpgradableStreamInner::Upgraded(e, cfg),
396                        options: self.options,
397                    })
398                } else {
399                    Ok(())
400                }
401            }
402            UpgradableStreamInner::UpgradedPreview(upgraded, cfg) => {
403                let (stm, buf) = upgraded.into_inner();
404                if let Err(e) = D::unclean_shutdown(stm) {
405                    Err(Self {
406                        inner: UpgradableStreamInner::UpgradedPreview(
407                            RewindStream {
408                                buffer: buf,
409                                inner: e,
410                            },
411                            cfg,
412                        ),
413                        options: self.options,
414                    })
415                } else {
416                    Ok(())
417                }
418            }
419        }
420    }
421}
422
423impl<S: Stream, D: TlsDriver> StreamUpgrade for UpgradableStream<S, D> {
424    async fn secure_upgrade(self) -> Result<Self, SslError> {
425        let (upgraded, handshake) = match self.inner {
426            UpgradableStreamInner::BaseClient(base, config) => {
427                let Some(config) = config else {
428                    return Err(SslError::SslUnsupported);
429                };
430                D::upgrade_client(config, base).await?
431            }
432            UpgradableStreamInner::BaseServer(base, config) => {
433                let Some(config) = config else {
434                    return Err(SslError::SslUnsupported);
435                };
436                D::upgrade_server(config, base).await?
437            }
438            UpgradableStreamInner::BaseServerPreview(base, config) => {
439                let Some(config) = config else {
440                    return Err(SslError::SslUnsupported);
441                };
442                D::upgrade_server(config, base).await?
443            }
444            _ => {
445                return Err(SslError::SslAlreadyUpgraded);
446            }
447        };
448        Ok(Self {
449            inner: UpgradableStreamInner::Upgraded(upgraded, handshake),
450            options: self.options,
451        })
452    }
453
454    async fn secure_upgrade_preview(
455        self,
456        options: PreviewConfiguration,
457    ) -> Result<(Preview, Self), SslError> {
458        let (mut upgraded, handshake) = match self.inner {
459            UpgradableStreamInner::BaseClient(base, config) => {
460                let Some(config) = config else {
461                    return Err(SslError::SslUnsupported);
462                };
463                D::upgrade_client(config, base).await?
464            }
465            UpgradableStreamInner::BaseServer(base, config) => {
466                let Some(config) = config else {
467                    return Err(SslError::SslUnsupported);
468                };
469                D::upgrade_server(config, base).await?
470            }
471            UpgradableStreamInner::BaseServerPreview(base, config) => {
472                let Some(config) = config else {
473                    return Err(SslError::SslUnsupported);
474                };
475                D::upgrade_server(config, base).await?
476            }
477            _ => {
478                return Err(SslError::SslAlreadyUpgraded);
479            }
480        };
481        let mut buffer = smallvec::SmallVec::with_capacity(options.max_preview_bytes.get());
482        buffer.resize(options.max_preview_bytes.get(), 0);
483        upgraded.read_exact(&mut buffer).await?;
484        let mut rewind = RewindStream::new(upgraded);
485        rewind.rewind(&buffer);
486        Ok((
487            Preview { buffer },
488            Self {
489                inner: UpgradableStreamInner::UpgradedPreview(rewind, handshake),
490                options: self.options,
491            },
492        ))
493    }
494
495    fn handshake(&self) -> Option<&TlsHandshake> {
496        match &self.inner {
497            UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
498            _ => None,
499        }
500    }
501}
502
503#[cfg(feature = "tokio")]
504impl<S: Stream, D: TlsDriver> tokio::io::AsyncRead for UpgradableStream<S, D> {
505    #[inline(always)]
506    fn poll_read(
507        self: Pin<&mut Self>,
508        cx: &mut std::task::Context<'_>,
509        buf: &mut tokio::io::ReadBuf<'_>,
510    ) -> std::task::Poll<std::io::Result<()>> {
511        let ignore_missing_close_notify = self.options.ignore_missing_close_notify;
512        let inner = &mut self.get_mut().inner;
513        let res = match inner {
514            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_read(cx, buf),
515            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_read(cx, buf),
516            UpgradableStreamInner::BaseServerPreview(base, _) => Pin::new(base).poll_read(cx, buf),
517            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_read(cx, buf),
518            UpgradableStreamInner::UpgradedPreview(upgraded, _) => {
519                Pin::new(upgraded).poll_read(cx, buf)
520            }
521        };
522        if ignore_missing_close_notify
523            && matches!(res, std::task::Poll::Ready(Err(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof)
524        {
525            return std::task::Poll::Ready(Ok(()));
526        }
527        res
528    }
529}
530
531impl<S: Stream, D: TlsDriver> LocalAddress for UpgradableStream<S, D> {
532    fn local_address(&self) -> std::io::Result<ResolvedTarget> {
533        self.inner
534            .with_inner_metadata(|inner| inner.local_address())
535    }
536}
537
538impl<S: Stream, D: TlsDriver> RemoteAddress for UpgradableStream<S, D> {
539    fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
540        self.inner
541            .with_inner_metadata(|inner| inner.remote_address())
542    }
543}
544
545impl<S: Stream, D: TlsDriver> StreamMetadata for UpgradableStream<S, D> {
546    fn transport(&self) -> Transport {
547        self.inner.with_inner_metadata(|inner| inner.transport())
548    }
549}
550
551#[derive(
552    derive_more::Debug, derive_io::AsyncRead, derive_io::AsyncWrite, derive_io::AsSocketDescriptor,
553)]
554enum UpgradableStreamInner<S: Stream, D: TlsDriver> {
555    #[debug("BaseClient(..)")]
556    BaseClient(
557        #[read]
558        #[write]
559        #[descriptor]
560        S,
561        Option<D::ClientParams>,
562    ),
563    #[debug("BaseServer(..)")]
564    BaseServer(
565        #[read]
566        #[write]
567        #[descriptor]
568        S,
569        Option<TlsServerParameterProvider>,
570    ),
571    #[debug("Preview(..)")]
572    BaseServerPreview(
573        #[read]
574        #[write]
575        #[descriptor]
576        RewindStream<S>,
577        Option<TlsServerParameterProvider>,
578    ),
579    #[debug("Upgraded(..)")]
580    Upgraded(
581        #[read]
582        #[write]
583        #[descriptor]
584        D::Stream,
585        TlsHandshake,
586    ),
587    #[debug("Upgraded(..)")]
588    UpgradedPreview(
589        #[read]
590        #[write]
591        #[descriptor]
592        RewindStream<D::Stream>,
593        TlsHandshake,
594    ),
595}
596
597impl<S: Stream, D: TlsDriver> UpgradableStreamInner<S, D> {
598    #[inline(always)]
599    fn with_inner_metadata<T>(&self, f: impl FnOnce(&dyn StreamMetadata) -> T) -> T {
600        match self {
601            UpgradableStreamInner::BaseClient(base, _) => f(base),
602            UpgradableStreamInner::BaseServer(base, _) => f(base),
603            UpgradableStreamInner::BaseServerPreview(base, _) => f(base),
604            UpgradableStreamInner::Upgraded(upgraded, _) => f(upgraded),
605            UpgradableStreamInner::UpgradedPreview(upgraded, _) => f(upgraded),
606        }
607    }
608
609    #[inline(always)]
610    fn as_inner_handle(&self) -> &dyn AsHandle {
611        match self {
612            UpgradableStreamInner::BaseClient(base, _) => base,
613            UpgradableStreamInner::BaseServer(base, _) => base,
614            UpgradableStreamInner::BaseServerPreview(base, _) => base,
615            UpgradableStreamInner::Upgraded(upgraded, _) => upgraded,
616            UpgradableStreamInner::UpgradedPreview(upgraded, _) => upgraded,
617        }
618    }
619}
620
621impl<S: Stream, D: TlsDriver> AsHandle for UpgradableStream<S, D> {
622    #[cfg(windows)]
623    fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
624        self.inner.as_inner_handle().as_handle()
625    }
626
627    #[cfg(unix)]
628    fn as_fd(&self) -> std::os::fd::BorrowedFd {
629        self.inner.as_inner_handle().as_fd()
630    }
631}
632
633pub trait Rewindable {
634    fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()>;
635}
636
637/// A stream that can be rewound.
638#[derive(derive_io::AsyncWrite, derive_io::AsSocketDescriptor)]
639pub(crate) struct RewindStream<S> {
640    buffer: Vec<u8>,
641    #[write]
642    #[descriptor]
643    inner: S,
644}
645
646impl<S> RewindStream<S> {
647    pub fn new(inner: S) -> Self {
648        RewindStream {
649            buffer: Vec::new(),
650            inner,
651        }
652    }
653
654    pub fn rewind(&mut self, data: &[u8]) {
655        self.buffer.extend_from_slice(data);
656    }
657
658    pub fn into_inner(self) -> (S, Vec<u8>) {
659        (self.inner, self.buffer)
660    }
661}
662
663#[cfg(feature = "tokio")]
664impl<S: AsyncRead + Unpin> AsyncRead for RewindStream<S> {
665    #[inline(always)]
666    fn poll_read(
667        mut self: Pin<&mut Self>,
668        cx: &mut Context<'_>,
669        buf: &mut ReadBuf<'_>,
670    ) -> Poll<std::io::Result<()>> {
671        if !self.buffer.is_empty() {
672            let to_read = std::cmp::min(buf.remaining(), self.buffer.len());
673            let data = self.buffer.drain(..to_read).collect::<Vec<_>>();
674            buf.put_slice(&data);
675            Poll::Ready(Ok(()))
676        } else {
677            Pin::new(&mut self.inner).poll_read(cx, buf)
678        }
679    }
680}
681
682impl<S: Stream> Rewindable for RewindStream<S> {
683    fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
684        self.rewind(bytes);
685        Ok(())
686    }
687}
688
689impl<S: LocalAddress> LocalAddress for RewindStream<S> {
690    fn local_address(&self) -> std::io::Result<ResolvedTarget> {
691        self.inner.local_address()
692    }
693}
694
695impl<S: RemoteAddress> RemoteAddress for RewindStream<S> {
696    fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
697        self.inner.remote_address()
698    }
699}
700
701impl<S: PeerCred> PeerCred for RewindStream<S> {
702    #[cfg(all(unix, feature = "tokio"))]
703    fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
704        self.inner.peer_cred()
705    }
706}
707
708impl<S: StreamMetadata> StreamMetadata for RewindStream<S> {
709    fn transport(&self) -> Transport {
710        self.inner.transport()
711    }
712}
713
714impl<S: PeekableStream> PeekableStream for RewindStream<S> {
715    fn poll_peek(
716        mut self: Pin<&mut Self>,
717        cx: &mut Context<'_>,
718        buf: &mut ReadBuf<'_>,
719    ) -> Poll<std::io::Result<usize>> {
720        if !self.buffer.is_empty() {
721            let to_read = std::cmp::min(buf.remaining(), self.buffer.len());
722            buf.put_slice(&self.buffer[..to_read]);
723            Poll::Ready(Ok(to_read))
724        } else {
725            Pin::new(&mut self.inner).poll_peek(cx, buf)
726        }
727    }
728}
729
730impl<S: Stream + AsHandle> AsHandle for RewindStream<S> {
731    #[cfg(windows)]
732    fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
733        self.inner.as_handle()
734    }
735
736    #[cfg(unix)]
737    fn as_fd(&self) -> std::os::fd::BorrowedFd {
738        self.inner.as_fd()
739    }
740}
741
742impl<S: Stream + Rewindable, D: TlsDriver> Rewindable for UpgradableStream<S, D>
743where
744    D::Stream: Rewindable,
745{
746    fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
747        match &mut self.inner {
748            UpgradableStreamInner::BaseClient(stm, _) => stm.rewind(bytes),
749            UpgradableStreamInner::BaseServer(stm, _) => stm.rewind(bytes),
750            UpgradableStreamInner::BaseServerPreview(stm, _) => {
751                stm.rewind(bytes);
752                Ok(())
753            }
754            UpgradableStreamInner::Upgraded(stm, _) => stm.rewind(bytes),
755            UpgradableStreamInner::UpgradedPreview(stm, _) => {
756                stm.rewind(bytes);
757                Ok(())
758            }
759        }
760    }
761}
762
763impl<S: PeekableStream, D: TlsDriver> PeekableStream for UpgradableStream<S, D>
764where
765    D::Stream: PeekableStream,
766{
767    #[cfg(feature = "tokio")]
768    fn poll_peek(
769        self: Pin<&mut Self>,
770        cx: &mut Context<'_>,
771        buf: &mut tokio::io::ReadBuf,
772    ) -> Poll<std::io::Result<usize>> {
773        match &mut self.get_mut().inner {
774            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_peek(cx, buf),
775            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_peek(cx, buf),
776            UpgradableStreamInner::BaseServerPreview(base, _) => Pin::new(base).poll_peek(cx, buf),
777            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_peek(cx, buf),
778            UpgradableStreamInner::UpgradedPreview(upgraded, _) => {
779                Pin::new(upgraded).poll_peek(cx, buf)
780            }
781        }
782    }
783}
784
785impl<S: PeerCred + Stream, D: TlsDriver> PeerCred for UpgradableStream<S, D> {
786    #[cfg(all(unix, feature = "tokio"))]
787    fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
788        self.inner.with_inner_metadata(|inner| inner.peer_cred())
789    }
790}