gel_stream/common/
stream.rs

1#[cfg(feature = "tokio")]
2use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
3
4#[cfg(feature = "tokio")]
5use std::{
6    any::Any,
7    pin::Pin,
8    task::{Context, Poll},
9};
10use std::{future::Future, num::NonZeroUsize, ops::Deref};
11
12use crate::{
13    LocalAddress, PeerCred, RemoteAddress, ResolvedTarget, Ssl, SslError, StreamMetadata,
14    TlsDriver, TlsHandshake, TlsServerParameterProvider, Transport, DEFAULT_PREVIEW_BUFFER_SIZE,
15};
16
17/// A trait for streams that can be converted to a handle or file descriptor.
18#[cfg(unix)]
19pub trait AsHandle {
20    fn as_fd(&self) -> std::os::fd::BorrowedFd;
21}
22
23/// A trait for streams that can be converted to a handle or file descriptor.
24#[cfg(windows)]
25pub trait AsHandle {
26    fn as_handle(&self) -> std::os::windows::io::BorrowedSocket;
27}
28
29/// A convenience trait for streams from this crate.
30#[cfg(feature = "tokio")]
31pub trait Stream:
32    tokio::io::AsyncRead + tokio::io::AsyncWrite + StreamMetadata + Send + Unpin + AsHandle + 'static
33{
34    /// Attempt to downcast a generic stream to a specific stream type.
35    fn downcast<S: Stream + 'static>(self) -> Result<S, Self>
36    where
37        Self: Sized + 'static,
38    {
39        let mut holder = Some(self);
40        let stream = &mut holder as &mut dyn Any;
41        if let Some(stream) = stream.downcast_mut::<Option<S>>() {
42            return Ok(stream.take().unwrap());
43        }
44        if let Some(stream) = stream.downcast_mut::<Option<Box<S>>>() {
45            return Ok(*stream.take().unwrap());
46        }
47        Err(holder.take().unwrap())
48    }
49
50    /// Box the stream as a `Box<dyn Stream + Send>`.
51    fn boxed(self) -> Box<dyn Stream + Send>
52    where
53        Self: Sized + 'static,
54    {
55        let mut holder = Some(self);
56        let stream = &mut holder as &mut dyn Any;
57        if let Some(stream) = stream.downcast_mut::<Option<Box<dyn Stream>>>() {
58            stream.take().unwrap()
59        } else {
60            Box::new(holder.take().unwrap())
61        }
62    }
63}
64
65#[cfg(feature = "tokio")]
66impl<T> Stream for T where
67    T: tokio::io::AsyncRead
68        + tokio::io::AsyncWrite
69        + StreamMetadata
70        + AsHandle
71        + Unpin
72        + Send
73        + 'static
74{
75}
76
77// NOTE: Once we're on Rust 1.87, we can use trait upcasting and get rid of this impl.
78impl PeerCred for Box<dyn Stream + Send> {
79    #[cfg(all(unix, feature = "tokio"))]
80    fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
81        self.as_ref().peer_cred()
82    }
83}
84
85impl LocalAddress for Box<dyn Stream + Send> {
86    fn local_address(&self) -> std::io::Result<ResolvedTarget> {
87        self.as_ref().local_address()
88    }
89}
90
91impl RemoteAddress for Box<dyn Stream + Send> {
92    fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
93        self.as_ref().remote_address()
94    }
95}
96
97impl StreamMetadata for Box<dyn Stream + Send> {
98    fn transport(&self) -> Transport {
99        self.as_ref().transport()
100    }
101}
102
103#[cfg(not(feature = "tokio"))]
104pub trait Stream: StreamMetadata + Unpin + AsHandle + 'static {}
105#[cfg(not(feature = "tokio"))]
106impl<S: Stream, D: TlsDriver> Stream for UpgradableStream<S, D> {}
107#[cfg(not(feature = "tokio"))]
108impl Stream for () {}
109
110/// A trait for streams that can be upgraded to a TLS stream.
111pub trait StreamUpgrade: Stream + Sized {
112    /// Upgrade the stream to a TLS stream.
113    fn secure_upgrade(self) -> impl Future<Output = Result<Self, SslError>> + Send;
114    /// Upgrade the stream to a TLS stream, and preview the initial bytes.
115    fn secure_upgrade_preview(
116        self,
117        options: PreviewConfiguration,
118    ) -> impl Future<Output = Result<(Preview, Self), SslError>> + Send;
119    /// Get the TLS handshake information, if the stream is upgraded.
120    fn handshake(&self) -> Option<&TlsHandshake>;
121}
122
123#[cfg(feature = "optimization")]
124#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
125pub enum StreamOptimization {
126    #[default]
127    /// Optimize for general use.
128    General,
129    /// Optimize for interactive use with low latency.
130    Interactive,
131    /// Optimize for bulk streaming.
132    BulkStreaming(BulkStreamDirection),
133}
134
135#[cfg(feature = "optimization")]
136#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
137pub enum BulkStreamDirection {
138    Send,
139    Receive,
140    #[default]
141    Both,
142}
143
144/// A trait for streams that can provide a `socket2::SockRef`.
145#[cfg(any(feature = "optimization", feature = "keepalive"))]
146fn with_socket2<S: AsHandle + Sized>(
147    stream: &S,
148    f: &mut dyn for<'a> FnMut(socket2::SockRef<'a>) -> Result<(), std::io::Error>,
149) -> Result<(), std::io::Error> {
150    #[cfg(unix)]
151    let res = f(socket2::SockRef::from(&stream.as_fd()));
152    #[cfg(windows)]
153    let res = f(socket2::SockRef::from(&stream.as_handle()));
154    res
155}
156
157#[cfg(feature = "optimization")]
158pub trait StreamOptimizationExt: Stream + Sized {
159    /// Optimize the stream for the given optimization.
160    #[cfg(feature = "optimization")]
161    fn optimize_for(&mut self, optimization: StreamOptimization) -> Result<(), std::io::Error> {
162        macro_rules! try_optimize(
163            ( $s:ident . $method:ident ( $($args:tt)* ) ) => {
164                $s.$method($($args)*).map_err(|e: std::io::Error| std::io::Error::new(e.kind(), format!("{}: {}", stringify!($method), e)))
165            };
166        );
167
168        #[cfg(unix)]
169        if self.transport() == Transport::Unix {
170            return Ok(());
171        }
172
173        let mut with_socket2_fn = move |s: socket2::SockRef<'_>| {
174            #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
175            try_optimize!(s.set_cork(false))?;
176
177            #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
178            try_optimize!(s.set_quickack(false))?;
179
180            #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
181            try_optimize!(s.set_thin_linear_timeouts(false))?;
182
183            try_optimize!(s.set_send_buffer_size(256 * 1024))?;
184            try_optimize!(s.set_recv_buffer_size(256 * 1024))?;
185
186            match optimization {
187                StreamOptimization::General => {
188                    try_optimize!(s.set_nodelay(false))?;
189                    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
190                    try_optimize!(s.set_thin_linear_timeouts(true))?;
191                }
192                StreamOptimization::Interactive => {
193                    try_optimize!(s.set_nodelay(true))?;
194                    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
195                    try_optimize!(s.set_thin_linear_timeouts(true))?;
196                }
197                StreamOptimization::BulkStreaming(direction) => {
198                    try_optimize!(s.set_nodelay(false))?;
199                    // Handle send buffer size
200                    match direction {
201                        BulkStreamDirection::Send | BulkStreamDirection::Both => {
202                            try_optimize!(s.set_send_buffer_size(16 * 1024 * 1024))?;
203                            #[cfg(any(
204                                target_os = "android",
205                                target_os = "fuchsia",
206                                target_os = "linux"
207                            ))]
208                            try_optimize!(s.set_cork(true))?;
209                        }
210                        BulkStreamDirection::Receive => {}
211                    }
212
213                    // Handle receive buffer size
214                    match direction {
215                        BulkStreamDirection::Receive | BulkStreamDirection::Both => {
216                            try_optimize!(s.set_recv_buffer_size(16 * 1024 * 1024))?;
217                            #[cfg(any(
218                                target_os = "android",
219                                target_os = "fuchsia",
220                                target_os = "linux"
221                            ))]
222                            try_optimize!(s.set_quickack(true))?;
223                        }
224                        BulkStreamDirection::Send => {}
225                    }
226                }
227            }
228            Ok(())
229        };
230
231        with_socket2(self, &mut with_socket2_fn)
232    }
233}
234
235#[cfg(feature = "optimization")]
236impl<S: Stream + Sized> StreamOptimizationExt for S {}
237
238/// A trait for streams that can be peeked asynchronously.
239pub trait PeekableStream: Stream {
240    #[cfg(feature = "tokio")]
241    fn poll_peek(
242        self: Pin<&mut Self>,
243        cx: &mut Context<'_>,
244        buf: &mut tokio::io::ReadBuf,
245    ) -> Poll<std::io::Result<usize>>;
246    #[cfg(feature = "tokio")]
247    fn peek(self: Pin<&mut Self>, buf: &mut [u8]) -> impl Future<Output = std::io::Result<usize>> {
248        async {
249            let mut this = self;
250            std::future::poll_fn(move |cx| this.as_mut().poll_peek(cx, &mut ReadBuf::new(buf)))
251                .await
252        }
253    }
254}
255
256/// A preview of the initial bytes of the stream.
257#[derive(Debug, Clone, PartialEq, Eq)]
258#[must_use]
259pub struct Preview {
260    buffer: smallvec::SmallVec<[u8; DEFAULT_PREVIEW_BUFFER_SIZE as usize]>,
261}
262
263impl Preview {
264    pub(crate) fn new(
265        buffer: smallvec::SmallVec<[u8; DEFAULT_PREVIEW_BUFFER_SIZE as usize]>,
266    ) -> Self {
267        Self { buffer }
268    }
269}
270
271impl Deref for Preview {
272    type Target = [u8];
273    fn deref(&self) -> &Self::Target {
274        &self.buffer
275    }
276}
277
278impl AsRef<[u8]> for Preview {
279    fn as_ref(&self) -> &[u8] {
280        &self.buffer
281    }
282}
283
284impl<const N: usize> PartialEq<[u8; N]> for Preview {
285    fn eq(&self, other: &[u8; N]) -> bool {
286        self.buffer.as_slice() == other
287    }
288}
289
290impl<const N: usize> PartialEq<&[u8; N]> for Preview {
291    fn eq(&self, other: &&[u8; N]) -> bool {
292        self.buffer.as_slice() == *other
293    }
294}
295
296impl PartialEq<[u8]> for Preview {
297    fn eq(&self, other: &[u8]) -> bool {
298        self.buffer.as_slice() == other
299    }
300}
301
302/// Configuration for the initial preview of the client connection.
303#[derive(Debug, Clone, Copy)]
304pub struct PreviewConfiguration {
305    /// The maximum number of bytes to preview. Recommended value is 8 bytes.
306    pub max_preview_bytes: NonZeroUsize,
307    /// The maximum duration to preview for. Recommended value is 10 seconds.
308    pub max_preview_duration: std::time::Duration,
309}
310
311impl Default for PreviewConfiguration {
312    fn default() -> Self {
313        Self {
314            max_preview_bytes: NonZeroUsize::new(DEFAULT_PREVIEW_BUFFER_SIZE as usize).unwrap(),
315            max_preview_duration: std::time::Duration::from_secs(10),
316        }
317    }
318}
319
320#[derive(Default, Debug)]
321struct UpgradableStreamOptions {
322    ignore_missing_close_notify: bool,
323}
324
325#[allow(private_bounds)]
326#[derive(derive_more::Debug, derive_io::AsyncWrite, derive_io::AsSocketDescriptor)]
327pub struct UpgradableStream<S: Stream, D: TlsDriver = Ssl> {
328    #[write]
329    #[descriptor]
330    inner: UpgradableStreamInner<S, D>,
331    options: UpgradableStreamOptions,
332}
333
334#[allow(private_bounds)]
335impl<S: Stream, D: TlsDriver> UpgradableStream<S, D> {
336    #[inline(always)]
337    pub(crate) fn new_client(base: S, config: Option<D::ClientParams>) -> Self {
338        UpgradableStream {
339            inner: UpgradableStreamInner::BaseClient(base, config),
340            options: Default::default(),
341        }
342    }
343
344    #[inline(always)]
345    pub(crate) fn new_server(base: S, config: Option<TlsServerParameterProvider>) -> Self {
346        UpgradableStream {
347            inner: UpgradableStreamInner::BaseServer(base, config),
348            options: Default::default(),
349        }
350    }
351
352    #[inline(always)]
353    pub(crate) fn new_server_preview(
354        base: RewindStream<S>,
355        config: Option<TlsServerParameterProvider>,
356    ) -> Self {
357        UpgradableStream {
358            inner: UpgradableStreamInner::BaseServerPreview(base, config),
359            options: Default::default(),
360        }
361    }
362
363    /// Consume the `UpgradableStream` and return the underlying stream as a [`Box<dyn Stream>`].
364    pub fn into_boxed(self) -> Result<Box<dyn Stream>, Self> {
365        match self.inner {
366            UpgradableStreamInner::BaseClient(base, _) => Ok(Box::new(base)),
367            UpgradableStreamInner::BaseServer(base, _) => Ok(Box::new(base)),
368            UpgradableStreamInner::BaseServerPreview(base, _) => Ok(Box::new(base)),
369            UpgradableStreamInner::Upgraded(upgraded, _) => Ok(Box::new(upgraded)),
370            UpgradableStreamInner::UpgradedPreview(upgraded, _) => Ok(Box::new(upgraded)),
371        }
372    }
373
374    pub fn handshake(&self) -> Option<&TlsHandshake> {
375        match &self.inner {
376            UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
377            _ => None,
378        }
379    }
380
381    pub fn ignore_missing_close_notify(&mut self) {
382        self.options.ignore_missing_close_notify = true;
383    }
384
385    /// Uncleanly shut down the stream. This may cause errors on the peer side
386    /// when using TLS.
387    pub fn unclean_shutdown(self) -> Result<(), Self> {
388        match self.inner {
389            UpgradableStreamInner::BaseClient(..) => Ok(()),
390            UpgradableStreamInner::BaseServer(..) => Ok(()),
391            UpgradableStreamInner::BaseServerPreview(..) => Ok(()),
392            UpgradableStreamInner::Upgraded(upgraded, cfg) => {
393                if let Err(e) = D::unclean_shutdown(upgraded) {
394                    Err(Self {
395                        inner: UpgradableStreamInner::Upgraded(e, cfg),
396                        options: self.options,
397                    })
398                } else {
399                    Ok(())
400                }
401            }
402            UpgradableStreamInner::UpgradedPreview(upgraded, cfg) => {
403                let (stm, buf) = upgraded.into_inner();
404                if let Err(e) = D::unclean_shutdown(stm) {
405                    Err(Self {
406                        inner: UpgradableStreamInner::UpgradedPreview(
407                            RewindStream {
408                                buffer: buf,
409                                inner: e,
410                            },
411                            cfg,
412                        ),
413                        options: self.options,
414                    })
415                } else {
416                    Ok(())
417                }
418            }
419        }
420    }
421}
422
423impl<S: Stream, D: TlsDriver> StreamUpgrade for UpgradableStream<S, D> {
424    fn secure_upgrade(self) -> impl Future<Output = Result<Self, SslError>> + Send {
425        async move {
426            let (upgraded, handshake) = match self.inner {
427                UpgradableStreamInner::BaseClient(base, config) => {
428                    let Some(config) = config else {
429                        return Err(SslError::SslUnsupported);
430                    };
431                    D::upgrade_client(config, base).await?
432                }
433                UpgradableStreamInner::BaseServer(base, config) => {
434                    let Some(config) = config else {
435                        return Err(SslError::SslUnsupported);
436                    };
437                    D::upgrade_server(config, base).await?
438                }
439                UpgradableStreamInner::BaseServerPreview(base, config) => {
440                    let Some(config) = config else {
441                        return Err(SslError::SslUnsupported);
442                    };
443                    D::upgrade_server(config, base).await?
444                }
445                _ => {
446                    return Err(SslError::SslAlreadyUpgraded);
447                }
448            };
449            Ok(Self {
450                inner: UpgradableStreamInner::Upgraded(upgraded, handshake),
451                options: self.options,
452            })
453        }
454    }
455
456    fn secure_upgrade_preview(
457        self,
458        options: PreviewConfiguration,
459    ) -> impl Future<Output = Result<(Preview, Self), SslError>> + Send {
460        async move {
461            let (mut upgraded, handshake) = match self.inner {
462                UpgradableStreamInner::BaseClient(base, config) => {
463                    let Some(config) = config else {
464                        return Err(SslError::SslUnsupported);
465                    };
466                    D::upgrade_client(config, base).await?
467                }
468                UpgradableStreamInner::BaseServer(base, config) => {
469                    let Some(config) = config else {
470                        return Err(SslError::SslUnsupported);
471                    };
472                    D::upgrade_server(config, base).await?
473                }
474                UpgradableStreamInner::BaseServerPreview(base, config) => {
475                    let Some(config) = config else {
476                        return Err(SslError::SslUnsupported);
477                    };
478                    D::upgrade_server(config, base).await?
479                }
480                _ => {
481                    return Err(SslError::SslAlreadyUpgraded);
482                }
483            };
484            let mut buffer = smallvec::SmallVec::with_capacity(options.max_preview_bytes.get());
485            buffer.resize(options.max_preview_bytes.get(), 0);
486            upgraded.read_exact(&mut buffer).await?;
487            let mut rewind = RewindStream::new(upgraded);
488            rewind.rewind(&buffer);
489            Ok((
490                Preview { buffer },
491                Self {
492                    inner: UpgradableStreamInner::UpgradedPreview(rewind, handshake),
493                    options: self.options,
494                },
495            ))
496        }
497    }
498
499    fn handshake(&self) -> Option<&TlsHandshake> {
500        match &self.inner {
501            UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
502            _ => None,
503        }
504    }
505}
506
507#[cfg(feature = "tokio")]
508impl<S: Stream, D: TlsDriver> tokio::io::AsyncRead for UpgradableStream<S, D> {
509    #[inline(always)]
510    fn poll_read(
511        self: Pin<&mut Self>,
512        cx: &mut std::task::Context<'_>,
513        buf: &mut tokio::io::ReadBuf<'_>,
514    ) -> std::task::Poll<std::io::Result<()>> {
515        let ignore_missing_close_notify = self.options.ignore_missing_close_notify;
516        let inner = &mut self.get_mut().inner;
517        let res = match inner {
518            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_read(cx, buf),
519            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_read(cx, buf),
520            UpgradableStreamInner::BaseServerPreview(base, _) => Pin::new(base).poll_read(cx, buf),
521            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_read(cx, buf),
522            UpgradableStreamInner::UpgradedPreview(upgraded, _) => {
523                Pin::new(upgraded).poll_read(cx, buf)
524            }
525        };
526        if ignore_missing_close_notify {
527            if matches!(res, std::task::Poll::Ready(Err(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof)
528            {
529                return std::task::Poll::Ready(Ok(()));
530            }
531        }
532        res
533    }
534}
535
536impl<S: Stream, D: TlsDriver> LocalAddress for UpgradableStream<S, D> {
537    fn local_address(&self) -> std::io::Result<ResolvedTarget> {
538        self.inner
539            .with_inner_metadata(|inner| inner.local_address())
540    }
541}
542
543impl<S: Stream, D: TlsDriver> RemoteAddress for UpgradableStream<S, D> {
544    fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
545        self.inner
546            .with_inner_metadata(|inner| inner.remote_address())
547    }
548}
549
550impl<S: Stream, D: TlsDriver> StreamMetadata for UpgradableStream<S, D> {
551    fn transport(&self) -> Transport {
552        self.inner.with_inner_metadata(|inner| inner.transport())
553    }
554}
555
556#[derive(
557    derive_more::Debug, derive_io::AsyncRead, derive_io::AsyncWrite, derive_io::AsSocketDescriptor,
558)]
559enum UpgradableStreamInner<S: Stream, D: TlsDriver> {
560    #[debug("BaseClient(..)")]
561    BaseClient(
562        #[read]
563        #[write]
564        #[descriptor]
565        S,
566        Option<D::ClientParams>,
567    ),
568    #[debug("BaseServer(..)")]
569    BaseServer(
570        #[read]
571        #[write]
572        #[descriptor]
573        S,
574        Option<TlsServerParameterProvider>,
575    ),
576    #[debug("Preview(..)")]
577    BaseServerPreview(
578        #[read]
579        #[write]
580        #[descriptor]
581        RewindStream<S>,
582        Option<TlsServerParameterProvider>,
583    ),
584    #[debug("Upgraded(..)")]
585    Upgraded(
586        #[read]
587        #[write]
588        #[descriptor]
589        D::Stream,
590        TlsHandshake,
591    ),
592    #[debug("Upgraded(..)")]
593    UpgradedPreview(
594        #[read]
595        #[write]
596        #[descriptor]
597        RewindStream<D::Stream>,
598        TlsHandshake,
599    ),
600}
601
602impl<S: Stream, D: TlsDriver> UpgradableStreamInner<S, D> {
603    #[inline(always)]
604    fn with_inner_metadata<T>(&self, f: impl FnOnce(&dyn StreamMetadata) -> T) -> T {
605        match self {
606            UpgradableStreamInner::BaseClient(base, _) => f(base),
607            UpgradableStreamInner::BaseServer(base, _) => f(base),
608            UpgradableStreamInner::BaseServerPreview(base, _) => f(base),
609            UpgradableStreamInner::Upgraded(upgraded, _) => f(upgraded),
610            UpgradableStreamInner::UpgradedPreview(upgraded, _) => f(upgraded),
611        }
612    }
613
614    #[inline(always)]
615    fn as_inner_handle(&self) -> &dyn AsHandle {
616        match self {
617            UpgradableStreamInner::BaseClient(base, _) => base,
618            UpgradableStreamInner::BaseServer(base, _) => base,
619            UpgradableStreamInner::BaseServerPreview(base, _) => base,
620            UpgradableStreamInner::Upgraded(upgraded, _) => upgraded,
621            UpgradableStreamInner::UpgradedPreview(upgraded, _) => upgraded,
622        }
623    }
624}
625
626impl<S: Stream, D: TlsDriver> AsHandle for UpgradableStream<S, D> {
627    #[cfg(windows)]
628    fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
629        self.inner.as_inner_handle().as_handle()
630    }
631
632    #[cfg(unix)]
633    fn as_fd(&self) -> std::os::fd::BorrowedFd {
634        self.inner.as_inner_handle().as_fd()
635    }
636}
637
638pub trait Rewindable {
639    fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()>;
640}
641
642/// A stream that can be rewound.
643#[derive(derive_io::AsyncWrite, derive_io::AsSocketDescriptor)]
644pub(crate) struct RewindStream<S> {
645    buffer: Vec<u8>,
646    #[write]
647    #[descriptor]
648    inner: S,
649}
650
651impl<S> RewindStream<S> {
652    pub fn new(inner: S) -> Self {
653        RewindStream {
654            buffer: Vec::new(),
655            inner,
656        }
657    }
658
659    pub fn rewind(&mut self, data: &[u8]) {
660        self.buffer.extend_from_slice(data);
661    }
662
663    pub fn into_inner(self) -> (S, Vec<u8>) {
664        (self.inner, self.buffer)
665    }
666}
667
668#[cfg(feature = "tokio")]
669impl<S: AsyncRead + Unpin> AsyncRead for RewindStream<S> {
670    #[inline(always)]
671    fn poll_read(
672        mut self: Pin<&mut Self>,
673        cx: &mut Context<'_>,
674        buf: &mut ReadBuf<'_>,
675    ) -> Poll<std::io::Result<()>> {
676        if !self.buffer.is_empty() {
677            let to_read = std::cmp::min(buf.remaining(), self.buffer.len());
678            let data = self.buffer.drain(..to_read).collect::<Vec<_>>();
679            buf.put_slice(&data);
680            Poll::Ready(Ok(()))
681        } else {
682            Pin::new(&mut self.inner).poll_read(cx, buf)
683        }
684    }
685}
686
687impl<S: Stream> Rewindable for RewindStream<S> {
688    fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
689        self.rewind(bytes);
690        Ok(())
691    }
692}
693
694impl<S: LocalAddress> LocalAddress for RewindStream<S> {
695    fn local_address(&self) -> std::io::Result<ResolvedTarget> {
696        self.inner.local_address()
697    }
698}
699
700impl<S: RemoteAddress> RemoteAddress for RewindStream<S> {
701    fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
702        self.inner.remote_address()
703    }
704}
705
706impl<S: PeerCred> PeerCred for RewindStream<S> {
707    #[cfg(all(unix, feature = "tokio"))]
708    fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
709        self.inner.peer_cred()
710    }
711}
712
713impl<S: StreamMetadata> StreamMetadata for RewindStream<S> {
714    fn transport(&self) -> Transport {
715        self.inner.transport()
716    }
717}
718
719impl<S: PeekableStream> PeekableStream for RewindStream<S> {
720    fn poll_peek(
721        mut self: Pin<&mut Self>,
722        cx: &mut Context<'_>,
723        buf: &mut ReadBuf<'_>,
724    ) -> Poll<std::io::Result<usize>> {
725        if !self.buffer.is_empty() {
726            let to_read = std::cmp::min(buf.remaining(), self.buffer.len());
727            buf.put_slice(&self.buffer[..to_read]);
728            Poll::Ready(Ok(to_read))
729        } else {
730            Pin::new(&mut self.inner).poll_peek(cx, buf)
731        }
732    }
733}
734
735impl<S: Stream + AsHandle> AsHandle for RewindStream<S> {
736    #[cfg(windows)]
737    fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
738        self.inner.as_handle()
739    }
740
741    #[cfg(unix)]
742    fn as_fd(&self) -> std::os::fd::BorrowedFd {
743        self.inner.as_fd()
744    }
745}
746
747impl<S: Stream + Rewindable, D: TlsDriver> Rewindable for UpgradableStream<S, D>
748where
749    D::Stream: Rewindable,
750{
751    fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
752        match &mut self.inner {
753            UpgradableStreamInner::BaseClient(stm, _) => stm.rewind(bytes),
754            UpgradableStreamInner::BaseServer(stm, _) => stm.rewind(bytes),
755            UpgradableStreamInner::BaseServerPreview(stm, _) => Ok(stm.rewind(bytes)),
756            UpgradableStreamInner::Upgraded(stm, _) => stm.rewind(bytes),
757            UpgradableStreamInner::UpgradedPreview(stm, _) => Ok(stm.rewind(bytes)),
758        }
759    }
760}
761
762impl<S: PeekableStream, D: TlsDriver> PeekableStream for UpgradableStream<S, D>
763where
764    D::Stream: PeekableStream,
765{
766    #[cfg(feature = "tokio")]
767    fn poll_peek(
768        self: Pin<&mut Self>,
769        cx: &mut Context<'_>,
770        buf: &mut tokio::io::ReadBuf,
771    ) -> Poll<std::io::Result<usize>> {
772        match &mut self.get_mut().inner {
773            UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_peek(cx, buf),
774            UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_peek(cx, buf),
775            UpgradableStreamInner::BaseServerPreview(base, _) => Pin::new(base).poll_peek(cx, buf),
776            UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_peek(cx, buf),
777            UpgradableStreamInner::UpgradedPreview(upgraded, _) => {
778                Pin::new(upgraded).poll_peek(cx, buf)
779            }
780        }
781    }
782}
783
784impl<S: PeerCred + Stream, D: TlsDriver> PeerCred for UpgradableStream<S, D> {
785    #[cfg(all(unix, feature = "tokio"))]
786    fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
787        self.inner.with_inner_metadata(|inner| inner.peer_cred())
788    }
789}