1use async_trait::async_trait;
18use futures::FutureExt;
19use log::{debug, error};
20
21use pingora_error::{ErrorType::*, OrErr, Result};
22#[cfg(target_os = "linux")]
23use std::io::IoSliceMut;
24#[cfg(unix)]
25use std::os::unix::io::AsRawFd;
26#[cfg(windows)]
27use std::os::windows::io::AsRawSocket;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll};
31use std::time::{Duration, Instant, SystemTime};
32#[cfg(target_os = "linux")]
33use tokio::io::Interest;
34use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
35use tokio::net::TcpStream;
36#[cfg(unix)]
37use tokio::net::UnixStream;
38
39use crate::protocols::l4::ext::{set_tcp_keepalive, TcpKeepalive};
40use crate::protocols::l4::virt;
41use crate::protocols::raw_connect::ProxyDigest;
42use crate::protocols::{
43 GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, SocketDigest, Ssl,
44 TimingDigest, UniqueID, UniqueIDType,
45};
46use crate::upstreams::peer::Tracer;
47
48#[derive(Debug)]
49enum RawStream {
50 Tcp(TcpStream),
51 #[cfg(unix)]
52 Unix(UnixStream),
53 Virtual(virt::VirtualSocketStream),
54}
55
56impl AsyncRead for RawStream {
57 fn poll_read(
58 self: Pin<&mut Self>,
59 cx: &mut Context<'_>,
60 buf: &mut ReadBuf<'_>,
61 ) -> Poll<io::Result<()>> {
62 unsafe {
64 match &mut Pin::get_unchecked_mut(self) {
65 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
66 #[cfg(unix)]
67 RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
68 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_read(cx, buf),
69 }
70 }
71 }
72}
73
74impl AsyncWrite for RawStream {
75 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
76 unsafe {
78 match &mut Pin::get_unchecked_mut(self) {
79 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
80 #[cfg(unix)]
81 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
82 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write(cx, buf),
83 }
84 }
85 }
86
87 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
88 unsafe {
90 match &mut Pin::get_unchecked_mut(self) {
91 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
92 #[cfg(unix)]
93 RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
94 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_flush(cx),
95 }
96 }
97 }
98
99 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
100 unsafe {
102 match &mut Pin::get_unchecked_mut(self) {
103 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
104 #[cfg(unix)]
105 RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
106 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_shutdown(cx),
107 }
108 }
109 }
110
111 fn poll_write_vectored(
112 self: Pin<&mut Self>,
113 cx: &mut Context<'_>,
114 bufs: &[std::io::IoSlice<'_>],
115 ) -> Poll<io::Result<usize>> {
116 unsafe {
118 match &mut Pin::get_unchecked_mut(self) {
119 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
120 #[cfg(unix)]
121 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
122 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
123 }
124 }
125 }
126
127 fn is_write_vectored(&self) -> bool {
128 match self {
129 RawStream::Tcp(s) => s.is_write_vectored(),
130 #[cfg(unix)]
131 RawStream::Unix(s) => s.is_write_vectored(),
132 RawStream::Virtual(s) => s.is_write_vectored(),
133 }
134 }
135}
136
137#[cfg(unix)]
138impl AsRawFd for RawStream {
139 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
140 match self {
141 RawStream::Tcp(s) => s.as_raw_fd(),
142 RawStream::Unix(s) => s.as_raw_fd(),
143 RawStream::Virtual(_) => -1, }
145 }
146}
147
148#[cfg(windows)]
149impl AsRawSocket for RawStream {
150 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
151 match self {
152 RawStream::Tcp(s) => s.as_raw_socket(),
153 }
154 }
155}
156
157#[derive(Debug)]
158struct RawStreamWrapper {
159 pub(crate) stream: RawStream,
160 pub(crate) rx_ts: Option<SystemTime>,
162 #[cfg(target_os = "linux")]
164 pub(crate) enable_rx_ts: bool,
165 #[cfg(target_os = "linux")]
166 reusable_cmsg_space: Vec<u8>,
170}
171
172impl RawStreamWrapper {
173 pub fn new(stream: RawStream) -> Self {
174 RawStreamWrapper {
175 stream,
176 rx_ts: None,
177 #[cfg(target_os = "linux")]
178 enable_rx_ts: false,
179 #[cfg(target_os = "linux")]
180 reusable_cmsg_space: nix::cmsg_space!(nix::sys::time::TimeSpec),
181 }
182 }
183
184 #[cfg(target_os = "linux")]
185 pub fn enable_rx_ts(&mut self, enable_rx_ts: bool) {
186 self.enable_rx_ts = enable_rx_ts;
187 }
188}
189
190impl AsyncRead for RawStreamWrapper {
191 #[cfg(not(target_os = "linux"))]
192 fn poll_read(
193 self: Pin<&mut Self>,
194 cx: &mut Context<'_>,
195 buf: &mut ReadBuf<'_>,
196 ) -> Poll<io::Result<()>> {
197 unsafe {
199 let rs_wrapper = Pin::get_unchecked_mut(self);
200 match &mut rs_wrapper.stream {
201 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf),
202 #[cfg(unix)]
203 RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf),
204 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_read(cx, buf),
205 }
206 }
207 }
208
209 #[cfg(target_os = "linux")]
210 fn poll_read(
211 self: Pin<&mut Self>,
212 cx: &mut Context<'_>,
213 buf: &mut ReadBuf<'_>,
214 ) -> Poll<io::Result<()>> {
215 use futures::ready;
216 use nix::sys::socket::{recvmsg, ControlMessageOwned, MsgFlags, SockaddrStorage};
217
218 if !self.enable_rx_ts {
220 unsafe {
222 let rs_wrapper = Pin::get_unchecked_mut(self);
223 match &mut rs_wrapper.stream {
224 RawStream::Tcp(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
225 RawStream::Unix(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
226 RawStream::Virtual(s) => return Pin::new_unchecked(s).poll_read(cx, buf),
227 }
228 }
229 }
230
231 let rs_wrapper = unsafe { Pin::get_unchecked_mut(self) };
233 match &mut rs_wrapper.stream {
234 RawStream::Tcp(s) => {
235 loop {
236 ready!(s.poll_read_ready(cx))?;
237 let b = unsafe {
239 &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>]
240 as *mut [u8])
241 };
242 let mut iov = [IoSliceMut::new(b)];
243 rs_wrapper.reusable_cmsg_space.clear();
244
245 match s.try_io(Interest::READABLE, || {
246 recvmsg::<SockaddrStorage>(
247 s.as_raw_fd(),
248 &mut iov,
249 Some(&mut rs_wrapper.reusable_cmsg_space),
250 MsgFlags::empty(),
251 )
252 .map_err(|errno| errno.into())
253 }) {
254 Ok(r) => {
255 if let Some(ControlMessageOwned::ScmTimestampsns(rtime)) = r
256 .cmsgs()
257 .find(|i| matches!(i, ControlMessageOwned::ScmTimestampsns(_)))
258 {
259 rs_wrapper.rx_ts =
262 SystemTime::UNIX_EPOCH.checked_add(rtime.system.into());
263 }
264 unsafe {
266 buf.assume_init(r.bytes);
267 }
268 buf.advance(r.bytes);
269 return Poll::Ready(Ok(()));
270 }
271 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
272 Err(e) => return Poll::Ready(Err(e)),
273 }
274 }
275 }
276 RawStream::Unix(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) },
278 RawStream::Virtual(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) },
279 }
280 }
281}
282
283impl AsyncWrite for RawStreamWrapper {
284 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
285 unsafe {
287 match &mut Pin::get_unchecked_mut(self).stream {
288 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf),
289 #[cfg(unix)]
290 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf),
291 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write(cx, buf),
292 }
293 }
294 }
295
296 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
297 unsafe {
299 match &mut Pin::get_unchecked_mut(self).stream {
300 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx),
301 #[cfg(unix)]
302 RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx),
303 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_flush(cx),
304 }
305 }
306 }
307
308 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
309 unsafe {
311 match &mut Pin::get_unchecked_mut(self).stream {
312 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx),
313 #[cfg(unix)]
314 RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx),
315 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_shutdown(cx),
316 }
317 }
318 }
319
320 fn poll_write_vectored(
321 self: Pin<&mut Self>,
322 cx: &mut Context<'_>,
323 bufs: &[std::io::IoSlice<'_>],
324 ) -> Poll<io::Result<usize>> {
325 unsafe {
327 match &mut Pin::get_unchecked_mut(self).stream {
328 RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
329 #[cfg(unix)]
330 RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
331 RawStream::Virtual(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs),
332 }
333 }
334 }
335
336 fn is_write_vectored(&self) -> bool {
337 self.stream.is_write_vectored()
338 }
339}
340
341#[cfg(unix)]
342impl AsRawFd for RawStreamWrapper {
343 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
344 self.stream.as_raw_fd()
345 }
346}
347
348#[cfg(windows)]
349impl AsRawSocket for RawStreamWrapper {
350 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
351 self.stream.as_raw_socket()
352 }
353}
354
355const BUF_READ_SIZE: usize = 64 * 1024;
358const BUF_WRITE_SIZE: usize = 1460;
363
364#[derive(Debug)]
369pub struct Stream {
370 stream: Option<BufStream<RawStreamWrapper>>,
372 rewind_read_buf: Vec<Vec<u8>>,
374 buffer_write: bool,
375 proxy_digest: Option<Arc<ProxyDigest>>,
376 socket_digest: Option<Arc<SocketDigest>>,
377 pub established_ts: SystemTime,
379 pub tracer: Option<Tracer>,
381 read_pending_time: AccumulatedDuration,
382 write_pending_time: AccumulatedDuration,
383 pub rx_ts: Option<SystemTime>,
385}
386
387impl Stream {
388 fn stream(&self) -> &BufStream<RawStreamWrapper> {
389 self.stream.as_ref().expect("stream should always be set")
390 }
391
392 fn stream_mut(&mut self) -> &mut BufStream<RawStreamWrapper> {
393 self.stream.as_mut().expect("stream should always be set")
394 }
395
396 pub fn set_nodelay(&mut self) -> Result<()> {
398 match &self.stream_mut().get_mut().stream {
399 RawStream::Tcp(s) => {
400 s.set_nodelay(true)
401 .or_err(ConnectError, "failed to set_nodelay")?;
402 }
403 RawStream::Virtual(s) => {
404 s.set_socket_option(virt::VirtualSockOpt::NoDelay)
405 .or_err(ConnectError, "failed to set_nodelay on virtual socket")?;
406 }
407 _ => (),
408 }
409 Ok(())
410 }
411
412 pub fn set_keepalive(&mut self, ka: &TcpKeepalive) -> Result<()> {
414 match &self.stream_mut().get_mut().stream {
415 RawStream::Tcp(s) => {
416 debug!("Setting tcp keepalive");
417 set_tcp_keepalive(s, ka)?;
418 }
419 RawStream::Virtual(s) => {
420 s.set_socket_option(virt::VirtualSockOpt::KeepAlive(ka.clone()))
421 .or_err(ConnectError, "failed to set_keepalive on virtual socket")?;
422 }
423 _ => (),
424 }
425 Ok(())
426 }
427
428 #[cfg(target_os = "linux")]
429 pub fn set_rx_timestamp(&mut self) -> Result<()> {
430 use nix::sys::socket::{setsockopt, sockopt, TimestampingFlag};
431
432 if let RawStream::Tcp(s) = &self.stream_mut().get_mut().stream {
433 let timestamp_options = TimestampingFlag::SOF_TIMESTAMPING_RX_SOFTWARE
434 | TimestampingFlag::SOF_TIMESTAMPING_SOFTWARE;
435 setsockopt(s.as_raw_fd(), sockopt::Timestamping, ×tamp_options)
436 .or_err(InternalError, "failed to set SOF_TIMESTAMPING_RX_SOFTWARE")?;
437 self.stream_mut().get_mut().enable_rx_ts(true);
438 }
439
440 Ok(())
441 }
442
443 #[cfg(not(target_os = "linux"))]
444 pub fn set_rx_timestamp(&mut self) -> io::Result<()> {
445 Ok(())
446 }
447
448 pub(crate) fn rewind(&mut self, data: &[u8]) {
450 if !data.is_empty() {
451 self.rewind_read_buf.push(data.to_vec());
452 }
453 }
454
455 pub(crate) fn set_buffer(&mut self) {
458 use std::mem;
459 let stream = mem::take(&mut self.stream);
462 let stream =
463 stream.map(|s| BufStream::with_capacity(BUF_READ_SIZE, BUF_WRITE_SIZE, s.into_inner()));
464 let _ = mem::replace(&mut self.stream, stream);
465 }
466}
467
468impl From<TcpStream> for Stream {
469 fn from(s: TcpStream) -> Self {
470 Stream {
471 stream: Some(BufStream::with_capacity(
472 0,
473 0,
474 RawStreamWrapper::new(RawStream::Tcp(s)),
475 )),
476 rewind_read_buf: Vec::new(),
477 buffer_write: true,
478 established_ts: SystemTime::now(),
479 proxy_digest: None,
480 socket_digest: None,
481 tracer: None,
482 read_pending_time: AccumulatedDuration::new(),
483 write_pending_time: AccumulatedDuration::new(),
484 rx_ts: None,
485 }
486 }
487}
488
489impl From<virt::VirtualSocketStream> for Stream {
490 fn from(s: virt::VirtualSocketStream) -> Self {
491 Stream {
492 stream: Some(BufStream::with_capacity(
493 0,
494 0,
495 RawStreamWrapper::new(RawStream::Virtual(s)),
496 )),
497 rewind_read_buf: Vec::new(),
498 buffer_write: true,
499 established_ts: SystemTime::now(),
500 proxy_digest: None,
501 socket_digest: None,
502 tracer: None,
503 read_pending_time: AccumulatedDuration::new(),
504 write_pending_time: AccumulatedDuration::new(),
505 rx_ts: None,
506 }
507 }
508}
509
510#[cfg(unix)]
511impl From<UnixStream> for Stream {
512 fn from(s: UnixStream) -> Self {
513 Stream {
514 stream: Some(BufStream::with_capacity(
515 0,
516 0,
517 RawStreamWrapper::new(RawStream::Unix(s)),
518 )),
519 rewind_read_buf: Vec::new(),
520 buffer_write: true,
521 established_ts: SystemTime::now(),
522 proxy_digest: None,
523 socket_digest: None,
524 tracer: None,
525 read_pending_time: AccumulatedDuration::new(),
526 write_pending_time: AccumulatedDuration::new(),
527 rx_ts: None,
528 }
529 }
530}
531
532#[cfg(unix)]
533impl AsRawFd for Stream {
534 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
535 self.stream().get_ref().as_raw_fd()
536 }
537}
538
539#[cfg(windows)]
540impl AsRawSocket for Stream {
541 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
542 self.stream().get_ref().as_raw_socket()
543 }
544}
545
546#[cfg(unix)]
547impl UniqueID for Stream {
548 fn id(&self) -> UniqueIDType {
549 self.as_raw_fd()
550 }
551}
552
553#[cfg(windows)]
554impl UniqueID for Stream {
555 fn id(&self) -> usize {
556 self.as_raw_socket() as usize
557 }
558}
559
560impl Ssl for Stream {}
561
562#[async_trait]
563impl Peek for Stream {
564 async fn try_peek(&mut self, buf: &mut [u8]) -> std::io::Result<bool> {
565 use tokio::io::AsyncReadExt;
566 self.read_exact(buf).await?;
567 self.rewind(buf);
569 Ok(true)
570 }
571}
572
573#[async_trait]
574impl Shutdown for Stream {
575 async fn shutdown(&mut self) {
576 AsyncWriteExt::shutdown(self).await.unwrap_or_else(|e| {
577 debug!("Failed to shutdown connection: {:?}", e);
578 });
579 }
580}
581
582impl GetTimingDigest for Stream {
583 fn get_timing_digest(&self) -> Vec<Option<TimingDigest>> {
584 let mut digest = Vec::with_capacity(2); digest.push(Some(TimingDigest {
586 established_ts: self.established_ts,
587 }));
588 digest
589 }
590
591 fn get_read_pending_time(&self) -> Duration {
592 self.read_pending_time.total
593 }
594
595 fn get_write_pending_time(&self) -> Duration {
596 self.write_pending_time.total
597 }
598}
599
600impl GetProxyDigest for Stream {
601 fn get_proxy_digest(&self) -> Option<Arc<ProxyDigest>> {
602 self.proxy_digest.clone()
603 }
604
605 fn set_proxy_digest(&mut self, digest: ProxyDigest) {
606 self.proxy_digest = Some(Arc::new(digest));
607 }
608}
609
610impl GetSocketDigest for Stream {
611 fn get_socket_digest(&self) -> Option<Arc<SocketDigest>> {
612 self.socket_digest.clone()
613 }
614
615 fn set_socket_digest(&mut self, socket_digest: SocketDigest) {
616 self.socket_digest = Some(Arc::new(socket_digest))
617 }
618}
619
620impl Drop for Stream {
621 fn drop(&mut self) {
622 if let Some(t) = self.tracer.as_ref() {
623 t.0.on_disconnected();
624 }
625 let ret = match &self.stream().get_ref().stream {
627 RawStream::Tcp(s) => s.nodelay().err(),
628 #[cfg(unix)]
629 RawStream::Unix(s) => s.local_addr().err(),
630 RawStream::Virtual(_) => {
631 None
633 }
634 };
635 if let Some(e) = ret {
636 match e.kind() {
637 tokio::io::ErrorKind::Other => {
638 if let Some(ecode) = e.raw_os_error() {
639 if ecode == 9 {
640 error!("Crit: socket {:?} is being double closed", self.stream);
642 }
643 }
644 }
645 _ => {
646 debug!("Socket is already broken {:?}", e);
647 }
648 }
649 } else {
650 let _ = self.flush().now_or_never();
654 }
655 debug!("Dropping socket {:?}", self.stream);
656 }
657}
658
659impl AsyncRead for Stream {
660 fn poll_read(
661 mut self: Pin<&mut Self>,
662 cx: &mut Context<'_>,
663 buf: &mut ReadBuf<'_>,
664 ) -> Poll<io::Result<()>> {
665 let result = if !self.rewind_read_buf.is_empty() {
666 let data_to_read = self.rewind_read_buf.pop().unwrap(); let mut data_to_read = data_to_read.as_slice();
668 let result = Pin::new(&mut data_to_read).poll_read(cx, buf);
669 if !data_to_read.is_empty() {
671 let remaining_buf = Vec::from(data_to_read);
672 self.rewind_read_buf.push(remaining_buf);
673 }
674 result
675 } else {
676 Pin::new(&mut self.stream_mut()).poll_read(cx, buf)
677 };
678 self.read_pending_time.poll_time(&result);
679 self.rx_ts = self.stream().get_ref().rx_ts;
680 result
681 }
682}
683
684impl AsyncWrite for Stream {
685 fn poll_write(
686 mut self: Pin<&mut Self>,
687 cx: &mut Context,
688 buf: &[u8],
689 ) -> Poll<io::Result<usize>> {
690 let result = if self.buffer_write {
691 Pin::new(&mut self.stream_mut()).poll_write(cx, buf)
692 } else {
693 Pin::new(&mut self.stream_mut().get_mut()).poll_write(cx, buf)
694 };
695 self.write_pending_time.poll_write_time(&result, buf.len());
696 result
697 }
698
699 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
700 let result = Pin::new(&mut self.stream_mut()).poll_flush(cx);
701 self.write_pending_time.poll_time(&result);
702 result
703 }
704
705 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
706 Pin::new(&mut self.stream_mut()).poll_shutdown(cx)
707 }
708
709 fn poll_write_vectored(
710 mut self: Pin<&mut Self>,
711 cx: &mut Context<'_>,
712 bufs: &[std::io::IoSlice<'_>],
713 ) -> Poll<io::Result<usize>> {
714 let total_size = bufs.iter().fold(0, |acc, s| acc + s.len());
715
716 let result = if self.buffer_write {
717 Pin::new(&mut self.stream_mut()).poll_write_vectored(cx, bufs)
718 } else {
719 Pin::new(&mut self.stream_mut().get_mut()).poll_write_vectored(cx, bufs)
720 };
721
722 self.write_pending_time.poll_write_time(&result, total_size);
723 result
724 }
725
726 fn is_write_vectored(&self) -> bool {
727 if self.buffer_write {
728 self.stream().is_write_vectored() } else {
730 self.stream().get_ref().is_write_vectored()
731 }
732 }
733}
734
735pub mod async_write_vec {
736 use bytes::Buf;
737 use futures::ready;
738 use std::future::Future;
739 use std::io::IoSlice;
740 use std::pin::Pin;
741 use std::task::{Context, Poll};
742 use tokio::io;
743 use tokio::io::AsyncWrite;
744
745 #[must_use = "futures do nothing unless you `.await` or poll them"]
752 pub struct WriteVec<'a, W, B> {
753 writer: &'a mut W,
754 buf: &'a mut B,
755 }
756
757 #[must_use = "futures do nothing unless you `.await` or poll them"]
758 pub struct WriteVecAll<'a, W, B> {
759 writer: &'a mut W,
760 buf: &'a mut B,
761 }
762
763 pub trait AsyncWriteVec {
764 fn poll_write_vec<B: Buf>(
765 self: Pin<&mut Self>,
766 _cx: &mut Context<'_>,
767 _buf: &mut B,
768 ) -> Poll<io::Result<usize>>;
769
770 fn write_vec<'a, B>(&'a mut self, src: &'a mut B) -> WriteVec<'a, Self, B>
771 where
772 Self: Sized,
773 B: Buf,
774 {
775 WriteVec {
776 writer: self,
777 buf: src,
778 }
779 }
780
781 fn write_vec_all<'a, B>(&'a mut self, src: &'a mut B) -> WriteVecAll<'a, Self, B>
782 where
783 Self: Sized,
784 B: Buf,
785 {
786 WriteVecAll {
787 writer: self,
788 buf: src,
789 }
790 }
791 }
792
793 impl<W, B> Future for WriteVec<'_, W, B>
794 where
795 W: AsyncWriteVec + Unpin,
796 B: Buf,
797 {
798 type Output = io::Result<usize>;
799
800 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<usize>> {
801 let me = &mut *self;
802 Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf)
803 }
804 }
805
806 impl<W, B> Future for WriteVecAll<'_, W, B>
807 where
808 W: AsyncWriteVec + Unpin,
809 B: Buf,
810 {
811 type Output = io::Result<()>;
812
813 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
814 let me = &mut *self;
815 while me.buf.has_remaining() {
816 let n = ready!(Pin::new(&mut *me.writer).poll_write_vec(ctx, me.buf))?;
817 if n == 0 {
818 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
819 }
820 }
821 Poll::Ready(Ok(()))
822 }
823 }
824
825 impl<T> AsyncWriteVec for T
827 where
828 T: AsyncWrite,
829 {
830 fn poll_write_vec<B: Buf>(
831 self: Pin<&mut Self>,
832 ctx: &mut Context,
833 buf: &mut B,
834 ) -> Poll<io::Result<usize>> {
835 const MAX_BUFS: usize = 64;
836
837 if !buf.has_remaining() {
838 return Poll::Ready(Ok(0));
839 }
840
841 let n = if self.is_write_vectored() {
842 let mut slices = [IoSlice::new(&[]); MAX_BUFS];
843 let cnt = buf.chunks_vectored(&mut slices);
844 ready!(self.poll_write_vectored(ctx, &slices[..cnt]))?
845 } else {
846 ready!(self.poll_write(ctx, buf.chunk()))?
847 };
848
849 buf.advance(n);
850
851 Poll::Ready(Ok(n))
852 }
853 }
854}
855
856pub use async_write_vec::AsyncWriteVec;
857
858#[derive(Debug)]
859struct AccumulatedDuration {
860 total: Duration,
861 last_start: Option<Instant>,
862}
863
864impl AccumulatedDuration {
865 fn new() -> Self {
866 AccumulatedDuration {
867 total: Duration::ZERO,
868 last_start: None,
869 }
870 }
871
872 fn start(&mut self) {
873 if self.last_start.is_none() {
874 self.last_start = Some(Instant::now());
875 }
876 }
877
878 fn stop(&mut self) {
879 if let Some(start) = self.last_start.take() {
880 self.total += start.elapsed();
881 }
882 }
883
884 fn poll_write_time(&mut self, result: &Poll<io::Result<usize>>, buf_size: usize) {
885 match result {
886 Poll::Ready(Ok(n)) => {
887 if *n == buf_size {
888 self.stop();
889 } else {
890 self.start();
892 }
893 }
894 Poll::Ready(Err(_)) => {
895 self.stop();
896 }
897 _ => self.start(),
898 }
899 }
900
901 fn poll_time(&mut self, result: &Poll<io::Result<()>>) {
902 match result {
903 Poll::Ready(_) => {
904 self.stop();
905 }
906 _ => self.start(),
907 }
908 }
909}
910
911#[cfg(test)]
912#[cfg(target_os = "linux")]
913mod tests {
914 use super::*;
915 use std::sync::Arc;
916 use tokio::io::AsyncReadExt;
917 use tokio::io::AsyncWriteExt;
918 use tokio::net::TcpListener;
919 use tokio::sync::Notify;
920
921 #[cfg(target_os = "linux")]
922 #[tokio::test]
923 async fn test_rx_timestamp() {
924 let message = "hello world".as_bytes();
925 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
926 let addr = listener.local_addr().unwrap();
927 let notify = Arc::new(Notify::new());
928 let notify2 = notify.clone();
929
930 tokio::spawn(async move {
931 let (mut stream, _) = listener.accept().await.unwrap();
932 notify2.notified().await;
933 stream.write_all(message).await.unwrap();
934 });
935
936 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
937 stream.set_rx_timestamp().unwrap();
938 std::thread::sleep(Duration::from_micros(100));
942 notify.notify_one();
943
944 let mut buffer = vec![0u8; message.len()];
945 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
946 assert_eq!(n, message.len());
947 assert!(stream.rx_ts.is_some());
948 }
949
950 #[cfg(target_os = "linux")]
951 #[tokio::test]
952 async fn test_rx_timestamp_standard_path() {
953 let message = "hello world".as_bytes();
954 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
955 let addr = listener.local_addr().unwrap();
956 let notify = Arc::new(Notify::new());
957 let notify2 = notify.clone();
958
959 tokio::spawn(async move {
960 let (mut stream, _) = listener.accept().await.unwrap();
961 notify2.notified().await;
962 stream.write_all(message).await.unwrap();
963 });
964
965 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
966 std::thread::sleep(Duration::from_micros(100));
967 notify.notify_one();
968
969 let mut buffer = vec![0u8; message.len()];
970 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
971 assert_eq!(n, message.len());
972 assert!(stream.rx_ts.is_none());
973 }
974
975 #[tokio::test]
976 async fn test_stream_rewind() {
977 let message = b"hello world";
978 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
979 let addr = listener.local_addr().unwrap();
980 let notify = Arc::new(Notify::new());
981 let notify2 = notify.clone();
982
983 tokio::spawn(async move {
984 let (mut stream, _) = listener.accept().await.unwrap();
985 notify2.notified().await;
986 stream.write_all(message).await.unwrap();
987 });
988
989 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
990
991 let rewind_test = b"this is Sparta!";
992 stream.rewind(rewind_test);
993
994 let mut buffer = vec![0u8; message.len()];
996 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
997 assert_eq!(n, message.len());
998 assert_eq!(buffer, rewind_test[..message.len()]);
999
1000 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
1002 assert_eq!(n, rewind_test.len() - message.len());
1003 assert_eq!(buffer[..n], rewind_test[message.len()..]);
1004
1005 notify.notify_one();
1007 let n = stream.read(buffer.as_mut_slice()).await.unwrap();
1008 assert_eq!(n, message.len());
1009 assert_eq!(buffer, message);
1010 }
1011
1012 #[tokio::test]
1013 async fn test_stream_peek() {
1014 let message = b"hello world";
1015 dbg!("try peek");
1016 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1017 let addr = listener.local_addr().unwrap();
1018 let notify = Arc::new(Notify::new());
1019 let notify2 = notify.clone();
1020
1021 tokio::spawn(async move {
1022 let (mut stream, _) = listener.accept().await.unwrap();
1023 notify2.notified().await;
1024 stream.write_all(message).await.unwrap();
1025 drop(stream);
1026 });
1027
1028 notify.notify_one();
1029
1030 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
1031 let mut buffer = vec![0u8; 5];
1032 assert!(stream.try_peek(&mut buffer).await.unwrap());
1033 assert_eq!(buffer, message[0..5]);
1034 let mut buffer = vec![];
1035 stream.read_to_end(&mut buffer).await.unwrap();
1036 assert_eq!(buffer, message);
1037 }
1038
1039 #[tokio::test]
1040 async fn test_stream_two_subsequent_peek_calls_before_read() {
1041 let message = b"abcdefghijklmn";
1042
1043 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1044 let addr = listener.local_addr().unwrap();
1045 let notify = Arc::new(Notify::new());
1046 let notify2 = notify.clone();
1047
1048 tokio::spawn(async move {
1049 let (mut stream, _) = listener.accept().await.unwrap();
1050 notify2.notified().await;
1051 stream.write_all(message).await.unwrap();
1052 drop(stream);
1053 });
1054
1055 notify.notify_one();
1056
1057 let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
1058
1059 let mut buffer = vec![0u8; 4];
1061 assert!(stream.try_peek(&mut buffer).await.unwrap());
1062 assert_eq!(buffer, message[0..4]);
1063
1064 let mut buffer = vec![0u8; 2];
1066 assert!(stream.try_peek(&mut buffer).await.unwrap());
1067 assert_eq!(buffer, message[0..2]);
1068
1069 let mut buffer = vec![0u8; 1];
1071 stream.read_exact(&mut buffer).await.unwrap();
1072 assert_eq!(buffer, message[0..1]);
1073
1074 let mut buffer = vec![0u8; 100];
1077 let n = stream.read(&mut buffer).await.unwrap();
1078 assert_eq!(n, 1);
1079 assert_eq!(buffer[..n], message[1..2]);
1080
1081 let mut buffer = vec![];
1083 stream.read_to_end(&mut buffer).await.unwrap();
1084 assert_eq!(buffer, message[2..]);
1085 }
1086}