Skip to main content

pingora_core/protocols/l4/
stream.rs

1// Copyright 2026 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Transport layer connection
16
17use async_trait::async_trait;
18use futures::FutureExt;
19use log::{debug, error};
20
21use pingora_error::{ErrorType::*, OrErr, Result};
22#[cfg(target_os = "linux")]
23use std::io::IoSliceMut;
24#[cfg(unix)]
25use std::os::unix::io::AsRawFd;
26#[cfg(windows)]
27use std::os::windows::io::AsRawSocket;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll};
31use std::time::{Duration, Instant, SystemTime};
32#[cfg(target_os = "linux")]
33use tokio::io::Interest;
34use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
35use tokio::net::TcpStream;
36#[cfg(unix)]
37use tokio::net::UnixStream;
38
39use crate::protocols::l4::ext::{set_tcp_keepalive, TcpKeepalive};
40use crate::protocols::l4::virt;
41use crate::protocols::raw_connect::ProxyDigest;
42use crate::protocols::{
43    GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, SocketDigest, Ssl,
44    TimingDigest, UniqueID, UniqueIDType,
45};
46use crate::upstreams::peer::Tracer;
47
48#[derive(Debug)]
49enum RawStream {
50    Tcp(TcpStream),
51    #[cfg(unix)]
52    Unix(UnixStream),
53    Virtual(virt::VirtualSocketStream),
54}
55
56impl AsyncRead for RawStream {
57    fn poll_read(
58        self: Pin<&mut Self>,
59        cx: &mut Context<'_>,
60        buf: &mut ReadBuf<'_>,
61    ) -> Poll<io::Result<()>> {
62        // Safety: Basic enum pin projection
63        unsafe {
64            match &mut Pin::get_unchecked_mut(self) {
65                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
66                #[cfg(unix)]
67                RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
68                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_read(cx, buf),
69            }
70        }
71    }
72}
73
74impl AsyncWrite for RawStream {
75    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
76        // Safety: Basic enum pin projection
77        unsafe {
78            match &mut Pin::get_unchecked_mut(self) {
79                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
80                #[cfg(unix)]
81                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
82                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write(cx, buf),
83            }
84        }
85    }
86
87    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
88        // Safety: Basic enum pin projection
89        unsafe {
90            match &mut Pin::get_unchecked_mut(self) {
91                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
92                #[cfg(unix)]
93                RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
94                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_flush(cx),
95            }
96        }
97    }
98
99    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
100        // Safety: Basic enum pin projection
101        unsafe {
102            match &mut Pin::get_unchecked_mut(self) {
103                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
104                #[cfg(unix)]
105                RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
106                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_shutdown(cx),
107            }
108        }
109    }
110
111    fn poll_write_vectored(
112        self: Pin<&mut Self>,
113        cx: &mut Context<'_>,
114        bufs: &[std::io::IoSlice<'_>],
115    ) -> Poll<io::Result<usize>> {
116        // Safety: Basic enum pin projection
117        unsafe {
118            match &mut Pin::get_unchecked_mut(self) {
119                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
120                #[cfg(unix)]
121                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
122                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
123            }
124        }
125    }
126
127    fn is_write_vectored(&self) -> bool {
128        match self {
129            RawStream::Tcp(s) => s.is_write_vectored(),
130            #[cfg(unix)]
131            RawStream::Unix(s) => s.is_write_vectored(),
132            RawStream::Virtual(s) => s.is_write_vectored(),
133        }
134    }
135}
136
137#[cfg(unix)]
138impl AsRawFd for RawStream {
139    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
140        match self {
141            RawStream::Tcp(s) => s.as_raw_fd(),
142            RawStream::Unix(s) => s.as_raw_fd(),
143            RawStream::Virtual(_) => -1, // Virtual stream does not have a real fd
144        }
145    }
146}
147
148#[cfg(windows)]
149impl AsRawSocket for RawStream {
150    fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
151        match self {
152            RawStream::Tcp(s) => s.as_raw_socket(),
153        }
154    }
155}
156
157#[derive(Debug)]
158struct RawStreamWrapper {
159    pub(crate) stream: RawStream,
160    /// store the last rx timestamp of the stream.
161    pub(crate) rx_ts: Option<SystemTime>,
162    /// enable reading rx timestamp
163    #[cfg(target_os = "linux")]
164    pub(crate) enable_rx_ts: bool,
165    #[cfg(target_os = "linux")]
166    /// This can be reused across multiple recvmsg calls. The cmsg buffer may
167    /// come from old sockets created by older version of pingora and so,
168    /// this vector can only grow.
169    reusable_cmsg_space: Vec<u8>,
170}
171
172impl RawStreamWrapper {
173    pub fn new(stream: RawStream) -> Self {
174        RawStreamWrapper {
175            stream,
176            rx_ts: None,
177            #[cfg(target_os = "linux")]
178            enable_rx_ts: false,
179            #[cfg(target_os = "linux")]
180            reusable_cmsg_space: nix::cmsg_space!(nix::sys::time::TimeSpec),
181        }
182    }
183
184    #[cfg(target_os = "linux")]
185    pub fn enable_rx_ts(&mut self, enable_rx_ts: bool) {
186        self.enable_rx_ts = enable_rx_ts;
187    }
188}
189
190impl AsyncRead for RawStreamWrapper {
191    #[cfg(not(target_os = "linux"))]
192    fn poll_read(
193        self: Pin<&mut Self>,
194        cx: &mut Context<'_>,
195        buf: &mut ReadBuf<'_>,
196    ) -> Poll<io::Result<()>> {
197        // Safety: Basic enum pin projection
198        unsafe {
199            let rs_wrapper = Pin::get_unchecked_mut(self);
200            match &mut rs_wrapper.stream {
201                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
202                #[cfg(unix)]
203                RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
204                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_read(cx, buf),
205            }
206        }
207    }
208
209    #[cfg(target_os = "linux")]
210    fn poll_read(
211        self: Pin<&mut Self>,
212        cx: &mut Context<'_>,
213        buf: &mut ReadBuf<'_>,
214    ) -> Poll<io::Result<()>> {
215        use futures::ready;
216        use nix::sys::socket::{recvmsg, ControlMessageOwned, MsgFlags, SockaddrStorage};
217
218        // if we do not need rx timestamp, then use the standard path
219        if !self.enable_rx_ts {
220            // Safety: Basic enum pin projection
221            unsafe {
222                let rs_wrapper = Pin::get_unchecked_mut(self);
223                match &mut rs_wrapper.stream {
224                    RawStream::Tcp(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
225                    RawStream::Unix(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
226                    RawStream::Virtual(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
227                }
228            }
229        }
230
231        // Safety: Basic pin projection to get mutable stream
232        let rs_wrapper = unsafe { Pin::get_unchecked_mut(self) };
233        match &mut rs_wrapper.stream {
234            RawStream::Tcp(s) => {
235                loop {
236                    ready!(s.poll_read_ready(cx))?;
237                    // Safety: maybe uninitialized bytes will only be passed to recvmsg
238                    let b = unsafe {
239                        &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>]
240                            as *mut [u8])
241                    };
242                    let mut iov = [IoSliceMut::new(b)];
243                    rs_wrapper.reusable_cmsg_space.clear();
244
245                    match s.try_io(Interest::READABLE, || {
246                        recvmsg::<SockaddrStorage>(
247                            s.as_raw_fd(),
248                            &mut iov,
249                            Some(&mut rs_wrapper.reusable_cmsg_space),
250                            MsgFlags::empty(),
251                        )
252                        .map_err(|errno| errno.into())
253                    }) {
254                        Ok(r) => {
255                            if let Some(ControlMessageOwned::ScmTimestampsns(rtime)) = r
256                                .cmsgs()
257                                .find(|i| matches!(i, ControlMessageOwned::ScmTimestampsns(_)))
258                            {
259                                // The returned timestamp is a real (i.e. not monotonic) timestamp
260                                // https://docs.kernel.org/networking/timestamping.html
261                                rs_wrapper.rx_ts =
262                                    SystemTime::UNIX_EPOCH.checked_add(rtime.system.into());
263                            }
264                            // Safety: We trust `recvmsg` to have filled up `r.bytes` bytes in the buffer.
265                            unsafe {
266                                buf.assume_init(r.bytes);
267                            }
268                            buf.advance(r.bytes);
269                            return Poll::Ready(Ok(()));
270                        }
271                        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
272                        Err(e) => return Poll::Ready(Err(e)),
273                    }
274                }
275            }
276            // Unix RX timestamp only works with datagram for now, so we do not care about it
277            RawStream::Unix(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) },
278            RawStream::Virtual(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) },
279        }
280    }
281}
282
283impl AsyncWrite for RawStreamWrapper {
284    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
285        // Safety: Basic enum pin projection
286        unsafe {
287            match &mut Pin::get_unchecked_mut(self).stream {
288                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
289                #[cfg(unix)]
290                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
291                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write(cx, buf),
292            }
293        }
294    }
295
296    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
297        // Safety: Basic enum pin projection
298        unsafe {
299            match &mut Pin::get_unchecked_mut(self).stream {
300                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
301                #[cfg(unix)]
302                RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
303                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_flush(cx),
304            }
305        }
306    }
307
308    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
309        // Safety: Basic enum pin projection
310        unsafe {
311            match &mut Pin::get_unchecked_mut(self).stream {
312                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
313                #[cfg(unix)]
314                RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
315                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_shutdown(cx),
316            }
317        }
318    }
319
320    fn poll_write_vectored(
321        self: Pin<&mut Self>,
322        cx: &mut Context<'_>,
323        bufs: &[std::io::IoSlice<'_>],
324    ) -> Poll<io::Result<usize>> {
325        // Safety: Basic enum pin projection
326        unsafe {
327            match &mut Pin::get_unchecked_mut(self).stream {
328                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
329                #[cfg(unix)]
330                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
331                RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
332            }
333        }
334    }
335
336    fn is_write_vectored(&self) -> bool {
337        self.stream.is_write_vectored()
338    }
339}
340
341#[cfg(unix)]
342impl AsRawFd for RawStreamWrapper {
343    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
344        self.stream.as_raw_fd()
345    }
346}
347
348#[cfg(windows)]
349impl AsRawSocket for RawStreamWrapper {
350    fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
351        self.stream.as_raw_socket()
352    }
353}
354
355// Large read buffering helps reducing syscalls with little trade-off
356// Ssl layer always does "small" reads in 16k (TLS record size) so L4 read buffer helps a lot.
357const BUF_READ_SIZE: usize = 64 * 1024;
358// Small write buf to match MSS. Too large write buf delays real time communication.
359// This buffering effectively implements something similar to Nagle's algorithm.
360// The benefit is that user space can control when to flush, where Nagle's can't be controlled.
361// And userspace buffering reduce both syscalls and small packets.
362const BUF_WRITE_SIZE: usize = 1460;
363
364// NOTE: with writer buffering, users need to call flush() to make sure the data is actually
365// sent. Otherwise data could be stuck in the buffer forever or get lost when stream is closed.
366
367/// A concrete type for transport layer connection + extra fields for logging
368#[derive(Debug)]
369pub struct Stream {
370    // Use `Option` to be able to swap to adjust the buffer size. Always safe to unwrap
371    stream: Option<BufStream<RawStreamWrapper>>,
372    // the data put back at the front of the read buffer, in order to replay the read
373    rewind_read_buf: Vec<Vec<u8>>,
374    buffer_write: bool,
375    proxy_digest: Option<Arc<ProxyDigest>>,
376    socket_digest: Option<Arc<SocketDigest>>,
377    /// When this connection is established
378    pub established_ts: SystemTime,
379    /// The distributed tracing object for this stream
380    pub tracer: Option<Tracer>,
381    read_pending_time: AccumulatedDuration,
382    write_pending_time: AccumulatedDuration,
383    /// Last rx timestamp associated with the last recvmsg call.
384    pub rx_ts: Option<SystemTime>,
385}
386
387impl Stream {
388    fn stream(&self) -> &BufStream<RawStreamWrapper> {
389        self.stream.as_ref().expect("stream should always be set")
390    }
391
392    fn stream_mut(&mut self) -> &mut BufStream<RawStreamWrapper> {
393        self.stream.as_mut().expect("stream should always be set")
394    }
395
396    /// set TCP nodelay for this connection if `self` is TCP
397    pub fn set_nodelay(&mut self) -> Result<()> {
398        match &self.stream_mut().get_mut().stream {
399            RawStream::Tcp(s) => {
400                s.set_nodelay(true)
401                    .or_err(ConnectError, "failed to set_nodelay")?;
402            }
403            RawStream::Virtual(s) => {
404                s.set_socket_option(virt::VirtualSockOpt::NoDelay)
405                    .or_err(ConnectError, "failed to set_nodelay on virtual socket")?;
406            }
407            _ => (),
408        }
409        Ok(())
410    }
411
412    /// set TCP keepalive settings for this connection if `self` is TCP
413    pub fn set_keepalive(&mut self, ka: &TcpKeepalive) -> Result<()> {
414        match &self.stream_mut().get_mut().stream {
415            RawStream::Tcp(s) => {
416                debug!("Setting tcp keepalive");
417                set_tcp_keepalive(s, ka)?;
418            }
419            RawStream::Virtual(s) => {
420                s.set_socket_option(virt::VirtualSockOpt::KeepAlive(ka.clone()))
421                    .or_err(ConnectError, "failed to set_keepalive on virtual socket")?;
422            }
423            _ => (),
424        }
425        Ok(())
426    }
427
428    #[cfg(target_os = "linux")]
429    pub fn set_rx_timestamp(&mut self) -> Result<()> {
430        use nix::sys::socket::{setsockopt, sockopt, TimestampingFlag};
431
432        if let RawStream::Tcp(s) = &self.stream_mut().get_mut().stream {
433            let timestamp_options = TimestampingFlag::SOF_TIMESTAMPING_RX_SOFTWARE
434                | TimestampingFlag::SOF_TIMESTAMPING_SOFTWARE;
435            setsockopt(s.as_raw_fd(), sockopt::Timestamping, &timestamp_options)
436                .or_err(InternalError, "failed to set SOF_TIMESTAMPING_RX_SOFTWARE")?;
437            self.stream_mut().get_mut().enable_rx_ts(true);
438        }
439
440        Ok(())
441    }
442
443    #[cfg(not(target_os = "linux"))]
444    pub fn set_rx_timestamp(&mut self) -> io::Result<()> {
445        Ok(())
446    }
447
448    /// Put Some data back to the head of the stream to be read again
449    pub(crate) fn rewind(&mut self, data: &[u8]) {
450        if !data.is_empty() {
451            self.rewind_read_buf.push(data.to_vec());
452        }
453    }
454
455    /// Set the buffer of BufStream
456    /// It is only set later because of the malloc overhead in critical accept() path
457    pub(crate) fn set_buffer(&mut self) {
458        use std::mem;
459        // Since BufStream doesn't provide an API to adjust the buf directly,
460        // we take the raw stream out of it and put it in a new BufStream with the size we want
461        let stream = mem::take(&mut self.stream);
462        let stream =
463            stream.map(|s| BufStream::with_capacity(BUF_READ_SIZE, BUF_WRITE_SIZE, s.into_inner()));
464        let _ = mem::replace(&mut self.stream, stream);
465    }
466}
467
468impl From<TcpStream> for Stream {
469    fn from(s: TcpStream) -> Self {
470        Stream {
471            stream: Some(BufStream::with_capacity(
472                0,
473                0,
474                RawStreamWrapper::new(RawStream::Tcp(s)),
475            )),
476            rewind_read_buf: Vec::new(),
477            buffer_write: true,
478            established_ts: SystemTime::now(),
479            proxy_digest: None,
480            socket_digest: None,
481            tracer: None,
482            read_pending_time: AccumulatedDuration::new(),
483            write_pending_time: AccumulatedDuration::new(),
484            rx_ts: None,
485        }
486    }
487}
488
489impl From<virt::VirtualSocketStream> for Stream {
490    fn from(s: virt::VirtualSocketStream) -> Self {
491        Stream {
492            stream: Some(BufStream::with_capacity(
493                0,
494                0,
495                RawStreamWrapper::new(RawStream::Virtual(s)),
496            )),
497            rewind_read_buf: Vec::new(),
498            buffer_write: true,
499            established_ts: SystemTime::now(),
500            proxy_digest: None,
501            socket_digest: None,
502            tracer: None,
503            read_pending_time: AccumulatedDuration::new(),
504            write_pending_time: AccumulatedDuration::new(),
505            rx_ts: None,
506        }
507    }
508}
509
510#[cfg(unix)]
511impl From<UnixStream> for Stream {
512    fn from(s: UnixStream) -> Self {
513        Stream {
514            stream: Some(BufStream::with_capacity(
515                0,
516                0,
517                RawStreamWrapper::new(RawStream::Unix(s)),
518            )),
519            rewind_read_buf: Vec::new(),
520            buffer_write: true,
521            established_ts: SystemTime::now(),
522            proxy_digest: None,
523            socket_digest: None,
524            tracer: None,
525            read_pending_time: AccumulatedDuration::new(),
526            write_pending_time: AccumulatedDuration::new(),
527            rx_ts: None,
528        }
529    }
530}
531
532#[cfg(unix)]
533impl AsRawFd for Stream {
534    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
535        self.stream().get_ref().as_raw_fd()
536    }
537}
538
539#[cfg(windows)]
540impl AsRawSocket for Stream {
541    fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
542        self.stream().get_ref().as_raw_socket()
543    }
544}
545
546#[cfg(unix)]
547impl UniqueID for Stream {
548    fn id(&self) -> UniqueIDType {
549        self.as_raw_fd()
550    }
551}
552
553#[cfg(windows)]
554impl UniqueID for Stream {
555    fn id(&self) -> usize {
556        self.as_raw_socket() as usize
557    }
558}
559
560impl Ssl for Stream {}
561
562#[async_trait]
563impl Peek for Stream {
564    async fn try_peek(&mut self, buf: &mut [u8]) -> std::io::Result<bool> {
565        use tokio::io::AsyncReadExt;
566        self.read_exact(buf).await?;
567        // rewind regardless of what is read
568        self.rewind(buf);
569        Ok(true)
570    }
571}
572
573#[async_trait]
574impl Shutdown for Stream {
575    async fn shutdown(&mut self) {
576        AsyncWriteExt::shutdown(self).await.unwrap_or_else(|e| {
577            debug!("Failed to shutdown connection: {:?}", e);
578        });
579    }
580}
581
582impl GetTimingDigest for Stream {
583    fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
584        let mut digest = Vec::with_capacity(2); // expect to have both L4 stream and TLS layer
585        digest.push(Some(TimingDigest {
586            established_ts: self.established_ts,
587        }));
588        digest
589    }
590
591    fn get_read_pending_time(&self) -> Duration {
592        self.read_pending_time.total
593    }
594
595    fn get_write_pending_time(&self) -> Duration {
596        self.write_pending_time.total
597    }
598}
599
600impl GetProxyDigest for Stream {
601    fn get_proxy_digest(&self) -> Option<Arc<ProxyDigest>> {
602        self.proxy_digest.clone()
603    }
604
605    fn set_proxy_digest(&mut self, digest: ProxyDigest) {
606        self.proxy_digest = Some(Arc::new(digest));
607    }
608}
609
610impl GetSocketDigest for Stream {
611    fn get_socket_digest(&self) -> Option<Arc<SocketDigest>> {
612        self.socket_digest.clone()
613    }
614
615    fn set_socket_digest(&mut self, socket_digest: SocketDigest) {
616        self.socket_digest = Some(Arc::new(socket_digest))
617    }
618}
619
620impl Drop for Stream {
621    fn drop(&mut self) {
622        if let Some(t) = self.tracer.as_ref() {
623            t.0.on_disconnected();
624        }
625        /* use nodelay/local_addr function to detect socket status */
626        let ret = match &self.stream().get_ref().stream {
627            RawStream::Tcp(s) => s.nodelay().err(),
628            #[cfg(unix)]
629            RawStream::Unix(s) => s.local_addr().err(),
630            RawStream::Virtual(_) => {
631                // TODO: should this do something?
632                None
633            }
634        };
635        if let Some(e) = ret {
636            match e.kind() {
637                tokio::io::ErrorKind::Other => {
638                    if let Some(ecode) = e.raw_os_error() {
639                        if ecode == 9 {
640                            // Or we could panic here
641                            error!("Crit: socket {:?} is being double closed", self.stream);
642                        }
643                    }
644                }
645                _ => {
646                    debug!("Socket is already broken {:?}", e);
647                }
648            }
649        } else {
650            // try flush the write buffer. We use now_or_never() because
651            // 1. Drop cannot be async
652            // 2. write should usually be ready, unless the buf is full.
653            let _ = self.flush().now_or_never();
654        }
655        debug!("Dropping socket {:?}", self.stream);
656    }
657}
658
659impl AsyncRead for Stream {
660    fn poll_read(
661        mut self: Pin<&mut Self>,
662        cx: &mut Context<'_>,
663        buf: &mut ReadBuf<'_>,
664    ) -> Poll<io::Result<()>> {
665        let result = if !self.rewind_read_buf.is_empty() {
666            let data_to_read = self.rewind_read_buf.pop().unwrap(); // safe
667            let mut data_to_read = data_to_read.as_slice();
668            let result = Pin::new(&mut data_to_read).poll_read(cx, buf);
669            // return the remaining data back to the head of rewind_read_buf
670            if !data_to_read.is_empty() {
671                let remaining_buf = Vec::from(data_to_read);
672                self.rewind_read_buf.push(remaining_buf);
673            }
674            result
675        } else {
676            Pin::new(&mut self.stream_mut()).poll_read(cx, buf)
677        };
678        self.read_pending_time.poll_time(&result);
679        self.rx_ts = self.stream().get_ref().rx_ts;
680        result
681    }
682}
683
684impl AsyncWrite for Stream {
685    fn poll_write(
686        mut self: Pin<&mut Self>,
687        cx: &mut Context,
688        buf: &[u8],
689    ) -> Poll<io::Result<usize>> {
690        let result = if self.buffer_write {
691            Pin::new(&mut self.stream_mut()).poll_write(cx, buf)
692        } else {
693            Pin::new(&mut self.stream_mut().get_mut()).poll_write(cx, buf)
694        };
695        self.write_pending_time.poll_write_time(&result, buf.len());
696        result
697    }
698
699    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
700        let result = Pin::new(&mut self.stream_mut()).poll_flush(cx);
701        self.write_pending_time.poll_time(&result);
702        result
703    }
704
705    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
706        Pin::new(&mut self.stream_mut()).poll_shutdown(cx)
707    }
708
709    fn poll_write_vectored(
710        mut self: Pin<&mut Self>,
711        cx: &mut Context<'_>,
712        bufs: &[std::io::IoSlice<'_>],
713    ) -> Poll<io::Result<usize>> {
714        let total_size = bufs.iter().fold(0, |acc, s| acc + s.len());
715
716        let result = if self.buffer_write {
717            Pin::new(&mut self.stream_mut()).poll_write_vectored(cx, bufs)
718        } else {
719            Pin::new(&mut self.stream_mut().get_mut()).poll_write_vectored(cx, bufs)
720        };
721
722        self.write_pending_time.poll_write_time(&result, total_size);
723        result
724    }
725
726    fn is_write_vectored(&self) -> bool {
727        if self.buffer_write {
728            self.stream().is_write_vectored() // it is true
729        } else {
730            self.stream().get_ref().is_write_vectored()
731        }
732    }
733}
734
735pub mod async_write_vec {
736    use bytes::Buf;
737    use futures::ready;
738    use std::future::Future;
739    use std::io::IoSlice;
740    use std::pin::Pin;
741    use std::task::{Context, Poll};
742    use tokio::io;
743    use tokio::io::AsyncWrite;
744
745    /*
746        the missing write_buf https://github.com/tokio-rs/tokio/pull/3156#issuecomment-738207409
747        https://github.com/tokio-rs/tokio/issues/2610
748        In general vectored write is lost when accessing the trait object: Box<S: AsyncWrite>
749    */
750
751    #[must_use = "futures do nothing unless you `.await` or poll them"]
752    pub struct WriteVec<'a, W, B> {
753        writer: &'a mut W,
754        buf: &'a mut B,
755    }
756
757    #[must_use = "futures do nothing unless you `.await` or poll them"]
758    pub struct WriteVecAll<'a, W, B> {
759        writer: &'a mut W,
760        buf: &'a mut B,
761    }
762
763    pub trait AsyncWriteVec {
764        fn poll_write_vec<B: Buf>(
765            self: Pin<&mut Self>,
766            _cx: &mut Context<'_>,
767            _buf: &mut B,
768        ) -> Poll<io::Result<usize>>;
769
770        fn write_vec<'a, B>(&'a mut self, src: &'a mut B) -> WriteVec<'a, Self, B>
771        where
772            Self: Sized,
773            B: Buf,
774        {
775            WriteVec {
776                writer: self,
777                buf: src,
778            }
779        }
780
781        fn write_vec_all<'a, B>(&'a mut self, src: &'a mut B) -> WriteVecAll<'a, Self, B>
782        where
783            Self: Sized,
784            B: Buf,
785        {
786            WriteVecAll {
787                writer: self,
788                buf: src,
789            }
790        }
791    }
792
793    impl<W, B> Future for WriteVec<'_, W, B>
794    where
795        W: AsyncWriteVec + Unpin,
796        B: Buf,
797    {
798        type Output = io::Result<usize>;
799
800        fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<usize>> {
801            let me = &mut *self;
802            Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf)
803        }
804    }
805
806    impl<W, B> Future for WriteVecAll<'_, W, B>
807    where
808        W: AsyncWriteVec + Unpin,
809        B: Buf,
810    {
811        type Output = io::Result<()>;
812
813        fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
814            let me = &mut *self;
815            while me.buf.has_remaining() {
816                let n = ready!(Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf))?;
817                if n == 0 {
818                    return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
819                }
820            }
821            Poll::Ready(Ok(()))
822        }
823    }
824
825    /* from https://github.com/tokio-rs/tokio/blob/master/tokio-util/src/lib.rs#L177 */
826    impl<T> AsyncWriteVec for T
827    where
828        T: AsyncWrite,
829    {
830        fn poll_write_vec<B: Buf>(
831            self: Pin<&mut Self>,
832            ctx: &mut Context,
833            buf: &mut B,
834        ) -> Poll<io::Result<usize>> {
835            const MAX_BUFS: usize = 64;
836
837            if !buf.has_remaining() {
838                return Poll::Ready(Ok(0));
839            }
840
841            let n = if self.is_write_vectored() {
842                let mut slices = [IoSlice::new(&[]); MAX_BUFS];
843                let cnt = buf.chunks_vectored(&mut slices);
844                ready!(self.poll_write_vectored(ctx, &slices[..cnt]))?
845            } else {
846                ready!(self.poll_write(ctx, buf.chunk()))?
847            };
848
849            buf.advance(n);
850
851            Poll::Ready(Ok(n))
852        }
853    }
854}
855
856pub use async_write_vec::AsyncWriteVec;
857
858#[derive(Debug)]
859struct AccumulatedDuration {
860    total: Duration,
861    last_start: Option<Instant>,
862}
863
864impl AccumulatedDuration {
865    fn new() -> Self {
866        AccumulatedDuration {
867            total: Duration::ZERO,
868            last_start: None,
869        }
870    }
871
872    fn start(&mut self) {
873        if self.last_start.is_none() {
874            self.last_start = Some(Instant::now());
875        }
876    }
877
878    fn stop(&mut self) {
879        if let Some(start) = self.last_start.take() {
880            self.total += start.elapsed();
881        }
882    }
883
884    fn poll_write_time(&mut self, result: &Poll<io::Result<usize>>, buf_size: usize) {
885        match result {
886            Poll::Ready(Ok(n)) => {
887                if *n == buf_size {
888                    self.stop();
889                } else {
890                    // partial write
891                    self.start();
892                }
893            }
894            Poll::Ready(Err(_)) => {
895                self.stop();
896            }
897            _ => self.start(),
898        }
899    }
900
901    fn poll_time(&mut self, result: &Poll<io::Result<()>>) {
902        match result {
903            Poll::Ready(_) => {
904                self.stop();
905            }
906            _ => self.start(),
907        }
908    }
909}
910
911#[cfg(test)]
912#[cfg(target_os = "linux")]
913mod tests {
914    use super::*;
915    use std::sync::Arc;
916    use tokio::io::AsyncReadExt;
917    use tokio::io::AsyncWriteExt;
918    use tokio::net::TcpListener;
919    use tokio::sync::Notify;
920
921    #[cfg(target_os = "linux")]
922    #[tokio::test]
923    async fn test_rx_timestamp() {
924        let message = "hello world".as_bytes();
925        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
926        let addr = listener.local_addr().unwrap();
927        let notify = Arc::new(Notify::new());
928        let notify2 = notify.clone();
929
930        tokio::spawn(async move {
931            let (mut stream, _) = listener.accept().await.unwrap();
932            notify2.notified().await;
933            stream.write_all(message).await.unwrap();
934        });
935
936        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
937        stream.set_rx_timestamp().unwrap();
938        // Receive the message
939        // setsockopt for SO_TIMESTAMPING is asynchronous so sleep a little bit
940        // to let kernel do the work
941        std::thread::sleep(Duration::from_micros(100));
942        notify.notify_one();
943
944        let mut buffer = vec![0u8; message.len()];
945        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
946        assert_eq!(n, message.len());
947        assert!(stream.rx_ts.is_some());
948    }
949
950    #[cfg(target_os = "linux")]
951    #[tokio::test]
952    async fn test_rx_timestamp_standard_path() {
953        let message = "hello world".as_bytes();
954        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
955        let addr = listener.local_addr().unwrap();
956        let notify = Arc::new(Notify::new());
957        let notify2 = notify.clone();
958
959        tokio::spawn(async move {
960            let (mut stream, _) = listener.accept().await.unwrap();
961            notify2.notified().await;
962            stream.write_all(message).await.unwrap();
963        });
964
965        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
966        std::thread::sleep(Duration::from_micros(100));
967        notify.notify_one();
968
969        let mut buffer = vec![0u8; message.len()];
970        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
971        assert_eq!(n, message.len());
972        assert!(stream.rx_ts.is_none());
973    }
974
975    #[tokio::test]
976    async fn test_stream_rewind() {
977        let message = b"hello world";
978        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
979        let addr = listener.local_addr().unwrap();
980        let notify = Arc::new(Notify::new());
981        let notify2 = notify.clone();
982
983        tokio::spawn(async move {
984            let (mut stream, _) = listener.accept().await.unwrap();
985            notify2.notified().await;
986            stream.write_all(message).await.unwrap();
987        });
988
989        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
990
991        let rewind_test = b"this is Sparta!";
992        stream.rewind(rewind_test);
993
994        // partially read rewind_test because of the buffer size limit
995        let mut buffer = vec![0u8; message.len()];
996        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
997        assert_eq!(n, message.len());
998        assert_eq!(buffer, rewind_test[..message.len()]);
999
1000        // read the rest of rewind_test
1001        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
1002        assert_eq!(n, rewind_test.len() - message.len());
1003        assert_eq!(buffer[..n], rewind_test[message.len()..]);
1004
1005        // read the actual data
1006        notify.notify_one();
1007        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
1008        assert_eq!(n, message.len());
1009        assert_eq!(buffer, message);
1010    }
1011
1012    #[tokio::test]
1013    async fn test_stream_peek() {
1014        let message = b"hello world";
1015        dbg!("try peek");
1016        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1017        let addr = listener.local_addr().unwrap();
1018        let notify = Arc::new(Notify::new());
1019        let notify2 = notify.clone();
1020
1021        tokio::spawn(async move {
1022            let (mut stream, _) = listener.accept().await.unwrap();
1023            notify2.notified().await;
1024            stream.write_all(message).await.unwrap();
1025            drop(stream);
1026        });
1027
1028        notify.notify_one();
1029
1030        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
1031        let mut buffer = vec![0u8; 5];
1032        assert!(stream.try_peek(&mut buffer).await.unwrap());
1033        assert_eq!(buffer, message[0..5]);
1034        let mut buffer = vec![];
1035        stream.read_to_end(&mut buffer).await.unwrap();
1036        assert_eq!(buffer, message);
1037    }
1038
1039    #[tokio::test]
1040    async fn test_stream_two_subsequent_peek_calls_before_read() {
1041        let message = b"abcdefghijklmn";
1042
1043        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1044        let addr = listener.local_addr().unwrap();
1045        let notify = Arc::new(Notify::new());
1046        let notify2 = notify.clone();
1047
1048        tokio::spawn(async move {
1049            let (mut stream, _) = listener.accept().await.unwrap();
1050            notify2.notified().await;
1051            stream.write_all(message).await.unwrap();
1052            drop(stream);
1053        });
1054
1055        notify.notify_one();
1056
1057        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
1058
1059        // Peek 4 bytes
1060        let mut buffer = vec![0u8; 4];
1061        assert!(stream.try_peek(&mut buffer).await.unwrap());
1062        assert_eq!(buffer, message[0..4]);
1063
1064        // Peek 2 bytes
1065        let mut buffer = vec![0u8; 2];
1066        assert!(stream.try_peek(&mut buffer).await.unwrap());
1067        assert_eq!(buffer, message[0..2]);
1068
1069        // Read 1 byte: ['a']
1070        let mut buffer = vec![0u8; 1];
1071        stream.read_exact(&mut buffer).await.unwrap();
1072        assert_eq!(buffer, message[0..1]);
1073
1074        // Read as many bytes as possible, return 1 byte ['b']
1075        //  from the first retry buffer chunk
1076        let mut buffer = vec![0u8; 100];
1077        let n = stream.read(&mut buffer).await.unwrap();
1078        assert_eq!(n, 1);
1079        assert_eq!(buffer[..n], message[1..2]);
1080
1081        // Read the rest ['cdefghijklmn']
1082        let mut buffer = vec![];
1083        stream.read_to_end(&mut buffer).await.unwrap();
1084        assert_eq!(buffer, message[2..]);
1085    }
1086}