1#[cfg(feature = "tokio")]
2use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
3
4#[cfg(feature = "tokio")]
5use std::{
6 any::Any,
7 pin::Pin,
8 task::{Context, Poll},
9};
10use std::{future::Future, num::NonZeroUsize, ops::Deref};
11
12use crate::{
13 LocalAddress, PeerCred, RemoteAddress, ResolvedTarget, Ssl, SslError, StreamMetadata,
14 TlsDriver, TlsHandshake, TlsServerParameterProvider, Transport, DEFAULT_PREVIEW_BUFFER_SIZE,
15};
16
17#[cfg(unix)]
19pub trait AsHandle {
20 fn as_fd(&self) -> std::os::fd::BorrowedFd;
21}
22
23#[cfg(windows)]
25pub trait AsHandle {
26 fn as_handle(&self) -> std::os::windows::io::BorrowedSocket;
27}
28
29#[cfg(feature = "tokio")]
31pub trait Stream:
32 tokio::io::AsyncRead + tokio::io::AsyncWrite + StreamMetadata + Send + Unpin + AsHandle + 'static
33{
34 fn downcast<S: Stream + 'static>(self) -> Result<S, Self>
36 where
37 Self: Sized + 'static,
38 {
39 let mut holder = Some(self);
40 let stream = &mut holder as &mut dyn Any;
41 if let Some(stream) = stream.downcast_mut::<Option<S>>() {
42 return Ok(stream.take().unwrap());
43 }
44 if let Some(stream) = stream.downcast_mut::<Option<Box<S>>>() {
45 return Ok(*stream.take().unwrap());
46 }
47 Err(holder.take().unwrap())
48 }
49
50 fn boxed(self) -> Box<dyn Stream + Send>
52 where
53 Self: Sized + 'static,
54 {
55 let mut holder = Some(self);
56 let stream = &mut holder as &mut dyn Any;
57 if let Some(stream) = stream.downcast_mut::<Option<Box<dyn Stream>>>() {
58 stream.take().unwrap()
59 } else {
60 Box::new(holder.take().unwrap())
61 }
62 }
63}
64
65#[cfg(feature = "tokio")]
66impl<T> Stream for T where
67 T: tokio::io::AsyncRead
68 + tokio::io::AsyncWrite
69 + StreamMetadata
70 + AsHandle
71 + Unpin
72 + Send
73 + 'static
74{
75}
76
77impl PeerCred for Box<dyn Stream + Send> {
79 #[cfg(all(unix, feature = "tokio"))]
80 fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
81 self.as_ref().peer_cred()
82 }
83}
84
85impl LocalAddress for Box<dyn Stream + Send> {
86 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
87 self.as_ref().local_address()
88 }
89}
90
91impl RemoteAddress for Box<dyn Stream + Send> {
92 fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
93 self.as_ref().remote_address()
94 }
95}
96
97impl StreamMetadata for Box<dyn Stream + Send> {
98 fn transport(&self) -> Transport {
99 self.as_ref().transport()
100 }
101}
102
103#[cfg(not(feature = "tokio"))]
104pub trait Stream: StreamMetadata + Unpin + AsHandle + 'static {}
105#[cfg(not(feature = "tokio"))]
106impl<S: Stream, D: TlsDriver> Stream for UpgradableStream<S, D> {}
107#[cfg(not(feature = "tokio"))]
108impl Stream for () {}
109
110pub trait StreamUpgrade: Stream + Sized {
112 fn secure_upgrade(self) -> impl Future<Output = Result<Self, SslError>> + Send;
114 fn secure_upgrade_preview(
116 self,
117 options: PreviewConfiguration,
118 ) -> impl Future<Output = Result<(Preview, Self), SslError>> + Send;
119 fn handshake(&self) -> Option<&TlsHandshake>;
121}
122
123#[cfg(feature = "optimization")]
124#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
125pub enum StreamOptimization {
126 #[default]
127 General,
129 Interactive,
131 BulkStreaming(BulkStreamDirection),
133}
134
135#[cfg(feature = "optimization")]
136#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
137pub enum BulkStreamDirection {
138 Send,
139 Receive,
140 #[default]
141 Both,
142}
143
144#[cfg(any(feature = "optimization", feature = "keepalive"))]
146fn with_socket2<S: AsHandle + Sized>(
147 stream: &S,
148 f: &mut dyn for<'a> FnMut(socket2::SockRef<'a>) -> Result<(), std::io::Error>,
149) -> Result<(), std::io::Error> {
150 #[cfg(unix)]
151 let res = f(socket2::SockRef::from(&stream.as_fd()));
152 #[cfg(windows)]
153 let res = f(socket2::SockRef::from(&stream.as_handle()));
154 res
155}
156
157#[cfg(feature = "optimization")]
158pub trait StreamOptimizationExt: Stream + Sized {
159 #[cfg(feature = "optimization")]
161 fn optimize_for(&mut self, optimization: StreamOptimization) -> Result<(), std::io::Error> {
162 macro_rules! try_optimize(
163 ( $s:ident . $method:ident ( $($args:tt)* ) ) => {
164 $s.$method($($args)*).map_err(|e: std::io::Error| std::io::Error::new(e.kind(), format!("{}: {}", stringify!($method), e)))
165 };
166 );
167
168 #[cfg(unix)]
169 if self.transport() == Transport::Unix {
170 return Ok(());
171 }
172
173 let mut with_socket2_fn = move |s: socket2::SockRef<'_>| {
174 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
175 try_optimize!(s.set_cork(false))?;
176
177 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
178 try_optimize!(s.set_quickack(false))?;
179
180 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
181 try_optimize!(s.set_thin_linear_timeouts(false))?;
182
183 try_optimize!(s.set_send_buffer_size(256 * 1024))?;
184 try_optimize!(s.set_recv_buffer_size(256 * 1024))?;
185
186 match optimization {
187 StreamOptimization::General => {
188 try_optimize!(s.set_nodelay(false))?;
189 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
190 try_optimize!(s.set_thin_linear_timeouts(true))?;
191 }
192 StreamOptimization::Interactive => {
193 try_optimize!(s.set_nodelay(true))?;
194 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
195 try_optimize!(s.set_thin_linear_timeouts(true))?;
196 }
197 StreamOptimization::BulkStreaming(direction) => {
198 try_optimize!(s.set_nodelay(false))?;
199 match direction {
201 BulkStreamDirection::Send | BulkStreamDirection::Both => {
202 try_optimize!(s.set_send_buffer_size(16 * 1024 * 1024))?;
203 #[cfg(any(
204 target_os = "android",
205 target_os = "fuchsia",
206 target_os = "linux"
207 ))]
208 try_optimize!(s.set_cork(true))?;
209 }
210 BulkStreamDirection::Receive => {}
211 }
212
213 match direction {
215 BulkStreamDirection::Receive | BulkStreamDirection::Both => {
216 try_optimize!(s.set_recv_buffer_size(16 * 1024 * 1024))?;
217 #[cfg(any(
218 target_os = "android",
219 target_os = "fuchsia",
220 target_os = "linux"
221 ))]
222 try_optimize!(s.set_quickack(true))?;
223 }
224 BulkStreamDirection::Send => {}
225 }
226 }
227 }
228 Ok(())
229 };
230
231 with_socket2(self, &mut with_socket2_fn)
232 }
233}
234
235#[cfg(feature = "optimization")]
236impl<S: Stream + Sized> StreamOptimizationExt for S {}
237
238pub trait PeekableStream: Stream {
240 #[cfg(feature = "tokio")]
241 fn poll_peek(
242 self: Pin<&mut Self>,
243 cx: &mut Context<'_>,
244 buf: &mut tokio::io::ReadBuf,
245 ) -> Poll<std::io::Result<usize>>;
246 #[cfg(feature = "tokio")]
247 fn peek(self: Pin<&mut Self>, buf: &mut [u8]) -> impl Future<Output = std::io::Result<usize>> {
248 async {
249 let mut this = self;
250 std::future::poll_fn(move |cx| this.as_mut().poll_peek(cx, &mut ReadBuf::new(buf)))
251 .await
252 }
253 }
254}
255
256#[derive(Debug, Clone, PartialEq, Eq)]
258#[must_use]
259pub struct Preview {
260 buffer: smallvec::SmallVec<[u8; DEFAULT_PREVIEW_BUFFER_SIZE as usize]>,
261}
262
263impl Preview {
264 pub(crate) fn new(
265 buffer: smallvec::SmallVec<[u8; DEFAULT_PREVIEW_BUFFER_SIZE as usize]>,
266 ) -> Self {
267 Self { buffer }
268 }
269}
270
271impl Deref for Preview {
272 type Target = [u8];
273 fn deref(&self) -> &Self::Target {
274 &self.buffer
275 }
276}
277
278impl AsRef<[u8]> for Preview {
279 fn as_ref(&self) -> &[u8] {
280 &self.buffer
281 }
282}
283
284impl<const N: usize> PartialEq<[u8; N]> for Preview {
285 fn eq(&self, other: &[u8; N]) -> bool {
286 self.buffer.as_slice() == other
287 }
288}
289
290impl<const N: usize> PartialEq<&[u8; N]> for Preview {
291 fn eq(&self, other: &&[u8; N]) -> bool {
292 self.buffer.as_slice() == *other
293 }
294}
295
296impl PartialEq<[u8]> for Preview {
297 fn eq(&self, other: &[u8]) -> bool {
298 self.buffer.as_slice() == other
299 }
300}
301
302#[derive(Debug, Clone, Copy)]
304pub struct PreviewConfiguration {
305 pub max_preview_bytes: NonZeroUsize,
307 pub max_preview_duration: std::time::Duration,
309}
310
311impl Default for PreviewConfiguration {
312 fn default() -> Self {
313 Self {
314 max_preview_bytes: NonZeroUsize::new(DEFAULT_PREVIEW_BUFFER_SIZE as usize).unwrap(),
315 max_preview_duration: std::time::Duration::from_secs(10),
316 }
317 }
318}
319
320#[derive(Default, Debug)]
321struct UpgradableStreamOptions {
322 ignore_missing_close_notify: bool,
323}
324
325#[allow(private_bounds)]
326#[derive(derive_more::Debug, derive_io::AsyncWrite, derive_io::AsSocketDescriptor)]
327pub struct UpgradableStream<S: Stream, D: TlsDriver = Ssl> {
328 #[write]
329 #[descriptor]
330 inner: UpgradableStreamInner<S, D>,
331 options: UpgradableStreamOptions,
332}
333
334#[allow(private_bounds)]
335impl<S: Stream, D: TlsDriver> UpgradableStream<S, D> {
336 #[inline(always)]
337 pub(crate) fn new_client(base: S, config: Option<D::ClientParams>) -> Self {
338 UpgradableStream {
339 inner: UpgradableStreamInner::BaseClient(base, config),
340 options: Default::default(),
341 }
342 }
343
344 #[inline(always)]
345 pub(crate) fn new_server(base: S, config: Option<TlsServerParameterProvider>) -> Self {
346 UpgradableStream {
347 inner: UpgradableStreamInner::BaseServer(base, config),
348 options: Default::default(),
349 }
350 }
351
352 #[inline(always)]
353 pub(crate) fn new_server_preview(
354 base: RewindStream<S>,
355 config: Option<TlsServerParameterProvider>,
356 ) -> Self {
357 UpgradableStream {
358 inner: UpgradableStreamInner::BaseServerPreview(base, config),
359 options: Default::default(),
360 }
361 }
362
363 pub fn into_boxed(self) -> Result<Box<dyn Stream>, Self> {
365 match self.inner {
366 UpgradableStreamInner::BaseClient(base, _) => Ok(Box::new(base)),
367 UpgradableStreamInner::BaseServer(base, _) => Ok(Box::new(base)),
368 UpgradableStreamInner::BaseServerPreview(base, _) => Ok(Box::new(base)),
369 UpgradableStreamInner::Upgraded(upgraded, _) => Ok(Box::new(upgraded)),
370 UpgradableStreamInner::UpgradedPreview(upgraded, _) => Ok(Box::new(upgraded)),
371 }
372 }
373
374 pub fn handshake(&self) -> Option<&TlsHandshake> {
375 match &self.inner {
376 UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
377 _ => None,
378 }
379 }
380
381 pub fn ignore_missing_close_notify(&mut self) {
382 self.options.ignore_missing_close_notify = true;
383 }
384
385 pub fn unclean_shutdown(self) -> Result<(), Self> {
388 match self.inner {
389 UpgradableStreamInner::BaseClient(..) => Ok(()),
390 UpgradableStreamInner::BaseServer(..) => Ok(()),
391 UpgradableStreamInner::BaseServerPreview(..) => Ok(()),
392 UpgradableStreamInner::Upgraded(upgraded, cfg) => {
393 if let Err(e) = D::unclean_shutdown(upgraded) {
394 Err(Self {
395 inner: UpgradableStreamInner::Upgraded(e, cfg),
396 options: self.options,
397 })
398 } else {
399 Ok(())
400 }
401 }
402 UpgradableStreamInner::UpgradedPreview(upgraded, cfg) => {
403 let (stm, buf) = upgraded.into_inner();
404 if let Err(e) = D::unclean_shutdown(stm) {
405 Err(Self {
406 inner: UpgradableStreamInner::UpgradedPreview(
407 RewindStream {
408 buffer: buf,
409 inner: e,
410 },
411 cfg,
412 ),
413 options: self.options,
414 })
415 } else {
416 Ok(())
417 }
418 }
419 }
420 }
421}
422
423impl<S: Stream, D: TlsDriver> StreamUpgrade for UpgradableStream<S, D> {
424 fn secure_upgrade(self) -> impl Future<Output = Result<Self, SslError>> + Send {
425 async move {
426 let (upgraded, handshake) = match self.inner {
427 UpgradableStreamInner::BaseClient(base, config) => {
428 let Some(config) = config else {
429 return Err(SslError::SslUnsupported);
430 };
431 D::upgrade_client(config, base).await?
432 }
433 UpgradableStreamInner::BaseServer(base, config) => {
434 let Some(config) = config else {
435 return Err(SslError::SslUnsupported);
436 };
437 D::upgrade_server(config, base).await?
438 }
439 UpgradableStreamInner::BaseServerPreview(base, config) => {
440 let Some(config) = config else {
441 return Err(SslError::SslUnsupported);
442 };
443 D::upgrade_server(config, base).await?
444 }
445 _ => {
446 return Err(SslError::SslAlreadyUpgraded);
447 }
448 };
449 Ok(Self {
450 inner: UpgradableStreamInner::Upgraded(upgraded, handshake),
451 options: self.options,
452 })
453 }
454 }
455
456 fn secure_upgrade_preview(
457 self,
458 options: PreviewConfiguration,
459 ) -> impl Future<Output = Result<(Preview, Self), SslError>> + Send {
460 async move {
461 let (mut upgraded, handshake) = match self.inner {
462 UpgradableStreamInner::BaseClient(base, config) => {
463 let Some(config) = config else {
464 return Err(SslError::SslUnsupported);
465 };
466 D::upgrade_client(config, base).await?
467 }
468 UpgradableStreamInner::BaseServer(base, config) => {
469 let Some(config) = config else {
470 return Err(SslError::SslUnsupported);
471 };
472 D::upgrade_server(config, base).await?
473 }
474 UpgradableStreamInner::BaseServerPreview(base, config) => {
475 let Some(config) = config else {
476 return Err(SslError::SslUnsupported);
477 };
478 D::upgrade_server(config, base).await?
479 }
480 _ => {
481 return Err(SslError::SslAlreadyUpgraded);
482 }
483 };
484 let mut buffer = smallvec::SmallVec::with_capacity(options.max_preview_bytes.get());
485 buffer.resize(options.max_preview_bytes.get(), 0);
486 upgraded.read_exact(&mut buffer).await?;
487 let mut rewind = RewindStream::new(upgraded);
488 rewind.rewind(&buffer);
489 Ok((
490 Preview { buffer },
491 Self {
492 inner: UpgradableStreamInner::UpgradedPreview(rewind, handshake),
493 options: self.options,
494 },
495 ))
496 }
497 }
498
499 fn handshake(&self) -> Option<&TlsHandshake> {
500 match &self.inner {
501 UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
502 _ => None,
503 }
504 }
505}
506
507#[cfg(feature = "tokio")]
508impl<S: Stream, D: TlsDriver> tokio::io::AsyncRead for UpgradableStream<S, D> {
509 #[inline(always)]
510 fn poll_read(
511 self: Pin<&mut Self>,
512 cx: &mut std::task::Context<'_>,
513 buf: &mut tokio::io::ReadBuf<'_>,
514 ) -> std::task::Poll<std::io::Result<()>> {
515 let ignore_missing_close_notify = self.options.ignore_missing_close_notify;
516 let inner = &mut self.get_mut().inner;
517 let res = match inner {
518 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_read(cx, buf),
519 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_read(cx, buf),
520 UpgradableStreamInner::BaseServerPreview(base, _) => Pin::new(base).poll_read(cx, buf),
521 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_read(cx, buf),
522 UpgradableStreamInner::UpgradedPreview(upgraded, _) => {
523 Pin::new(upgraded).poll_read(cx, buf)
524 }
525 };
526 if ignore_missing_close_notify {
527 if matches!(res, std::task::Poll::Ready(Err(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof)
528 {
529 return std::task::Poll::Ready(Ok(()));
530 }
531 }
532 res
533 }
534}
535
536impl<S: Stream, D: TlsDriver> LocalAddress for UpgradableStream<S, D> {
537 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
538 self.inner
539 .with_inner_metadata(|inner| inner.local_address())
540 }
541}
542
543impl<S: Stream, D: TlsDriver> RemoteAddress for UpgradableStream<S, D> {
544 fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
545 self.inner
546 .with_inner_metadata(|inner| inner.remote_address())
547 }
548}
549
550impl<S: Stream, D: TlsDriver> StreamMetadata for UpgradableStream<S, D> {
551 fn transport(&self) -> Transport {
552 self.inner.with_inner_metadata(|inner| inner.transport())
553 }
554}
555
556#[derive(
557 derive_more::Debug, derive_io::AsyncRead, derive_io::AsyncWrite, derive_io::AsSocketDescriptor,
558)]
559enum UpgradableStreamInner<S: Stream, D: TlsDriver> {
560 #[debug("BaseClient(..)")]
561 BaseClient(
562 #[read]
563 #[write]
564 #[descriptor]
565 S,
566 Option<D::ClientParams>,
567 ),
568 #[debug("BaseServer(..)")]
569 BaseServer(
570 #[read]
571 #[write]
572 #[descriptor]
573 S,
574 Option<TlsServerParameterProvider>,
575 ),
576 #[debug("Preview(..)")]
577 BaseServerPreview(
578 #[read]
579 #[write]
580 #[descriptor]
581 RewindStream<S>,
582 Option<TlsServerParameterProvider>,
583 ),
584 #[debug("Upgraded(..)")]
585 Upgraded(
586 #[read]
587 #[write]
588 #[descriptor]
589 D::Stream,
590 TlsHandshake,
591 ),
592 #[debug("Upgraded(..)")]
593 UpgradedPreview(
594 #[read]
595 #[write]
596 #[descriptor]
597 RewindStream<D::Stream>,
598 TlsHandshake,
599 ),
600}
601
602impl<S: Stream, D: TlsDriver> UpgradableStreamInner<S, D> {
603 #[inline(always)]
604 fn with_inner_metadata<T>(&self, f: impl FnOnce(&dyn StreamMetadata) -> T) -> T {
605 match self {
606 UpgradableStreamInner::BaseClient(base, _) => f(base),
607 UpgradableStreamInner::BaseServer(base, _) => f(base),
608 UpgradableStreamInner::BaseServerPreview(base, _) => f(base),
609 UpgradableStreamInner::Upgraded(upgraded, _) => f(upgraded),
610 UpgradableStreamInner::UpgradedPreview(upgraded, _) => f(upgraded),
611 }
612 }
613
614 #[inline(always)]
615 fn as_inner_handle(&self) -> &dyn AsHandle {
616 match self {
617 UpgradableStreamInner::BaseClient(base, _) => base,
618 UpgradableStreamInner::BaseServer(base, _) => base,
619 UpgradableStreamInner::BaseServerPreview(base, _) => base,
620 UpgradableStreamInner::Upgraded(upgraded, _) => upgraded,
621 UpgradableStreamInner::UpgradedPreview(upgraded, _) => upgraded,
622 }
623 }
624}
625
626impl<S: Stream, D: TlsDriver> AsHandle for UpgradableStream<S, D> {
627 #[cfg(windows)]
628 fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
629 self.inner.as_inner_handle().as_handle()
630 }
631
632 #[cfg(unix)]
633 fn as_fd(&self) -> std::os::fd::BorrowedFd {
634 self.inner.as_inner_handle().as_fd()
635 }
636}
637
638pub trait Rewindable {
639 fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()>;
640}
641
642#[derive(derive_io::AsyncWrite, derive_io::AsSocketDescriptor)]
644pub(crate) struct RewindStream<S> {
645 buffer: Vec<u8>,
646 #[write]
647 #[descriptor]
648 inner: S,
649}
650
651impl<S> RewindStream<S> {
652 pub fn new(inner: S) -> Self {
653 RewindStream {
654 buffer: Vec::new(),
655 inner,
656 }
657 }
658
659 pub fn rewind(&mut self, data: &[u8]) {
660 self.buffer.extend_from_slice(data);
661 }
662
663 pub fn into_inner(self) -> (S, Vec<u8>) {
664 (self.inner, self.buffer)
665 }
666}
667
668#[cfg(feature = "tokio")]
669impl<S: AsyncRead + Unpin> AsyncRead for RewindStream<S> {
670 #[inline(always)]
671 fn poll_read(
672 mut self: Pin<&mut Self>,
673 cx: &mut Context<'_>,
674 buf: &mut ReadBuf<'_>,
675 ) -> Poll<std::io::Result<()>> {
676 if !self.buffer.is_empty() {
677 let to_read = std::cmp::min(buf.remaining(), self.buffer.len());
678 let data = self.buffer.drain(..to_read).collect::<Vec<_>>();
679 buf.put_slice(&data);
680 Poll::Ready(Ok(()))
681 } else {
682 Pin::new(&mut self.inner).poll_read(cx, buf)
683 }
684 }
685}
686
687impl<S: Stream> Rewindable for RewindStream<S> {
688 fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
689 self.rewind(bytes);
690 Ok(())
691 }
692}
693
694impl<S: LocalAddress> LocalAddress for RewindStream<S> {
695 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
696 self.inner.local_address()
697 }
698}
699
700impl<S: RemoteAddress> RemoteAddress for RewindStream<S> {
701 fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
702 self.inner.remote_address()
703 }
704}
705
706impl<S: PeerCred> PeerCred for RewindStream<S> {
707 #[cfg(all(unix, feature = "tokio"))]
708 fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
709 self.inner.peer_cred()
710 }
711}
712
713impl<S: StreamMetadata> StreamMetadata for RewindStream<S> {
714 fn transport(&self) -> Transport {
715 self.inner.transport()
716 }
717}
718
719impl<S: PeekableStream> PeekableStream for RewindStream<S> {
720 fn poll_peek(
721 mut self: Pin<&mut Self>,
722 cx: &mut Context<'_>,
723 buf: &mut ReadBuf<'_>,
724 ) -> Poll<std::io::Result<usize>> {
725 if !self.buffer.is_empty() {
726 let to_read = std::cmp::min(buf.remaining(), self.buffer.len());
727 buf.put_slice(&self.buffer[..to_read]);
728 Poll::Ready(Ok(to_read))
729 } else {
730 Pin::new(&mut self.inner).poll_peek(cx, buf)
731 }
732 }
733}
734
735impl<S: Stream + AsHandle> AsHandle for RewindStream<S> {
736 #[cfg(windows)]
737 fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
738 self.inner.as_handle()
739 }
740
741 #[cfg(unix)]
742 fn as_fd(&self) -> std::os::fd::BorrowedFd {
743 self.inner.as_fd()
744 }
745}
746
747impl<S: Stream + Rewindable, D: TlsDriver> Rewindable for UpgradableStream<S, D>
748where
749 D::Stream: Rewindable,
750{
751 fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
752 match &mut self.inner {
753 UpgradableStreamInner::BaseClient(stm, _) => stm.rewind(bytes),
754 UpgradableStreamInner::BaseServer(stm, _) => stm.rewind(bytes),
755 UpgradableStreamInner::BaseServerPreview(stm, _) => Ok(stm.rewind(bytes)),
756 UpgradableStreamInner::Upgraded(stm, _) => stm.rewind(bytes),
757 UpgradableStreamInner::UpgradedPreview(stm, _) => Ok(stm.rewind(bytes)),
758 }
759 }
760}
761
762impl<S: PeekableStream, D: TlsDriver> PeekableStream for UpgradableStream<S, D>
763where
764 D::Stream: PeekableStream,
765{
766 #[cfg(feature = "tokio")]
767 fn poll_peek(
768 self: Pin<&mut Self>,
769 cx: &mut Context<'_>,
770 buf: &mut tokio::io::ReadBuf,
771 ) -> Poll<std::io::Result<usize>> {
772 match &mut self.get_mut().inner {
773 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_peek(cx, buf),
774 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_peek(cx, buf),
775 UpgradableStreamInner::BaseServerPreview(base, _) => Pin::new(base).poll_peek(cx, buf),
776 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_peek(cx, buf),
777 UpgradableStreamInner::UpgradedPreview(upgraded, _) => {
778 Pin::new(upgraded).poll_peek(cx, buf)
779 }
780 }
781 }
782}
783
784impl<S: PeerCred + Stream, D: TlsDriver> PeerCred for UpgradableStream<S, D> {
785 #[cfg(all(unix, feature = "tokio"))]
786 fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
787 self.inner.with_inner_metadata(|inner| inner.peer_cred())
788 }
789}