1use async_trait::async_trait;
18use futures::FutureExt;
19use log::{debug, error};
20
21use pingora_error::{ErrorType::*, OrErr, Result};
22#[cfg(target_os = "linux")]
23use std::io::IoSliceMut;
24#[cfg(unix)]
25use std::os::unix::io::AsRawFd;
26#[cfg(windows)]
27use std::os::windows::io::AsRawSocket;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll};
31use std::time::{Duration, Instant, SystemTime};
32#[cfg(target_os = "linux")]
33use tokio::io::Interest;
34use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
35use tokio::net::TcpStream;
36#[cfg(unix)]
37use tokio::net::UnixStream;
38
39use crate::protocols::l4::ext::{set_tcp_keepalive, TcpKeepalive};
40use crate::protocols::l4::virt;
41use crate::protocols::raw_connect::ProxyDigest;
42use crate::protocols::{
43 GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, SocketDigest, Ssl,
44 TimingDigest, UniqueID, UniqueIDType,
45};
46use crate::upstreams::peer::Tracer;
47
48#[derive(Debug)]
49enum RawStream {
50 Tcp(TcpStream),
51 #[cfg(unix)]
52 Unix(UnixStream),
53 Virtual(virt::VirtualSocketStream),
54}
55
56impl AsyncRead for RawStream {
57 fn poll_read(
58 self: Pin<&mut Self>,
59 cx: &mut Context<'_>,
60 buf: &mut ReadBuf<'_>,
61 ) -> Poll<io::Result<()>> {
62 unsafe {
64 match &mut Pin::get_unchecked_mut(self) {
65 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
66 #[cfg(unix)]
67 RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
68 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_read(cx, buf),
69 }
70 }
71 }
72}
73
74impl AsyncWrite for RawStream {
75 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
76 unsafe {
78 match &mut Pin::get_unchecked_mut(self) {
79 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
80 #[cfg(unix)]
81 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
82 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write(cx, buf),
83 }
84 }
85 }
86
87 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
88 unsafe {
90 match &mut Pin::get_unchecked_mut(self) {
91 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
92 #[cfg(unix)]
93 RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
94 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_flush(cx),
95 }
96 }
97 }
98
99 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
100 unsafe {
102 match &mut Pin::get_unchecked_mut(self) {
103 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
104 #[cfg(unix)]
105 RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
106 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_shutdown(cx),
107 }
108 }
109 }
110
111 fn poll_write_vectored(
112 self: Pin<&mut Self>,
113 cx: &mut Context<'_>,
114 bufs: &[std::io::IoSlice<'_>],
115 ) -> Poll<io::Result<usize>> {
116 unsafe {
118 match &mut Pin::get_unchecked_mut(self) {
119 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
120 #[cfg(unix)]
121 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
122 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
123 }
124 }
125 }
126
127 fn is_write_vectored(&self) -> bool {
128 match self {
129 RawStream::Tcp(s) => s.is_write_vectored(),
130 #[cfg(unix)]
131 RawStream::Unix(s) => s.is_write_vectored(),
132 RawStream::Virtual(s) => s.is_write_vectored(),
133 }
134 }
135}
136
137#[cfg(unix)]
138impl AsRawFd for RawStream {
139 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
140 match self {
141 RawStream::Tcp(s) => s.as_raw_fd(),
142 RawStream::Unix(s) => s.as_raw_fd(),
143 RawStream::Virtual(_) => -1, }
145 }
146}
147
148#[cfg(windows)]
149impl AsRawSocket for RawStream {
150 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
151 match self {
152 RawStream::Tcp(s) => s.as_raw_socket(),
153 RawStream::Virtual(_) => !0,
155 }
156 }
157}
158
159#[derive(Debug)]
160struct RawStreamWrapper {
161 pub(crate) stream: RawStream,
162 pub(crate) rx_ts: Option<SystemTime>,
164 #[cfg(target_os = "linux")]
166 pub(crate) enable_rx_ts: bool,
167 #[cfg(target_os = "linux")]
168 reusable_cmsg_space: Vec<u8>,
172}
173
174impl RawStreamWrapper {
175 pub fn new(stream: RawStream) -> Self {
176 RawStreamWrapper {
177 stream,
178 rx_ts: None,
179 #[cfg(target_os = "linux")]
180 enable_rx_ts: false,
181 #[cfg(target_os = "linux")]
182 reusable_cmsg_space: nix::cmsg_space!(nix::sys::time::TimeSpec),
183 }
184 }
185
186 #[cfg(target_os = "linux")]
187 pub fn enable_rx_ts(&mut self, enable_rx_ts: bool) {
188 self.enable_rx_ts = enable_rx_ts;
189 }
190}
191
192impl AsyncRead for RawStreamWrapper {
193 #[cfg(not(target_os = "linux"))]
194 fn poll_read(
195 self: Pin<&mut Self>,
196 cx: &mut Context<'_>,
197 buf: &mut ReadBuf<'_>,
198 ) -> Poll<io::Result<()>> {
199 unsafe {
201 let rs_wrapper = Pin::get_unchecked_mut(self);
202 match &mut rs_wrapper.stream {
203 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
204 #[cfg(unix)]
205 RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
206 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_read(cx, buf),
207 }
208 }
209 }
210
211 #[cfg(target_os = "linux")]
212 fn poll_read(
213 self: Pin<&mut Self>,
214 cx: &mut Context<'_>,
215 buf: &mut ReadBuf<'_>,
216 ) -> Poll<io::Result<()>> {
217 use futures::ready;
218 use nix::sys::socket::{recvmsg, ControlMessageOwned, MsgFlags, SockaddrStorage};
219
220 if !self.enable_rx_ts {
222 unsafe {
224 let rs_wrapper = Pin::get_unchecked_mut(self);
225 match &mut rs_wrapper.stream {
226 RawStream::Tcp(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
227 RawStream::Unix(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
228 RawStream::Virtual(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
229 }
230 }
231 }
232
233 let rs_wrapper = unsafe { Pin::get_unchecked_mut(self) };
235 match &mut rs_wrapper.stream {
236 RawStream::Tcp(s) => {
237 loop {
238 ready!(s.poll_read_ready(cx))?;
239 let b = unsafe {
241 &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>]
242 as *mut [u8])
243 };
244 let mut iov = [IoSliceMut::new(b)];
245 rs_wrapper.reusable_cmsg_space.clear();
246
247 match s.try_io(Interest::READABLE, || {
248 recvmsg::<SockaddrStorage>(
249 s.as_raw_fd(),
250 &mut iov,
251 Some(&mut rs_wrapper.reusable_cmsg_space),
252 MsgFlags::empty(),
253 )
254 .map_err(|errno| errno.into())
255 }) {
256 Ok(r) => {
257 if let Some(ControlMessageOwned::ScmTimestampsns(rtime)) = r
258 .cmsgs()
259 .find(|i| matches!(i, ControlMessageOwned::ScmTimestampsns(_)))
260 {
261 rs_wrapper.rx_ts =
264 SystemTime::UNIX_EPOCH.checked_add(rtime.system.into());
265 }
266 unsafe {
268 buf.assume_init(r.bytes);
269 }
270 buf.advance(r.bytes);
271 return Poll::Ready(Ok(()));
272 }
273 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
274 Err(e) => return Poll::Ready(Err(e)),
275 }
276 }
277 }
278 RawStream::Unix(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) },
280 RawStream::Virtual(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) },
281 }
282 }
283}
284
285impl AsyncWrite for RawStreamWrapper {
286 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
287 unsafe {
289 match &mut Pin::get_unchecked_mut(self).stream {
290 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
291 #[cfg(unix)]
292 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
293 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write(cx, buf),
294 }
295 }
296 }
297
298 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
299 unsafe {
301 match &mut Pin::get_unchecked_mut(self).stream {
302 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
303 #[cfg(unix)]
304 RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
305 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_flush(cx),
306 }
307 }
308 }
309
310 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
311 unsafe {
313 match &mut Pin::get_unchecked_mut(self).stream {
314 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
315 #[cfg(unix)]
316 RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
317 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_shutdown(cx),
318 }
319 }
320 }
321
322 fn poll_write_vectored(
323 self: Pin<&mut Self>,
324 cx: &mut Context<'_>,
325 bufs: &[std::io::IoSlice<'_>],
326 ) -> Poll<io::Result<usize>> {
327 unsafe {
329 match &mut Pin::get_unchecked_mut(self).stream {
330 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
331 #[cfg(unix)]
332 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
333 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
334 }
335 }
336 }
337
338 fn is_write_vectored(&self) -> bool {
339 self.stream.is_write_vectored()
340 }
341}
342
343#[cfg(unix)]
344impl AsRawFd for RawStreamWrapper {
345 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
346 self.stream.as_raw_fd()
347 }
348}
349
350#[cfg(windows)]
351impl AsRawSocket for RawStreamWrapper {
352 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
353 self.stream.as_raw_socket()
354 }
355}
356
357const BUF_READ_SIZE: usize = 64 * 1024;
360const BUF_WRITE_SIZE: usize = 1460;
365
366#[derive(Debug)]
371pub struct Stream {
372 stream: Option<BufStream<RawStreamWrapper>>,
374 rewind_read_buf: Vec<Vec<u8>>,
376 buffer_write: bool,
377 proxy_digest: Option<Arc<ProxyDigest>>,
378 socket_digest: Option<Arc<SocketDigest>>,
379 pub established_ts: SystemTime,
381 pub tracer: Option<Tracer>,
383 read_pending_time: AccumulatedDuration,
384 write_pending_time: AccumulatedDuration,
385 pub rx_ts: Option<SystemTime>,
387}
388
389impl Stream {
390 fn stream(&self) -> &BufStream<RawStreamWrapper> {
391 self.stream.as_ref().expect("stream should always be set")
392 }
393
394 fn stream_mut(&mut self) -> &mut BufStream<RawStreamWrapper> {
395 self.stream.as_mut().expect("stream should always be set")
396 }
397
398 pub fn set_nodelay(&mut self) -> Result<()> {
400 match &self.stream_mut().get_mut().stream {
401 RawStream::Tcp(s) => {
402 s.set_nodelay(true)
403 .or_err(ConnectError, "failed to set_nodelay")?;
404 }
405 RawStream::Virtual(s) => {
406 s.set_socket_option(virt::VirtualSockOpt::NoDelay)
407 .or_err(ConnectError, "failed to set_nodelay on virtual socket")?;
408 }
409 _ => (),
410 }
411 Ok(())
412 }
413
414 pub fn set_keepalive(&mut self, ka: &TcpKeepalive) -> Result<()> {
416 match &self.stream_mut().get_mut().stream {
417 RawStream::Tcp(s) => {
418 debug!("Setting tcp keepalive");
419 set_tcp_keepalive(s, ka)?;
420 }
421 RawStream::Virtual(s) => {
422 s.set_socket_option(virt::VirtualSockOpt::KeepAlive(ka.clone()))
423 .or_err(ConnectError, "failed to set_keepalive on virtual socket")?;
424 }
425 _ => (),
426 }
427 Ok(())
428 }
429
430 #[cfg(target_os = "linux")]
431 pub fn set_rx_timestamp(&mut self) -> Result<()> {
432 use nix::sys::socket::{setsockopt, sockopt, TimestampingFlag};
433
434 if let RawStream::Tcp(s) = &self.stream_mut().get_mut().stream {
435 let timestamp_options = TimestampingFlag::SOF_TIMESTAMPING_RX_SOFTWARE
436 | TimestampingFlag::SOF_TIMESTAMPING_SOFTWARE;
437 setsockopt(s.as_raw_fd(), sockopt::Timestamping, ×tamp_options)
438 .or_err(InternalError, "failed to set SOF_TIMESTAMPING_RX_SOFTWARE")?;
439 self.stream_mut().get_mut().enable_rx_ts(true);
440 }
441
442 Ok(())
443 }
444
445 #[cfg(not(target_os = "linux"))]
446 pub fn set_rx_timestamp(&mut self) -> io::Result<()> {
447 Ok(())
448 }
449
450 pub(crate) fn rewind(&mut self, data: &[u8]) {
452 if !data.is_empty() {
453 self.rewind_read_buf.push(data.to_vec());
454 }
455 }
456
457 pub(crate) fn set_buffer(&mut self) {
460 use std::mem;
461 let stream = mem::take(&mut self.stream);
464 let stream =
465 stream.map(|s| BufStream::with_capacity(BUF_READ_SIZE, BUF_WRITE_SIZE, s.into_inner()));
466 let _ = mem::replace(&mut self.stream, stream);
467 }
468}
469
470impl From<TcpStream> for Stream {
471 fn from(s: TcpStream) -> Self {
472 Stream {
473 stream: Some(BufStream::with_capacity(
474 0,
475 0,
476 RawStreamWrapper::new(RawStream::Tcp(s)),
477 )),
478 rewind_read_buf: Vec::new(),
479 buffer_write: true,
480 established_ts: SystemTime::now(),
481 proxy_digest: None,
482 socket_digest: None,
483 tracer: None,
484 read_pending_time: AccumulatedDuration::new(),
485 write_pending_time: AccumulatedDuration::new(),
486 rx_ts: None,
487 }
488 }
489}
490
491impl From<virt::VirtualSocketStream> for Stream {
492 fn from(s: virt::VirtualSocketStream) -> Self {
493 Stream {
494 stream: Some(BufStream::with_capacity(
495 0,
496 0,
497 RawStreamWrapper::new(RawStream::Virtual(s)),
498 )),
499 rewind_read_buf: Vec::new(),
500 buffer_write: true,
501 established_ts: SystemTime::now(),
502 proxy_digest: None,
503 socket_digest: None,
504 tracer: None,
505 read_pending_time: AccumulatedDuration::new(),
506 write_pending_time: AccumulatedDuration::new(),
507 rx_ts: None,
508 }
509 }
510}
511
512#[cfg(unix)]
513impl From<UnixStream> for Stream {
514 fn from(s: UnixStream) -> Self {
515 Stream {
516 stream: Some(BufStream::with_capacity(
517 0,
518 0,
519 RawStreamWrapper::new(RawStream::Unix(s)),
520 )),
521 rewind_read_buf: Vec::new(),
522 buffer_write: true,
523 established_ts: SystemTime::now(),
524 proxy_digest: None,
525 socket_digest: None,
526 tracer: None,
527 read_pending_time: AccumulatedDuration::new(),
528 write_pending_time: AccumulatedDuration::new(),
529 rx_ts: None,
530 }
531 }
532}
533
534#[cfg(unix)]
535impl AsRawFd for Stream {
536 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
537 self.stream().get_ref().as_raw_fd()
538 }
539}
540
541#[cfg(windows)]
542impl AsRawSocket for Stream {
543 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
544 self.stream().get_ref().as_raw_socket()
545 }
546}
547
548#[cfg(unix)]
549impl UniqueID for Stream {
550 fn id(&self) -> UniqueIDType {
551 self.as_raw_fd()
552 }
553}
554
555#[cfg(windows)]
556impl UniqueID for Stream {
557 fn id(&self) -> usize {
558 self.as_raw_socket() as usize
559 }
560}
561
562impl Ssl for Stream {}
563
564#[async_trait]
565impl Peek for Stream {
566 async fn try_peek(&mut self, buf: &mut [u8]) -> std::io::Result<bool> {
567 use tokio::io::AsyncReadExt;
568 self.read_exact(buf).await?;
569 self.rewind(buf);
571 Ok(true)
572 }
573}
574
575#[async_trait]
576impl Shutdown for Stream {
577 async fn shutdown(&mut self) {
578 AsyncWriteExt::shutdown(self).await.unwrap_or_else(|e| {
579 debug!("Failed to shutdown connection: {:?}", e);
580 });
581 }
582}
583
584impl GetTimingDigest for Stream {
585 fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
586 let mut digest = Vec::with_capacity(2); digest.push(Some(TimingDigest {
588 established_ts: self.established_ts,
589 }));
590 digest
591 }
592
593 fn get_read_pending_time(&self) -> Duration {
594 self.read_pending_time.total
595 }
596
597 fn get_write_pending_time(&self) -> Duration {
598 self.write_pending_time.total
599 }
600}
601
602impl GetProxyDigest for Stream {
603 fn get_proxy_digest(&self) -> Option<Arc<ProxyDigest>> {
604 self.proxy_digest.clone()
605 }
606
607 fn set_proxy_digest(&mut self, digest: ProxyDigest) {
608 self.proxy_digest = Some(Arc::new(digest));
609 }
610}
611
612impl GetSocketDigest for Stream {
613 fn get_socket_digest(&self) -> Option<Arc<SocketDigest>> {
614 self.socket_digest.clone()
615 }
616
617 fn set_socket_digest(&mut self, socket_digest: SocketDigest) {
618 self.socket_digest = Some(Arc::new(socket_digest))
619 }
620}
621
622impl Drop for Stream {
623 fn drop(&mut self) {
624 if let Some(t) = self.tracer.as_ref() {
625 t.0.on_disconnected();
626 }
627 let ret = match &self.stream().get_ref().stream {
629 RawStream::Tcp(s) => s.nodelay().err(),
630 #[cfg(unix)]
631 RawStream::Unix(s) => s.local_addr().err(),
632 RawStream::Virtual(_) => {
633 None
635 }
636 };
637 if let Some(e) = ret {
638 match e.kind() {
639 tokio::io::ErrorKind::Other => {
640 if let Some(ecode) = e.raw_os_error() {
641 if ecode == 9 {
642 error!("Crit: socket {:?} is being double closed", self.stream);
644 }
645 }
646 }
647 _ => {
648 debug!("Socket is already broken {:?}", e);
649 }
650 }
651 } else {
652 let _ = self.flush().now_or_never();
656 }
657 debug!("Dropping socket {:?}", self.stream);
658 }
659}
660
661impl AsyncRead for Stream {
662 fn poll_read(
663 mut self: Pin<&mut Self>,
664 cx: &mut Context<'_>,
665 buf: &mut ReadBuf<'_>,
666 ) -> Poll<io::Result<()>> {
667 let result = if !self.rewind_read_buf.is_empty() {
668 let data_to_read = self.rewind_read_buf.pop().unwrap(); let mut data_to_read = data_to_read.as_slice();
670 let result = Pin::new(&mut data_to_read).poll_read(cx, buf);
671 if !data_to_read.is_empty() {
673 let remaining_buf = Vec::from(data_to_read);
674 self.rewind_read_buf.push(remaining_buf);
675 }
676 result
677 } else {
678 Pin::new(&mut self.stream_mut()).poll_read(cx, buf)
679 };
680 self.read_pending_time.poll_time(&result);
681 self.rx_ts = self.stream().get_ref().rx_ts;
682 result
683 }
684}
685
686impl AsyncWrite for Stream {
687 fn poll_write(
688 mut self: Pin<&mut Self>,
689 cx: &mut Context,
690 buf: &[u8],
691 ) -> Poll<io::Result<usize>> {
692 let result = if self.buffer_write {
693 Pin::new(&mut self.stream_mut()).poll_write(cx, buf)
694 } else {
695 Pin::new(&mut self.stream_mut().get_mut()).poll_write(cx, buf)
696 };
697 self.write_pending_time.poll_write_time(&result, buf.len());
698 result
699 }
700
701 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
702 let result = Pin::new(&mut self.stream_mut()).poll_flush(cx);
703 self.write_pending_time.poll_time(&result);
704 result
705 }
706
707 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
708 Pin::new(&mut self.stream_mut()).poll_shutdown(cx)
709 }
710
711 fn poll_write_vectored(
712 mut self: Pin<&mut Self>,
713 cx: &mut Context<'_>,
714 bufs: &[std::io::IoSlice<'_>],
715 ) -> Poll<io::Result<usize>> {
716 let total_size = bufs.iter().fold(0, |acc, s| acc + s.len());
717
718 let result = if self.buffer_write {
719 Pin::new(&mut self.stream_mut()).poll_write_vectored(cx, bufs)
720 } else {
721 Pin::new(&mut self.stream_mut().get_mut()).poll_write_vectored(cx, bufs)
722 };
723
724 self.write_pending_time.poll_write_time(&result, total_size);
725 result
726 }
727
728 fn is_write_vectored(&self) -> bool {
729 if self.buffer_write {
730 self.stream().is_write_vectored() } else {
732 self.stream().get_ref().is_write_vectored()
733 }
734 }
735}
736
737pub mod async_write_vec {
738 use bytes::Buf;
739 use futures::ready;
740 use std::future::Future;
741 use std::io::IoSlice;
742 use std::pin::Pin;
743 use std::task::{Context, Poll};
744 use tokio::io;
745 use tokio::io::AsyncWrite;
746
747 #[must_use = "futures do nothing unless you `.await` or poll them"]
754 pub struct WriteVec<'a, W, B> {
755 writer: &'a mut W,
756 buf: &'a mut B,
757 }
758
759 #[must_use = "futures do nothing unless you `.await` or poll them"]
760 pub struct WriteVecAll<'a, W, B> {
761 writer: &'a mut W,
762 buf: &'a mut B,
763 }
764
765 pub trait AsyncWriteVec {
766 fn poll_write_vec<B: Buf>(
767 self: Pin<&mut Self>,
768 _cx: &mut Context<'_>,
769 _buf: &mut B,
770 ) -> Poll<io::Result<usize>>;
771
772 fn write_vec<'a, B>(&'a mut self, src: &'a mut B) -> WriteVec<'a, Self, B>
773 where
774 Self: Sized,
775 B: Buf,
776 {
777 WriteVec {
778 writer: self,
779 buf: src,
780 }
781 }
782
783 fn write_vec_all<'a, B>(&'a mut self, src: &'a mut B) -> WriteVecAll<'a, Self, B>
784 where
785 Self: Sized,
786 B: Buf,
787 {
788 WriteVecAll {
789 writer: self,
790 buf: src,
791 }
792 }
793 }
794
795 impl<W, B> Future for WriteVec<'_, W, B>
796 where
797 W: AsyncWriteVec + Unpin,
798 B: Buf,
799 {
800 type Output = io::Result<usize>;
801
802 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<usize>> {
803 let me = &mut *self;
804 Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf)
805 }
806 }
807
808 impl<W, B> Future for WriteVecAll<'_, W, B>
809 where
810 W: AsyncWriteVec + Unpin,
811 B: Buf,
812 {
813 type Output = io::Result<()>;
814
815 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
816 let me = &mut *self;
817 while me.buf.has_remaining() {
818 let n = ready!(Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf))?;
819 if n == 0 {
820 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
821 }
822 }
823 Poll::Ready(Ok(()))
824 }
825 }
826
827 impl<T> AsyncWriteVec for T
829 where
830 T: AsyncWrite,
831 {
832 fn poll_write_vec<B: Buf>(
833 self: Pin<&mut Self>,
834 ctx: &mut Context,
835 buf: &mut B,
836 ) -> Poll<io::Result<usize>> {
837 const MAX_BUFS: usize = 64;
838
839 if !buf.has_remaining() {
840 return Poll::Ready(Ok(0));
841 }
842
843 let n = if self.is_write_vectored() {
844 let mut slices = [IoSlice::new(&[]); MAX_BUFS];
845 let cnt = buf.chunks_vectored(&mut slices);
846 ready!(self.poll_write_vectored(ctx, &slices[..cnt]))?
847 } else {
848 ready!(self.poll_write(ctx, buf.chunk()))?
849 };
850
851 buf.advance(n);
852
853 Poll::Ready(Ok(n))
854 }
855 }
856}
857
858pub use async_write_vec::AsyncWriteVec;
859
860#[derive(Debug)]
861struct AccumulatedDuration {
862 total: Duration,
863 last_start: Option<Instant>,
864}
865
866impl AccumulatedDuration {
867 fn new() -> Self {
868 AccumulatedDuration {
869 total: Duration::ZERO,
870 last_start: None,
871 }
872 }
873
874 fn start(&mut self) {
875 if self.last_start.is_none() {
876 self.last_start = Some(Instant::now());
877 }
878 }
879
880 fn stop(&mut self) {
881 if let Some(start) = self.last_start.take() {
882 self.total += start.elapsed();
883 }
884 }
885
886 fn poll_write_time(&mut self, result: &Poll<io::Result<usize>>, buf_size: usize) {
887 match result {
888 Poll::Ready(Ok(n)) => {
889 if *n == buf_size {
890 self.stop();
891 } else {
892 self.start();
894 }
895 }
896 Poll::Ready(Err(_)) => {
897 self.stop();
898 }
899 _ => self.start(),
900 }
901 }
902
903 fn poll_time(&mut self, result: &Poll<io::Result<()>>) {
904 match result {
905 Poll::Ready(_) => {
906 self.stop();
907 }
908 _ => self.start(),
909 }
910 }
911}
912
913#[cfg(test)]
914#[cfg(target_os = "linux")]
915mod tests {
916 use super::*;
917 use std::sync::Arc;
918 use tokio::io::AsyncReadExt;
919 use tokio::io::AsyncWriteExt;
920 use tokio::net::TcpListener;
921 use tokio::sync::Notify;
922
923 #[cfg(target_os = "linux")]
924 #[tokio::test]
925 async fn test_rx_timestamp() {
926 let message = "hello world".as_bytes();
927 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
928 let addr = listener.local_addr().unwrap();
929 let notify = Arc::new(Notify::new());
930 let notify2 = notify.clone();
931
932 tokio::spawn(async move {
933 let (mut stream, _) = listener.accept().await.unwrap();
934 notify2.notified().await;
935 stream.write_all(message).await.unwrap();
936 });
937
938 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
939 stream.set_rx_timestamp().unwrap();
940 std::thread::sleep(Duration::from_micros(100));
944 notify.notify_one();
945
946 let mut buffer = vec![0u8; message.len()];
947 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
948 assert_eq!(n, message.len());
949 assert!(stream.rx_ts.is_some());
950 }
951
952 #[cfg(target_os = "linux")]
953 #[tokio::test]
954 async fn test_rx_timestamp_standard_path() {
955 let message = "hello world".as_bytes();
956 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
957 let addr = listener.local_addr().unwrap();
958 let notify = Arc::new(Notify::new());
959 let notify2 = notify.clone();
960
961 tokio::spawn(async move {
962 let (mut stream, _) = listener.accept().await.unwrap();
963 notify2.notified().await;
964 stream.write_all(message).await.unwrap();
965 });
966
967 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
968 std::thread::sleep(Duration::from_micros(100));
969 notify.notify_one();
970
971 let mut buffer = vec![0u8; message.len()];
972 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
973 assert_eq!(n, message.len());
974 assert!(stream.rx_ts.is_none());
975 }
976
977 #[tokio::test]
978 async fn test_stream_rewind() {
979 let message = b"hello world";
980 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
981 let addr = listener.local_addr().unwrap();
982 let notify = Arc::new(Notify::new());
983 let notify2 = notify.clone();
984
985 tokio::spawn(async move {
986 let (mut stream, _) = listener.accept().await.unwrap();
987 notify2.notified().await;
988 stream.write_all(message).await.unwrap();
989 });
990
991 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
992
993 let rewind_test = b"this is Sparta!";
994 stream.rewind(rewind_test);
995
996 let mut buffer = vec![0u8; message.len()];
998 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
999 assert_eq!(n, message.len());
1000 assert_eq!(buffer, rewind_test[..message.len()]);
1001
1002 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
1004 assert_eq!(n, rewind_test.len() - message.len());
1005 assert_eq!(buffer[..n], rewind_test[message.len()..]);
1006
1007 notify.notify_one();
1009 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
1010 assert_eq!(n, message.len());
1011 assert_eq!(buffer, message);
1012 }
1013
1014 #[tokio::test]
1015 async fn test_stream_peek() {
1016 let message = b"hello world";
1017 dbg!("try peek");
1018 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1019 let addr = listener.local_addr().unwrap();
1020 let notify = Arc::new(Notify::new());
1021 let notify2 = notify.clone();
1022
1023 tokio::spawn(async move {
1024 let (mut stream, _) = listener.accept().await.unwrap();
1025 notify2.notified().await;
1026 stream.write_all(message).await.unwrap();
1027 drop(stream);
1028 });
1029
1030 notify.notify_one();
1031
1032 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
1033 let mut buffer = vec![0u8; 5];
1034 assert!(stream.try_peek(&mut buffer).await.unwrap());
1035 assert_eq!(buffer, message[0..5]);
1036 let mut buffer = vec![];
1037 stream.read_to_end(&mut buffer).await.unwrap();
1038 assert_eq!(buffer, message);
1039 }
1040
1041 #[tokio::test]
1042 async fn test_stream_two_subsequent_peek_calls_before_read() {
1043 let message = b"abcdefghijklmn";
1044
1045 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1046 let addr = listener.local_addr().unwrap();
1047 let notify = Arc::new(Notify::new());
1048 let notify2 = notify.clone();
1049
1050 tokio::spawn(async move {
1051 let (mut stream, _) = listener.accept().await.unwrap();
1052 notify2.notified().await;
1053 stream.write_all(message).await.unwrap();
1054 drop(stream);
1055 });
1056
1057 notify.notify_one();
1058
1059 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
1060
1061 let mut buffer = vec![0u8; 4];
1063 assert!(stream.try_peek(&mut buffer).await.unwrap());
1064 assert_eq!(buffer, message[0..4]);
1065
1066 let mut buffer = vec![0u8; 2];
1068 assert!(stream.try_peek(&mut buffer).await.unwrap());
1069 assert_eq!(buffer, message[0..2]);
1070
1071 let mut buffer = vec![0u8; 1];
1073 stream.read_exact(&mut buffer).await.unwrap();
1074 assert_eq!(buffer, message[0..1]);
1075
1076 let mut buffer = vec![0u8; 100];
1079 let n = stream.read(&mut buffer).await.unwrap();
1080 assert_eq!(n, 1);
1081 assert_eq!(buffer[..n], message[1..2]);
1082
1083 let mut buffer = vec![];
1085 stream.read_to_end(&mut buffer).await.unwrap();
1086 assert_eq!(buffer, message[2..]);
1087 }
1088}