1#[cfg(feature = "tokio")]
2use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
3
4#[cfg(feature = "tokio")]
5use std::{
6 any::Any,
7 pin::Pin,
8 task::{Context, Poll},
9};
10use std::{future::Future, num::NonZeroUsize, ops::Deref};
11
12use crate::{
13 LocalAddress, PeerCred, RemoteAddress, ResolvedTarget, Ssl, SslError, StreamMetadata,
14 TlsDriver, TlsHandshake, TlsServerParameterProvider, Transport, DEFAULT_PREVIEW_BUFFER_SIZE,
15};
16
17#[cfg(unix)]
19pub trait AsHandle {
20 fn as_fd(&self) -> std::os::fd::BorrowedFd;
21}
22
23#[cfg(windows)]
25pub trait AsHandle {
26 fn as_handle(&self) -> std::os::windows::io::BorrowedSocket;
27}
28
29#[cfg(feature = "tokio")]
31pub trait Stream:
32 tokio::io::AsyncRead + tokio::io::AsyncWrite + StreamMetadata + Send + Unpin + AsHandle + 'static
33{
34 fn downcast<S: Stream + 'static>(self) -> Result<S, Self>
36 where
37 Self: Sized + 'static,
38 {
39 let mut holder = Some(self);
40 let stream = &mut holder as &mut dyn Any;
41 if let Some(stream) = stream.downcast_mut::<Option<S>>() {
42 return Ok(stream.take().unwrap());
43 }
44 if let Some(stream) = stream.downcast_mut::<Option<Box<S>>>() {
45 return Ok(*stream.take().unwrap());
46 }
47 Err(holder.take().unwrap())
48 }
49
50 fn boxed(self) -> Box<dyn Stream + Send>
52 where
53 Self: Sized + 'static,
54 {
55 let mut holder = Some(self);
56 let stream = &mut holder as &mut dyn Any;
57 if let Some(stream) = stream.downcast_mut::<Option<Box<dyn Stream>>>() {
58 stream.take().unwrap()
59 } else {
60 Box::new(holder.take().unwrap())
61 }
62 }
63}
64
65#[cfg(feature = "tokio")]
66impl<T> Stream for T where
67 T: tokio::io::AsyncRead
68 + tokio::io::AsyncWrite
69 + StreamMetadata
70 + AsHandle
71 + Unpin
72 + Send
73 + 'static
74{
75}
76
77impl PeerCred for Box<dyn Stream + Send> {
79 #[cfg(all(unix, feature = "tokio"))]
80 fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
81 self.as_ref().peer_cred()
82 }
83}
84
85impl LocalAddress for Box<dyn Stream + Send> {
86 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
87 self.as_ref().local_address()
88 }
89}
90
91impl RemoteAddress for Box<dyn Stream + Send> {
92 fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
93 self.as_ref().remote_address()
94 }
95}
96
97impl StreamMetadata for Box<dyn Stream + Send> {
98 fn transport(&self) -> Transport {
99 self.as_ref().transport()
100 }
101}
102
103#[cfg(not(feature = "tokio"))]
104pub trait Stream: StreamMetadata + Unpin + AsHandle + 'static {}
105#[cfg(not(feature = "tokio"))]
106impl<S: Stream, D: TlsDriver> Stream for UpgradableStream<S, D> {}
107#[cfg(not(feature = "tokio"))]
108impl Stream for () {}
109
110pub trait StreamUpgrade: Stream + Sized {
112 fn secure_upgrade(self) -> impl Future<Output = Result<Self, SslError>> + Send;
114 fn secure_upgrade_preview(
116 self,
117 options: PreviewConfiguration,
118 ) -> impl Future<Output = Result<(Preview, Self), SslError>> + Send;
119 fn handshake(&self) -> Option<&TlsHandshake>;
121}
122
123#[cfg(feature = "optimization")]
124#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
125pub enum StreamOptimization {
126 #[default]
127 General,
129 Interactive,
131 BulkStreaming(BulkStreamDirection),
133}
134
135#[cfg(feature = "optimization")]
136#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
137pub enum BulkStreamDirection {
138 Send,
139 Receive,
140 #[default]
141 Both,
142}
143
144#[cfg(any(feature = "optimization", feature = "keepalive"))]
146fn with_socket2<S: AsHandle + Sized>(
147 stream: &S,
148 f: &mut dyn for<'a> FnMut(socket2::SockRef<'a>) -> Result<(), std::io::Error>,
149) -> Result<(), std::io::Error> {
150 #[cfg(unix)]
151 let res = f(socket2::SockRef::from(&stream.as_fd()));
152 #[cfg(windows)]
153 let res = f(socket2::SockRef::from(&stream.as_handle()));
154 res
155}
156
157#[cfg(feature = "optimization")]
158pub trait StreamOptimizationExt: Stream + Sized {
159 #[cfg(feature = "optimization")]
161 fn optimize_for(&mut self, optimization: StreamOptimization) -> Result<(), std::io::Error> {
162 macro_rules! try_optimize(
163 ( $s:ident . $method:ident ( $($args:tt)* ) ) => {
164 $s.$method($($args)*).map_err(|e: std::io::Error| std::io::Error::new(e.kind(), format!("{}: {}", stringify!($method), e)))
165 };
166 );
167
168 #[cfg(unix)]
169 if self.transport() == Transport::Unix {
170 return Ok(());
171 }
172
173 let mut with_socket2_fn = move |s: socket2::SockRef<'_>| {
174 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
175 try_optimize!(s.set_cork(false))?;
176
177 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
178 try_optimize!(s.set_quickack(false))?;
179
180 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
181 try_optimize!(s.set_thin_linear_timeouts(false))?;
182
183 try_optimize!(s.set_send_buffer_size(256 * 1024))?;
184 try_optimize!(s.set_recv_buffer_size(256 * 1024))?;
185
186 match optimization {
187 StreamOptimization::General => {
188 try_optimize!(s.set_nodelay(false))?;
189 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
190 try_optimize!(s.set_thin_linear_timeouts(true))?;
191 }
192 StreamOptimization::Interactive => {
193 try_optimize!(s.set_nodelay(true))?;
194 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
195 try_optimize!(s.set_thin_linear_timeouts(true))?;
196 }
197 StreamOptimization::BulkStreaming(direction) => {
198 try_optimize!(s.set_nodelay(false))?;
199 match direction {
201 BulkStreamDirection::Send | BulkStreamDirection::Both => {
202 try_optimize!(s.set_send_buffer_size(16 * 1024 * 1024))?;
203 #[cfg(any(
204 target_os = "android",
205 target_os = "fuchsia",
206 target_os = "linux"
207 ))]
208 try_optimize!(s.set_cork(true))?;
209 }
210 BulkStreamDirection::Receive => {}
211 }
212
213 match direction {
215 BulkStreamDirection::Receive | BulkStreamDirection::Both => {
216 try_optimize!(s.set_recv_buffer_size(16 * 1024 * 1024))?;
217 #[cfg(any(
218 target_os = "android",
219 target_os = "fuchsia",
220 target_os = "linux"
221 ))]
222 try_optimize!(s.set_quickack(true))?;
223 }
224 BulkStreamDirection::Send => {}
225 }
226 }
227 }
228 Ok(())
229 };
230
231 with_socket2(self, &mut with_socket2_fn)
232 }
233}
234
235#[cfg(feature = "optimization")]
236impl<S: Stream + Sized> StreamOptimizationExt for S {}
237
238pub trait PeekableStream: Stream {
240 #[cfg(feature = "tokio")]
241 fn poll_peek(
242 self: Pin<&mut Self>,
243 cx: &mut Context<'_>,
244 buf: &mut tokio::io::ReadBuf,
245 ) -> Poll<std::io::Result<usize>>;
246 #[cfg(feature = "tokio")]
247 fn peek(self: Pin<&mut Self>, buf: &mut [u8]) -> impl Future<Output = std::io::Result<usize>> {
248 async {
249 let mut this = self;
250 std::future::poll_fn(move |cx| this.as_mut().poll_peek(cx, &mut ReadBuf::new(buf)))
251 .await
252 }
253 }
254}
255
256#[derive(Debug, Clone, PartialEq, Eq)]
258#[must_use]
259pub struct Preview {
260 buffer: smallvec::SmallVec<[u8; DEFAULT_PREVIEW_BUFFER_SIZE as usize]>,
261}
262
263impl Preview {
264 pub(crate) fn new(
265 buffer: smallvec::SmallVec<[u8; DEFAULT_PREVIEW_BUFFER_SIZE as usize]>,
266 ) -> Self {
267 Self { buffer }
268 }
269}
270
271impl Deref for Preview {
272 type Target = [u8];
273 fn deref(&self) -> &Self::Target {
274 &self.buffer
275 }
276}
277
278impl AsRef<[u8]> for Preview {
279 fn as_ref(&self) -> &[u8] {
280 &self.buffer
281 }
282}
283
284impl<const N: usize> PartialEq<[u8; N]> for Preview {
285 fn eq(&self, other: &[u8; N]) -> bool {
286 self.buffer.as_slice() == other
287 }
288}
289
290impl<const N: usize> PartialEq<&[u8; N]> for Preview {
291 fn eq(&self, other: &&[u8; N]) -> bool {
292 self.buffer.as_slice() == *other
293 }
294}
295
296impl PartialEq<[u8]> for Preview {
297 fn eq(&self, other: &[u8]) -> bool {
298 self.buffer.as_slice() == other
299 }
300}
301
302#[derive(Debug, Clone, Copy)]
304pub struct PreviewConfiguration {
305 pub max_preview_bytes: NonZeroUsize,
307 pub max_preview_duration: std::time::Duration,
309}
310
311impl Default for PreviewConfiguration {
312 fn default() -> Self {
313 Self {
314 max_preview_bytes: NonZeroUsize::new(DEFAULT_PREVIEW_BUFFER_SIZE as usize).unwrap(),
315 max_preview_duration: std::time::Duration::from_secs(10),
316 }
317 }
318}
319
320#[derive(Default, Debug)]
321struct UpgradableStreamOptions {
322 ignore_missing_close_notify: bool,
323}
324
325#[allow(private_bounds)]
326#[derive(derive_more::Debug, derive_io::AsyncWrite, derive_io::AsSocketDescriptor)]
327pub struct UpgradableStream<S: Stream, D: TlsDriver = Ssl> {
328 #[write]
329 #[descriptor]
330 inner: UpgradableStreamInner<S, D>,
331 options: UpgradableStreamOptions,
332}
333
334#[allow(private_bounds)]
335impl<S: Stream, D: TlsDriver> UpgradableStream<S, D> {
336 #[inline(always)]
337 pub(crate) fn new_client(base: S, config: Option<D::ClientParams>) -> Self {
338 UpgradableStream {
339 inner: UpgradableStreamInner::BaseClient(base, config),
340 options: Default::default(),
341 }
342 }
343
344 #[inline(always)]
345 pub(crate) fn new_server(base: S, config: Option<TlsServerParameterProvider>) -> Self {
346 UpgradableStream {
347 inner: UpgradableStreamInner::BaseServer(base, config),
348 options: Default::default(),
349 }
350 }
351
352 #[inline(always)]
353 pub(crate) fn new_server_preview(
354 base: RewindStream<S>,
355 config: Option<TlsServerParameterProvider>,
356 ) -> Self {
357 UpgradableStream {
358 inner: UpgradableStreamInner::BaseServerPreview(base, config),
359 options: Default::default(),
360 }
361 }
362
363 pub fn into_boxed(self) -> Result<Box<dyn Stream>, Self> {
365 match self.inner {
366 UpgradableStreamInner::BaseClient(base, _) => Ok(Box::new(base)),
367 UpgradableStreamInner::BaseServer(base, _) => Ok(Box::new(base)),
368 UpgradableStreamInner::BaseServerPreview(base, _) => Ok(Box::new(base)),
369 UpgradableStreamInner::Upgraded(upgraded, _) => Ok(Box::new(upgraded)),
370 UpgradableStreamInner::UpgradedPreview(upgraded, _) => Ok(Box::new(upgraded)),
371 }
372 }
373
374 pub fn handshake(&self) -> Option<&TlsHandshake> {
375 match &self.inner {
376 UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
377 _ => None,
378 }
379 }
380
381 pub fn ignore_missing_close_notify(&mut self) {
382 self.options.ignore_missing_close_notify = true;
383 }
384
385 pub fn unclean_shutdown(self) -> Result<(), Self> {
388 match self.inner {
389 UpgradableStreamInner::BaseClient(..) => Ok(()),
390 UpgradableStreamInner::BaseServer(..) => Ok(()),
391 UpgradableStreamInner::BaseServerPreview(..) => Ok(()),
392 UpgradableStreamInner::Upgraded(upgraded, cfg) => {
393 if let Err(e) = D::unclean_shutdown(upgraded) {
394 Err(Self {
395 inner: UpgradableStreamInner::Upgraded(e, cfg),
396 options: self.options,
397 })
398 } else {
399 Ok(())
400 }
401 }
402 UpgradableStreamInner::UpgradedPreview(upgraded, cfg) => {
403 let (stm, buf) = upgraded.into_inner();
404 if let Err(e) = D::unclean_shutdown(stm) {
405 Err(Self {
406 inner: UpgradableStreamInner::UpgradedPreview(
407 RewindStream {
408 buffer: buf,
409 inner: e,
410 },
411 cfg,
412 ),
413 options: self.options,
414 })
415 } else {
416 Ok(())
417 }
418 }
419 }
420 }
421}
422
423impl<S: Stream, D: TlsDriver> StreamUpgrade for UpgradableStream<S, D> {
424 async fn secure_upgrade(self) -> Result<Self, SslError> {
425 let (upgraded, handshake) = match self.inner {
426 UpgradableStreamInner::BaseClient(base, config) => {
427 let Some(config) = config else {
428 return Err(SslError::SslUnsupported);
429 };
430 D::upgrade_client(config, base).await?
431 }
432 UpgradableStreamInner::BaseServer(base, config) => {
433 let Some(config) = config else {
434 return Err(SslError::SslUnsupported);
435 };
436 D::upgrade_server(config, base).await?
437 }
438 UpgradableStreamInner::BaseServerPreview(base, config) => {
439 let Some(config) = config else {
440 return Err(SslError::SslUnsupported);
441 };
442 D::upgrade_server(config, base).await?
443 }
444 _ => {
445 return Err(SslError::SslAlreadyUpgraded);
446 }
447 };
448 Ok(Self {
449 inner: UpgradableStreamInner::Upgraded(upgraded, handshake),
450 options: self.options,
451 })
452 }
453
454 async fn secure_upgrade_preview(
455 self,
456 options: PreviewConfiguration,
457 ) -> Result<(Preview, Self), SslError> {
458 let (mut upgraded, handshake) = match self.inner {
459 UpgradableStreamInner::BaseClient(base, config) => {
460 let Some(config) = config else {
461 return Err(SslError::SslUnsupported);
462 };
463 D::upgrade_client(config, base).await?
464 }
465 UpgradableStreamInner::BaseServer(base, config) => {
466 let Some(config) = config else {
467 return Err(SslError::SslUnsupported);
468 };
469 D::upgrade_server(config, base).await?
470 }
471 UpgradableStreamInner::BaseServerPreview(base, config) => {
472 let Some(config) = config else {
473 return Err(SslError::SslUnsupported);
474 };
475 D::upgrade_server(config, base).await?
476 }
477 _ => {
478 return Err(SslError::SslAlreadyUpgraded);
479 }
480 };
481 let mut buffer = smallvec::SmallVec::with_capacity(options.max_preview_bytes.get());
482 buffer.resize(options.max_preview_bytes.get(), 0);
483 upgraded.read_exact(&mut buffer).await?;
484 let mut rewind = RewindStream::new(upgraded);
485 rewind.rewind(&buffer);
486 Ok((
487 Preview { buffer },
488 Self {
489 inner: UpgradableStreamInner::UpgradedPreview(rewind, handshake),
490 options: self.options,
491 },
492 ))
493 }
494
495 fn handshake(&self) -> Option<&TlsHandshake> {
496 match &self.inner {
497 UpgradableStreamInner::Upgraded(_, handshake) => Some(handshake),
498 _ => None,
499 }
500 }
501}
502
503#[cfg(feature = "tokio")]
504impl<S: Stream, D: TlsDriver> tokio::io::AsyncRead for UpgradableStream<S, D> {
505 #[inline(always)]
506 fn poll_read(
507 self: Pin<&mut Self>,
508 cx: &mut std::task::Context<'_>,
509 buf: &mut tokio::io::ReadBuf<'_>,
510 ) -> std::task::Poll<std::io::Result<()>> {
511 let ignore_missing_close_notify = self.options.ignore_missing_close_notify;
512 let inner = &mut self.get_mut().inner;
513 let res = match inner {
514 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_read(cx, buf),
515 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_read(cx, buf),
516 UpgradableStreamInner::BaseServerPreview(base, _) => Pin::new(base).poll_read(cx, buf),
517 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_read(cx, buf),
518 UpgradableStreamInner::UpgradedPreview(upgraded, _) => {
519 Pin::new(upgraded).poll_read(cx, buf)
520 }
521 };
522 if ignore_missing_close_notify
523 && matches!(res, std::task::Poll::Ready(Err(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof)
524 {
525 return std::task::Poll::Ready(Ok(()));
526 }
527 res
528 }
529}
530
531impl<S: Stream, D: TlsDriver> LocalAddress for UpgradableStream<S, D> {
532 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
533 self.inner
534 .with_inner_metadata(|inner| inner.local_address())
535 }
536}
537
538impl<S: Stream, D: TlsDriver> RemoteAddress for UpgradableStream<S, D> {
539 fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
540 self.inner
541 .with_inner_metadata(|inner| inner.remote_address())
542 }
543}
544
545impl<S: Stream, D: TlsDriver> StreamMetadata for UpgradableStream<S, D> {
546 fn transport(&self) -> Transport {
547 self.inner.with_inner_metadata(|inner| inner.transport())
548 }
549}
550
551#[derive(
552 derive_more::Debug, derive_io::AsyncRead, derive_io::AsyncWrite, derive_io::AsSocketDescriptor,
553)]
554enum UpgradableStreamInner<S: Stream, D: TlsDriver> {
555 #[debug("BaseClient(..)")]
556 BaseClient(
557 #[read]
558 #[write]
559 #[descriptor]
560 S,
561 Option<D::ClientParams>,
562 ),
563 #[debug("BaseServer(..)")]
564 BaseServer(
565 #[read]
566 #[write]
567 #[descriptor]
568 S,
569 Option<TlsServerParameterProvider>,
570 ),
571 #[debug("Preview(..)")]
572 BaseServerPreview(
573 #[read]
574 #[write]
575 #[descriptor]
576 RewindStream<S>,
577 Option<TlsServerParameterProvider>,
578 ),
579 #[debug("Upgraded(..)")]
580 Upgraded(
581 #[read]
582 #[write]
583 #[descriptor]
584 D::Stream,
585 TlsHandshake,
586 ),
587 #[debug("Upgraded(..)")]
588 UpgradedPreview(
589 #[read]
590 #[write]
591 #[descriptor]
592 RewindStream<D::Stream>,
593 TlsHandshake,
594 ),
595}
596
597impl<S: Stream, D: TlsDriver> UpgradableStreamInner<S, D> {
598 #[inline(always)]
599 fn with_inner_metadata<T>(&self, f: impl FnOnce(&dyn StreamMetadata) -> T) -> T {
600 match self {
601 UpgradableStreamInner::BaseClient(base, _) => f(base),
602 UpgradableStreamInner::BaseServer(base, _) => f(base),
603 UpgradableStreamInner::BaseServerPreview(base, _) => f(base),
604 UpgradableStreamInner::Upgraded(upgraded, _) => f(upgraded),
605 UpgradableStreamInner::UpgradedPreview(upgraded, _) => f(upgraded),
606 }
607 }
608
609 #[inline(always)]
610 fn as_inner_handle(&self) -> &dyn AsHandle {
611 match self {
612 UpgradableStreamInner::BaseClient(base, _) => base,
613 UpgradableStreamInner::BaseServer(base, _) => base,
614 UpgradableStreamInner::BaseServerPreview(base, _) => base,
615 UpgradableStreamInner::Upgraded(upgraded, _) => upgraded,
616 UpgradableStreamInner::UpgradedPreview(upgraded, _) => upgraded,
617 }
618 }
619}
620
621impl<S: Stream, D: TlsDriver> AsHandle for UpgradableStream<S, D> {
622 #[cfg(windows)]
623 fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
624 self.inner.as_inner_handle().as_handle()
625 }
626
627 #[cfg(unix)]
628 fn as_fd(&self) -> std::os::fd::BorrowedFd {
629 self.inner.as_inner_handle().as_fd()
630 }
631}
632
633pub trait Rewindable {
634 fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()>;
635}
636
637#[derive(derive_io::AsyncWrite, derive_io::AsSocketDescriptor)]
639pub(crate) struct RewindStream<S> {
640 buffer: Vec<u8>,
641 #[write]
642 #[descriptor]
643 inner: S,
644}
645
646impl<S> RewindStream<S> {
647 pub fn new(inner: S) -> Self {
648 RewindStream {
649 buffer: Vec::new(),
650 inner,
651 }
652 }
653
654 pub fn rewind(&mut self, data: &[u8]) {
655 self.buffer.extend_from_slice(data);
656 }
657
658 pub fn into_inner(self) -> (S, Vec<u8>) {
659 (self.inner, self.buffer)
660 }
661}
662
663#[cfg(feature = "tokio")]
664impl<S: AsyncRead + Unpin> AsyncRead for RewindStream<S> {
665 #[inline(always)]
666 fn poll_read(
667 mut self: Pin<&mut Self>,
668 cx: &mut Context<'_>,
669 buf: &mut ReadBuf<'_>,
670 ) -> Poll<std::io::Result<()>> {
671 if !self.buffer.is_empty() {
672 let to_read = std::cmp::min(buf.remaining(), self.buffer.len());
673 let data = self.buffer.drain(..to_read).collect::<Vec<_>>();
674 buf.put_slice(&data);
675 Poll::Ready(Ok(()))
676 } else {
677 Pin::new(&mut self.inner).poll_read(cx, buf)
678 }
679 }
680}
681
682impl<S: Stream> Rewindable for RewindStream<S> {
683 fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
684 self.rewind(bytes);
685 Ok(())
686 }
687}
688
689impl<S: LocalAddress> LocalAddress for RewindStream<S> {
690 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
691 self.inner.local_address()
692 }
693}
694
695impl<S: RemoteAddress> RemoteAddress for RewindStream<S> {
696 fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
697 self.inner.remote_address()
698 }
699}
700
701impl<S: PeerCred> PeerCred for RewindStream<S> {
702 #[cfg(all(unix, feature = "tokio"))]
703 fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
704 self.inner.peer_cred()
705 }
706}
707
708impl<S: StreamMetadata> StreamMetadata for RewindStream<S> {
709 fn transport(&self) -> Transport {
710 self.inner.transport()
711 }
712}
713
714impl<S: PeekableStream> PeekableStream for RewindStream<S> {
715 fn poll_peek(
716 mut self: Pin<&mut Self>,
717 cx: &mut Context<'_>,
718 buf: &mut ReadBuf<'_>,
719 ) -> Poll<std::io::Result<usize>> {
720 if !self.buffer.is_empty() {
721 let to_read = std::cmp::min(buf.remaining(), self.buffer.len());
722 buf.put_slice(&self.buffer[..to_read]);
723 Poll::Ready(Ok(to_read))
724 } else {
725 Pin::new(&mut self.inner).poll_peek(cx, buf)
726 }
727 }
728}
729
730impl<S: Stream + AsHandle> AsHandle for RewindStream<S> {
731 #[cfg(windows)]
732 fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
733 self.inner.as_handle()
734 }
735
736 #[cfg(unix)]
737 fn as_fd(&self) -> std::os::fd::BorrowedFd {
738 self.inner.as_fd()
739 }
740}
741
742impl<S: Stream + Rewindable, D: TlsDriver> Rewindable for UpgradableStream<S, D>
743where
744 D::Stream: Rewindable,
745{
746 fn rewind(&mut self, bytes: &[u8]) -> std::io::Result<()> {
747 match &mut self.inner {
748 UpgradableStreamInner::BaseClient(stm, _) => stm.rewind(bytes),
749 UpgradableStreamInner::BaseServer(stm, _) => stm.rewind(bytes),
750 UpgradableStreamInner::BaseServerPreview(stm, _) => {
751 stm.rewind(bytes);
752 Ok(())
753 }
754 UpgradableStreamInner::Upgraded(stm, _) => stm.rewind(bytes),
755 UpgradableStreamInner::UpgradedPreview(stm, _) => {
756 stm.rewind(bytes);
757 Ok(())
758 }
759 }
760 }
761}
762
763impl<S: PeekableStream, D: TlsDriver> PeekableStream for UpgradableStream<S, D>
764where
765 D::Stream: PeekableStream,
766{
767 #[cfg(feature = "tokio")]
768 fn poll_peek(
769 self: Pin<&mut Self>,
770 cx: &mut Context<'_>,
771 buf: &mut tokio::io::ReadBuf,
772 ) -> Poll<std::io::Result<usize>> {
773 match &mut self.get_mut().inner {
774 UpgradableStreamInner::BaseClient(base, _) => Pin::new(base).poll_peek(cx, buf),
775 UpgradableStreamInner::BaseServer(base, _) => Pin::new(base).poll_peek(cx, buf),
776 UpgradableStreamInner::BaseServerPreview(base, _) => Pin::new(base).poll_peek(cx, buf),
777 UpgradableStreamInner::Upgraded(upgraded, _) => Pin::new(upgraded).poll_peek(cx, buf),
778 UpgradableStreamInner::UpgradedPreview(upgraded, _) => {
779 Pin::new(upgraded).poll_peek(cx, buf)
780 }
781 }
782 }
783}
784
785impl<S: PeerCred + Stream, D: TlsDriver> PeerCred for UpgradableStream<S, D> {
786 #[cfg(all(unix, feature = "tokio"))]
787 fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
788 self.inner.with_inner_metadata(|inner| inner.peer_cred())
789 }
790}