Skip to main content

pingora_core/protocols/l4/
stream.rs

1// Copyright 2026 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Transport layer connection
16
17use async_trait::async_trait;
18use futures::FutureExt;
19use log::{debug, error};
20
21use pingora_error::{ErrorType::*, OrErr, Result};
22#[cfg(target_os = "linux")]
23use std::io::IoSliceMut;
24#[cfg(unix)]
25use std::os::unix::io::AsRawFd;
26#[cfg(windows)]
27use std::os::windows::io::AsRawSocket;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll};
31use std::time::{Duration, Instant, SystemTime};
32#[cfg(target_os = "linux")]
33use tokio::io::Interest;
34use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
35use tokio::net::TcpStream;
36#[cfg(unix)]
37use tokio::net::UnixStream;
38
39use crate::protocols::l4::ext::{set_tcp_keepalive, TcpKeepalive};
40use crate::protocols::l4::virt;
41use crate::protocols::raw_connect::ProxyDigest;
42use crate::protocols::{
43    GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, SocketDigest, Ssl,
44    TimingDigest, UniqueID, UniqueIDType,
45};
46use crate::upstreams::peer::Tracer;
47
48#[derive(Debug)]
49enum RawStream {
50    Tcp(TcpStream),
51    #[cfg(unix)]
52    Unix(UnixStream),
53    Virtual(virt::VirtualSocketStream),
54}
55
56impl AsyncRead for RawStream {
57    fn poll_read(
58        self: Pin<&mut Self>,
59        cx: &mut Context<'_>,
60        buf: &mut ReadBuf<'_>,
61    ) -> Poll<io::Result<()>> {
62        // Safety: Basic enum pin projection
63        unsafe {
64            match &mut Pin::get_unchecked_mut(self) {
65                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
66                #[cfg(unix)]
67                RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
68                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_read(cx, buf),
69            }
70        }
71    }
72}
73
74impl AsyncWrite for RawStream {
75    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
76        // Safety: Basic enum pin projection
77        unsafe {
78            match &mut Pin::get_unchecked_mut(self) {
79                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
80                #[cfg(unix)]
81                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
82                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write(cx, buf),
83            }
84        }
85    }
86
87    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
88        // Safety: Basic enum pin projection
89        unsafe {
90            match &mut Pin::get_unchecked_mut(self) {
91                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
92                #[cfg(unix)]
93                RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
94                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_flush(cx),
95            }
96        }
97    }
98
99    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
100        // Safety: Basic enum pin projection
101        unsafe {
102            match &mut Pin::get_unchecked_mut(self) {
103                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
104                #[cfg(unix)]
105                RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
106                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_shutdown(cx),
107            }
108        }
109    }
110
111    fn poll_write_vectored(
112        self: Pin<&mut Self>,
113        cx: &mut Context<'_>,
114        bufs: &[std::io::IoSlice<'_>],
115    ) -> Poll<io::Result<usize>> {
116        // Safety: Basic enum pin projection
117        unsafe {
118            match &mut Pin::get_unchecked_mut(self) {
119                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
120                #[cfg(unix)]
121                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
122                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
123            }
124        }
125    }
126
127    fn is_write_vectored(&self) -> bool {
128        match self {
129            RawStream::Tcp(s) => s.is_write_vectored(),
130            #[cfg(unix)]
131            RawStream::Unix(s) => s.is_write_vectored(),
132            RawStream::Virtual(s) => s.is_write_vectored(),
133        }
134    }
135}
136
137#[cfg(unix)]
138impl AsRawFd for RawStream {
139    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
140        match self {
141            RawStream::Tcp(s) => s.as_raw_fd(),
142            RawStream::Unix(s) => s.as_raw_fd(),
143            RawStream::Virtual(_) => -1, // Virtual stream does not have a real fd
144        }
145    }
146}
147
148#[cfg(windows)]
149impl AsRawSocket for RawStream {
150    fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
151        match self {
152            RawStream::Tcp(s) => s.as_raw_socket(),
153            // Virtual stream does not have a real socket, return INVALID_SOCKET (!0)
154            RawStream::Virtual(_) => !0,
155        }
156    }
157}
158
159#[derive(Debug)]
160struct RawStreamWrapper {
161    pub(crate) stream: RawStream,
162    /// store the last rx timestamp of the stream.
163    pub(crate) rx_ts: Option<SystemTime>,
164    /// enable reading rx timestamp
165    #[cfg(target_os = "linux")]
166    pub(crate) enable_rx_ts: bool,
167    #[cfg(target_os = "linux")]
168    /// This can be reused across multiple recvmsg calls. The cmsg buffer may
169    /// come from old sockets created by older version of pingora and so,
170    /// this vector can only grow.
171    reusable_cmsg_space: Vec<u8>,
172}
173
174impl RawStreamWrapper {
175    pub fn new(stream: RawStream) -> Self {
176        RawStreamWrapper {
177            stream,
178            rx_ts: None,
179            #[cfg(target_os = "linux")]
180            enable_rx_ts: false,
181            #[cfg(target_os = "linux")]
182            reusable_cmsg_space: nix::cmsg_space!(nix::sys::time::TimeSpec),
183        }
184    }
185
186    #[cfg(target_os = "linux")]
187    pub fn enable_rx_ts(&mut self, enable_rx_ts: bool) {
188        self.enable_rx_ts = enable_rx_ts;
189    }
190}
191
192impl AsyncRead for RawStreamWrapper {
193    #[cfg(not(target_os = "linux"))]
194    fn poll_read(
195        self: Pin<&mut Self>,
196        cx: &mut Context<'_>,
197        buf: &mut ReadBuf<'_>,
198    ) -> Poll<io::Result<()>> {
199        // Safety: Basic enum pin projection
200        unsafe {
201            let rs_wrapper = Pin::get_unchecked_mut(self);
202            match &mut rs_wrapper.stream {
203                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
204                #[cfg(unix)]
205                RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
206                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_read(cx, buf),
207            }
208        }
209    }
210
211    #[cfg(target_os = "linux")]
212    fn poll_read(
213        self: Pin<&mut Self>,
214        cx: &mut Context<'_>,
215        buf: &mut ReadBuf<'_>,
216    ) -> Poll<io::Result<()>> {
217        use futures::ready;
218        use nix::sys::socket::{recvmsg, ControlMessageOwned, MsgFlags, SockaddrStorage};
219
220        // if we do not need rx timestamp, then use the standard path
221        if !self.enable_rx_ts {
222            // Safety: Basic enum pin projection
223            unsafe {
224                let rs_wrapper = Pin::get_unchecked_mut(self);
225                match &mut rs_wrapper.stream {
226                    RawStream::Tcp(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
227                    RawStream::Unix(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
228                    RawStream::Virtual(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
229                }
230            }
231        }
232
233        // Safety: Basic pin projection to get mutable stream
234        let rs_wrapper = unsafe { Pin::get_unchecked_mut(self) };
235        match &mut rs_wrapper.stream {
236            RawStream::Tcp(s) => {
237                loop {
238                    ready!(s.poll_read_ready(cx))?;
239                    // Safety: maybe uninitialized bytes will only be passed to recvmsg
240                    let b = unsafe {
241                        &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>]
242                            as *mut [u8])
243                    };
244                    let mut iov = [IoSliceMut::new(b)];
245                    rs_wrapper.reusable_cmsg_space.clear();
246
247                    match s.try_io(Interest::READABLE, || {
248                        recvmsg::<SockaddrStorage>(
249                            s.as_raw_fd(),
250                            &mut iov,
251                            Some(&mut rs_wrapper.reusable_cmsg_space),
252                            MsgFlags::empty(),
253                        )
254                        .map_err(|errno| errno.into())
255                    }) {
256                        Ok(r) => {
257                            if let Some(ControlMessageOwned::ScmTimestampsns(rtime)) = r
258                                .cmsgs()
259                                .find(|i| matches!(i, ControlMessageOwned::ScmTimestampsns(_)))
260                            {
261                                // The returned timestamp is a real (i.e. not monotonic) timestamp
262                                // https://docs.kernel.org/networking/timestamping.html
263                                rs_wrapper.rx_ts =
264                                    SystemTime::UNIX_EPOCH.checked_add(rtime.system.into());
265                            }
266                            // Safety: We trust `recvmsg` to have filled up `r.bytes` bytes in the buffer.
267                            unsafe {
268                                buf.assume_init(r.bytes);
269                            }
270                            buf.advance(r.bytes);
271                            return Poll::Ready(Ok(()));
272                        }
273                        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
274                        Err(e) => return Poll::Ready(Err(e)),
275                    }
276                }
277            }
278            // Unix RX timestamp only works with datagram for now, so we do not care about it
279            RawStream::Unix(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) },
280            RawStream::Virtual(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) },
281        }
282    }
283}
284
285impl AsyncWrite for RawStreamWrapper {
286    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
287        // Safety: Basic enum pin projection
288        unsafe {
289            match &mut Pin::get_unchecked_mut(self).stream {
290                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
291                #[cfg(unix)]
292                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
293                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write(cx, buf),
294            }
295        }
296    }
297
298    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
299        // Safety: Basic enum pin projection
300        unsafe {
301            match &mut Pin::get_unchecked_mut(self).stream {
302                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
303                #[cfg(unix)]
304                RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
305                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_flush(cx),
306            }
307        }
308    }
309
310    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
311        // Safety: Basic enum pin projection
312        unsafe {
313            match &mut Pin::get_unchecked_mut(self).stream {
314                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
315                #[cfg(unix)]
316                RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
317                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_shutdown(cx),
318            }
319        }
320    }
321
322    fn poll_write_vectored(
323        self: Pin<&mut Self>,
324        cx: &mut Context<'_>,
325        bufs: &[std::io::IoSlice<'_>],
326    ) -> Poll<io::Result<usize>> {
327        // Safety: Basic enum pin projection
328        unsafe {
329            match &mut Pin::get_unchecked_mut(self).stream {
330                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
331                #[cfg(unix)]
332                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
333                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
334            }
335        }
336    }
337
338    fn is_write_vectored(&self) -> bool {
339        self.stream.is_write_vectored()
340    }
341}
342
343#[cfg(unix)]
344impl AsRawFd for RawStreamWrapper {
345    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
346        self.stream.as_raw_fd()
347    }
348}
349
350#[cfg(windows)]
351impl AsRawSocket for RawStreamWrapper {
352    fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
353        self.stream.as_raw_socket()
354    }
355}
356
357// Large read buffering helps reducing syscalls with little trade-off
358// Ssl layer always does "small" reads in 16k (TLS record size) so L4 read buffer helps a lot.
359const BUF_READ_SIZE: usize = 64 * 1024;
360// Small write buf to match MSS. Too large write buf delays real time communication.
361// This buffering effectively implements something similar to Nagle's algorithm.
362// The benefit is that user space can control when to flush, where Nagle's can't be controlled.
363// And userspace buffering reduce both syscalls and small packets.
364const BUF_WRITE_SIZE: usize = 1460;
365
366// NOTE: with writer buffering, users need to call flush() to make sure the data is actually
367// sent. Otherwise data could be stuck in the buffer forever or get lost when stream is closed.
368
369/// A concrete type for transport layer connection + extra fields for logging
370#[derive(Debug)]
371pub struct Stream {
372    // Use `Option` to be able to swap to adjust the buffer size. Always safe to unwrap
373    stream: Option<BufStream<RawStreamWrapper>>,
374    // the data put back at the front of the read buffer, in order to replay the read
375    rewind_read_buf: Vec<Vec<u8>>,
376    buffer_write: bool,
377    proxy_digest: Option<Arc<ProxyDigest>>,
378    socket_digest: Option<Arc<SocketDigest>>,
379    /// When this connection is established
380    pub established_ts: SystemTime,
381    /// The distributed tracing object for this stream
382    pub tracer: Option<Tracer>,
383    read_pending_time: AccumulatedDuration,
384    write_pending_time: AccumulatedDuration,
385    /// Last rx timestamp associated with the last recvmsg call.
386    pub rx_ts: Option<SystemTime>,
387}
388
389impl Stream {
390    fn stream(&self) -> &BufStream<RawStreamWrapper> {
391        self.stream.as_ref().expect("stream should always be set")
392    }
393
394    fn stream_mut(&mut self) -> &mut BufStream<RawStreamWrapper> {
395        self.stream.as_mut().expect("stream should always be set")
396    }
397
398    /// set TCP nodelay for this connection if `self` is TCP
399    pub fn set_nodelay(&mut self) -> Result<()> {
400        match &self.stream_mut().get_mut().stream {
401            RawStream::Tcp(s) => {
402                s.set_nodelay(true)
403                    .or_err(ConnectError, "failed to set_nodelay")?;
404            }
405            RawStream::Virtual(s) => {
406                s.set_socket_option(virt::VirtualSockOpt::NoDelay)
407                    .or_err(ConnectError, "failed to set_nodelay on virtual socket")?;
408            }
409            _ => (),
410        }
411        Ok(())
412    }
413
414    /// set TCP keepalive settings for this connection if `self` is TCP
415    pub fn set_keepalive(&mut self, ka: &TcpKeepalive) -> Result<()> {
416        match &self.stream_mut().get_mut().stream {
417            RawStream::Tcp(s) => {
418                debug!("Setting tcp keepalive");
419                set_tcp_keepalive(s, ka)?;
420            }
421            RawStream::Virtual(s) => {
422                s.set_socket_option(virt::VirtualSockOpt::KeepAlive(ka.clone()))
423                    .or_err(ConnectError, "failed to set_keepalive on virtual socket")?;
424            }
425            _ => (),
426        }
427        Ok(())
428    }
429
430    #[cfg(target_os = "linux")]
431    pub fn set_rx_timestamp(&mut self) -> Result<()> {
432        use nix::sys::socket::{setsockopt, sockopt, TimestampingFlag};
433
434        if let RawStream::Tcp(s) = &self.stream_mut().get_mut().stream {
435            let timestamp_options = TimestampingFlag::SOF_TIMESTAMPING_RX_SOFTWARE
436                | TimestampingFlag::SOF_TIMESTAMPING_SOFTWARE;
437            setsockopt(s.as_raw_fd(), sockopt::Timestamping, &timestamp_options)
438                .or_err(InternalError, "failed to set SOF_TIMESTAMPING_RX_SOFTWARE")?;
439            self.stream_mut().get_mut().enable_rx_ts(true);
440        }
441
442        Ok(())
443    }
444
445    #[cfg(not(target_os = "linux"))]
446    pub fn set_rx_timestamp(&mut self) -> io::Result<()> {
447        Ok(())
448    }
449
450    /// Put Some data back to the head of the stream to be read again
451    pub(crate) fn rewind(&mut self, data: &[u8]) {
452        if !data.is_empty() {
453            self.rewind_read_buf.push(data.to_vec());
454        }
455    }
456
457    /// Set the buffer of BufStream
458    /// It is only set later because of the malloc overhead in critical accept() path
459    pub(crate) fn set_buffer(&mut self) {
460        use std::mem;
461        // Since BufStream doesn't provide an API to adjust the buf directly,
462        // we take the raw stream out of it and put it in a new BufStream with the size we want
463        let stream = mem::take(&mut self.stream);
464        let stream =
465            stream.map(|s| BufStream::with_capacity(BUF_READ_SIZE, BUF_WRITE_SIZE, s.into_inner()));
466        let _ = mem::replace(&mut self.stream, stream);
467    }
468}
469
470impl From<TcpStream> for Stream {
471    fn from(s: TcpStream) -> Self {
472        Stream {
473            stream: Some(BufStream::with_capacity(
474                0,
475                0,
476                RawStreamWrapper::new(RawStream::Tcp(s)),
477            )),
478            rewind_read_buf: Vec::new(),
479            buffer_write: true,
480            established_ts: SystemTime::now(),
481            proxy_digest: None,
482            socket_digest: None,
483            tracer: None,
484            read_pending_time: AccumulatedDuration::new(),
485            write_pending_time: AccumulatedDuration::new(),
486            rx_ts: None,
487        }
488    }
489}
490
491impl From<virt::VirtualSocketStream> for Stream {
492    fn from(s: virt::VirtualSocketStream) -> Self {
493        Stream {
494            stream: Some(BufStream::with_capacity(
495                0,
496                0,
497                RawStreamWrapper::new(RawStream::Virtual(s)),
498            )),
499            rewind_read_buf: Vec::new(),
500            buffer_write: true,
501            established_ts: SystemTime::now(),
502            proxy_digest: None,
503            socket_digest: None,
504            tracer: None,
505            read_pending_time: AccumulatedDuration::new(),
506            write_pending_time: AccumulatedDuration::new(),
507            rx_ts: None,
508        }
509    }
510}
511
512#[cfg(unix)]
513impl From<UnixStream> for Stream {
514    fn from(s: UnixStream) -> Self {
515        Stream {
516            stream: Some(BufStream::with_capacity(
517                0,
518                0,
519                RawStreamWrapper::new(RawStream::Unix(s)),
520            )),
521            rewind_read_buf: Vec::new(),
522            buffer_write: true,
523            established_ts: SystemTime::now(),
524            proxy_digest: None,
525            socket_digest: None,
526            tracer: None,
527            read_pending_time: AccumulatedDuration::new(),
528            write_pending_time: AccumulatedDuration::new(),
529            rx_ts: None,
530        }
531    }
532}
533
534#[cfg(unix)]
535impl AsRawFd for Stream {
536    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
537        self.stream().get_ref().as_raw_fd()
538    }
539}
540
541#[cfg(windows)]
542impl AsRawSocket for Stream {
543    fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
544        self.stream().get_ref().as_raw_socket()
545    }
546}
547
548#[cfg(unix)]
549impl UniqueID for Stream {
550    fn id(&self) -> UniqueIDType {
551        self.as_raw_fd()
552    }
553}
554
555#[cfg(windows)]
556impl UniqueID for Stream {
557    fn id(&self) -> usize {
558        self.as_raw_socket() as usize
559    }
560}
561
562impl Ssl for Stream {}
563
564#[async_trait]
565impl Peek for Stream {
566    async fn try_peek(&mut self, buf: &mut [u8]) -> std::io::Result<bool> {
567        use tokio::io::AsyncReadExt;
568        self.read_exact(buf).await?;
569        // rewind regardless of what is read
570        self.rewind(buf);
571        Ok(true)
572    }
573}
574
575#[async_trait]
576impl Shutdown for Stream {
577    async fn shutdown(&mut self) {
578        AsyncWriteExt::shutdown(self).await.unwrap_or_else(|e| {
579            debug!("Failed to shutdown connection: {:?}", e);
580        });
581    }
582}
583
584impl GetTimingDigest for Stream {
585    fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
586        let mut digest = Vec::with_capacity(2); // expect to have both L4 stream and TLS layer
587        digest.push(Some(TimingDigest {
588            established_ts: self.established_ts,
589        }));
590        digest
591    }
592
593    fn get_read_pending_time(&self) -> Duration {
594        self.read_pending_time.total
595    }
596
597    fn get_write_pending_time(&self) -> Duration {
598        self.write_pending_time.total
599    }
600}
601
602impl GetProxyDigest for Stream {
603    fn get_proxy_digest(&self) -> Option<Arc<ProxyDigest>> {
604        self.proxy_digest.clone()
605    }
606
607    fn set_proxy_digest(&mut self, digest: ProxyDigest) {
608        self.proxy_digest = Some(Arc::new(digest));
609    }
610}
611
612impl GetSocketDigest for Stream {
613    fn get_socket_digest(&self) -> Option<Arc<SocketDigest>> {
614        self.socket_digest.clone()
615    }
616
617    fn set_socket_digest(&mut self, socket_digest: SocketDigest) {
618        self.socket_digest = Some(Arc::new(socket_digest))
619    }
620}
621
622impl Drop for Stream {
623    fn drop(&mut self) {
624        if let Some(t) = self.tracer.as_ref() {
625            t.0.on_disconnected();
626        }
627        /* use nodelay/local_addr function to detect socket status */
628        let ret = match &self.stream().get_ref().stream {
629            RawStream::Tcp(s) => s.nodelay().err(),
630            #[cfg(unix)]
631            RawStream::Unix(s) => s.local_addr().err(),
632            RawStream::Virtual(_) => {
633                // TODO: should this do something?
634                None
635            }
636        };
637        if let Some(e) = ret {
638            match e.kind() {
639                tokio::io::ErrorKind::Other => {
640                    if let Some(ecode) = e.raw_os_error() {
641                        if ecode == 9 {
642                            // Or we could panic here
643                            error!("Crit: socket {:?} is being double closed", self.stream);
644                        }
645                    }
646                }
647                _ => {
648                    debug!("Socket is already broken {:?}", e);
649                }
650            }
651        } else {
652            // try flush the write buffer. We use now_or_never() because
653            // 1. Drop cannot be async
654            // 2. write should usually be ready, unless the buf is full.
655            let _ = self.flush().now_or_never();
656        }
657        debug!("Dropping socket {:?}", self.stream);
658    }
659}
660
661impl AsyncRead for Stream {
662    fn poll_read(
663        mut self: Pin<&mut Self>,
664        cx: &mut Context<'_>,
665        buf: &mut ReadBuf<'_>,
666    ) -> Poll<io::Result<()>> {
667        let result = if !self.rewind_read_buf.is_empty() {
668            let data_to_read = self.rewind_read_buf.pop().unwrap(); // safe
669            let mut data_to_read = data_to_read.as_slice();
670            let result = Pin::new(&mut data_to_read).poll_read(cx, buf);
671            // return the remaining data back to the head of rewind_read_buf
672            if !data_to_read.is_empty() {
673                let remaining_buf = Vec::from(data_to_read);
674                self.rewind_read_buf.push(remaining_buf);
675            }
676            result
677        } else {
678            Pin::new(&mut self.stream_mut()).poll_read(cx, buf)
679        };
680        self.read_pending_time.poll_time(&result);
681        self.rx_ts = self.stream().get_ref().rx_ts;
682        result
683    }
684}
685
686impl AsyncWrite for Stream {
687    fn poll_write(
688        mut self: Pin<&mut Self>,
689        cx: &mut Context,
690        buf: &[u8],
691    ) -> Poll<io::Result<usize>> {
692        let result = if self.buffer_write {
693            Pin::new(&mut self.stream_mut()).poll_write(cx, buf)
694        } else {
695            Pin::new(&mut self.stream_mut().get_mut()).poll_write(cx, buf)
696        };
697        self.write_pending_time.poll_write_time(&result, buf.len());
698        result
699    }
700
701    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
702        let result = Pin::new(&mut self.stream_mut()).poll_flush(cx);
703        self.write_pending_time.poll_time(&result);
704        result
705    }
706
707    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
708        Pin::new(&mut self.stream_mut()).poll_shutdown(cx)
709    }
710
711    fn poll_write_vectored(
712        mut self: Pin<&mut Self>,
713        cx: &mut Context<'_>,
714        bufs: &[std::io::IoSlice<'_>],
715    ) -> Poll<io::Result<usize>> {
716        let total_size = bufs.iter().fold(0, |acc, s| acc + s.len());
717
718        let result = if self.buffer_write {
719            Pin::new(&mut self.stream_mut()).poll_write_vectored(cx, bufs)
720        } else {
721            Pin::new(&mut self.stream_mut().get_mut()).poll_write_vectored(cx, bufs)
722        };
723
724        self.write_pending_time.poll_write_time(&result, total_size);
725        result
726    }
727
728    fn is_write_vectored(&self) -> bool {
729        if self.buffer_write {
730            self.stream().is_write_vectored() // it is true
731        } else {
732            self.stream().get_ref().is_write_vectored()
733        }
734    }
735}
736
737pub mod async_write_vec {
738    use bytes::Buf;
739    use futures::ready;
740    use std::future::Future;
741    use std::io::IoSlice;
742    use std::pin::Pin;
743    use std::task::{Context, Poll};
744    use tokio::io;
745    use tokio::io::AsyncWrite;
746
747    /*
748        the missing write_buf https://github.com/tokio-rs/tokio/pull/3156#issuecomment-738207409
749        https://github.com/tokio-rs/tokio/issues/2610
750        In general vectored write is lost when accessing the trait object: Box<S: AsyncWrite>
751    */
752
753    #[must_use = "futures do nothing unless you `.await` or poll them"]
754    pub struct WriteVec<'a, W, B> {
755        writer: &'a mut W,
756        buf: &'a mut B,
757    }
758
759    #[must_use = "futures do nothing unless you `.await` or poll them"]
760    pub struct WriteVecAll<'a, W, B> {
761        writer: &'a mut W,
762        buf: &'a mut B,
763    }
764
765    pub trait AsyncWriteVec {
766        fn poll_write_vec<B: Buf>(
767            self: Pin<&mut Self>,
768            _cx: &mut Context<'_>,
769            _buf: &mut B,
770        ) -> Poll<io::Result<usize>>;
771
772        fn write_vec<'a, B>(&'a mut self, src: &'a mut B) -> WriteVec<'a, Self, B>
773        where
774            Self: Sized,
775            B: Buf,
776        {
777            WriteVec {
778                writer: self,
779                buf: src,
780            }
781        }
782
783        fn write_vec_all<'a, B>(&'a mut self, src: &'a mut B) -> WriteVecAll<'a, Self, B>
784        where
785            Self: Sized,
786            B: Buf,
787        {
788            WriteVecAll {
789                writer: self,
790                buf: src,
791            }
792        }
793    }
794
795    impl<W, B> Future for WriteVec<'_, W, B>
796    where
797        W: AsyncWriteVec + Unpin,
798        B: Buf,
799    {
800        type Output = io::Result<usize>;
801
802        fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<usize>> {
803            let me = &mut *self;
804            Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf)
805        }
806    }
807
808    impl<W, B> Future for WriteVecAll<'_, W, B>
809    where
810        W: AsyncWriteVec + Unpin,
811        B: Buf,
812    {
813        type Output = io::Result<()>;
814
815        fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
816            let me = &mut *self;
817            while me.buf.has_remaining() {
818                let n = ready!(Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf))?;
819                if n == 0 {
820                    return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
821                }
822            }
823            Poll::Ready(Ok(()))
824        }
825    }
826
827    /* from https://github.com/tokio-rs/tokio/blob/master/tokio-util/src/lib.rs#L177 */
828    impl<T> AsyncWriteVec for T
829    where
830        T: AsyncWrite,
831    {
832        fn poll_write_vec<B: Buf>(
833            self: Pin<&mut Self>,
834            ctx: &mut Context,
835            buf: &mut B,
836        ) -> Poll<io::Result<usize>> {
837            const MAX_BUFS: usize = 64;
838
839            if !buf.has_remaining() {
840                return Poll::Ready(Ok(0));
841            }
842
843            let n = if self.is_write_vectored() {
844                let mut slices = [IoSlice::new(&[]); MAX_BUFS];
845                let cnt = buf.chunks_vectored(&mut slices);
846                ready!(self.poll_write_vectored(ctx, &slices[..cnt]))?
847            } else {
848                ready!(self.poll_write(ctx, buf.chunk()))?
849            };
850
851            buf.advance(n);
852
853            Poll::Ready(Ok(n))
854        }
855    }
856}
857
858pub use async_write_vec::AsyncWriteVec;
859
860#[derive(Debug)]
861struct AccumulatedDuration {
862    total: Duration,
863    last_start: Option<Instant>,
864}
865
866impl AccumulatedDuration {
867    fn new() -> Self {
868        AccumulatedDuration {
869            total: Duration::ZERO,
870            last_start: None,
871        }
872    }
873
874    fn start(&mut self) {
875        if self.last_start.is_none() {
876            self.last_start = Some(Instant::now());
877        }
878    }
879
880    fn stop(&mut self) {
881        if let Some(start) = self.last_start.take() {
882            self.total += start.elapsed();
883        }
884    }
885
886    fn poll_write_time(&mut self, result: &Poll<io::Result<usize>>, buf_size: usize) {
887        match result {
888            Poll::Ready(Ok(n)) => {
889                if *n == buf_size {
890                    self.stop();
891                } else {
892                    // partial write
893                    self.start();
894                }
895            }
896            Poll::Ready(Err(_)) => {
897                self.stop();
898            }
899            _ => self.start(),
900        }
901    }
902
903    fn poll_time(&mut self, result: &Poll<io::Result<()>>) {
904        match result {
905            Poll::Ready(_) => {
906                self.stop();
907            }
908            _ => self.start(),
909        }
910    }
911}
912
913#[cfg(test)]
914#[cfg(target_os = "linux")]
915mod tests {
916    use super::*;
917    use std::sync::Arc;
918    use tokio::io::AsyncReadExt;
919    use tokio::io::AsyncWriteExt;
920    use tokio::net::TcpListener;
921    use tokio::sync::Notify;
922
923    #[cfg(target_os = "linux")]
924    #[tokio::test]
925    async fn test_rx_timestamp() {
926        let message = "hello world".as_bytes();
927        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
928        let addr = listener.local_addr().unwrap();
929        let notify = Arc::new(Notify::new());
930        let notify2 = notify.clone();
931
932        tokio::spawn(async move {
933            let (mut stream, _) = listener.accept().await.unwrap();
934            notify2.notified().await;
935            stream.write_all(message).await.unwrap();
936        });
937
938        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
939        stream.set_rx_timestamp().unwrap();
940        // Receive the message
941        // setsockopt for SO_TIMESTAMPING is asynchronous so sleep a little bit
942        // to let kernel do the work
943        std::thread::sleep(Duration::from_micros(100));
944        notify.notify_one();
945
946        let mut buffer = vec![0u8; message.len()];
947        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
948        assert_eq!(n, message.len());
949        assert!(stream.rx_ts.is_some());
950    }
951
952    #[cfg(target_os = "linux")]
953    #[tokio::test]
954    async fn test_rx_timestamp_standard_path() {
955        let message = "hello world".as_bytes();
956        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
957        let addr = listener.local_addr().unwrap();
958        let notify = Arc::new(Notify::new());
959        let notify2 = notify.clone();
960
961        tokio::spawn(async move {
962            let (mut stream, _) = listener.accept().await.unwrap();
963            notify2.notified().await;
964            stream.write_all(message).await.unwrap();
965        });
966
967        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
968        std::thread::sleep(Duration::from_micros(100));
969        notify.notify_one();
970
971        let mut buffer = vec![0u8; message.len()];
972        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
973        assert_eq!(n, message.len());
974        assert!(stream.rx_ts.is_none());
975    }
976
977    #[tokio::test]
978    async fn test_stream_rewind() {
979        let message = b"hello world";
980        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
981        let addr = listener.local_addr().unwrap();
982        let notify = Arc::new(Notify::new());
983        let notify2 = notify.clone();
984
985        tokio::spawn(async move {
986            let (mut stream, _) = listener.accept().await.unwrap();
987            notify2.notified().await;
988            stream.write_all(message).await.unwrap();
989        });
990
991        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
992
993        let rewind_test = b"this is Sparta!";
994        stream.rewind(rewind_test);
995
996        // partially read rewind_test because of the buffer size limit
997        let mut buffer = vec![0u8; message.len()];
998        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
999        assert_eq!(n, message.len());
1000        assert_eq!(buffer, rewind_test[..message.len()]);
1001
1002        // read the rest of rewind_test
1003        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
1004        assert_eq!(n, rewind_test.len() - message.len());
1005        assert_eq!(buffer[..n], rewind_test[message.len()..]);
1006
1007        // read the actual data
1008        notify.notify_one();
1009        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
1010        assert_eq!(n, message.len());
1011        assert_eq!(buffer, message);
1012    }
1013
1014    #[tokio::test]
1015    async fn test_stream_peek() {
1016        let message = b"hello world";
1017        dbg!("try peek");
1018        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1019        let addr = listener.local_addr().unwrap();
1020        let notify = Arc::new(Notify::new());
1021        let notify2 = notify.clone();
1022
1023        tokio::spawn(async move {
1024            let (mut stream, _) = listener.accept().await.unwrap();
1025            notify2.notified().await;
1026            stream.write_all(message).await.unwrap();
1027            drop(stream);
1028        });
1029
1030        notify.notify_one();
1031
1032        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
1033        let mut buffer = vec![0u8; 5];
1034        assert!(stream.try_peek(&mut buffer).await.unwrap());
1035        assert_eq!(buffer, message[0..5]);
1036        let mut buffer = vec![];
1037        stream.read_to_end(&mut buffer).await.unwrap();
1038        assert_eq!(buffer, message);
1039    }
1040
1041    #[tokio::test]
1042    async fn test_stream_two_subsequent_peek_calls_before_read() {
1043        let message = b"abcdefghijklmn";
1044
1045        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1046        let addr = listener.local_addr().unwrap();
1047        let notify = Arc::new(Notify::new());
1048        let notify2 = notify.clone();
1049
1050        tokio::spawn(async move {
1051            let (mut stream, _) = listener.accept().await.unwrap();
1052            notify2.notified().await;
1053            stream.write_all(message).await.unwrap();
1054            drop(stream);
1055        });
1056
1057        notify.notify_one();
1058
1059        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
1060
1061        // Peek 4 bytes
1062        let mut buffer = vec![0u8; 4];
1063        assert!(stream.try_peek(&mut buffer).await.unwrap());
1064        assert_eq!(buffer, message[0..4]);
1065
1066        // Peek 2 bytes
1067        let mut buffer = vec![0u8; 2];
1068        assert!(stream.try_peek(&mut buffer).await.unwrap());
1069        assert_eq!(buffer, message[0..2]);
1070
1071        // Read 1 byte: ['a']
1072        let mut buffer = vec![0u8; 1];
1073        stream.read_exact(&mut buffer).await.unwrap();
1074        assert_eq!(buffer, message[0..1]);
1075
1076        // Read as many bytes as possible, return 1 byte ['b']
1077        //  from the first retry buffer chunk
1078        let mut buffer = vec![0u8; 100];
1079        let n = stream.read(&mut buffer).await.unwrap();
1080        assert_eq!(n, 1);
1081        assert_eq!(buffer[..n], message[1..2]);
1082
1083        // Read the rest ['cdefghijklmn']
1084        let mut buffer = vec![];
1085        stream.read_to_end(&mut buffer).await.unwrap();
1086        assert_eq!(buffer, message[2..]);
1087    }
1088}