1use async_trait::async_trait;
18use futures::FutureExt;
19use log::{debug, error};
20
21use pingora_error::{ErrorType::*, OrErr, Result};
22#[cfg(target_os = "linux")]
23use std::io::IoSliceMut;
24#[cfg(unix)]
25use std::os::unix::io::AsRawFd;
26#[cfg(windows)]
27use std::os::windows::io::AsRawSocket;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll};
31use std::time::{Duration, Instant, SystemTime};
32#[cfg(target_os = "linux")]
33use tokio::io::Interest;
34use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
35use tokio::net::TcpStream;
36#[cfg(unix)]
37use tokio::net::UnixStream;
38
39use crate::protocols::l4::ext::{set_tcp_keepalive, TcpKeepalive};
40use crate::protocols::raw_connect::ProxyDigest;
41use crate::protocols::{
42 GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, SocketDigest, Ssl,
43 TimingDigest, UniqueID, UniqueIDType,
44};
45use crate::upstreams::peer::Tracer;
46
47#[derive(Debug)]
48enum RawStream {
49 Tcp(TcpStream),
50 #[cfg(unix)]
51 Unix(UnixStream),
52}
53
54impl AsyncRead for RawStream {
55 fn poll_read(
56 self: Pin<&mut Self>,
57 cx: &mut Context<'_>,
58 buf: &mut ReadBuf<'_>,
59 ) -> Poll<io::Result<()>> {
60 unsafe {
62 match &mut Pin::get_unchecked_mut(self) {
63 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
64 #[cfg(unix)]
65 RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
66 }
67 }
68 }
69}
70
71impl AsyncWrite for RawStream {
72 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
73 unsafe {
75 match &mut Pin::get_unchecked_mut(self) {
76 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
77 #[cfg(unix)]
78 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
79 }
80 }
81 }
82
83 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
84 unsafe {
86 match &mut Pin::get_unchecked_mut(self) {
87 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
88 #[cfg(unix)]
89 RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
90 }
91 }
92 }
93
94 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
95 unsafe {
97 match &mut Pin::get_unchecked_mut(self) {
98 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
99 #[cfg(unix)]
100 RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
101 }
102 }
103 }
104
105 fn poll_write_vectored(
106 self: Pin<&mut Self>,
107 cx: &mut Context<'_>,
108 bufs: &[std::io::IoSlice<'_>],
109 ) -> Poll<io::Result<usize>> {
110 unsafe {
112 match &mut Pin::get_unchecked_mut(self) {
113 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
114 #[cfg(unix)]
115 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
116 }
117 }
118 }
119
120 fn is_write_vectored(&self) -> bool {
121 match self {
122 RawStream::Tcp(s) => s.is_write_vectored(),
123 #[cfg(unix)]
124 RawStream::Unix(s) => s.is_write_vectored(),
125 }
126 }
127}
128
129#[cfg(unix)]
130impl AsRawFd for RawStream {
131 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
132 match self {
133 RawStream::Tcp(s) => s.as_raw_fd(),
134 RawStream::Unix(s) => s.as_raw_fd(),
135 }
136 }
137}
138
139#[cfg(windows)]
140impl AsRawSocket for RawStream {
141 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
142 match self {
143 RawStream::Tcp(s) => s.as_raw_socket(),
144 }
145 }
146}
147
148#[derive(Debug)]
149struct RawStreamWrapper {
150 pub(crate) stream: RawStream,
151 pub(crate) rx_ts: Option<SystemTime>,
153 #[cfg(target_os = "linux")]
155 pub(crate) enable_rx_ts: bool,
156 #[cfg(target_os = "linux")]
157 reusable_cmsg_space: Vec<u8>,
161}
162
163impl RawStreamWrapper {
164 pub fn new(stream: RawStream) -> Self {
165 RawStreamWrapper {
166 stream,
167 rx_ts: None,
168 #[cfg(target_os = "linux")]
169 enable_rx_ts: false,
170 #[cfg(target_os = "linux")]
171 reusable_cmsg_space: nix::cmsg_space!(nix::sys::time::TimeSpec),
172 }
173 }
174
175 #[cfg(target_os = "linux")]
176 pub fn enable_rx_ts(&mut self, enable_rx_ts: bool) {
177 self.enable_rx_ts = enable_rx_ts;
178 }
179}
180
181impl AsyncRead for RawStreamWrapper {
182 #[cfg(not(target_os = "linux"))]
183 fn poll_read(
184 self: Pin<&mut Self>,
185 cx: &mut Context<'_>,
186 buf: &mut ReadBuf<'_>,
187 ) -> Poll<io::Result<()>> {
188 unsafe {
190 let rs_wrapper = Pin::get_unchecked_mut(self);
191 match &mut rs_wrapper.stream {
192 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
193 #[cfg(unix)]
194 RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
195 }
196 }
197 }
198
199 #[cfg(target_os = "linux")]
200 fn poll_read(
201 self: Pin<&mut Self>,
202 cx: &mut Context<'_>,
203 buf: &mut ReadBuf<'_>,
204 ) -> Poll<io::Result<()>> {
205 use futures::ready;
206 use nix::sys::socket::{recvmsg, ControlMessageOwned, MsgFlags, SockaddrStorage};
207
208 if !self.enable_rx_ts {
210 unsafe {
212 let rs_wrapper = Pin::get_unchecked_mut(self);
213 match &mut rs_wrapper.stream {
214 RawStream::Tcp(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
215 RawStream::Unix(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
216 }
217 }
218 }
219
220 let rs_wrapper = unsafe { Pin::get_unchecked_mut(self) };
222 match &mut rs_wrapper.stream {
223 RawStream::Tcp(s) => {
224 loop {
225 ready!(s.poll_read_ready(cx))?;
226 let b = unsafe {
228 &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>]
229 as *mut [u8])
230 };
231 let mut iov = [IoSliceMut::new(b)];
232 rs_wrapper.reusable_cmsg_space.clear();
233
234 match s.try_io(Interest::READABLE, || {
235 recvmsg::<SockaddrStorage>(
236 s.as_raw_fd(),
237 &mut iov,
238 Some(&mut rs_wrapper.reusable_cmsg_space),
239 MsgFlags::empty(),
240 )
241 .map_err(|errno| errno.into())
242 }) {
243 Ok(r) => {
244 if let Some(ControlMessageOwned::ScmTimestampsns(rtime)) = r
245 .cmsgs()
246 .find(|i| matches!(i, ControlMessageOwned::ScmTimestampsns(_)))
247 {
248 rs_wrapper.rx_ts =
251 SystemTime::UNIX_EPOCH.checked_add(rtime.system.into());
252 }
253 unsafe {
255 buf.assume_init(r.bytes);
256 }
257 buf.advance(r.bytes);
258 return Poll::Ready(Ok(()));
259 }
260 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
261 Err(e) => return Poll::Ready(Err(e)),
262 }
263 }
264 }
265 RawStream::Unix(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) },
267 }
268 }
269}
270
271impl AsyncWrite for RawStreamWrapper {
272 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
273 unsafe {
275 match &mut Pin::get_unchecked_mut(self).stream {
276 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
277 #[cfg(unix)]
278 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
279 }
280 }
281 }
282
283 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
284 unsafe {
286 match &mut Pin::get_unchecked_mut(self).stream {
287 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
288 #[cfg(unix)]
289 RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
290 }
291 }
292 }
293
294 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
295 unsafe {
297 match &mut Pin::get_unchecked_mut(self).stream {
298 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
299 #[cfg(unix)]
300 RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
301 }
302 }
303 }
304
305 fn poll_write_vectored(
306 self: Pin<&mut Self>,
307 cx: &mut Context<'_>,
308 bufs: &[std::io::IoSlice<'_>],
309 ) -> Poll<io::Result<usize>> {
310 unsafe {
312 match &mut Pin::get_unchecked_mut(self).stream {
313 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
314 #[cfg(unix)]
315 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
316 }
317 }
318 }
319
320 fn is_write_vectored(&self) -> bool {
321 self.stream.is_write_vectored()
322 }
323}
324
325#[cfg(unix)]
326impl AsRawFd for RawStreamWrapper {
327 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
328 self.stream.as_raw_fd()
329 }
330}
331
332#[cfg(windows)]
333impl AsRawSocket for RawStreamWrapper {
334 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
335 self.stream.as_raw_socket()
336 }
337}
338
339const BUF_READ_SIZE: usize = 64 * 1024;
342const BUF_WRITE_SIZE: usize = 1460;
347
348#[derive(Debug)]
353pub struct Stream {
354 stream: Option<BufStream<RawStreamWrapper>>,
356 rewind_read_buf: Vec<Vec<u8>>,
358 buffer_write: bool,
359 proxy_digest: Option<Arc<ProxyDigest>>,
360 socket_digest: Option<Arc<SocketDigest>>,
361 pub established_ts: SystemTime,
363 pub tracer: Option<Tracer>,
365 read_pending_time: AccumulatedDuration,
366 write_pending_time: AccumulatedDuration,
367 pub rx_ts: Option<SystemTime>,
369}
370
371impl Stream {
372 fn stream(&self) -> &BufStream<RawStreamWrapper> {
373 self.stream.as_ref().expect("stream should always be set")
374 }
375
376 fn stream_mut(&mut self) -> &mut BufStream<RawStreamWrapper> {
377 self.stream.as_mut().expect("stream should always be set")
378 }
379
380 pub fn set_nodelay(&mut self) -> Result<()> {
382 if let RawStream::Tcp(s) = &self.stream_mut().get_mut().stream {
383 s.set_nodelay(true)
384 .or_err(ConnectError, "failed to set_nodelay")?;
385 }
386 Ok(())
387 }
388
389 pub fn set_keepalive(&mut self, ka: &TcpKeepalive) -> Result<()> {
391 if let RawStream::Tcp(s) = &self.stream_mut().get_mut().stream {
392 debug!("Setting tcp keepalive");
393 set_tcp_keepalive(s, ka)?;
394 }
395 Ok(())
396 }
397
398 #[cfg(target_os = "linux")]
399 pub fn set_rx_timestamp(&mut self) -> Result<()> {
400 use nix::sys::socket::{setsockopt, sockopt, TimestampingFlag};
401
402 if let RawStream::Tcp(s) = &self.stream_mut().get_mut().stream {
403 let timestamp_options = TimestampingFlag::SOF_TIMESTAMPING_RX_SOFTWARE
404 | TimestampingFlag::SOF_TIMESTAMPING_SOFTWARE;
405 setsockopt(s.as_raw_fd(), sockopt::Timestamping, ×tamp_options)
406 .or_err(InternalError, "failed to set SOF_TIMESTAMPING_RX_SOFTWARE")?;
407 self.stream_mut().get_mut().enable_rx_ts(true);
408 }
409
410 Ok(())
411 }
412
413 #[cfg(not(target_os = "linux"))]
414 pub fn set_rx_timestamp(&mut self) -> io::Result<()> {
415 Ok(())
416 }
417
418 pub(crate) fn rewind(&mut self, data: &[u8]) {
420 if !data.is_empty() {
421 self.rewind_read_buf.push(data.to_vec());
422 }
423 }
424
425 pub(crate) fn set_buffer(&mut self) {
428 use std::mem;
429 let stream = mem::take(&mut self.stream);
432 let stream =
433 stream.map(|s| BufStream::with_capacity(BUF_READ_SIZE, BUF_WRITE_SIZE, s.into_inner()));
434 let _ = mem::replace(&mut self.stream, stream);
435 }
436}
437
438impl From<TcpStream> for Stream {
439 fn from(s: TcpStream) -> Self {
440 Stream {
441 stream: Some(BufStream::with_capacity(
442 0,
443 0,
444 RawStreamWrapper::new(RawStream::Tcp(s)),
445 )),
446 rewind_read_buf: Vec::new(),
447 buffer_write: true,
448 established_ts: SystemTime::now(),
449 proxy_digest: None,
450 socket_digest: None,
451 tracer: None,
452 read_pending_time: AccumulatedDuration::new(),
453 write_pending_time: AccumulatedDuration::new(),
454 rx_ts: None,
455 }
456 }
457}
458
459#[cfg(unix)]
460impl From<UnixStream> for Stream {
461 fn from(s: UnixStream) -> Self {
462 Stream {
463 stream: Some(BufStream::with_capacity(
464 0,
465 0,
466 RawStreamWrapper::new(RawStream::Unix(s)),
467 )),
468 rewind_read_buf: Vec::new(),
469 buffer_write: true,
470 established_ts: SystemTime::now(),
471 proxy_digest: None,
472 socket_digest: None,
473 tracer: None,
474 read_pending_time: AccumulatedDuration::new(),
475 write_pending_time: AccumulatedDuration::new(),
476 rx_ts: None,
477 }
478 }
479}
480
481#[cfg(unix)]
482impl AsRawFd for Stream {
483 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
484 self.stream().get_ref().as_raw_fd()
485 }
486}
487
488#[cfg(windows)]
489impl AsRawSocket for Stream {
490 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
491 self.stream().get_ref().as_raw_socket()
492 }
493}
494
495#[cfg(unix)]
496impl UniqueID for Stream {
497 fn id(&self) -> UniqueIDType {
498 self.as_raw_fd()
499 }
500}
501
502#[cfg(windows)]
503impl UniqueID for Stream {
504 fn id(&self) -> usize {
505 self.as_raw_socket() as usize
506 }
507}
508
509impl Ssl for Stream {}
510
511#[async_trait]
512impl Peek for Stream {
513 async fn try_peek(&mut self, buf: &mut [u8]) -> std::io::Result<bool> {
514 use tokio::io::AsyncReadExt;
515 self.read_exact(buf).await?;
516 self.rewind(buf);
518 Ok(true)
519 }
520}
521
522#[async_trait]
523impl Shutdown for Stream {
524 async fn shutdown(&mut self) {
525 AsyncWriteExt::shutdown(self).await.unwrap_or_else(|e| {
526 debug!("Failed to shutdown connection: {:?}", e);
527 });
528 }
529}
530
531impl GetTimingDigest for Stream {
532 fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
533 let mut digest = Vec::with_capacity(2); digest.push(Some(TimingDigest {
535 established_ts: self.established_ts,
536 }));
537 digest
538 }
539
540 fn get_read_pending_time(&self) -> Duration {
541 self.read_pending_time.total
542 }
543
544 fn get_write_pending_time(&self) -> Duration {
545 self.write_pending_time.total
546 }
547}
548
549impl GetProxyDigest for Stream {
550 fn get_proxy_digest(&self) -> Option<Arc<ProxyDigest>> {
551 self.proxy_digest.clone()
552 }
553
554 fn set_proxy_digest(&mut self, digest: ProxyDigest) {
555 self.proxy_digest = Some(Arc::new(digest));
556 }
557}
558
559impl GetSocketDigest for Stream {
560 fn get_socket_digest(&self) -> Option<Arc<SocketDigest>> {
561 self.socket_digest.clone()
562 }
563
564 fn set_socket_digest(&mut self, socket_digest: SocketDigest) {
565 self.socket_digest = Some(Arc::new(socket_digest))
566 }
567}
568
569impl Drop for Stream {
570 fn drop(&mut self) {
571 if let Some(t) = self.tracer.as_ref() {
572 t.0.on_disconnected();
573 }
574 let ret = match &self.stream().get_ref().stream {
576 RawStream::Tcp(s) => s.nodelay().err(),
577 #[cfg(unix)]
578 RawStream::Unix(s) => s.local_addr().err(),
579 };
580 if let Some(e) = ret {
581 match e.kind() {
582 tokio::io::ErrorKind::Other => {
583 if let Some(ecode) = e.raw_os_error() {
584 if ecode == 9 {
585 error!("Crit: socket {:?} is being double closed", self.stream);
587 }
588 }
589 }
590 _ => {
591 debug!("Socket is already broken {:?}", e);
592 }
593 }
594 } else {
595 let _ = self.flush().now_or_never();
599 }
600 debug!("Dropping socket {:?}", self.stream);
601 }
602}
603
604impl AsyncRead for Stream {
605 fn poll_read(
606 mut self: Pin<&mut Self>,
607 cx: &mut Context<'_>,
608 buf: &mut ReadBuf<'_>,
609 ) -> Poll<io::Result<()>> {
610 let result = if !self.rewind_read_buf.is_empty() {
611 let data_to_read = self.rewind_read_buf.pop().unwrap(); let mut data_to_read = data_to_read.as_slice();
613 let result = Pin::new(&mut data_to_read).poll_read(cx, buf);
614 if !data_to_read.is_empty() {
616 let remaining_buf = Vec::from(data_to_read);
617 self.rewind_read_buf.push(remaining_buf);
618 }
619 result
620 } else {
621 Pin::new(&mut self.stream_mut()).poll_read(cx, buf)
622 };
623 self.read_pending_time.poll_time(&result);
624 self.rx_ts = self.stream().get_ref().rx_ts;
625 result
626 }
627}
628
629impl AsyncWrite for Stream {
630 fn poll_write(
631 mut self: Pin<&mut Self>,
632 cx: &mut Context,
633 buf: &[u8],
634 ) -> Poll<io::Result<usize>> {
635 let result = if self.buffer_write {
636 Pin::new(&mut self.stream_mut()).poll_write(cx, buf)
637 } else {
638 Pin::new(&mut self.stream_mut().get_mut()).poll_write(cx, buf)
639 };
640 self.write_pending_time.poll_write_time(&result, buf.len());
641 result
642 }
643
644 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
645 let result = Pin::new(&mut self.stream_mut()).poll_flush(cx);
646 self.write_pending_time.poll_time(&result);
647 result
648 }
649
650 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
651 Pin::new(&mut self.stream_mut()).poll_shutdown(cx)
652 }
653
654 fn poll_write_vectored(
655 mut self: Pin<&mut Self>,
656 cx: &mut Context<'_>,
657 bufs: &[std::io::IoSlice<'_>],
658 ) -> Poll<io::Result<usize>> {
659 let total_size = bufs.iter().fold(0, |acc, s| acc + s.len());
660
661 let result = if self.buffer_write {
662 Pin::new(&mut self.stream_mut()).poll_write_vectored(cx, bufs)
663 } else {
664 Pin::new(&mut self.stream_mut().get_mut()).poll_write_vectored(cx, bufs)
665 };
666
667 self.write_pending_time.poll_write_time(&result, total_size);
668 result
669 }
670
671 fn is_write_vectored(&self) -> bool {
672 if self.buffer_write {
673 self.stream().is_write_vectored() } else {
675 self.stream().get_ref().is_write_vectored()
676 }
677 }
678}
679
680pub mod async_write_vec {
681 use bytes::Buf;
682 use futures::ready;
683 use std::future::Future;
684 use std::io::IoSlice;
685 use std::pin::Pin;
686 use std::task::{Context, Poll};
687 use tokio::io;
688 use tokio::io::AsyncWrite;
689
690 #[must_use = "futures do nothing unless you `.await` or poll them"]
697 pub struct WriteVec<'a, W, B> {
698 writer: &'a mut W,
699 buf: &'a mut B,
700 }
701
702 #[must_use = "futures do nothing unless you `.await` or poll them"]
703 pub struct WriteVecAll<'a, W, B> {
704 writer: &'a mut W,
705 buf: &'a mut B,
706 }
707
708 pub trait AsyncWriteVec {
709 fn poll_write_vec<B: Buf>(
710 self: Pin<&mut Self>,
711 _cx: &mut Context<'_>,
712 _buf: &mut B,
713 ) -> Poll<io::Result<usize>>;
714
715 fn write_vec<'a, B>(&'a mut self, src: &'a mut B) -> WriteVec<'a, Self, B>
716 where
717 Self: Sized,
718 B: Buf,
719 {
720 WriteVec {
721 writer: self,
722 buf: src,
723 }
724 }
725
726 fn write_vec_all<'a, B>(&'a mut self, src: &'a mut B) -> WriteVecAll<'a, Self, B>
727 where
728 Self: Sized,
729 B: Buf,
730 {
731 WriteVecAll {
732 writer: self,
733 buf: src,
734 }
735 }
736 }
737
738 impl<W, B> Future for WriteVec<'_, W, B>
739 where
740 W: AsyncWriteVec + Unpin,
741 B: Buf,
742 {
743 type Output = io::Result<usize>;
744
745 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<usize>> {
746 let me = &mut *self;
747 Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf)
748 }
749 }
750
751 impl<W, B> Future for WriteVecAll<'_, W, B>
752 where
753 W: AsyncWriteVec + Unpin,
754 B: Buf,
755 {
756 type Output = io::Result<()>;
757
758 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
759 let me = &mut *self;
760 while me.buf.has_remaining() {
761 let n = ready!(Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf))?;
762 if n == 0 {
763 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
764 }
765 }
766 Poll::Ready(Ok(()))
767 }
768 }
769
770 impl<T> AsyncWriteVec for T
772 where
773 T: AsyncWrite,
774 {
775 fn poll_write_vec<B: Buf>(
776 self: Pin<&mut Self>,
777 ctx: &mut Context,
778 buf: &mut B,
779 ) -> Poll<io::Result<usize>> {
780 const MAX_BUFS: usize = 64;
781
782 if !buf.has_remaining() {
783 return Poll::Ready(Ok(0));
784 }
785
786 let n = if self.is_write_vectored() {
787 let mut slices = [IoSlice::new(&[]); MAX_BUFS];
788 let cnt = buf.chunks_vectored(&mut slices);
789 ready!(self.poll_write_vectored(ctx, &slices[..cnt]))?
790 } else {
791 ready!(self.poll_write(ctx, buf.chunk()))?
792 };
793
794 buf.advance(n);
795
796 Poll::Ready(Ok(n))
797 }
798 }
799}
800
801pub use async_write_vec::AsyncWriteVec;
802
803#[derive(Debug)]
804struct AccumulatedDuration {
805 total: Duration,
806 last_start: Option<Instant>,
807}
808
809impl AccumulatedDuration {
810 fn new() -> Self {
811 AccumulatedDuration {
812 total: Duration::ZERO,
813 last_start: None,
814 }
815 }
816
817 fn start(&mut self) {
818 if self.last_start.is_none() {
819 self.last_start = Some(Instant::now());
820 }
821 }
822
823 fn stop(&mut self) {
824 if let Some(start) = self.last_start.take() {
825 self.total += start.elapsed();
826 }
827 }
828
829 fn poll_write_time(&mut self, result: &Poll<io::Result<usize>>, buf_size: usize) {
830 match result {
831 Poll::Ready(Ok(n)) => {
832 if *n == buf_size {
833 self.stop();
834 } else {
835 self.start();
837 }
838 }
839 Poll::Ready(Err(_)) => {
840 self.stop();
841 }
842 _ => self.start(),
843 }
844 }
845
846 fn poll_time(&mut self, result: &Poll<io::Result<()>>) {
847 match result {
848 Poll::Ready(_) => {
849 self.stop();
850 }
851 _ => self.start(),
852 }
853 }
854}
855
856#[cfg(test)]
857#[cfg(target_os = "linux")]
858mod tests {
859 use super::*;
860 use std::sync::Arc;
861 use tokio::io::AsyncReadExt;
862 use tokio::io::AsyncWriteExt;
863 use tokio::net::TcpListener;
864 use tokio::sync::Notify;
865
866 #[cfg(target_os = "linux")]
867 #[tokio::test]
868 async fn test_rx_timestamp() {
869 let message = "hello world".as_bytes();
870 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
871 let addr = listener.local_addr().unwrap();
872 let notify = Arc::new(Notify::new());
873 let notify2 = notify.clone();
874
875 tokio::spawn(async move {
876 let (mut stream, _) = listener.accept().await.unwrap();
877 notify2.notified().await;
878 stream.write_all(message).await.unwrap();
879 });
880
881 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
882 stream.set_rx_timestamp().unwrap();
883 std::thread::sleep(Duration::from_micros(100));
887 notify.notify_one();
888
889 let mut buffer = vec![0u8; message.len()];
890 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
891 assert_eq!(n, message.len());
892 assert!(stream.rx_ts.is_some());
893 }
894
895 #[cfg(target_os = "linux")]
896 #[tokio::test]
897 async fn test_rx_timestamp_standard_path() {
898 let message = "hello world".as_bytes();
899 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
900 let addr = listener.local_addr().unwrap();
901 let notify = Arc::new(Notify::new());
902 let notify2 = notify.clone();
903
904 tokio::spawn(async move {
905 let (mut stream, _) = listener.accept().await.unwrap();
906 notify2.notified().await;
907 stream.write_all(message).await.unwrap();
908 });
909
910 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
911 std::thread::sleep(Duration::from_micros(100));
912 notify.notify_one();
913
914 let mut buffer = vec![0u8; message.len()];
915 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
916 assert_eq!(n, message.len());
917 assert!(stream.rx_ts.is_none());
918 }
919
920 #[tokio::test]
921 async fn test_stream_rewind() {
922 let message = b"hello world";
923 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
924 let addr = listener.local_addr().unwrap();
925 let notify = Arc::new(Notify::new());
926 let notify2 = notify.clone();
927
928 tokio::spawn(async move {
929 let (mut stream, _) = listener.accept().await.unwrap();
930 notify2.notified().await;
931 stream.write_all(message).await.unwrap();
932 });
933
934 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
935
936 let rewind_test = b"this is Sparta!";
937 stream.rewind(rewind_test);
938
939 let mut buffer = vec![0u8; message.len()];
941 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
942 assert_eq!(n, message.len());
943 assert_eq!(buffer, rewind_test[..message.len()]);
944
945 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
947 assert_eq!(n, rewind_test.len() - message.len());
948 assert_eq!(buffer[..n], rewind_test[message.len()..]);
949
950 notify.notify_one();
952 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
953 assert_eq!(n, message.len());
954 assert_eq!(buffer, message);
955 }
956
957 #[tokio::test]
958 async fn test_stream_peek() {
959 let message = b"hello world";
960 dbg!("try peek");
961 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
962 let addr = listener.local_addr().unwrap();
963 let notify = Arc::new(Notify::new());
964 let notify2 = notify.clone();
965
966 tokio::spawn(async move {
967 let (mut stream, _) = listener.accept().await.unwrap();
968 notify2.notified().await;
969 stream.write_all(message).await.unwrap();
970 drop(stream);
971 });
972
973 notify.notify_one();
974
975 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
976 let mut buffer = vec![0u8; 5];
977 assert!(stream.try_peek(&mut buffer).await.unwrap());
978 assert_eq!(buffer, message[0..5]);
979 let mut buffer = vec![];
980 stream.read_to_end(&mut buffer).await.unwrap();
981 assert_eq!(buffer, message);
982 }
983
984 #[tokio::test]
985 async fn test_stream_two_subsequent_peek_calls_before_read() {
986 let message = b"abcdefghijklmn";
987
988 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
989 let addr = listener.local_addr().unwrap();
990 let notify = Arc::new(Notify::new());
991 let notify2 = notify.clone();
992
993 tokio::spawn(async move {
994 let (mut stream, _) = listener.accept().await.unwrap();
995 notify2.notified().await;
996 stream.write_all(message).await.unwrap();
997 drop(stream);
998 });
999
1000 notify.notify_one();
1001
1002 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
1003
1004 let mut buffer = vec![0u8; 4];
1006 assert!(stream.try_peek(&mut buffer).await.unwrap());
1007 assert_eq!(buffer, message[0..4]);
1008
1009 let mut buffer = vec![0u8; 2];
1011 assert!(stream.try_peek(&mut buffer).await.unwrap());
1012 assert_eq!(buffer, message[0..2]);
1013
1014 let mut buffer = vec![0u8; 1];
1016 stream.read_exact(&mut buffer).await.unwrap();
1017 assert_eq!(buffer, message[0..1]);
1018
1019 let mut buffer = vec![0u8; 100];
1022 let n = stream.read(&mut buffer).await.unwrap();
1023 assert_eq!(n, 1);
1024 assert_eq!(buffer[..n], message[1..2]);
1025
1026 let mut buffer = vec![];
1028 stream.read_to_end(&mut buffer).await.unwrap();
1029 assert_eq!(buffer, message[2..]);
1030 }
1031}