pingora_core/protocols/l4/
stream.rs

1// Copyright 2025 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Transport layer connection
16
17use async_trait::async_trait;
18use futures::FutureExt;
19use log::{debug, error};
20
21use pingora_error::{ErrorType::*, OrErr, Result};
22#[cfg(target_os = "linux")]
23use std::io::IoSliceMut;
24#[cfg(unix)]
25use std::os::unix::io::AsRawFd;
26#[cfg(windows)]
27use std::os::windows::io::AsRawSocket;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll};
31use std::time::{Duration, Instant, SystemTime};
32#[cfg(target_os = "linux")]
33use tokio::io::Interest;
34use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
35use tokio::net::TcpStream;
36#[cfg(unix)]
37use tokio::net::UnixStream;
38
39use crate::protocols::l4::ext::{set_tcp_keepalive, TcpKeepalive};
40use crate::protocols::raw_connect::ProxyDigest;
41use crate::protocols::{
42    GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, SocketDigest, Ssl,
43    TimingDigest, UniqueID, UniqueIDType,
44};
45use crate::upstreams::peer::Tracer;
46
47#[derive(Debug)]
48enum RawStream {
49    Tcp(TcpStream),
50    #[cfg(unix)]
51    Unix(UnixStream),
52}
53
54impl AsyncRead for RawStream {
55    fn poll_read(
56        self: Pin<&mut Self>,
57        cx: &mut Context<'_>,
58        buf: &mut ReadBuf<'_>,
59    ) -> Poll<io::Result<()>> {
60        // Safety: Basic enum pin projection
61        unsafe {
62            match &mut Pin::get_unchecked_mut(self) {
63                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
64                #[cfg(unix)]
65                RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
66            }
67        }
68    }
69}
70
71impl AsyncWrite for RawStream {
72    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
73        // Safety: Basic enum pin projection
74        unsafe {
75            match &mut Pin::get_unchecked_mut(self) {
76                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
77                #[cfg(unix)]
78                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
79            }
80        }
81    }
82
83    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
84        // Safety: Basic enum pin projection
85        unsafe {
86            match &mut Pin::get_unchecked_mut(self) {
87                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
88                #[cfg(unix)]
89                RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
90            }
91        }
92    }
93
94    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
95        // Safety: Basic enum pin projection
96        unsafe {
97            match &mut Pin::get_unchecked_mut(self) {
98                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
99                #[cfg(unix)]
100                RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
101            }
102        }
103    }
104
105    fn poll_write_vectored(
106        self: Pin<&mut Self>,
107        cx: &mut Context<'_>,
108        bufs: &[std::io::IoSlice<'_>],
109    ) -> Poll<io::Result<usize>> {
110        // Safety: Basic enum pin projection
111        unsafe {
112            match &mut Pin::get_unchecked_mut(self) {
113                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
114                #[cfg(unix)]
115                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
116            }
117        }
118    }
119
120    fn is_write_vectored(&self) -> bool {
121        match self {
122            RawStream::Tcp(s) => s.is_write_vectored(),
123            #[cfg(unix)]
124            RawStream::Unix(s) => s.is_write_vectored(),
125        }
126    }
127}
128
129#[cfg(unix)]
130impl AsRawFd for RawStream {
131    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
132        match self {
133            RawStream::Tcp(s) => s.as_raw_fd(),
134            RawStream::Unix(s) => s.as_raw_fd(),
135        }
136    }
137}
138
139#[cfg(windows)]
140impl AsRawSocket for RawStream {
141    fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
142        match self {
143            RawStream::Tcp(s) => s.as_raw_socket(),
144        }
145    }
146}
147
148#[derive(Debug)]
149struct RawStreamWrapper {
150    pub(crate) stream: RawStream,
151    /// store the last rx timestamp of the stream.
152    pub(crate) rx_ts: Option<SystemTime>,
153    /// enable reading rx timestamp
154    #[cfg(target_os = "linux")]
155    pub(crate) enable_rx_ts: bool,
156    #[cfg(target_os = "linux")]
157    /// This can be reused across multiple recvmsg calls. The cmsg buffer may
158    /// come from old sockets created by older version of pingora and so,
159    /// this vector can only grow.
160    reusable_cmsg_space: Vec<u8>,
161}
162
163impl RawStreamWrapper {
164    pub fn new(stream: RawStream) -> Self {
165        RawStreamWrapper {
166            stream,
167            rx_ts: None,
168            #[cfg(target_os = "linux")]
169            enable_rx_ts: false,
170            #[cfg(target_os = "linux")]
171            reusable_cmsg_space: nix::cmsg_space!(nix::sys::time::TimeSpec),
172        }
173    }
174
175    #[cfg(target_os = "linux")]
176    pub fn enable_rx_ts(&mut self, enable_rx_ts: bool) {
177        self.enable_rx_ts = enable_rx_ts;
178    }
179}
180
181impl AsyncRead for RawStreamWrapper {
182    #[cfg(not(target_os = "linux"))]
183    fn poll_read(
184        self: Pin<&mut Self>,
185        cx: &mut Context<'_>,
186        buf: &mut ReadBuf<'_>,
187    ) -> Poll<io::Result<()>> {
188        // Safety: Basic enum pin projection
189        unsafe {
190            let rs_wrapper = Pin::get_unchecked_mut(self);
191            match &mut rs_wrapper.stream {
192                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
193                #[cfg(unix)]
194                RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
195            }
196        }
197    }
198
199    #[cfg(target_os = "linux")]
200    fn poll_read(
201        self: Pin<&mut Self>,
202        cx: &mut Context<'_>,
203        buf: &mut ReadBuf<'_>,
204    ) -> Poll<io::Result<()>> {
205        use futures::ready;
206        use nix::sys::socket::{recvmsg, ControlMessageOwned, MsgFlags, SockaddrStorage};
207
208        // if we do not need rx timestamp, then use the standard path
209        if !self.enable_rx_ts {
210            // Safety: Basic enum pin projection
211            unsafe {
212                let rs_wrapper = Pin::get_unchecked_mut(self);
213                match &mut rs_wrapper.stream {
214                    RawStream::Tcp(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
215                    RawStream::Unix(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
216                }
217            }
218        }
219
220        // Safety: Basic pin projection to get mutable stream
221        let rs_wrapper = unsafe { Pin::get_unchecked_mut(self) };
222        match &mut rs_wrapper.stream {
223            RawStream::Tcp(s) => {
224                loop {
225                    ready!(s.poll_read_ready(cx))?;
226                    // Safety: maybe uninitialized bytes will only be passed to recvmsg
227                    let b = unsafe {
228                        &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>]
229                            as *mut [u8])
230                    };
231                    let mut iov = [IoSliceMut::new(b)];
232                    rs_wrapper.reusable_cmsg_space.clear();
233
234                    match s.try_io(Interest::READABLE, || {
235                        recvmsg::<SockaddrStorage>(
236                            s.as_raw_fd(),
237                            &mut iov,
238                            Some(&mut rs_wrapper.reusable_cmsg_space),
239                            MsgFlags::empty(),
240                        )
241                        .map_err(|errno| errno.into())
242                    }) {
243                        Ok(r) => {
244                            if let Some(ControlMessageOwned::ScmTimestampsns(rtime)) = r
245                                .cmsgs()
246                                .find(|i| matches!(i, ControlMessageOwned::ScmTimestampsns(_)))
247                            {
248                                // The returned timestamp is a real (i.e. not monotonic) timestamp
249                                // https://docs.kernel.org/networking/timestamping.html
250                                rs_wrapper.rx_ts =
251                                    SystemTime::UNIX_EPOCH.checked_add(rtime.system.into());
252                            }
253                            // Safety: We trust `recvmsg` to have filled up `r.bytes` bytes in the buffer.
254                            unsafe {
255                                buf.assume_init(r.bytes);
256                            }
257                            buf.advance(r.bytes);
258                            return Poll::Ready(Ok(()));
259                        }
260                        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
261                        Err(e) => return Poll::Ready(Err(e)),
262                    }
263                }
264            }
265            // Unix RX timestamp only works with datagram for now, so we do not care about it
266            RawStream::Unix(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) },
267        }
268    }
269}
270
271impl AsyncWrite for RawStreamWrapper {
272    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
273        // Safety: Basic enum pin projection
274        unsafe {
275            match &mut Pin::get_unchecked_mut(self).stream {
276                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
277                #[cfg(unix)]
278                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
279            }
280        }
281    }
282
283    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
284        // Safety: Basic enum pin projection
285        unsafe {
286            match &mut Pin::get_unchecked_mut(self).stream {
287                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
288                #[cfg(unix)]
289                RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
290            }
291        }
292    }
293
294    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
295        // Safety: Basic enum pin projection
296        unsafe {
297            match &mut Pin::get_unchecked_mut(self).stream {
298                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
299                #[cfg(unix)]
300                RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
301            }
302        }
303    }
304
305    fn poll_write_vectored(
306        self: Pin<&mut Self>,
307        cx: &mut Context<'_>,
308        bufs: &[std::io::IoSlice<'_>],
309    ) -> Poll<io::Result<usize>> {
310        // Safety: Basic enum pin projection
311        unsafe {
312            match &mut Pin::get_unchecked_mut(self).stream {
313                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
314                #[cfg(unix)]
315                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
316            }
317        }
318    }
319
320    fn is_write_vectored(&self) -> bool {
321        self.stream.is_write_vectored()
322    }
323}
324
325#[cfg(unix)]
326impl AsRawFd for RawStreamWrapper {
327    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
328        self.stream.as_raw_fd()
329    }
330}
331
332#[cfg(windows)]
333impl AsRawSocket for RawStreamWrapper {
334    fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
335        self.stream.as_raw_socket()
336    }
337}
338
339// Large read buffering helps reducing syscalls with little trade-off
340// Ssl layer always does "small" reads in 16k (TLS record size) so L4 read buffer helps a lot.
341const BUF_READ_SIZE: usize = 64 * 1024;
342// Small write buf to match MSS. Too large write buf delays real time communication.
343// This buffering effectively implements something similar to Nagle's algorithm.
344// The benefit is that user space can control when to flush, where Nagle's can't be controlled.
345// And userspace buffering reduce both syscalls and small packets.
346const BUF_WRITE_SIZE: usize = 1460;
347
348// NOTE: with writer buffering, users need to call flush() to make sure the data is actually
349// sent. Otherwise data could be stuck in the buffer forever or get lost when stream is closed.
350
351/// A concrete type for transport layer connection + extra fields for logging
352#[derive(Debug)]
353pub struct Stream {
354    // Use `Option` to be able to swap to adjust the buffer size. Always safe to unwrap
355    stream: Option<BufStream<RawStreamWrapper>>,
356    // the data put back at the front of the read buffer, in order to replay the read
357    rewind_read_buf: Vec<Vec<u8>>,
358    buffer_write: bool,
359    proxy_digest: Option<Arc<ProxyDigest>>,
360    socket_digest: Option<Arc<SocketDigest>>,
361    /// When this connection is established
362    pub established_ts: SystemTime,
363    /// The distributed tracing object for this stream
364    pub tracer: Option<Tracer>,
365    read_pending_time: AccumulatedDuration,
366    write_pending_time: AccumulatedDuration,
367    /// Last rx timestamp associated with the last recvmsg call.
368    pub rx_ts: Option<SystemTime>,
369}
370
371impl Stream {
372    fn stream(&self) -> &BufStream<RawStreamWrapper> {
373        self.stream.as_ref().expect("stream should always be set")
374    }
375
376    fn stream_mut(&mut self) -> &mut BufStream<RawStreamWrapper> {
377        self.stream.as_mut().expect("stream should always be set")
378    }
379
380    /// set TCP nodelay for this connection if `self` is TCP
381    pub fn set_nodelay(&mut self) -> Result<()> {
382        if let RawStream::Tcp(s) = &self.stream_mut().get_mut().stream {
383            s.set_nodelay(true)
384                .or_err(ConnectError, "failed to set_nodelay")?;
385        }
386        Ok(())
387    }
388
389    /// set TCP keepalive settings for this connection if `self` is TCP
390    pub fn set_keepalive(&mut self, ka: &TcpKeepalive) -> Result<()> {
391        if let RawStream::Tcp(s) = &self.stream_mut().get_mut().stream {
392            debug!("Setting tcp keepalive");
393            set_tcp_keepalive(s, ka)?;
394        }
395        Ok(())
396    }
397
398    #[cfg(target_os = "linux")]
399    pub fn set_rx_timestamp(&mut self) -> Result<()> {
400        use nix::sys::socket::{setsockopt, sockopt, TimestampingFlag};
401
402        if let RawStream::Tcp(s) = &self.stream_mut().get_mut().stream {
403            let timestamp_options = TimestampingFlag::SOF_TIMESTAMPING_RX_SOFTWARE
404                | TimestampingFlag::SOF_TIMESTAMPING_SOFTWARE;
405            setsockopt(s.as_raw_fd(), sockopt::Timestamping, &timestamp_options)
406                .or_err(InternalError, "failed to set SOF_TIMESTAMPING_RX_SOFTWARE")?;
407            self.stream_mut().get_mut().enable_rx_ts(true);
408        }
409
410        Ok(())
411    }
412
413    #[cfg(not(target_os = "linux"))]
414    pub fn set_rx_timestamp(&mut self) -> io::Result<()> {
415        Ok(())
416    }
417
418    /// Put Some data back to the head of the stream to be read again
419    pub(crate) fn rewind(&mut self, data: &[u8]) {
420        if !data.is_empty() {
421            self.rewind_read_buf.push(data.to_vec());
422        }
423    }
424
425    /// Set the buffer of BufStream
426    /// It is only set later because of the malloc overhead in critical accept() path
427    pub(crate) fn set_buffer(&mut self) {
428        use std::mem;
429        // Since BufStream doesn't provide an API to adjust the buf directly,
430        // we take the raw stream out of it and put it in a new BufStream with the size we want
431        let stream = mem::take(&mut self.stream);
432        let stream =
433            stream.map(|s| BufStream::with_capacity(BUF_READ_SIZE, BUF_WRITE_SIZE, s.into_inner()));
434        let _ = mem::replace(&mut self.stream, stream);
435    }
436}
437
438impl From<TcpStream> for Stream {
439    fn from(s: TcpStream) -> Self {
440        Stream {
441            stream: Some(BufStream::with_capacity(
442                0,
443                0,
444                RawStreamWrapper::new(RawStream::Tcp(s)),
445            )),
446            rewind_read_buf: Vec::new(),
447            buffer_write: true,
448            established_ts: SystemTime::now(),
449            proxy_digest: None,
450            socket_digest: None,
451            tracer: None,
452            read_pending_time: AccumulatedDuration::new(),
453            write_pending_time: AccumulatedDuration::new(),
454            rx_ts: None,
455        }
456    }
457}
458
459#[cfg(unix)]
460impl From<UnixStream> for Stream {
461    fn from(s: UnixStream) -> Self {
462        Stream {
463            stream: Some(BufStream::with_capacity(
464                0,
465                0,
466                RawStreamWrapper::new(RawStream::Unix(s)),
467            )),
468            rewind_read_buf: Vec::new(),
469            buffer_write: true,
470            established_ts: SystemTime::now(),
471            proxy_digest: None,
472            socket_digest: None,
473            tracer: None,
474            read_pending_time: AccumulatedDuration::new(),
475            write_pending_time: AccumulatedDuration::new(),
476            rx_ts: None,
477        }
478    }
479}
480
481#[cfg(unix)]
482impl AsRawFd for Stream {
483    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
484        self.stream().get_ref().as_raw_fd()
485    }
486}
487
488#[cfg(windows)]
489impl AsRawSocket for Stream {
490    fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
491        self.stream().get_ref().as_raw_socket()
492    }
493}
494
495#[cfg(unix)]
496impl UniqueID for Stream {
497    fn id(&self) -> UniqueIDType {
498        self.as_raw_fd()
499    }
500}
501
502#[cfg(windows)]
503impl UniqueID for Stream {
504    fn id(&self) -> usize {
505        self.as_raw_socket() as usize
506    }
507}
508
509impl Ssl for Stream {}
510
511#[async_trait]
512impl Peek for Stream {
513    async fn try_peek(&mut self, buf: &mut [u8]) -> std::io::Result<bool> {
514        use tokio::io::AsyncReadExt;
515        self.read_exact(buf).await?;
516        // rewind regardless of what is read
517        self.rewind(buf);
518        Ok(true)
519    }
520}
521
522#[async_trait]
523impl Shutdown for Stream {
524    async fn shutdown(&mut self) {
525        AsyncWriteExt::shutdown(self).await.unwrap_or_else(|e| {
526            debug!("Failed to shutdown connection: {:?}", e);
527        });
528    }
529}
530
531impl GetTimingDigest for Stream {
532    fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
533        let mut digest = Vec::with_capacity(2); // expect to have both L4 stream and TLS layer
534        digest.push(Some(TimingDigest {
535            established_ts: self.established_ts,
536        }));
537        digest
538    }
539
540    fn get_read_pending_time(&self) -> Duration {
541        self.read_pending_time.total
542    }
543
544    fn get_write_pending_time(&self) -> Duration {
545        self.write_pending_time.total
546    }
547}
548
549impl GetProxyDigest for Stream {
550    fn get_proxy_digest(&self) -> Option<Arc<ProxyDigest>> {
551        self.proxy_digest.clone()
552    }
553
554    fn set_proxy_digest(&mut self, digest: ProxyDigest) {
555        self.proxy_digest = Some(Arc::new(digest));
556    }
557}
558
559impl GetSocketDigest for Stream {
560    fn get_socket_digest(&self) -> Option<Arc<SocketDigest>> {
561        self.socket_digest.clone()
562    }
563
564    fn set_socket_digest(&mut self, socket_digest: SocketDigest) {
565        self.socket_digest = Some(Arc::new(socket_digest))
566    }
567}
568
569impl Drop for Stream {
570    fn drop(&mut self) {
571        if let Some(t) = self.tracer.as_ref() {
572            t.0.on_disconnected();
573        }
574        /* use nodelay/local_addr function to detect socket status */
575        let ret = match &self.stream().get_ref().stream {
576            RawStream::Tcp(s) => s.nodelay().err(),
577            #[cfg(unix)]
578            RawStream::Unix(s) => s.local_addr().err(),
579        };
580        if let Some(e) = ret {
581            match e.kind() {
582                tokio::io::ErrorKind::Other => {
583                    if let Some(ecode) = e.raw_os_error() {
584                        if ecode == 9 {
585                            // Or we could panic here
586                            error!("Crit: socket {:?} is being double closed", self.stream);
587                        }
588                    }
589                }
590                _ => {
591                    debug!("Socket is already broken {:?}", e);
592                }
593            }
594        } else {
595            // try flush the write buffer. We use now_or_never() because
596            // 1. Drop cannot be async
597            // 2. write should usually be ready, unless the buf is full.
598            let _ = self.flush().now_or_never();
599        }
600        debug!("Dropping socket {:?}", self.stream);
601    }
602}
603
604impl AsyncRead for Stream {
605    fn poll_read(
606        mut self: Pin<&mut Self>,
607        cx: &mut Context<'_>,
608        buf: &mut ReadBuf<'_>,
609    ) -> Poll<io::Result<()>> {
610        let result = if !self.rewind_read_buf.is_empty() {
611            let data_to_read = self.rewind_read_buf.pop().unwrap(); // safe
612            let mut data_to_read = data_to_read.as_slice();
613            let result = Pin::new(&mut data_to_read).poll_read(cx, buf);
614            // return the remaining data back to the head of rewind_read_buf
615            if !data_to_read.is_empty() {
616                let remaining_buf = Vec::from(data_to_read);
617                self.rewind_read_buf.push(remaining_buf);
618            }
619            result
620        } else {
621            Pin::new(&mut self.stream_mut()).poll_read(cx, buf)
622        };
623        self.read_pending_time.poll_time(&result);
624        self.rx_ts = self.stream().get_ref().rx_ts;
625        result
626    }
627}
628
629impl AsyncWrite for Stream {
630    fn poll_write(
631        mut self: Pin<&mut Self>,
632        cx: &mut Context,
633        buf: &[u8],
634    ) -> Poll<io::Result<usize>> {
635        let result = if self.buffer_write {
636            Pin::new(&mut self.stream_mut()).poll_write(cx, buf)
637        } else {
638            Pin::new(&mut self.stream_mut().get_mut()).poll_write(cx, buf)
639        };
640        self.write_pending_time.poll_write_time(&result, buf.len());
641        result
642    }
643
644    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
645        let result = Pin::new(&mut self.stream_mut()).poll_flush(cx);
646        self.write_pending_time.poll_time(&result);
647        result
648    }
649
650    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
651        Pin::new(&mut self.stream_mut()).poll_shutdown(cx)
652    }
653
654    fn poll_write_vectored(
655        mut self: Pin<&mut Self>,
656        cx: &mut Context<'_>,
657        bufs: &[std::io::IoSlice<'_>],
658    ) -> Poll<io::Result<usize>> {
659        let total_size = bufs.iter().fold(0, |acc, s| acc + s.len());
660
661        let result = if self.buffer_write {
662            Pin::new(&mut self.stream_mut()).poll_write_vectored(cx, bufs)
663        } else {
664            Pin::new(&mut self.stream_mut().get_mut()).poll_write_vectored(cx, bufs)
665        };
666
667        self.write_pending_time.poll_write_time(&result, total_size);
668        result
669    }
670
671    fn is_write_vectored(&self) -> bool {
672        if self.buffer_write {
673            self.stream().is_write_vectored() // it is true
674        } else {
675            self.stream().get_ref().is_write_vectored()
676        }
677    }
678}
679
680pub mod async_write_vec {
681    use bytes::Buf;
682    use futures::ready;
683    use std::future::Future;
684    use std::io::IoSlice;
685    use std::pin::Pin;
686    use std::task::{Context, Poll};
687    use tokio::io;
688    use tokio::io::AsyncWrite;
689
690    /*
691        the missing write_buf https://github.com/tokio-rs/tokio/pull/3156#issuecomment-738207409
692        https://github.com/tokio-rs/tokio/issues/2610
693        In general vectored write is lost when accessing the trait object: Box<S: AsyncWrite>
694    */
695
696    #[must_use = "futures do nothing unless you `.await` or poll them"]
697    pub struct WriteVec<'a, W, B> {
698        writer: &'a mut W,
699        buf: &'a mut B,
700    }
701
702    #[must_use = "futures do nothing unless you `.await` or poll them"]
703    pub struct WriteVecAll<'a, W, B> {
704        writer: &'a mut W,
705        buf: &'a mut B,
706    }
707
708    pub trait AsyncWriteVec {
709        fn poll_write_vec<B: Buf>(
710            self: Pin<&mut Self>,
711            _cx: &mut Context<'_>,
712            _buf: &mut B,
713        ) -> Poll<io::Result<usize>>;
714
715        fn write_vec<'a, B>(&'a mut self, src: &'a mut B) -> WriteVec<'a, Self, B>
716        where
717            Self: Sized,
718            B: Buf,
719        {
720            WriteVec {
721                writer: self,
722                buf: src,
723            }
724        }
725
726        fn write_vec_all<'a, B>(&'a mut self, src: &'a mut B) -> WriteVecAll<'a, Self, B>
727        where
728            Self: Sized,
729            B: Buf,
730        {
731            WriteVecAll {
732                writer: self,
733                buf: src,
734            }
735        }
736    }
737
738    impl<W, B> Future for WriteVec<'_, W, B>
739    where
740        W: AsyncWriteVec + Unpin,
741        B: Buf,
742    {
743        type Output = io::Result<usize>;
744
745        fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<usize>> {
746            let me = &mut *self;
747            Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf)
748        }
749    }
750
751    impl<W, B> Future for WriteVecAll<'_, W, B>
752    where
753        W: AsyncWriteVec + Unpin,
754        B: Buf,
755    {
756        type Output = io::Result<()>;
757
758        fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
759            let me = &mut *self;
760            while me.buf.has_remaining() {
761                let n = ready!(Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf))?;
762                if n == 0 {
763                    return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
764                }
765            }
766            Poll::Ready(Ok(()))
767        }
768    }
769
770    /* from https://github.com/tokio-rs/tokio/blob/master/tokio-util/src/lib.rs#L177 */
771    impl<T> AsyncWriteVec for T
772    where
773        T: AsyncWrite,
774    {
775        fn poll_write_vec<B: Buf>(
776            self: Pin<&mut Self>,
777            ctx: &mut Context,
778            buf: &mut B,
779        ) -> Poll<io::Result<usize>> {
780            const MAX_BUFS: usize = 64;
781
782            if !buf.has_remaining() {
783                return Poll::Ready(Ok(0));
784            }
785
786            let n = if self.is_write_vectored() {
787                let mut slices = [IoSlice::new(&[]); MAX_BUFS];
788                let cnt = buf.chunks_vectored(&mut slices);
789                ready!(self.poll_write_vectored(ctx, &slices[..cnt]))?
790            } else {
791                ready!(self.poll_write(ctx, buf.chunk()))?
792            };
793
794            buf.advance(n);
795
796            Poll::Ready(Ok(n))
797        }
798    }
799}
800
801pub use async_write_vec::AsyncWriteVec;
802
803#[derive(Debug)]
804struct AccumulatedDuration {
805    total: Duration,
806    last_start: Option<Instant>,
807}
808
809impl AccumulatedDuration {
810    fn new() -> Self {
811        AccumulatedDuration {
812            total: Duration::ZERO,
813            last_start: None,
814        }
815    }
816
817    fn start(&mut self) {
818        if self.last_start.is_none() {
819            self.last_start = Some(Instant::now());
820        }
821    }
822
823    fn stop(&mut self) {
824        if let Some(start) = self.last_start.take() {
825            self.total += start.elapsed();
826        }
827    }
828
829    fn poll_write_time(&mut self, result: &Poll<io::Result<usize>>, buf_size: usize) {
830        match result {
831            Poll::Ready(Ok(n)) => {
832                if *n == buf_size {
833                    self.stop();
834                } else {
835                    // partial write
836                    self.start();
837                }
838            }
839            Poll::Ready(Err(_)) => {
840                self.stop();
841            }
842            _ => self.start(),
843        }
844    }
845
846    fn poll_time(&mut self, result: &Poll<io::Result<()>>) {
847        match result {
848            Poll::Ready(_) => {
849                self.stop();
850            }
851            _ => self.start(),
852        }
853    }
854}
855
856#[cfg(test)]
857#[cfg(target_os = "linux")]
858mod tests {
859    use super::*;
860    use std::sync::Arc;
861    use tokio::io::AsyncReadExt;
862    use tokio::io::AsyncWriteExt;
863    use tokio::net::TcpListener;
864    use tokio::sync::Notify;
865
866    #[cfg(target_os = "linux")]
867    #[tokio::test]
868    async fn test_rx_timestamp() {
869        let message = "hello world".as_bytes();
870        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
871        let addr = listener.local_addr().unwrap();
872        let notify = Arc::new(Notify::new());
873        let notify2 = notify.clone();
874
875        tokio::spawn(async move {
876            let (mut stream, _) = listener.accept().await.unwrap();
877            notify2.notified().await;
878            stream.write_all(message).await.unwrap();
879        });
880
881        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
882        stream.set_rx_timestamp().unwrap();
883        // Receive the message
884        // setsockopt for SO_TIMESTAMPING is asynchronous so sleep a little bit
885        // to let kernel do the work
886        std::thread::sleep(Duration::from_micros(100));
887        notify.notify_one();
888
889        let mut buffer = vec![0u8; message.len()];
890        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
891        assert_eq!(n, message.len());
892        assert!(stream.rx_ts.is_some());
893    }
894
895    #[cfg(target_os = "linux")]
896    #[tokio::test]
897    async fn test_rx_timestamp_standard_path() {
898        let message = "hello world".as_bytes();
899        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
900        let addr = listener.local_addr().unwrap();
901        let notify = Arc::new(Notify::new());
902        let notify2 = notify.clone();
903
904        tokio::spawn(async move {
905            let (mut stream, _) = listener.accept().await.unwrap();
906            notify2.notified().await;
907            stream.write_all(message).await.unwrap();
908        });
909
910        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
911        std::thread::sleep(Duration::from_micros(100));
912        notify.notify_one();
913
914        let mut buffer = vec![0u8; message.len()];
915        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
916        assert_eq!(n, message.len());
917        assert!(stream.rx_ts.is_none());
918    }
919
920    #[tokio::test]
921    async fn test_stream_rewind() {
922        let message = b"hello world";
923        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
924        let addr = listener.local_addr().unwrap();
925        let notify = Arc::new(Notify::new());
926        let notify2 = notify.clone();
927
928        tokio::spawn(async move {
929            let (mut stream, _) = listener.accept().await.unwrap();
930            notify2.notified().await;
931            stream.write_all(message).await.unwrap();
932        });
933
934        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
935
936        let rewind_test = b"this is Sparta!";
937        stream.rewind(rewind_test);
938
939        // partially read rewind_test because of the buffer size limit
940        let mut buffer = vec![0u8; message.len()];
941        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
942        assert_eq!(n, message.len());
943        assert_eq!(buffer, rewind_test[..message.len()]);
944
945        // read the rest of rewind_test
946        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
947        assert_eq!(n, rewind_test.len() - message.len());
948        assert_eq!(buffer[..n], rewind_test[message.len()..]);
949
950        // read the actual data
951        notify.notify_one();
952        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
953        assert_eq!(n, message.len());
954        assert_eq!(buffer, message);
955    }
956
957    #[tokio::test]
958    async fn test_stream_peek() {
959        let message = b"hello world";
960        dbg!("try peek");
961        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
962        let addr = listener.local_addr().unwrap();
963        let notify = Arc::new(Notify::new());
964        let notify2 = notify.clone();
965
966        tokio::spawn(async move {
967            let (mut stream, _) = listener.accept().await.unwrap();
968            notify2.notified().await;
969            stream.write_all(message).await.unwrap();
970            drop(stream);
971        });
972
973        notify.notify_one();
974
975        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
976        let mut buffer = vec![0u8; 5];
977        assert!(stream.try_peek(&mut buffer).await.unwrap());
978        assert_eq!(buffer, message[0..5]);
979        let mut buffer = vec![];
980        stream.read_to_end(&mut buffer).await.unwrap();
981        assert_eq!(buffer, message);
982    }
983
984    #[tokio::test]
985    async fn test_stream_two_subsequent_peek_calls_before_read() {
986        let message = b"abcdefghijklmn";
987
988        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
989        let addr = listener.local_addr().unwrap();
990        let notify = Arc::new(Notify::new());
991        let notify2 = notify.clone();
992
993        tokio::spawn(async move {
994            let (mut stream, _) = listener.accept().await.unwrap();
995            notify2.notified().await;
996            stream.write_all(message).await.unwrap();
997            drop(stream);
998        });
999
1000        notify.notify_one();
1001
1002        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
1003
1004        // Peek 4 bytes
1005        let mut buffer = vec![0u8; 4];
1006        assert!(stream.try_peek(&mut buffer).await.unwrap());
1007        assert_eq!(buffer, message[0..4]);
1008
1009        // Peek 2 bytes
1010        let mut buffer = vec![0u8; 2];
1011        assert!(stream.try_peek(&mut buffer).await.unwrap());
1012        assert_eq!(buffer, message[0..2]);
1013
1014        // Read 1 byte: ['a']
1015        let mut buffer = vec![0u8; 1];
1016        stream.read_exact(&mut buffer).await.unwrap();
1017        assert_eq!(buffer, message[0..1]);
1018
1019        // Read as many bytes as possible, return 1 byte ['b']
1020        //  from the first retry buffer chunk
1021        let mut buffer = vec![0u8; 100];
1022        let n = stream.read(&mut buffer).await.unwrap();
1023        assert_eq!(n, 1);
1024        assert_eq!(buffer[..n], message[1..2]);
1025
1026        // Read the rest ['cdefghijklmn']
1027        let mut buffer = vec![];
1028        stream.read_to_end(&mut buffer).await.unwrap();
1029        assert_eq!(buffer, message[2..]);
1030    }
1031}