pingora_core/protocols/l4/
stream.rs

1// Copyright 2024 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Transport layer connection
16
17use async_trait::async_trait;
18use futures::FutureExt;
19use log::{debug, error};
20
21use pingora_error::{ErrorType::*, OrErr, Result};
22#[cfg(target_os = "linux")]
23use std::io::IoSliceMut;
24#[cfg(unix)]
25use std::os::unix::io::AsRawFd;
26#[cfg(windows)]
27use std::os::windows::io::AsRawSocket;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll};
31use std::time::{Duration, Instant, SystemTime};
32#[cfg(target_os = "linux")]
33use tokio::io::Interest;
34use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
35use tokio::net::TcpStream;
36#[cfg(unix)]
37use tokio::net::UnixStream;
38
39use crate::protocols::l4::ext::{set_tcp_keepalive, TcpKeepalive};
40use crate::protocols::raw_connect::ProxyDigest;
41use crate::protocols::{
42    GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, SocketDigest, Ssl,
43    TimingDigest, UniqueID, UniqueIDType,
44};
45use crate::upstreams::peer::Tracer;
46
47#[derive(Debug)]
48enum RawStream {
49    Tcp(TcpStream),
50    #[cfg(unix)]
51    Unix(UnixStream),
52}
53
54impl AsyncRead for RawStream {
55    fn poll_read(
56        self: Pin<&mut Self>,
57        cx: &mut Context<'_>,
58        buf: &mut ReadBuf<'_>,
59    ) -> Poll<io::Result<()>> {
60        // Safety: Basic enum pin projection
61        unsafe {
62            match &mut Pin::get_unchecked_mut(self) {
63                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
64                #[cfg(unix)]
65                RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
66            }
67        }
68    }
69}
70
71impl AsyncWrite for RawStream {
72    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
73        // Safety: Basic enum pin projection
74        unsafe {
75            match &mut Pin::get_unchecked_mut(self) {
76                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
77                #[cfg(unix)]
78                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
79            }
80        }
81    }
82
83    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
84        // Safety: Basic enum pin projection
85        unsafe {
86            match &mut Pin::get_unchecked_mut(self) {
87                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
88                #[cfg(unix)]
89                RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
90            }
91        }
92    }
93
94    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
95        // Safety: Basic enum pin projection
96        unsafe {
97            match &mut Pin::get_unchecked_mut(self) {
98                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
99                #[cfg(unix)]
100                RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
101            }
102        }
103    }
104
105    fn poll_write_vectored(
106        self: Pin<&mut Self>,
107        cx: &mut Context<'_>,
108        bufs: &[std::io::IoSlice<'_>],
109    ) -> Poll<io::Result<usize>> {
110        // Safety: Basic enum pin projection
111        unsafe {
112            match &mut Pin::get_unchecked_mut(self) {
113                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
114                #[cfg(unix)]
115                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
116            }
117        }
118    }
119
120    fn is_write_vectored(&self) -> bool {
121        match self {
122            RawStream::Tcp(s) => s.is_write_vectored(),
123            #[cfg(unix)]
124            RawStream::Unix(s) => s.is_write_vectored(),
125        }
126    }
127}
128
129#[cfg(unix)]
130impl AsRawFd for RawStream {
131    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
132        match self {
133            RawStream::Tcp(s) => s.as_raw_fd(),
134            RawStream::Unix(s) => s.as_raw_fd(),
135        }
136    }
137}
138
139#[cfg(windows)]
140impl AsRawSocket for RawStream {
141    fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
142        match self {
143            RawStream::Tcp(s) => s.as_raw_socket(),
144        }
145    }
146}
147
148#[derive(Debug)]
149struct RawStreamWrapper {
150    pub(crate) stream: RawStream,
151    /// store the last rx timestamp of the stream.
152    pub(crate) rx_ts: Option<SystemTime>,
153    /// enable reading rx timestamp
154    #[cfg(target_os = "linux")]
155    pub(crate) enable_rx_ts: bool,
156    #[cfg(target_os = "linux")]
157    /// This can be reused across multiple recvmsg calls. The cmsg buffer may
158    /// come from old sockets created by older version of pingora and so,
159    /// this vector can only grow.
160    reusable_cmsg_space: Vec<u8>,
161}
162
163impl RawStreamWrapper {
164    pub fn new(stream: RawStream) -> Self {
165        RawStreamWrapper {
166            stream,
167            rx_ts: None,
168            #[cfg(target_os = "linux")]
169            enable_rx_ts: false,
170            #[cfg(target_os = "linux")]
171            reusable_cmsg_space: nix::cmsg_space!(nix::sys::time::TimeSpec),
172        }
173    }
174
175    #[cfg(target_os = "linux")]
176    pub fn enable_rx_ts(&mut self, enable_rx_ts: bool) {
177        self.enable_rx_ts = enable_rx_ts;
178    }
179}
180
181impl AsyncRead for RawStreamWrapper {
182    #[cfg(not(target_os = "linux"))]
183    fn poll_read(
184        self: Pin<&mut Self>,
185        cx: &mut Context<'_>,
186        buf: &mut ReadBuf<'_>,
187    ) -> Poll<io::Result<()>> {
188        // Safety: Basic enum pin projection
189        unsafe {
190            let rs_wrapper = Pin::get_unchecked_mut(self);
191            match &mut rs_wrapper.stream {
192                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
193                #[cfg(unix)]
194                RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
195            }
196        }
197    }
198
199    #[cfg(target_os = "linux")]
200    fn poll_read(
201        self: Pin<&mut Self>,
202        cx: &mut Context<'_>,
203        buf: &mut ReadBuf<'_>,
204    ) -> Poll<io::Result<()>> {
205        use futures::ready;
206        use nix::sys::socket::{recvmsg, ControlMessageOwned, MsgFlags, SockaddrStorage};
207
208        // if we do not need rx timestamp, then use the standard path
209        if !self.enable_rx_ts {
210            // Safety: Basic enum pin projection
211            unsafe {
212                let rs_wrapper = Pin::get_unchecked_mut(self);
213                match &mut rs_wrapper.stream {
214                    RawStream::Tcp(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
215                    RawStream::Unix(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
216                }
217            }
218        }
219
220        // Safety: Basic pin projection to get mutable stream
221        let rs_wrapper = unsafe { Pin::get_unchecked_mut(self) };
222        match &mut rs_wrapper.stream {
223            RawStream::Tcp(s) => {
224                loop {
225                    ready!(s.poll_read_ready(cx))?;
226                    // Safety: maybe uninitialized bytes will only be passed to recvmsg
227                    let b = unsafe {
228                        &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>]
229                            as *mut [u8])
230                    };
231                    let mut iov = [IoSliceMut::new(b)];
232                    rs_wrapper.reusable_cmsg_space.clear();
233
234                    match s.try_io(Interest::READABLE, || {
235                        recvmsg::<SockaddrStorage>(
236                            s.as_raw_fd(),
237                            &mut iov,
238                            Some(&mut rs_wrapper.reusable_cmsg_space),
239                            MsgFlags::empty(),
240                        )
241                        .map_err(|errno| errno.into())
242                    }) {
243                        Ok(r) => {
244                            if let Some(ControlMessageOwned::ScmTimestampsns(rtime)) = r
245                                .cmsgs()
246                                .find(|i| matches!(i, ControlMessageOwned::ScmTimestampsns(_)))
247                            {
248                                // The returned timestamp is a real (i.e. not monotonic) timestamp
249                                // https://docs.kernel.org/networking/timestamping.html
250                                rs_wrapper.rx_ts =
251                                    SystemTime::UNIX_EPOCH.checked_add(rtime.system.into());
252                            }
253                            // Safety: We trust `recvmsg` to have filled up `r.bytes` bytes in the buffer.
254                            unsafe {
255                                buf.assume_init(r.bytes);
256                            }
257                            buf.advance(r.bytes);
258                            return Poll::Ready(Ok(()));
259                        }
260                        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
261                        Err(e) => return Poll::Ready(Err(e)),
262                    }
263                }
264            }
265            // Unix RX timestamp only works with datagram for now, so we do not care about it
266            RawStream::Unix(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) },
267        }
268    }
269}
270
271impl AsyncWrite for RawStreamWrapper {
272    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
273        // Safety: Basic enum pin projection
274        unsafe {
275            match &mut Pin::get_unchecked_mut(self).stream {
276                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
277                #[cfg(unix)]
278                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
279            }
280        }
281    }
282
283    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
284        // Safety: Basic enum pin projection
285        unsafe {
286            match &mut Pin::get_unchecked_mut(self).stream {
287                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
288                #[cfg(unix)]
289                RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
290            }
291        }
292    }
293
294    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
295        // Safety: Basic enum pin projection
296        unsafe {
297            match &mut Pin::get_unchecked_mut(self).stream {
298                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
299                #[cfg(unix)]
300                RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
301            }
302        }
303    }
304
305    fn poll_write_vectored(
306        self: Pin<&mut Self>,
307        cx: &mut Context<'_>,
308        bufs: &[std::io::IoSlice<'_>],
309    ) -> Poll<io::Result<usize>> {
310        // Safety: Basic enum pin projection
311        unsafe {
312            match &mut Pin::get_unchecked_mut(self).stream {
313                RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
314                #[cfg(unix)]
315                RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
316            }
317        }
318    }
319
320    fn is_write_vectored(&self) -> bool {
321        self.stream.is_write_vectored()
322    }
323}
324
325#[cfg(unix)]
326impl AsRawFd for RawStreamWrapper {
327    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
328        self.stream.as_raw_fd()
329    }
330}
331
332#[cfg(windows)]
333impl AsRawSocket for RawStreamWrapper {
334    fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
335        self.stream.as_raw_socket()
336    }
337}
338
339// Large read buffering helps reducing syscalls with little trade-off
340// Ssl layer always does "small" reads in 16k (TLS record size) so L4 read buffer helps a lot.
341const BUF_READ_SIZE: usize = 64 * 1024;
342// Small write buf to match MSS. Too large write buf delays real time communication.
343// This buffering effectively implements something similar to Nagle's algorithm.
344// The benefit is that user space can control when to flush, where Nagle's can't be controlled.
345// And userspace buffering reduce both syscalls and small packets.
346const BUF_WRITE_SIZE: usize = 1460;
347
348// NOTE: with writer buffering, users need to call flush() to make sure the data is actually
349// sent. Otherwise data could be stuck in the buffer forever or get lost when stream is closed.
350
351/// A concrete type for transport layer connection + extra fields for logging
352#[derive(Debug)]
353pub struct Stream {
354    stream: BufStream<RawStreamWrapper>,
355    // the data put back at the front of the read buffer, in order to replay the read
356    rewind_read_buf: Vec<u8>,
357    buffer_write: bool,
358    proxy_digest: Option<Arc<ProxyDigest>>,
359    socket_digest: Option<Arc<SocketDigest>>,
360    /// When this connection is established
361    pub established_ts: SystemTime,
362    /// The distributed tracing object for this stream
363    pub tracer: Option<Tracer>,
364    read_pending_time: AccumulatedDuration,
365    write_pending_time: AccumulatedDuration,
366    /// Last rx timestamp associated with the last recvmsg call.
367    pub rx_ts: Option<SystemTime>,
368}
369
370impl Stream {
371    /// set TCP nodelay for this connection if `self` is TCP
372    pub fn set_nodelay(&mut self) -> Result<()> {
373        if let RawStream::Tcp(s) = &self.stream.get_mut().stream {
374            s.set_nodelay(true)
375                .or_err(ConnectError, "failed to set_nodelay")?;
376        }
377        Ok(())
378    }
379
380    /// set TCP keepalive settings for this connection if `self` is TCP
381    pub fn set_keepalive(&mut self, ka: &TcpKeepalive) -> Result<()> {
382        if let RawStream::Tcp(s) = &self.stream.get_mut().stream {
383            debug!("Setting tcp keepalive");
384            set_tcp_keepalive(s, ka)?;
385        }
386        Ok(())
387    }
388
389    #[cfg(target_os = "linux")]
390    pub fn set_rx_timestamp(&mut self) -> Result<()> {
391        use nix::sys::socket::{setsockopt, sockopt, TimestampingFlag};
392
393        if let RawStream::Tcp(s) = &self.stream.get_mut().stream {
394            let timestamp_options = TimestampingFlag::SOF_TIMESTAMPING_RX_SOFTWARE
395                | TimestampingFlag::SOF_TIMESTAMPING_SOFTWARE;
396            setsockopt(s.as_raw_fd(), sockopt::Timestamping, &timestamp_options)
397                .or_err(InternalError, "failed to set SOF_TIMESTAMPING_RX_SOFTWARE")?;
398            self.stream.get_mut().enable_rx_ts(true);
399        }
400
401        Ok(())
402    }
403
404    #[cfg(not(target_os = "linux"))]
405    pub fn set_rx_timestamp(&mut self) -> io::Result<()> {
406        Ok(())
407    }
408
409    /// Put Some data back to the head of the stream to be read again
410    pub(crate) fn rewind(&mut self, data: &[u8]) {
411        if !data.is_empty() {
412            self.rewind_read_buf.extend_from_slice(data);
413        }
414    }
415}
416
417impl From<TcpStream> for Stream {
418    fn from(s: TcpStream) -> Self {
419        Stream {
420            stream: BufStream::with_capacity(
421                BUF_READ_SIZE,
422                BUF_WRITE_SIZE,
423                RawStreamWrapper::new(RawStream::Tcp(s)),
424            ),
425            rewind_read_buf: Vec::new(),
426            buffer_write: true,
427            established_ts: SystemTime::now(),
428            proxy_digest: None,
429            socket_digest: None,
430            tracer: None,
431            read_pending_time: AccumulatedDuration::new(),
432            write_pending_time: AccumulatedDuration::new(),
433            rx_ts: None,
434        }
435    }
436}
437
438#[cfg(unix)]
439impl From<UnixStream> for Stream {
440    fn from(s: UnixStream) -> Self {
441        Stream {
442            stream: BufStream::with_capacity(
443                BUF_READ_SIZE,
444                BUF_WRITE_SIZE,
445                RawStreamWrapper::new(RawStream::Unix(s)),
446            ),
447            rewind_read_buf: Vec::new(),
448            buffer_write: true,
449            established_ts: SystemTime::now(),
450            proxy_digest: None,
451            socket_digest: None,
452            tracer: None,
453            read_pending_time: AccumulatedDuration::new(),
454            write_pending_time: AccumulatedDuration::new(),
455            rx_ts: None,
456        }
457    }
458}
459
460#[cfg(unix)]
461impl AsRawFd for Stream {
462    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
463        self.stream.get_ref().as_raw_fd()
464    }
465}
466
467#[cfg(windows)]
468impl AsRawSocket for Stream {
469    fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
470        self.stream.get_ref().as_raw_socket()
471    }
472}
473
474#[cfg(unix)]
475impl UniqueID for Stream {
476    fn id(&self) -> UniqueIDType {
477        self.as_raw_fd()
478    }
479}
480
481#[cfg(windows)]
482impl UniqueID for Stream {
483    fn id(&self) -> usize {
484        self.as_raw_socket() as usize
485    }
486}
487
488impl Ssl for Stream {}
489
490#[async_trait]
491impl Peek for Stream {
492    async fn try_peek(&mut self, buf: &mut [u8]) -> std::io::Result<bool> {
493        use tokio::io::AsyncReadExt;
494        self.read_exact(buf).await?;
495        // rewind regardless of what is read
496        self.rewind(buf);
497        Ok(true)
498    }
499}
500
501#[async_trait]
502impl Shutdown for Stream {
503    async fn shutdown(&mut self) {
504        AsyncWriteExt::shutdown(self).await.unwrap_or_else(|e| {
505            debug!("Failed to shutdown connection: {:?}", e);
506        });
507    }
508}
509
510impl GetTimingDigest for Stream {
511    fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
512        let mut digest = Vec::with_capacity(2); // expect to have both L4 stream and TLS layer
513        digest.push(Some(TimingDigest {
514            established_ts: self.established_ts,
515        }));
516        digest
517    }
518
519    fn get_read_pending_time(&self) -> Duration {
520        self.read_pending_time.total
521    }
522
523    fn get_write_pending_time(&self) -> Duration {
524        self.write_pending_time.total
525    }
526}
527
528impl GetProxyDigest for Stream {
529    fn get_proxy_digest(&self) -> Option<Arc<ProxyDigest>> {
530        self.proxy_digest.clone()
531    }
532
533    fn set_proxy_digest(&mut self, digest: ProxyDigest) {
534        self.proxy_digest = Some(Arc::new(digest));
535    }
536}
537
538impl GetSocketDigest for Stream {
539    fn get_socket_digest(&self) -> Option<Arc<SocketDigest>> {
540        self.socket_digest.clone()
541    }
542
543    fn set_socket_digest(&mut self, socket_digest: SocketDigest) {
544        self.socket_digest = Some(Arc::new(socket_digest))
545    }
546}
547
548impl Drop for Stream {
549    fn drop(&mut self) {
550        if let Some(t) = self.tracer.as_ref() {
551            t.0.on_disconnected();
552        }
553        /* use nodelay/local_addr function to detect socket status */
554        let ret = match &self.stream.get_ref().stream {
555            RawStream::Tcp(s) => s.nodelay().err(),
556            #[cfg(unix)]
557            RawStream::Unix(s) => s.local_addr().err(),
558        };
559        if let Some(e) = ret {
560            match e.kind() {
561                tokio::io::ErrorKind::Other => {
562                    if let Some(ecode) = e.raw_os_error() {
563                        if ecode == 9 {
564                            // Or we could panic here
565                            error!("Crit: socket {:?} is being double closed", self.stream);
566                        }
567                    }
568                }
569                _ => {
570                    debug!("Socket is already broken {:?}", e);
571                }
572            }
573        } else {
574            // try flush the write buffer. We use now_or_never() because
575            // 1. Drop cannot be async
576            // 2. write should usually be ready, unless the buf is full.
577            let _ = self.flush().now_or_never();
578        }
579        debug!("Dropping socket {:?}", self.stream);
580    }
581}
582
583impl AsyncRead for Stream {
584    fn poll_read(
585        mut self: Pin<&mut Self>,
586        cx: &mut Context<'_>,
587        buf: &mut ReadBuf<'_>,
588    ) -> Poll<io::Result<()>> {
589        let result = if !self.rewind_read_buf.is_empty() {
590            let mut data_to_read = self.rewind_read_buf.as_slice();
591            let result = Pin::new(&mut data_to_read).poll_read(cx, buf);
592            // put the remaining data in another Vec
593            let remaining_buf = Vec::from(data_to_read);
594            let _ = std::mem::replace(&mut self.rewind_read_buf, remaining_buf);
595            result
596        } else {
597            Pin::new(&mut self.stream).poll_read(cx, buf)
598        };
599        self.read_pending_time.poll_time(&result);
600        self.rx_ts = self.stream.get_ref().rx_ts;
601        result
602    }
603}
604
605impl AsyncWrite for Stream {
606    fn poll_write(
607        mut self: Pin<&mut Self>,
608        cx: &mut Context,
609        buf: &[u8],
610    ) -> Poll<io::Result<usize>> {
611        let result = if self.buffer_write {
612            Pin::new(&mut self.stream).poll_write(cx, buf)
613        } else {
614            Pin::new(&mut self.stream.get_mut()).poll_write(cx, buf)
615        };
616        self.write_pending_time.poll_write_time(&result, buf.len());
617        result
618    }
619
620    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
621        let result = Pin::new(&mut self.stream).poll_flush(cx);
622        self.write_pending_time.poll_time(&result);
623        result
624    }
625
626    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
627        Pin::new(&mut self.stream).poll_shutdown(cx)
628    }
629
630    fn poll_write_vectored(
631        mut self: Pin<&mut Self>,
632        cx: &mut Context<'_>,
633        bufs: &[std::io::IoSlice<'_>],
634    ) -> Poll<io::Result<usize>> {
635        let total_size = bufs.iter().fold(0, |acc, s| acc + s.len());
636
637        let result = if self.buffer_write {
638            Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
639        } else {
640            Pin::new(&mut self.stream.get_mut()).poll_write_vectored(cx, bufs)
641        };
642
643        self.write_pending_time.poll_write_time(&result, total_size);
644        result
645    }
646
647    fn is_write_vectored(&self) -> bool {
648        if self.buffer_write {
649            self.stream.is_write_vectored() // it is true
650        } else {
651            self.stream.get_ref().is_write_vectored()
652        }
653    }
654}
655
656pub mod async_write_vec {
657    use bytes::Buf;
658    use futures::ready;
659    use std::future::Future;
660    use std::io::IoSlice;
661    use std::pin::Pin;
662    use std::task::{Context, Poll};
663    use tokio::io;
664    use tokio::io::AsyncWrite;
665
666    /*
667        the missing write_buf https://github.com/tokio-rs/tokio/pull/3156#issuecomment-738207409
668        https://github.com/tokio-rs/tokio/issues/2610
669        In general vectored write is lost when accessing the trait object: Box<S: AsyncWrite>
670    */
671
672    #[must_use = "futures do nothing unless you `.await` or poll them"]
673    pub struct WriteVec<'a, W, B> {
674        writer: &'a mut W,
675        buf: &'a mut B,
676    }
677
678    #[must_use = "futures do nothing unless you `.await` or poll them"]
679    pub struct WriteVecAll<'a, W, B> {
680        writer: &'a mut W,
681        buf: &'a mut B,
682    }
683
684    pub trait AsyncWriteVec {
685        fn poll_write_vec<B: Buf>(
686            self: Pin<&mut Self>,
687            _cx: &mut Context<'_>,
688            _buf: &mut B,
689        ) -> Poll<io::Result<usize>>;
690
691        fn write_vec<'a, B>(&'a mut self, src: &'a mut B) -> WriteVec<'a, Self, B>
692        where
693            Self: Sized,
694            B: Buf,
695        {
696            WriteVec {
697                writer: self,
698                buf: src,
699            }
700        }
701
702        fn write_vec_all<'a, B>(&'a mut self, src: &'a mut B) -> WriteVecAll<'a, Self, B>
703        where
704            Self: Sized,
705            B: Buf,
706        {
707            WriteVecAll {
708                writer: self,
709                buf: src,
710            }
711        }
712    }
713
714    impl<W, B> Future for WriteVec<'_, W, B>
715    where
716        W: AsyncWriteVec + Unpin,
717        B: Buf,
718    {
719        type Output = io::Result<usize>;
720
721        fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<usize>> {
722            let me = &mut *self;
723            Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf)
724        }
725    }
726
727    impl<W, B> Future for WriteVecAll<'_, W, B>
728    where
729        W: AsyncWriteVec + Unpin,
730        B: Buf,
731    {
732        type Output = io::Result<()>;
733
734        fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
735            let me = &mut *self;
736            while me.buf.has_remaining() {
737                let n = ready!(Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf))?;
738                if n == 0 {
739                    return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
740                }
741            }
742            Poll::Ready(Ok(()))
743        }
744    }
745
746    /* from https://github.com/tokio-rs/tokio/blob/master/tokio-util/src/lib.rs#L177 */
747    impl<T> AsyncWriteVec for T
748    where
749        T: AsyncWrite,
750    {
751        fn poll_write_vec<B: Buf>(
752            self: Pin<&mut Self>,
753            ctx: &mut Context,
754            buf: &mut B,
755        ) -> Poll<io::Result<usize>> {
756            const MAX_BUFS: usize = 64;
757
758            if !buf.has_remaining() {
759                return Poll::Ready(Ok(0));
760            }
761
762            let n = if self.is_write_vectored() {
763                let mut slices = [IoSlice::new(&[]); MAX_BUFS];
764                let cnt = buf.chunks_vectored(&mut slices);
765                ready!(self.poll_write_vectored(ctx, &slices[..cnt]))?
766            } else {
767                ready!(self.poll_write(ctx, buf.chunk()))?
768            };
769
770            buf.advance(n);
771
772            Poll::Ready(Ok(n))
773        }
774    }
775}
776
777pub use async_write_vec::AsyncWriteVec;
778
779#[derive(Debug)]
780struct AccumulatedDuration {
781    total: Duration,
782    last_start: Option<Instant>,
783}
784
785impl AccumulatedDuration {
786    fn new() -> Self {
787        AccumulatedDuration {
788            total: Duration::ZERO,
789            last_start: None,
790        }
791    }
792
793    fn start(&mut self) {
794        if self.last_start.is_none() {
795            self.last_start = Some(Instant::now());
796        }
797    }
798
799    fn stop(&mut self) {
800        if let Some(start) = self.last_start.take() {
801            self.total += start.elapsed();
802        }
803    }
804
805    fn poll_write_time(&mut self, result: &Poll<io::Result<usize>>, buf_size: usize) {
806        match result {
807            Poll::Ready(Ok(n)) => {
808                if *n == buf_size {
809                    self.stop();
810                } else {
811                    // partial write
812                    self.start();
813                }
814            }
815            Poll::Ready(Err(_)) => {
816                self.stop();
817            }
818            _ => self.start(),
819        }
820    }
821
822    fn poll_time(&mut self, result: &Poll<io::Result<()>>) {
823        match result {
824            Poll::Ready(_) => {
825                self.stop();
826            }
827            _ => self.start(),
828        }
829    }
830}
831
832#[cfg(test)]
833#[cfg(target_os = "linux")]
834mod tests {
835    use super::*;
836    use std::sync::Arc;
837    use tokio::io::AsyncReadExt;
838    use tokio::io::AsyncWriteExt;
839    use tokio::net::TcpListener;
840    use tokio::sync::Notify;
841
842    #[cfg(target_os = "linux")]
843    #[tokio::test]
844    async fn test_rx_timestamp() {
845        let message = "hello world".as_bytes();
846        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
847        let addr = listener.local_addr().unwrap();
848        let notify = Arc::new(Notify::new());
849        let notify2 = notify.clone();
850
851        tokio::spawn(async move {
852            let (mut stream, _) = listener.accept().await.unwrap();
853            notify2.notified().await;
854            stream.write_all(message).await.unwrap();
855        });
856
857        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
858        stream.set_rx_timestamp().unwrap();
859        // Receive the message
860        // setsockopt for SO_TIMESTAMPING is asynchronous so sleep a little bit
861        // to let kernel do the work
862        std::thread::sleep(Duration::from_micros(100));
863        notify.notify_one();
864
865        let mut buffer = vec![0u8; message.len()];
866        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
867        assert_eq!(n, message.len());
868        assert!(stream.rx_ts.is_some());
869    }
870
871    #[cfg(target_os = "linux")]
872    #[tokio::test]
873    async fn test_rx_timestamp_standard_path() {
874        let message = "hello world".as_bytes();
875        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
876        let addr = listener.local_addr().unwrap();
877        let notify = Arc::new(Notify::new());
878        let notify2 = notify.clone();
879
880        tokio::spawn(async move {
881            let (mut stream, _) = listener.accept().await.unwrap();
882            notify2.notified().await;
883            stream.write_all(message).await.unwrap();
884        });
885
886        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
887        std::thread::sleep(Duration::from_micros(100));
888        notify.notify_one();
889
890        let mut buffer = vec![0u8; message.len()];
891        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
892        assert_eq!(n, message.len());
893        assert!(stream.rx_ts.is_none());
894    }
895
896    #[tokio::test]
897    async fn test_stream_rewind() {
898        let message = b"hello world";
899        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
900        let addr = listener.local_addr().unwrap();
901        let notify = Arc::new(Notify::new());
902        let notify2 = notify.clone();
903
904        tokio::spawn(async move {
905            let (mut stream, _) = listener.accept().await.unwrap();
906            notify2.notified().await;
907            stream.write_all(message).await.unwrap();
908        });
909
910        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
911
912        let rewind_test = b"this is Sparta!";
913        stream.rewind(rewind_test);
914
915        // partially read rewind_test because of the buffer size limit
916        let mut buffer = vec![0u8; message.len()];
917        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
918        assert_eq!(n, message.len());
919        assert_eq!(buffer, rewind_test[..message.len()]);
920
921        // read the rest of rewind_test
922        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
923        assert_eq!(n, rewind_test.len() - message.len());
924        assert_eq!(buffer[..n], rewind_test[message.len()..]);
925
926        // read the actual data
927        notify.notify_one();
928        let n = stream.read(buffer.as_mut_slice()).await.unwrap();
929        assert_eq!(n, message.len());
930        assert_eq!(buffer, message);
931    }
932
933    #[tokio::test]
934    async fn test_stream_peek() {
935        let message = b"hello world";
936        dbg!("try peek");
937        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
938        let addr = listener.local_addr().unwrap();
939        let notify = Arc::new(Notify::new());
940        let notify2 = notify.clone();
941
942        tokio::spawn(async move {
943            let (mut stream, _) = listener.accept().await.unwrap();
944            notify2.notified().await;
945            stream.write_all(message).await.unwrap();
946            drop(stream);
947        });
948
949        notify.notify_one();
950
951        let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
952        let mut buffer = vec![0u8; 5];
953        assert!(stream.try_peek(&mut buffer).await.unwrap());
954        assert_eq!(buffer, message[0..5]);
955        let mut buffer = vec![];
956        stream.read_to_end(&mut buffer).await.unwrap();
957        assert_eq!(buffer, message);
958    }
959}