1use async_trait::async_trait;
18use futures::FutureExt;
19use log::{debug, error};
20
21use pingora_error::{ErrorType::*, OrErr, Result};
22#[cfg(target_os = "linux")]
23use std::io::IoSliceMut;
24#[cfg(unix)]
25use std::os::unix::io::AsRawFd;
26#[cfg(windows)]
27use std::os::windows::io::AsRawSocket;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll};
31use std::time::{Duration, Instant, SystemTime};
32#[cfg(target_os = "linux")]
33use tokio::io::Interest;
34use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
35use tokio::net::TcpStream;
36#[cfg(unix)]
37use tokio::net::UnixStream;
38
39use crate::protocols::l4::ext::{set_tcp_keepalive, TcpKeepalive};
40use crate::protocols::raw_connect::ProxyDigest;
41use crate::protocols::{
42 GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, SocketDigest, Ssl,
43 TimingDigest, UniqueID, UniqueIDType,
44};
45use crate::upstreams::peer::Tracer;
46
47#[derive(Debug)]
48enum RawStream {
49 Tcp(TcpStream),
50 #[cfg(unix)]
51 Unix(UnixStream),
52}
53
54impl AsyncRead for RawStream {
55 fn poll_read(
56 self: Pin<&mut Self>,
57 cx: &mut Context<'_>,
58 buf: &mut ReadBuf<'_>,
59 ) -> Poll<io::Result<()>> {
60 unsafe {
62 match &mut Pin::get_unchecked_mut(self) {
63 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
64 #[cfg(unix)]
65 RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
66 }
67 }
68 }
69}
70
71impl AsyncWrite for RawStream {
72 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
73 unsafe {
75 match &mut Pin::get_unchecked_mut(self) {
76 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
77 #[cfg(unix)]
78 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
79 }
80 }
81 }
82
83 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
84 unsafe {
86 match &mut Pin::get_unchecked_mut(self) {
87 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
88 #[cfg(unix)]
89 RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
90 }
91 }
92 }
93
94 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
95 unsafe {
97 match &mut Pin::get_unchecked_mut(self) {
98 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
99 #[cfg(unix)]
100 RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
101 }
102 }
103 }
104
105 fn poll_write_vectored(
106 self: Pin<&mut Self>,
107 cx: &mut Context<'_>,
108 bufs: &[std::io::IoSlice<'_>],
109 ) -> Poll<io::Result<usize>> {
110 unsafe {
112 match &mut Pin::get_unchecked_mut(self) {
113 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
114 #[cfg(unix)]
115 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
116 }
117 }
118 }
119
120 fn is_write_vectored(&self) -> bool {
121 match self {
122 RawStream::Tcp(s) => s.is_write_vectored(),
123 #[cfg(unix)]
124 RawStream::Unix(s) => s.is_write_vectored(),
125 }
126 }
127}
128
129#[cfg(unix)]
130impl AsRawFd for RawStream {
131 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
132 match self {
133 RawStream::Tcp(s) => s.as_raw_fd(),
134 RawStream::Unix(s) => s.as_raw_fd(),
135 }
136 }
137}
138
139#[cfg(windows)]
140impl AsRawSocket for RawStream {
141 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
142 match self {
143 RawStream::Tcp(s) => s.as_raw_socket(),
144 }
145 }
146}
147
148#[derive(Debug)]
149struct RawStreamWrapper {
150 pub(crate) stream: RawStream,
151 pub(crate) rx_ts: Option<SystemTime>,
153 #[cfg(target_os = "linux")]
155 pub(crate) enable_rx_ts: bool,
156 #[cfg(target_os = "linux")]
157 reusable_cmsg_space: Vec<u8>,
161}
162
163impl RawStreamWrapper {
164 pub fn new(stream: RawStream) -> Self {
165 RawStreamWrapper {
166 stream,
167 rx_ts: None,
168 #[cfg(target_os = "linux")]
169 enable_rx_ts: false,
170 #[cfg(target_os = "linux")]
171 reusable_cmsg_space: nix::cmsg_space!(nix::sys::time::TimeSpec),
172 }
173 }
174
175 #[cfg(target_os = "linux")]
176 pub fn enable_rx_ts(&mut self, enable_rx_ts: bool) {
177 self.enable_rx_ts = enable_rx_ts;
178 }
179}
180
181impl AsyncRead for RawStreamWrapper {
182 #[cfg(not(target_os = "linux"))]
183 fn poll_read(
184 self: Pin<&mut Self>,
185 cx: &mut Context<'_>,
186 buf: &mut ReadBuf<'_>,
187 ) -> Poll<io::Result<()>> {
188 unsafe {
190 let rs_wrapper = Pin::get_unchecked_mut(self);
191 match &mut rs_wrapper.stream {
192 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
193 #[cfg(unix)]
194 RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
195 }
196 }
197 }
198
199 #[cfg(target_os = "linux")]
200 fn poll_read(
201 self: Pin<&mut Self>,
202 cx: &mut Context<'_>,
203 buf: &mut ReadBuf<'_>,
204 ) -> Poll<io::Result<()>> {
205 use futures::ready;
206 use nix::sys::socket::{recvmsg, ControlMessageOwned, MsgFlags, SockaddrStorage};
207
208 if !self.enable_rx_ts {
210 unsafe {
212 let rs_wrapper = Pin::get_unchecked_mut(self);
213 match &mut rs_wrapper.stream {
214 RawStream::Tcp(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
215 RawStream::Unix(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
216 }
217 }
218 }
219
220 let rs_wrapper = unsafe { Pin::get_unchecked_mut(self) };
222 match &mut rs_wrapper.stream {
223 RawStream::Tcp(s) => {
224 loop {
225 ready!(s.poll_read_ready(cx))?;
226 let b = unsafe {
228 &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>]
229 as *mut [u8])
230 };
231 let mut iov = [IoSliceMut::new(b)];
232 rs_wrapper.reusable_cmsg_space.clear();
233
234 match s.try_io(Interest::READABLE, || {
235 recvmsg::<SockaddrStorage>(
236 s.as_raw_fd(),
237 &mut iov,
238 Some(&mut rs_wrapper.reusable_cmsg_space),
239 MsgFlags::empty(),
240 )
241 .map_err(|errno| errno.into())
242 }) {
243 Ok(r) => {
244 if let Some(ControlMessageOwned::ScmTimestampsns(rtime)) = r
245 .cmsgs()
246 .find(|i| matches!(i, ControlMessageOwned::ScmTimestampsns(_)))
247 {
248 rs_wrapper.rx_ts =
251 SystemTime::UNIX_EPOCH.checked_add(rtime.system.into());
252 }
253 unsafe {
255 buf.assume_init(r.bytes);
256 }
257 buf.advance(r.bytes);
258 return Poll::Ready(Ok(()));
259 }
260 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
261 Err(e) => return Poll::Ready(Err(e)),
262 }
263 }
264 }
265 RawStream::Unix(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) },
267 }
268 }
269}
270
271impl AsyncWrite for RawStreamWrapper {
272 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
273 unsafe {
275 match &mut Pin::get_unchecked_mut(self).stream {
276 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
277 #[cfg(unix)]
278 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
279 }
280 }
281 }
282
283 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
284 unsafe {
286 match &mut Pin::get_unchecked_mut(self).stream {
287 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
288 #[cfg(unix)]
289 RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
290 }
291 }
292 }
293
294 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
295 unsafe {
297 match &mut Pin::get_unchecked_mut(self).stream {
298 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
299 #[cfg(unix)]
300 RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
301 }
302 }
303 }
304
305 fn poll_write_vectored(
306 self: Pin<&mut Self>,
307 cx: &mut Context<'_>,
308 bufs: &[std::io::IoSlice<'_>],
309 ) -> Poll<io::Result<usize>> {
310 unsafe {
312 match &mut Pin::get_unchecked_mut(self).stream {
313 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
314 #[cfg(unix)]
315 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
316 }
317 }
318 }
319
320 fn is_write_vectored(&self) -> bool {
321 self.stream.is_write_vectored()
322 }
323}
324
325#[cfg(unix)]
326impl AsRawFd for RawStreamWrapper {
327 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
328 self.stream.as_raw_fd()
329 }
330}
331
332#[cfg(windows)]
333impl AsRawSocket for RawStreamWrapper {
334 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
335 self.stream.as_raw_socket()
336 }
337}
338
339const BUF_READ_SIZE: usize = 64 * 1024;
342const BUF_WRITE_SIZE: usize = 1460;
347
348#[derive(Debug)]
353pub struct Stream {
354 stream: BufStream<RawStreamWrapper>,
355 rewind_read_buf: Vec<u8>,
357 buffer_write: bool,
358 proxy_digest: Option<Arc<ProxyDigest>>,
359 socket_digest: Option<Arc<SocketDigest>>,
360 pub established_ts: SystemTime,
362 pub tracer: Option<Tracer>,
364 read_pending_time: AccumulatedDuration,
365 write_pending_time: AccumulatedDuration,
366 pub rx_ts: Option<SystemTime>,
368}
369
370impl Stream {
371 pub fn set_nodelay(&mut self) -> Result<()> {
373 if let RawStream::Tcp(s) = &self.stream.get_mut().stream {
374 s.set_nodelay(true)
375 .or_err(ConnectError, "failed to set_nodelay")?;
376 }
377 Ok(())
378 }
379
380 pub fn set_keepalive(&mut self, ka: &TcpKeepalive) -> Result<()> {
382 if let RawStream::Tcp(s) = &self.stream.get_mut().stream {
383 debug!("Setting tcp keepalive");
384 set_tcp_keepalive(s, ka)?;
385 }
386 Ok(())
387 }
388
389 #[cfg(target_os = "linux")]
390 pub fn set_rx_timestamp(&mut self) -> Result<()> {
391 use nix::sys::socket::{setsockopt, sockopt, TimestampingFlag};
392
393 if let RawStream::Tcp(s) = &self.stream.get_mut().stream {
394 let timestamp_options = TimestampingFlag::SOF_TIMESTAMPING_RX_SOFTWARE
395 | TimestampingFlag::SOF_TIMESTAMPING_SOFTWARE;
396 setsockopt(s.as_raw_fd(), sockopt::Timestamping, ×tamp_options)
397 .or_err(InternalError, "failed to set SOF_TIMESTAMPING_RX_SOFTWARE")?;
398 self.stream.get_mut().enable_rx_ts(true);
399 }
400
401 Ok(())
402 }
403
404 #[cfg(not(target_os = "linux"))]
405 pub fn set_rx_timestamp(&mut self) -> io::Result<()> {
406 Ok(())
407 }
408
409 pub(crate) fn rewind(&mut self, data: &[u8]) {
411 if !data.is_empty() {
412 self.rewind_read_buf.extend_from_slice(data);
413 }
414 }
415}
416
417impl From<TcpStream> for Stream {
418 fn from(s: TcpStream) -> Self {
419 Stream {
420 stream: BufStream::with_capacity(
421 BUF_READ_SIZE,
422 BUF_WRITE_SIZE,
423 RawStreamWrapper::new(RawStream::Tcp(s)),
424 ),
425 rewind_read_buf: Vec::new(),
426 buffer_write: true,
427 established_ts: SystemTime::now(),
428 proxy_digest: None,
429 socket_digest: None,
430 tracer: None,
431 read_pending_time: AccumulatedDuration::new(),
432 write_pending_time: AccumulatedDuration::new(),
433 rx_ts: None,
434 }
435 }
436}
437
438#[cfg(unix)]
439impl From<UnixStream> for Stream {
440 fn from(s: UnixStream) -> Self {
441 Stream {
442 stream: BufStream::with_capacity(
443 BUF_READ_SIZE,
444 BUF_WRITE_SIZE,
445 RawStreamWrapper::new(RawStream::Unix(s)),
446 ),
447 rewind_read_buf: Vec::new(),
448 buffer_write: true,
449 established_ts: SystemTime::now(),
450 proxy_digest: None,
451 socket_digest: None,
452 tracer: None,
453 read_pending_time: AccumulatedDuration::new(),
454 write_pending_time: AccumulatedDuration::new(),
455 rx_ts: None,
456 }
457 }
458}
459
460#[cfg(unix)]
461impl AsRawFd for Stream {
462 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
463 self.stream.get_ref().as_raw_fd()
464 }
465}
466
467#[cfg(windows)]
468impl AsRawSocket for Stream {
469 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
470 self.stream.get_ref().as_raw_socket()
471 }
472}
473
474#[cfg(unix)]
475impl UniqueID for Stream {
476 fn id(&self) -> UniqueIDType {
477 self.as_raw_fd()
478 }
479}
480
481#[cfg(windows)]
482impl UniqueID for Stream {
483 fn id(&self) -> usize {
484 self.as_raw_socket() as usize
485 }
486}
487
488impl Ssl for Stream {}
489
490#[async_trait]
491impl Peek for Stream {
492 async fn try_peek(&mut self, buf: &mut [u8]) -> std::io::Result<bool> {
493 use tokio::io::AsyncReadExt;
494 self.read_exact(buf).await?;
495 self.rewind(buf);
497 Ok(true)
498 }
499}
500
501#[async_trait]
502impl Shutdown for Stream {
503 async fn shutdown(&mut self) {
504 AsyncWriteExt::shutdown(self).await.unwrap_or_else(|e| {
505 debug!("Failed to shutdown connection: {:?}", e);
506 });
507 }
508}
509
510impl GetTimingDigest for Stream {
511 fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
512 let mut digest = Vec::with_capacity(2); digest.push(Some(TimingDigest {
514 established_ts: self.established_ts,
515 }));
516 digest
517 }
518
519 fn get_read_pending_time(&self) -> Duration {
520 self.read_pending_time.total
521 }
522
523 fn get_write_pending_time(&self) -> Duration {
524 self.write_pending_time.total
525 }
526}
527
528impl GetProxyDigest for Stream {
529 fn get_proxy_digest(&self) -> Option<Arc<ProxyDigest>> {
530 self.proxy_digest.clone()
531 }
532
533 fn set_proxy_digest(&mut self, digest: ProxyDigest) {
534 self.proxy_digest = Some(Arc::new(digest));
535 }
536}
537
538impl GetSocketDigest for Stream {
539 fn get_socket_digest(&self) -> Option<Arc<SocketDigest>> {
540 self.socket_digest.clone()
541 }
542
543 fn set_socket_digest(&mut self, socket_digest: SocketDigest) {
544 self.socket_digest = Some(Arc::new(socket_digest))
545 }
546}
547
548impl Drop for Stream {
549 fn drop(&mut self) {
550 if let Some(t) = self.tracer.as_ref() {
551 t.0.on_disconnected();
552 }
553 let ret = match &self.stream.get_ref().stream {
555 RawStream::Tcp(s) => s.nodelay().err(),
556 #[cfg(unix)]
557 RawStream::Unix(s) => s.local_addr().err(),
558 };
559 if let Some(e) = ret {
560 match e.kind() {
561 tokio::io::ErrorKind::Other => {
562 if let Some(ecode) = e.raw_os_error() {
563 if ecode == 9 {
564 error!("Crit: socket {:?} is being double closed", self.stream);
566 }
567 }
568 }
569 _ => {
570 debug!("Socket is already broken {:?}", e);
571 }
572 }
573 } else {
574 let _ = self.flush().now_or_never();
578 }
579 debug!("Dropping socket {:?}", self.stream);
580 }
581}
582
583impl AsyncRead for Stream {
584 fn poll_read(
585 mut self: Pin<&mut Self>,
586 cx: &mut Context<'_>,
587 buf: &mut ReadBuf<'_>,
588 ) -> Poll<io::Result<()>> {
589 let result = if !self.rewind_read_buf.is_empty() {
590 let mut data_to_read = self.rewind_read_buf.as_slice();
591 let result = Pin::new(&mut data_to_read).poll_read(cx, buf);
592 let remaining_buf = Vec::from(data_to_read);
594 let _ = std::mem::replace(&mut self.rewind_read_buf, remaining_buf);
595 result
596 } else {
597 Pin::new(&mut self.stream).poll_read(cx, buf)
598 };
599 self.read_pending_time.poll_time(&result);
600 self.rx_ts = self.stream.get_ref().rx_ts;
601 result
602 }
603}
604
605impl AsyncWrite for Stream {
606 fn poll_write(
607 mut self: Pin<&mut Self>,
608 cx: &mut Context,
609 buf: &[u8],
610 ) -> Poll<io::Result<usize>> {
611 let result = if self.buffer_write {
612 Pin::new(&mut self.stream).poll_write(cx, buf)
613 } else {
614 Pin::new(&mut self.stream.get_mut()).poll_write(cx, buf)
615 };
616 self.write_pending_time.poll_write_time(&result, buf.len());
617 result
618 }
619
620 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
621 let result = Pin::new(&mut self.stream).poll_flush(cx);
622 self.write_pending_time.poll_time(&result);
623 result
624 }
625
626 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
627 Pin::new(&mut self.stream).poll_shutdown(cx)
628 }
629
630 fn poll_write_vectored(
631 mut self: Pin<&mut Self>,
632 cx: &mut Context<'_>,
633 bufs: &[std::io::IoSlice<'_>],
634 ) -> Poll<io::Result<usize>> {
635 let total_size = bufs.iter().fold(0, |acc, s| acc + s.len());
636
637 let result = if self.buffer_write {
638 Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
639 } else {
640 Pin::new(&mut self.stream.get_mut()).poll_write_vectored(cx, bufs)
641 };
642
643 self.write_pending_time.poll_write_time(&result, total_size);
644 result
645 }
646
647 fn is_write_vectored(&self) -> bool {
648 if self.buffer_write {
649 self.stream.is_write_vectored() } else {
651 self.stream.get_ref().is_write_vectored()
652 }
653 }
654}
655
656pub mod async_write_vec {
657 use bytes::Buf;
658 use futures::ready;
659 use std::future::Future;
660 use std::io::IoSlice;
661 use std::pin::Pin;
662 use std::task::{Context, Poll};
663 use tokio::io;
664 use tokio::io::AsyncWrite;
665
666 #[must_use = "futures do nothing unless you `.await` or poll them"]
673 pub struct WriteVec<'a, W, B> {
674 writer: &'a mut W,
675 buf: &'a mut B,
676 }
677
678 #[must_use = "futures do nothing unless you `.await` or poll them"]
679 pub struct WriteVecAll<'a, W, B> {
680 writer: &'a mut W,
681 buf: &'a mut B,
682 }
683
684 pub trait AsyncWriteVec {
685 fn poll_write_vec<B: Buf>(
686 self: Pin<&mut Self>,
687 _cx: &mut Context<'_>,
688 _buf: &mut B,
689 ) -> Poll<io::Result<usize>>;
690
691 fn write_vec<'a, B>(&'a mut self, src: &'a mut B) -> WriteVec<'a, Self, B>
692 where
693 Self: Sized,
694 B: Buf,
695 {
696 WriteVec {
697 writer: self,
698 buf: src,
699 }
700 }
701
702 fn write_vec_all<'a, B>(&'a mut self, src: &'a mut B) -> WriteVecAll<'a, Self, B>
703 where
704 Self: Sized,
705 B: Buf,
706 {
707 WriteVecAll {
708 writer: self,
709 buf: src,
710 }
711 }
712 }
713
714 impl<W, B> Future for WriteVec<'_, W, B>
715 where
716 W: AsyncWriteVec + Unpin,
717 B: Buf,
718 {
719 type Output = io::Result<usize>;
720
721 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<usize>> {
722 let me = &mut *self;
723 Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf)
724 }
725 }
726
727 impl<W, B> Future for WriteVecAll<'_, W, B>
728 where
729 W: AsyncWriteVec + Unpin,
730 B: Buf,
731 {
732 type Output = io::Result<()>;
733
734 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
735 let me = &mut *self;
736 while me.buf.has_remaining() {
737 let n = ready!(Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf))?;
738 if n == 0 {
739 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
740 }
741 }
742 Poll::Ready(Ok(()))
743 }
744 }
745
746 impl<T> AsyncWriteVec for T
748 where
749 T: AsyncWrite,
750 {
751 fn poll_write_vec<B: Buf>(
752 self: Pin<&mut Self>,
753 ctx: &mut Context,
754 buf: &mut B,
755 ) -> Poll<io::Result<usize>> {
756 const MAX_BUFS: usize = 64;
757
758 if !buf.has_remaining() {
759 return Poll::Ready(Ok(0));
760 }
761
762 let n = if self.is_write_vectored() {
763 let mut slices = [IoSlice::new(&[]); MAX_BUFS];
764 let cnt = buf.chunks_vectored(&mut slices);
765 ready!(self.poll_write_vectored(ctx, &slices[..cnt]))?
766 } else {
767 ready!(self.poll_write(ctx, buf.chunk()))?
768 };
769
770 buf.advance(n);
771
772 Poll::Ready(Ok(n))
773 }
774 }
775}
776
777pub use async_write_vec::AsyncWriteVec;
778
779#[derive(Debug)]
780struct AccumulatedDuration {
781 total: Duration,
782 last_start: Option<Instant>,
783}
784
785impl AccumulatedDuration {
786 fn new() -> Self {
787 AccumulatedDuration {
788 total: Duration::ZERO,
789 last_start: None,
790 }
791 }
792
793 fn start(&mut self) {
794 if self.last_start.is_none() {
795 self.last_start = Some(Instant::now());
796 }
797 }
798
799 fn stop(&mut self) {
800 if let Some(start) = self.last_start.take() {
801 self.total += start.elapsed();
802 }
803 }
804
805 fn poll_write_time(&mut self, result: &Poll<io::Result<usize>>, buf_size: usize) {
806 match result {
807 Poll::Ready(Ok(n)) => {
808 if *n == buf_size {
809 self.stop();
810 } else {
811 self.start();
813 }
814 }
815 Poll::Ready(Err(_)) => {
816 self.stop();
817 }
818 _ => self.start(),
819 }
820 }
821
822 fn poll_time(&mut self, result: &Poll<io::Result<()>>) {
823 match result {
824 Poll::Ready(_) => {
825 self.stop();
826 }
827 _ => self.start(),
828 }
829 }
830}
831
832#[cfg(test)]
833#[cfg(target_os = "linux")]
834mod tests {
835 use super::*;
836 use std::sync::Arc;
837 use tokio::io::AsyncReadExt;
838 use tokio::io::AsyncWriteExt;
839 use tokio::net::TcpListener;
840 use tokio::sync::Notify;
841
842 #[cfg(target_os = "linux")]
843 #[tokio::test]
844 async fn test_rx_timestamp() {
845 let message = "hello world".as_bytes();
846 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
847 let addr = listener.local_addr().unwrap();
848 let notify = Arc::new(Notify::new());
849 let notify2 = notify.clone();
850
851 tokio::spawn(async move {
852 let (mut stream, _) = listener.accept().await.unwrap();
853 notify2.notified().await;
854 stream.write_all(message).await.unwrap();
855 });
856
857 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
858 stream.set_rx_timestamp().unwrap();
859 std::thread::sleep(Duration::from_micros(100));
863 notify.notify_one();
864
865 let mut buffer = vec![0u8; message.len()];
866 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
867 assert_eq!(n, message.len());
868 assert!(stream.rx_ts.is_some());
869 }
870
871 #[cfg(target_os = "linux")]
872 #[tokio::test]
873 async fn test_rx_timestamp_standard_path() {
874 let message = "hello world".as_bytes();
875 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
876 let addr = listener.local_addr().unwrap();
877 let notify = Arc::new(Notify::new());
878 let notify2 = notify.clone();
879
880 tokio::spawn(async move {
881 let (mut stream, _) = listener.accept().await.unwrap();
882 notify2.notified().await;
883 stream.write_all(message).await.unwrap();
884 });
885
886 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
887 std::thread::sleep(Duration::from_micros(100));
888 notify.notify_one();
889
890 let mut buffer = vec![0u8; message.len()];
891 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
892 assert_eq!(n, message.len());
893 assert!(stream.rx_ts.is_none());
894 }
895
896 #[tokio::test]
897 async fn test_stream_rewind() {
898 let message = b"hello world";
899 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
900 let addr = listener.local_addr().unwrap();
901 let notify = Arc::new(Notify::new());
902 let notify2 = notify.clone();
903
904 tokio::spawn(async move {
905 let (mut stream, _) = listener.accept().await.unwrap();
906 notify2.notified().await;
907 stream.write_all(message).await.unwrap();
908 });
909
910 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
911
912 let rewind_test = b"this is Sparta!";
913 stream.rewind(rewind_test);
914
915 let mut buffer = vec![0u8; message.len()];
917 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
918 assert_eq!(n, message.len());
919 assert_eq!(buffer, rewind_test[..message.len()]);
920
921 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
923 assert_eq!(n, rewind_test.len() - message.len());
924 assert_eq!(buffer[..n], rewind_test[message.len()..]);
925
926 notify.notify_one();
928 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
929 assert_eq!(n, message.len());
930 assert_eq!(buffer, message);
931 }
932
933 #[tokio::test]
934 async fn test_stream_peek() {
935 let message = b"hello world";
936 dbg!("try peek");
937 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
938 let addr = listener.local_addr().unwrap();
939 let notify = Arc::new(Notify::new());
940 let notify2 = notify.clone();
941
942 tokio::spawn(async move {
943 let (mut stream, _) = listener.accept().await.unwrap();
944 notify2.notified().await;
945 stream.write_all(message).await.unwrap();
946 drop(stream);
947 });
948
949 notify.notify_one();
950
951 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
952 let mut buffer = vec![0u8; 5];
953 assert!(stream.try_peek(&mut buffer).await.unwrap());
954 assert_eq!(buffer, message[0..5]);
955 let mut buffer = vec![];
956 stream.read_to_end(&mut buffer).await.unwrap();
957 assert_eq!(buffer, message);
958 }
959}