1use async_trait::async_trait;
18use futures::FutureExt;
19use log::{debug, error};
20
21use pingora_error::{ErrorType::*, OrErr, Result};
22#[cfg(target_os = "linux")]
23use std::io::IoSliceMut;
24#[cfg(unix)]
25use std::os::unix::io::AsRawFd;
26#[cfg(windows)]
27use std::os::windows::io::AsRawSocket;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll};
31use std::time::{Duration, Instant, SystemTime};
32#[cfg(target_os = "linux")]
33use tokio::io::Interest;
34use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
35use tokio::net::TcpStream;
36#[cfg(unix)]
37use tokio::net::UnixStream;
38
39use crate::protocols::l4::ext::{set_tcp_keepalive, TcpKeepalive};
40use crate::protocols::raw_connect::ProxyDigest;
41use crate::protocols::{
42 GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, SocketDigest, Ssl,
43 TimingDigest, UniqueID, UniqueIDType,
44};
45use crate::upstreams::peer::Tracer;
46
47#[derive(Debug)]
48enum RawStream {
49 Tcp(TcpStream),
50 #[cfg(unix)]
51 Unix(UnixStream),
52}
53
54impl AsyncRead for RawStream {
55 fn poll_read(
56 self: Pin<&mut Self>,
57 cx: &mut Context<'_>,
58 buf: &mut ReadBuf<'_>,
59 ) -> Poll<io::Result<()>> {
60 unsafe {
62 match &mut Pin::get_unchecked_mut(self) {
63 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
64 #[cfg(unix)]
65 RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
66 }
67 }
68 }
69}
70
71impl AsyncWrite for RawStream {
72 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
73 unsafe {
75 match &mut Pin::get_unchecked_mut(self) {
76 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
77 #[cfg(unix)]
78 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
79 }
80 }
81 }
82
83 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
84 unsafe {
86 match &mut Pin::get_unchecked_mut(self) {
87 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
88 #[cfg(unix)]
89 RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
90 }
91 }
92 }
93
94 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
95 unsafe {
97 match &mut Pin::get_unchecked_mut(self) {
98 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
99 #[cfg(unix)]
100 RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
101 }
102 }
103 }
104
105 fn poll_write_vectored(
106 self: Pin<&mut Self>,
107 cx: &mut Context<'_>,
108 bufs: &[std::io::IoSlice<'_>],
109 ) -> Poll<io::Result<usize>> {
110 unsafe {
112 match &mut Pin::get_unchecked_mut(self) {
113 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
114 #[cfg(unix)]
115 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
116 }
117 }
118 }
119
120 fn is_write_vectored(&self) -> bool {
121 match self {
122 RawStream::Tcp(s) => s.is_write_vectored(),
123 #[cfg(unix)]
124 RawStream::Unix(s) => s.is_write_vectored(),
125 }
126 }
127}
128
129#[cfg(unix)]
130impl AsRawFd for RawStream {
131 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
132 match self {
133 RawStream::Tcp(s) => s.as_raw_fd(),
134 RawStream::Unix(s) => s.as_raw_fd(),
135 }
136 }
137}
138
139#[cfg(windows)]
140impl AsRawSocket for RawStream {
141 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
142 match self {
143 RawStream::Tcp(s) => s.as_raw_socket(),
144 }
145 }
146}
147
148#[derive(Debug)]
149struct RawStreamWrapper {
150 pub(crate) stream: RawStream,
151 pub(crate) rx_ts: Option<SystemTime>,
153 #[cfg(target_os = "linux")]
155 pub(crate) enable_rx_ts: bool,
156 #[cfg(target_os = "linux")]
157 reusable_cmsg_space: Vec<u8>,
161}
162
163impl RawStreamWrapper {
164 pub fn new(stream: RawStream) -> Self {
165 RawStreamWrapper {
166 stream,
167 rx_ts: None,
168 #[cfg(target_os = "linux")]
169 enable_rx_ts: false,
170 #[cfg(target_os = "linux")]
171 reusable_cmsg_space: nix::cmsg_space!(nix::sys::time::TimeSpec),
172 }
173 }
174
175 #[cfg(target_os = "linux")]
176 pub fn enable_rx_ts(&mut self, enable_rx_ts: bool) {
177 self.enable_rx_ts = enable_rx_ts;
178 }
179}
180
181impl AsyncRead for RawStreamWrapper {
182 #[cfg(not(target_os = "linux"))]
183 fn poll_read(
184 self: Pin<&mut Self>,
185 cx: &mut Context<'_>,
186 buf: &mut ReadBuf<'_>,
187 ) -> Poll<io::Result<()>> {
188 unsafe {
190 let rs_wrapper = Pin::get_unchecked_mut(self);
191 match &mut rs_wrapper.stream {
192 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
193 #[cfg(unix)]
194 RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
195 }
196 }
197 }
198
199 #[cfg(target_os = "linux")]
200 fn poll_read(
201 self: Pin<&mut Self>,
202 cx: &mut Context<'_>,
203 buf: &mut ReadBuf<'_>,
204 ) -> Poll<io::Result<()>> {
205 use futures::ready;
206 use nix::sys::socket::{recvmsg, ControlMessageOwned, MsgFlags, SockaddrStorage};
207
208 if !self.enable_rx_ts {
210 unsafe {
212 let rs_wrapper = Pin::get_unchecked_mut(self);
213 match &mut rs_wrapper.stream {
214 RawStream::Tcp(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
215 RawStream::Unix(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
216 }
217 }
218 }
219
220 let rs_wrapper = unsafe { Pin::get_unchecked_mut(self) };
222 match &mut rs_wrapper.stream {
223 RawStream::Tcp(s) => {
224 loop {
225 ready!(s.poll_read_ready(cx))?;
226 let b = unsafe {
228 &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>]
229 as *mut [u8])
230 };
231 let mut iov = [IoSliceMut::new(b)];
232 rs_wrapper.reusable_cmsg_space.clear();
233
234 match s.try_io(Interest::READABLE, || {
235 recvmsg::<SockaddrStorage>(
236 s.as_raw_fd(),
237 &mut iov,
238 Some(&mut rs_wrapper.reusable_cmsg_space),
239 MsgFlags::empty(),
240 )
241 .map_err(|errno| errno.into())
242 }) {
243 Ok(r) => {
244 if let Some(ControlMessageOwned::ScmTimestampsns(rtime)) = r
245 .cmsgs()
246 .find(|i| matches!(i, ControlMessageOwned::ScmTimestampsns(_)))
247 {
248 rs_wrapper.rx_ts =
251 SystemTime::UNIX_EPOCH.checked_add(rtime.system.into());
252 }
253 unsafe {
255 buf.assume_init(r.bytes);
256 }
257 buf.advance(r.bytes);
258 return Poll::Ready(Ok(()));
259 }
260 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
261 Err(e) => return Poll::Ready(Err(e)),
262 }
263 }
264 }
265 RawStream::Unix(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) },
267 }
268 }
269}
270
271impl AsyncWrite for RawStreamWrapper {
272 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
273 unsafe {
275 match &mut Pin::get_unchecked_mut(self).stream {
276 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
277 #[cfg(unix)]
278 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
279 }
280 }
281 }
282
283 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
284 unsafe {
286 match &mut Pin::get_unchecked_mut(self).stream {
287 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
288 #[cfg(unix)]
289 RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
290 }
291 }
292 }
293
294 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
295 unsafe {
297 match &mut Pin::get_unchecked_mut(self).stream {
298 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
299 #[cfg(unix)]
300 RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
301 }
302 }
303 }
304
305 fn poll_write_vectored(
306 self: Pin<&mut Self>,
307 cx: &mut Context<'_>,
308 bufs: &[std::io::IoSlice<'_>],
309 ) -> Poll<io::Result<usize>> {
310 unsafe {
312 match &mut Pin::get_unchecked_mut(self).stream {
313 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
314 #[cfg(unix)]
315 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
316 }
317 }
318 }
319
320 fn is_write_vectored(&self) -> bool {
321 self.stream.is_write_vectored()
322 }
323}
324
325#[cfg(unix)]
326impl AsRawFd for RawStreamWrapper {
327 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
328 self.stream.as_raw_fd()
329 }
330}
331
332#[cfg(windows)]
333impl AsRawSocket for RawStreamWrapper {
334 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
335 self.stream.as_raw_socket()
336 }
337}
338
339const BUF_READ_SIZE: usize = 64 * 1024;
342const BUF_WRITE_SIZE: usize = 1460;
347
348#[derive(Debug)]
353pub struct Stream {
354 stream: Option<BufStream<RawStreamWrapper>>,
356 rewind_read_buf: Vec<u8>,
358 buffer_write: bool,
359 proxy_digest: Option<Arc<ProxyDigest>>,
360 socket_digest: Option<Arc<SocketDigest>>,
361 pub established_ts: SystemTime,
363 pub tracer: Option<Tracer>,
365 read_pending_time: AccumulatedDuration,
366 write_pending_time: AccumulatedDuration,
367 pub rx_ts: Option<SystemTime>,
369}
370
371impl Stream {
372 fn stream(&self) -> &BufStream<RawStreamWrapper> {
373 self.stream.as_ref().expect("stream should always be set")
374 }
375
376 fn stream_mut(&mut self) -> &mut BufStream<RawStreamWrapper> {
377 self.stream.as_mut().expect("stream should always be set")
378 }
379
380 pub fn set_nodelay(&mut self) -> Result<()> {
382 if let RawStream::Tcp(s) = &self.stream_mut().get_mut().stream {
383 s.set_nodelay(true)
384 .or_err(ConnectError, "failed to set_nodelay")?;
385 }
386 Ok(())
387 }
388
389 pub fn set_keepalive(&mut self, ka: &TcpKeepalive) -> Result<()> {
391 if let RawStream::Tcp(s) = &self.stream_mut().get_mut().stream {
392 debug!("Setting tcp keepalive");
393 set_tcp_keepalive(s, ka)?;
394 }
395 Ok(())
396 }
397
398 #[cfg(target_os = "linux")]
399 pub fn set_rx_timestamp(&mut self) -> Result<()> {
400 use nix::sys::socket::{setsockopt, sockopt, TimestampingFlag};
401
402 if let RawStream::Tcp(s) = &self.stream_mut().get_mut().stream {
403 let timestamp_options = TimestampingFlag::SOF_TIMESTAMPING_RX_SOFTWARE
404 | TimestampingFlag::SOF_TIMESTAMPING_SOFTWARE;
405 setsockopt(s.as_raw_fd(), sockopt::Timestamping, ×tamp_options)
406 .or_err(InternalError, "failed to set SOF_TIMESTAMPING_RX_SOFTWARE")?;
407 self.stream_mut().get_mut().enable_rx_ts(true);
408 }
409
410 Ok(())
411 }
412
413 #[cfg(not(target_os = "linux"))]
414 pub fn set_rx_timestamp(&mut self) -> io::Result<()> {
415 Ok(())
416 }
417
418 pub(crate) fn rewind(&mut self, data: &[u8]) {
420 if !data.is_empty() {
421 self.rewind_read_buf.extend_from_slice(data);
422 }
423 }
424
425 pub(crate) fn set_buffer(&mut self) {
428 use std::mem;
429 let stream = mem::take(&mut self.stream);
432 let stream =
433 stream.map(|s| BufStream::with_capacity(BUF_READ_SIZE, BUF_WRITE_SIZE, s.into_inner()));
434 let _ = mem::replace(&mut self.stream, stream);
435 }
436}
437
438impl From<TcpStream> for Stream {
439 fn from(s: TcpStream) -> Self {
440 Stream {
441 stream: Some(BufStream::with_capacity(
442 0,
443 0,
444 RawStreamWrapper::new(RawStream::Tcp(s)),
445 )),
446 rewind_read_buf: Vec::new(),
447 buffer_write: true,
448 established_ts: SystemTime::now(),
449 proxy_digest: None,
450 socket_digest: None,
451 tracer: None,
452 read_pending_time: AccumulatedDuration::new(),
453 write_pending_time: AccumulatedDuration::new(),
454 rx_ts: None,
455 }
456 }
457}
458
459#[cfg(unix)]
460impl From<UnixStream> for Stream {
461 fn from(s: UnixStream) -> Self {
462 Stream {
463 stream: Some(BufStream::with_capacity(
464 0,
465 0,
466 RawStreamWrapper::new(RawStream::Unix(s)),
467 )),
468 rewind_read_buf: Vec::new(),
469 buffer_write: true,
470 established_ts: SystemTime::now(),
471 proxy_digest: None,
472 socket_digest: None,
473 tracer: None,
474 read_pending_time: AccumulatedDuration::new(),
475 write_pending_time: AccumulatedDuration::new(),
476 rx_ts: None,
477 }
478 }
479}
480
481#[cfg(unix)]
482impl AsRawFd for Stream {
483 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
484 self.stream().get_ref().as_raw_fd()
485 }
486}
487
488#[cfg(windows)]
489impl AsRawSocket for Stream {
490 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
491 self.stream().get_ref().as_raw_socket()
492 }
493}
494
495#[cfg(unix)]
496impl UniqueID for Stream {
497 fn id(&self) -> UniqueIDType {
498 self.as_raw_fd()
499 }
500}
501
502#[cfg(windows)]
503impl UniqueID for Stream {
504 fn id(&self) -> usize {
505 self.as_raw_socket() as usize
506 }
507}
508
509impl Ssl for Stream {}
510
511#[async_trait]
512impl Peek for Stream {
513 async fn try_peek(&mut self, buf: &mut [u8]) -> std::io::Result<bool> {
514 use tokio::io::AsyncReadExt;
515 self.read_exact(buf).await?;
516 self.rewind(buf);
518 Ok(true)
519 }
520}
521
522#[async_trait]
523impl Shutdown for Stream {
524 async fn shutdown(&mut self) {
525 AsyncWriteExt::shutdown(self).await.unwrap_or_else(|e| {
526 debug!("Failed to shutdown connection: {:?}", e);
527 });
528 }
529}
530
531impl GetTimingDigest for Stream {
532 fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
533 let mut digest = Vec::with_capacity(2); digest.push(Some(TimingDigest {
535 established_ts: self.established_ts,
536 }));
537 digest
538 }
539
540 fn get_read_pending_time(&self) -> Duration {
541 self.read_pending_time.total
542 }
543
544 fn get_write_pending_time(&self) -> Duration {
545 self.write_pending_time.total
546 }
547}
548
549impl GetProxyDigest for Stream {
550 fn get_proxy_digest(&self) -> Option<Arc<ProxyDigest>> {
551 self.proxy_digest.clone()
552 }
553
554 fn set_proxy_digest(&mut self, digest: ProxyDigest) {
555 self.proxy_digest = Some(Arc::new(digest));
556 }
557}
558
559impl GetSocketDigest for Stream {
560 fn get_socket_digest(&self) -> Option<Arc<SocketDigest>> {
561 self.socket_digest.clone()
562 }
563
564 fn set_socket_digest(&mut self, socket_digest: SocketDigest) {
565 self.socket_digest = Some(Arc::new(socket_digest))
566 }
567}
568
569impl Drop for Stream {
570 fn drop(&mut self) {
571 if let Some(t) = self.tracer.as_ref() {
572 t.0.on_disconnected();
573 }
574 let ret = match &self.stream().get_ref().stream {
576 RawStream::Tcp(s) => s.nodelay().err(),
577 #[cfg(unix)]
578 RawStream::Unix(s) => s.local_addr().err(),
579 };
580 if let Some(e) = ret {
581 match e.kind() {
582 tokio::io::ErrorKind::Other => {
583 if let Some(ecode) = e.raw_os_error() {
584 if ecode == 9 {
585 error!("Crit: socket {:?} is being double closed", self.stream);
587 }
588 }
589 }
590 _ => {
591 debug!("Socket is already broken {:?}", e);
592 }
593 }
594 } else {
595 let _ = self.flush().now_or_never();
599 }
600 debug!("Dropping socket {:?}", self.stream);
601 }
602}
603
604impl AsyncRead for Stream {
605 fn poll_read(
606 mut self: Pin<&mut Self>,
607 cx: &mut Context<'_>,
608 buf: &mut ReadBuf<'_>,
609 ) -> Poll<io::Result<()>> {
610 let result = if !self.rewind_read_buf.is_empty() {
611 let mut data_to_read = self.rewind_read_buf.as_slice();
612 let result = Pin::new(&mut data_to_read).poll_read(cx, buf);
613 let remaining_buf = Vec::from(data_to_read);
615 let _ = std::mem::replace(&mut self.rewind_read_buf, remaining_buf);
616 result
617 } else {
618 Pin::new(&mut self.stream_mut()).poll_read(cx, buf)
619 };
620 self.read_pending_time.poll_time(&result);
621 self.rx_ts = self.stream().get_ref().rx_ts;
622 result
623 }
624}
625
626impl AsyncWrite for Stream {
627 fn poll_write(
628 mut self: Pin<&mut Self>,
629 cx: &mut Context,
630 buf: &[u8],
631 ) -> Poll<io::Result<usize>> {
632 let result = if self.buffer_write {
633 Pin::new(&mut self.stream_mut()).poll_write(cx, buf)
634 } else {
635 Pin::new(&mut self.stream_mut().get_mut()).poll_write(cx, buf)
636 };
637 self.write_pending_time.poll_write_time(&result, buf.len());
638 result
639 }
640
641 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
642 let result = Pin::new(&mut self.stream_mut()).poll_flush(cx);
643 self.write_pending_time.poll_time(&result);
644 result
645 }
646
647 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
648 Pin::new(&mut self.stream_mut()).poll_shutdown(cx)
649 }
650
651 fn poll_write_vectored(
652 mut self: Pin<&mut Self>,
653 cx: &mut Context<'_>,
654 bufs: &[std::io::IoSlice<'_>],
655 ) -> Poll<io::Result<usize>> {
656 let total_size = bufs.iter().fold(0, |acc, s| acc + s.len());
657
658 let result = if self.buffer_write {
659 Pin::new(&mut self.stream_mut()).poll_write_vectored(cx, bufs)
660 } else {
661 Pin::new(&mut self.stream_mut().get_mut()).poll_write_vectored(cx, bufs)
662 };
663
664 self.write_pending_time.poll_write_time(&result, total_size);
665 result
666 }
667
668 fn is_write_vectored(&self) -> bool {
669 if self.buffer_write {
670 self.stream().is_write_vectored() } else {
672 self.stream().get_ref().is_write_vectored()
673 }
674 }
675}
676
677pub mod async_write_vec {
678 use bytes::Buf;
679 use futures::ready;
680 use std::future::Future;
681 use std::io::IoSlice;
682 use std::pin::Pin;
683 use std::task::{Context, Poll};
684 use tokio::io;
685 use tokio::io::AsyncWrite;
686
687 #[must_use = "futures do nothing unless you `.await` or poll them"]
694 pub struct WriteVec<'a, W, B> {
695 writer: &'a mut W,
696 buf: &'a mut B,
697 }
698
699 #[must_use = "futures do nothing unless you `.await` or poll them"]
700 pub struct WriteVecAll<'a, W, B> {
701 writer: &'a mut W,
702 buf: &'a mut B,
703 }
704
705 pub trait AsyncWriteVec {
706 fn poll_write_vec<B: Buf>(
707 self: Pin<&mut Self>,
708 _cx: &mut Context<'_>,
709 _buf: &mut B,
710 ) -> Poll<io::Result<usize>>;
711
712 fn write_vec<'a, B>(&'a mut self, src: &'a mut B) -> WriteVec<'a, Self, B>
713 where
714 Self: Sized,
715 B: Buf,
716 {
717 WriteVec {
718 writer: self,
719 buf: src,
720 }
721 }
722
723 fn write_vec_all<'a, B>(&'a mut self, src: &'a mut B) -> WriteVecAll<'a, Self, B>
724 where
725 Self: Sized,
726 B: Buf,
727 {
728 WriteVecAll {
729 writer: self,
730 buf: src,
731 }
732 }
733 }
734
735 impl<W, B> Future for WriteVec<'_, W, B>
736 where
737 W: AsyncWriteVec + Unpin,
738 B: Buf,
739 {
740 type Output = io::Result<usize>;
741
742 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<usize>> {
743 let me = &mut *self;
744 Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf)
745 }
746 }
747
748 impl<W, B> Future for WriteVecAll<'_, W, B>
749 where
750 W: AsyncWriteVec + Unpin,
751 B: Buf,
752 {
753 type Output = io::Result<()>;
754
755 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
756 let me = &mut *self;
757 while me.buf.has_remaining() {
758 let n = ready!(Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf))?;
759 if n == 0 {
760 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
761 }
762 }
763 Poll::Ready(Ok(()))
764 }
765 }
766
767 impl<T> AsyncWriteVec for T
769 where
770 T: AsyncWrite,
771 {
772 fn poll_write_vec<B: Buf>(
773 self: Pin<&mut Self>,
774 ctx: &mut Context,
775 buf: &mut B,
776 ) -> Poll<io::Result<usize>> {
777 const MAX_BUFS: usize = 64;
778
779 if !buf.has_remaining() {
780 return Poll::Ready(Ok(0));
781 }
782
783 let n = if self.is_write_vectored() {
784 let mut slices = [IoSlice::new(&[]); MAX_BUFS];
785 let cnt = buf.chunks_vectored(&mut slices);
786 ready!(self.poll_write_vectored(ctx, &slices[..cnt]))?
787 } else {
788 ready!(self.poll_write(ctx, buf.chunk()))?
789 };
790
791 buf.advance(n);
792
793 Poll::Ready(Ok(n))
794 }
795 }
796}
797
798pub use async_write_vec::AsyncWriteVec;
799
800#[derive(Debug)]
801struct AccumulatedDuration {
802 total: Duration,
803 last_start: Option<Instant>,
804}
805
806impl AccumulatedDuration {
807 fn new() -> Self {
808 AccumulatedDuration {
809 total: Duration::ZERO,
810 last_start: None,
811 }
812 }
813
814 fn start(&mut self) {
815 if self.last_start.is_none() {
816 self.last_start = Some(Instant::now());
817 }
818 }
819
820 fn stop(&mut self) {
821 if let Some(start) = self.last_start.take() {
822 self.total += start.elapsed();
823 }
824 }
825
826 fn poll_write_time(&mut self, result: &Poll<io::Result<usize>>, buf_size: usize) {
827 match result {
828 Poll::Ready(Ok(n)) => {
829 if *n == buf_size {
830 self.stop();
831 } else {
832 self.start();
834 }
835 }
836 Poll::Ready(Err(_)) => {
837 self.stop();
838 }
839 _ => self.start(),
840 }
841 }
842
843 fn poll_time(&mut self, result: &Poll<io::Result<()>>) {
844 match result {
845 Poll::Ready(_) => {
846 self.stop();
847 }
848 _ => self.start(),
849 }
850 }
851}
852
853#[cfg(test)]
854#[cfg(target_os = "linux")]
855mod tests {
856 use super::*;
857 use std::sync::Arc;
858 use tokio::io::AsyncReadExt;
859 use tokio::io::AsyncWriteExt;
860 use tokio::net::TcpListener;
861 use tokio::sync::Notify;
862
863 #[cfg(target_os = "linux")]
864 #[tokio::test]
865 async fn test_rx_timestamp() {
866 let message = "hello world".as_bytes();
867 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
868 let addr = listener.local_addr().unwrap();
869 let notify = Arc::new(Notify::new());
870 let notify2 = notify.clone();
871
872 tokio::spawn(async move {
873 let (mut stream, _) = listener.accept().await.unwrap();
874 notify2.notified().await;
875 stream.write_all(message).await.unwrap();
876 });
877
878 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
879 stream.set_rx_timestamp().unwrap();
880 std::thread::sleep(Duration::from_micros(100));
884 notify.notify_one();
885
886 let mut buffer = vec![0u8; message.len()];
887 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
888 assert_eq!(n, message.len());
889 assert!(stream.rx_ts.is_some());
890 }
891
892 #[cfg(target_os = "linux")]
893 #[tokio::test]
894 async fn test_rx_timestamp_standard_path() {
895 let message = "hello world".as_bytes();
896 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
897 let addr = listener.local_addr().unwrap();
898 let notify = Arc::new(Notify::new());
899 let notify2 = notify.clone();
900
901 tokio::spawn(async move {
902 let (mut stream, _) = listener.accept().await.unwrap();
903 notify2.notified().await;
904 stream.write_all(message).await.unwrap();
905 });
906
907 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
908 std::thread::sleep(Duration::from_micros(100));
909 notify.notify_one();
910
911 let mut buffer = vec![0u8; message.len()];
912 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
913 assert_eq!(n, message.len());
914 assert!(stream.rx_ts.is_none());
915 }
916
917 #[tokio::test]
918 async fn test_stream_rewind() {
919 let message = b"hello world";
920 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
921 let addr = listener.local_addr().unwrap();
922 let notify = Arc::new(Notify::new());
923 let notify2 = notify.clone();
924
925 tokio::spawn(async move {
926 let (mut stream, _) = listener.accept().await.unwrap();
927 notify2.notified().await;
928 stream.write_all(message).await.unwrap();
929 });
930
931 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
932
933 let rewind_test = b"this is Sparta!";
934 stream.rewind(rewind_test);
935
936 let mut buffer = vec![0u8; message.len()];
938 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
939 assert_eq!(n, message.len());
940 assert_eq!(buffer, rewind_test[..message.len()]);
941
942 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
944 assert_eq!(n, rewind_test.len() - message.len());
945 assert_eq!(buffer[..n], rewind_test[message.len()..]);
946
947 notify.notify_one();
949 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
950 assert_eq!(n, message.len());
951 assert_eq!(buffer, message);
952 }
953
954 #[tokio::test]
955 async fn test_stream_peek() {
956 let message = b"hello world";
957 dbg!("try peek");
958 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
959 let addr = listener.local_addr().unwrap();
960 let notify = Arc::new(Notify::new());
961 let notify2 = notify.clone();
962
963 tokio::spawn(async move {
964 let (mut stream, _) = listener.accept().await.unwrap();
965 notify2.notified().await;
966 stream.write_all(message).await.unwrap();
967 drop(stream);
968 });
969
970 notify.notify_one();
971
972 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
973 let mut buffer = vec![0u8; 5];
974 assert!(stream.try_peek(&mut buffer).await.unwrap());
975 assert_eq!(buffer, message[0..5]);
976 let mut buffer = vec![];
977 stream.read_to_end(&mut buffer).await.unwrap();
978 assert_eq!(buffer, message);
979 }
980}