1use std::{
8 fmt,
9 io::{BufRead, ErrorKind, Read, Write},
10 pin::Pin,
11 sync::Arc,
12 task::{self, Poll},
13 time::Duration,
14};
15
16use crate::{McWaker, Progress};
17
18#[doc(no_inline)]
19pub use futures_lite::io::{
20 AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BoxedReader, BoxedWriter,
21 BufReader, BufWriter, Cursor, ReadHalf, WriteHalf, copy, empty, repeat, sink, split,
22};
23use parking_lot::Mutex;
24use std::io::{Error, Result};
25use zng_time::{DInstant, INSTANT};
26use zng_txt::formatx;
27use zng_unit::{ByteLength, ByteUnits};
28use zng_var::{Var, impl_from_and_into_var, var};
29
30struct MeasureInner {
31 metrics: Var<Metrics>,
32 start_time: DInstant,
33 last_write: DInstant,
34 last_read: DInstant,
35}
36impl MeasureInner {
37 fn new(read_progress: (ByteLength, ByteLength), write_progress: (ByteLength, ByteLength)) -> Self {
38 let now = INSTANT.now();
39 Self {
40 metrics: var(Metrics {
41 read_progress,
42 read_speed: 0.bytes(),
43 write_progress,
44 write_speed: 0.bytes(),
45 total_time: Duration::ZERO,
46 }),
47 start_time: now,
48 last_write: now,
49 last_read: now,
50 }
51 }
52
53 fn on_read(&mut self, bytes: usize) {
54 if bytes == 0 {
55 return;
56 }
57
58 let bytes = bytes.bytes();
59
60 let now = INSTANT.now();
61 let elapsed = now - self.last_read;
62
63 self.last_read = now;
64 let read_speed = bytes_per_sec(bytes, elapsed);
65
66 let total_time = now - self.start_time;
67
68 self.metrics.modify(move |m| {
69 m.read_progress.0 += bytes;
70 m.read_speed = read_speed;
71 m.total_time = total_time;
72 });
73 }
74
75 fn on_write(&mut self, bytes: usize) {
76 if bytes == 0 {
77 return;
78 }
79
80 let bytes = bytes.bytes();
81
82 let now = INSTANT.now();
83 let elapsed = now - self.last_write;
84
85 self.last_write = now;
86 let write_speed = bytes_per_sec(bytes, elapsed);
87
88 let total_time = now - self.start_time;
89
90 self.metrics.modify(move |m| {
91 m.write_progress.0 += bytes;
92 m.write_speed = write_speed;
93 m.total_time = total_time;
94 });
95 }
96}
97
98pub struct Measure<T> {
103 task: T,
104 inner: MeasureInner,
105}
106impl<T> Measure<T> {
107 pub fn new(task: T, total_read: ByteLength, total_write: ByteLength) -> Self {
109 Self::new_ongoing(task, (0.bytes(), total_read), (0.bytes(), total_write))
110 }
111
112 pub fn new_ongoing(task: T, read_progress: (ByteLength, ByteLength), write_progress: (ByteLength, ByteLength)) -> Self {
114 Measure {
115 task,
116 inner: MeasureInner::new(read_progress, write_progress),
117 }
118 }
119
120 pub fn metrics(&mut self) -> Var<Metrics> {
124 self.inner.metrics.read_only()
125 }
126
127 pub fn finish(self) -> (T, Metrics) {
129 let mut metrics = self.inner.metrics.get();
130 metrics.total_time = self.inner.start_time.elapsed();
131 (self.task, metrics)
132 }
133}
134
135fn bytes_per_sec(bytes: ByteLength, elapsed: Duration) -> ByteLength {
136 let bytes_per_sec = bytes.0 as u128 / elapsed.as_nanos() / Duration::from_secs(1).as_nanos();
137 ByteLength(bytes_per_sec as usize)
138}
139
140impl<T: AsyncRead> AsyncRead for Measure<T> {
141 fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
142 let self_ = unsafe { self.get_unchecked_mut() };
144
145 match unsafe { Pin::new_unchecked(&mut self_.task) }.poll_read(cx, buf) {
147 Poll::Ready(Ok(bytes)) => {
148 self_.inner.on_read(bytes);
149 Poll::Ready(Ok(bytes))
150 }
151 p => p,
152 }
153 }
154}
155impl<T: AsyncWrite> AsyncWrite for Measure<T> {
156 fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
157 let self_ = unsafe { self.get_unchecked_mut() };
159
160 match unsafe { Pin::new_unchecked(&mut self_.task) }.poll_write(cx, buf) {
162 Poll::Ready(Ok(bytes)) => {
163 self_.inner.on_write(bytes);
164 Poll::Ready(Ok(bytes))
165 }
166 p => p,
167 }
168 }
169
170 fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<()>> {
171 let self_ = unsafe { self.get_unchecked_mut() };
173
174 unsafe { Pin::new_unchecked(&mut self_.task) }.poll_flush(cx)
176 }
177
178 fn poll_close(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<()>> {
179 let self_ = unsafe { self.get_unchecked_mut() };
181
182 unsafe { Pin::new_unchecked(&mut self_.task) }.poll_flush(cx)
184 }
185}
186impl<T: AsyncBufRead> AsyncBufRead for Measure<T> {
187 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<&[u8]>> {
188 let self_ = unsafe { self.get_unchecked_mut() };
190
191 unsafe { Pin::new_unchecked(&mut self_.task) }.poll_fill_buf(cx)
193 }
194
195 fn consume(self: Pin<&mut Self>, amt: usize) {
196 let self_ = unsafe { self.get_unchecked_mut() };
198 unsafe { Pin::new_unchecked(&mut self_.task) }.consume(amt);
200 self_.inner.on_read(amt);
201 }
202}
203impl<T: Read> Read for Measure<T> {
204 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
205 match self.task.read(buf) {
206 Ok(bytes) => {
207 self.inner.on_read(bytes);
208 Ok(bytes)
209 }
210 r => r,
211 }
212 }
213}
214impl<T: Write> Write for Measure<T> {
215 fn write(&mut self, buf: &[u8]) -> Result<usize> {
216 match self.task.write(buf) {
217 Ok(bytes) => {
218 self.inner.on_write(bytes);
219 Ok(bytes)
220 }
221 r => r,
222 }
223 }
224
225 fn flush(&mut self) -> Result<()> {
226 self.task.flush()
227 }
228}
229impl<T: BufRead> BufRead for Measure<T> {
230 fn fill_buf(&mut self) -> Result<&[u8]> {
231 self.task.fill_buf()
232 }
233
234 fn consume(&mut self, amount: usize) {
235 self.task.consume(amount);
236 self.inner.on_read(amount);
237 }
238}
239
240#[derive(Debug, Clone, PartialEq, Eq)]
247#[non_exhaustive]
248pub struct Metrics {
249 pub read_progress: (ByteLength, ByteLength),
251
252 pub read_speed: ByteLength,
254
255 pub write_progress: (ByteLength, ByteLength),
257
258 pub write_speed: ByteLength,
260
261 pub total_time: Duration,
264}
265impl Metrics {
266 pub fn zero() -> Self {
268 Self {
269 read_progress: (0.bytes(), 0.bytes()),
270 read_speed: 0.bytes(),
271 write_progress: (0.bytes(), 0.bytes()),
272 write_speed: 0.bytes(),
273 total_time: Duration::ZERO,
274 }
275 }
276}
277impl fmt::Display for Metrics {
278 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
279 let mut nl = false;
280 if self.read_progress.1 > 0.bytes() {
281 nl = true;
282 if self.read_progress.0 != self.read_progress.1 {
283 write!(f, "↓ {}-{}, {}/s", self.read_progress.0, self.read_progress.1, self.read_speed)?;
284 nl = true;
285 } else {
286 write!(f, "↓ {} . {:?}", self.read_progress.0, self.total_time)?;
287 }
288 }
289 if self.write_progress.1 > 0.bytes() {
290 if nl {
291 writeln!(f)?;
292 }
293 if self.write_progress.0 != self.write_progress.1 {
294 write!(f, "↑ {} - {}, {}/s", self.write_progress.0, self.write_progress.1, self.write_speed)?;
295 } else {
296 write!(f, "↑ {} . {:?}", self.write_progress.0, self.total_time)?;
297 }
298 }
299
300 Ok(())
301 }
302}
303impl_from_and_into_var! {
304 fn from(metrics: Metrics) -> Progress {
305 let mut status = Progress::indeterminate();
306 if metrics.read_progress.1 > 0.bytes() {
307 status = Progress::from_n_of(metrics.read_progress.0.0, metrics.read_progress.1.0);
308 }
309 if metrics.write_progress.1 > 0.bytes() {
310 let w_status = Progress::from_n_of(metrics.write_progress.0.0, metrics.write_progress.1.0);
311 if status.is_indeterminate() {
312 status = w_status;
313 } else {
314 status = status.and_fct(w_status.fct());
315 }
316 }
317 status.with_msg(formatx!("{metrics}")).with_meta_mut(|mut m| {
318 m.set(*METRICS_ID, metrics);
319 })
320 }
321}
322
323zng_state_map::static_id! {
324 pub static ref METRICS_ID: zng_state_map::StateId<Metrics>;
326}
327
328pub trait McBufErrorExt {
330 fn is_only_lazy_left(&self) -> bool;
335}
336impl McBufErrorExt for std::io::Error {
337 fn is_only_lazy_left(&self) -> bool {
338 matches!(self.kind(), ErrorKind::Other) && format!("{self:?}").contains(ONLY_NON_LAZY_ERROR_MSG)
339 }
340}
341const ONLY_NON_LAZY_ERROR_MSG: &str = "no non-lazy readers left to read";
342
343pub struct McBufReader<S: AsyncRead> {
365 inner: Arc<Mutex<McBufInner<S>>>,
366 index: usize,
367 lazy: bool,
368}
369struct McBufInner<S: AsyncRead> {
370 source: Option<S>,
371 waker: McWaker,
372 lazy_wakers: Vec<task::Waker>,
373
374 buf: Vec<u8>,
375
376 clones: Vec<usize>,
377 non_lazy_count: usize,
378
379 result: ReadState,
380}
381impl<S: AsyncRead> McBufReader<S> {
382 pub fn new(source: S) -> Self {
384 let mut clones = Vec::with_capacity(2);
385 clones.push(0);
386 McBufReader {
387 inner: Arc::new(Mutex::new(McBufInner {
388 source: Some(source),
389 waker: McWaker::empty(),
390 lazy_wakers: vec![],
391
392 buf: Vec::with_capacity(10.kilobytes().0),
393
394 clones,
395 non_lazy_count: 1,
396
397 result: ReadState::Running,
398 })),
399 index: 0,
400 lazy: false,
401 }
402 }
403
404 pub fn is_lazy(&self) -> bool {
408 self.lazy
409 }
410
411 pub fn set_lazy(&mut self, lazy: bool) {
415 if self.lazy != lazy {
416 if lazy {
417 self.inner.lock().non_lazy_count -= 1;
418 } else {
419 self.inner.lock().non_lazy_count += 1;
420 }
421 self.lazy = lazy;
422 }
423 }
424}
425impl<S: AsyncRead> Clone for McBufReader<S> {
426 fn clone(&self) -> Self {
427 let mut inner = self.inner.lock();
428
429 let offset = inner.clones[self.index];
430 let index = inner.clones.len();
431 inner.clones.push(offset);
432
433 if !self.lazy {
434 inner.non_lazy_count += 1;
435 }
436
437 Self {
438 inner: self.inner.clone(),
439 index,
440 lazy: self.lazy,
441 }
442 }
443}
444impl<S: AsyncRead> Drop for McBufReader<S> {
445 fn drop(&mut self) {
446 let mut inner = self.inner.lock();
447 inner.clones[self.index] = usize::MAX;
448 if !self.lazy {
449 inner.non_lazy_count -= 1;
450 if inner.non_lazy_count == 0 {
451 for waker in inner.lazy_wakers.drain(..) {
453 waker.wake();
454 }
455 }
456 }
457 }
458}
459impl<S: AsyncRead> AsyncRead for McBufReader<S> {
460 fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
461 let self_ = self.as_ref();
462 let mut inner = self_.inner.lock();
463 let inner = &mut *inner;
464
465 let mut i = inner.clones[self_.index];
467 let mut ready;
468
469 match &inner.result {
470 ReadState::Running => {
471 ready = &inner.buf[i..];
474
475 if ready.is_empty() {
476 if self.lazy {
477 if inner.non_lazy_count == 0 {
478 return Poll::Ready(Err(Error::other(ONLY_NON_LAZY_ERROR_MSG)));
480 } else {
481 inner.lazy_wakers.push(cx.waker().clone());
483
484 return Poll::Pending;
486 }
487 }
488
489 ready = &[];
492
493 let waker = match inner.waker.push(cx.waker().clone()) {
494 Some(w) => w,
495 None => {
496 return Poll::Pending;
498 }
499 };
500
501 let min_i = inner.clones.iter().copied().min().unwrap();
502 if min_i > 0 {
503 inner.buf.copy_within(min_i.., 0);
505 inner.buf.truncate(inner.buf.len() - min_i);
506
507 i -= min_i;
508 for i in &mut inner.clones {
509 *i -= min_i;
510 }
511 }
512
513 let new_start = inner.buf.len();
514
515 inner.buf.resize(inner.buf.len() + buf.len().max(10.kilobytes().0), 0);
516
517 let mut inner_cx = task::Context::from_waker(&waker);
518
519 let source = unsafe { Pin::new_unchecked(inner.source.as_mut().unwrap()) };
521 let result = source.poll_read(&mut inner_cx, &mut inner.buf[new_start..]);
522
523 match result {
524 Poll::Ready(result) => {
525 for waker in inner.lazy_wakers.drain(..) {
527 waker.wake();
528 }
529
530 match result {
531 Ok(0) => {
532 inner.waker.cancel();
533
534 inner.buf.truncate(new_start);
536 inner.result = ReadState::Eof;
537 inner.source = None;
538
539 }
541 Ok(read) => {
542 inner.waker.cancel();
543
544 inner.buf.truncate(new_start + read);
546 ready = &inner.buf[i..];
547
548 }
550 Err(e) => {
551 inner.waker.cancel();
552
553 inner.result = ReadState::Err(CloneableError::new(&e));
555 inner.buf = vec![];
556 inner.source = None;
557
558 return Poll::Ready(Err(e));
559 }
560 }
561 }
562
563 Poll::Pending => {
564 inner.buf.truncate(new_start);
565 return Poll::Pending;
566 }
567 }
568 }
569 }
570 ReadState::Eof => {
571 ready = &inner.buf[i..];
572
573 }
575 ReadState::Err(e) => return Poll::Ready(e.err()),
576 }
577
578 let max_ready = buf.len().min(ready.len());
581 buf[..max_ready].copy_from_slice(&ready[..max_ready]);
582
583 i += max_ready;
584 inner.clones[self_.index] = i;
585
586 Poll::Ready(Ok(max_ready))
587 }
588}
589
590#[derive(Clone)]
601pub struct CloneableError {
602 info: ErrorInfo,
603}
604#[derive(Clone)]
605enum ErrorInfo {
606 OsError(i32),
607 Other(ErrorKind, String),
608}
609impl CloneableError {
610 pub fn new(e: &Error) -> Self {
612 let info = if let Some(code) = e.raw_os_error() {
613 ErrorInfo::OsError(code)
614 } else {
615 ErrorInfo::Other(e.kind(), format!("{e}"))
616 };
617
618 Self { info }
619 }
620
621 pub fn err<T>(&self) -> Result<T> {
623 Err(self.clone().into())
624 }
625}
626impl From<CloneableError> for Error {
627 fn from(e: CloneableError) -> Self {
628 match e.info {
629 ErrorInfo::OsError(code) => Error::from_raw_os_error(code),
630 ErrorInfo::Other(kind, msg) => Error::new(kind, msg),
631 }
632 }
633}
634
635pub struct ReadLimited<S> {
639 source: S,
640 limit: usize,
641 on_limit: fn() -> std::io::Error,
642}
643impl<S> ReadLimited<S> {
644 pub fn new(source: S, limit: ByteLength, on_limit: fn() -> std::io::Error) -> Self {
648 Self {
649 source,
650 limit: limit.0,
651 on_limit,
652 }
653 }
654
655 pub fn new_default_err(source: S, limit: ByteLength) -> Self {
657 Self::new(source, limit, || {
658 std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "source exceeded read limit")
659 })
660 }
661}
662impl<S> AsyncRead for ReadLimited<S>
663where
664 S: AsyncRead,
665{
666 fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, mut buf: &mut [u8]) -> Poll<Result<usize>> {
667 let self_ = unsafe { self.get_unchecked_mut() };
669
670 if self_.limit == 0 {
671 let err = (self_.on_limit)();
672 return Poll::Ready(Err(err));
673 }
674
675 if buf.len() > self_.limit {
676 buf = &mut buf[..self_.limit];
677 }
678
679 match unsafe { Pin::new_unchecked(&mut self_.source) }.poll_read(cx, buf) {
681 Poll::Ready(Ok(n)) => {
682 self_.limit = self_.limit.saturating_sub(n);
683 Poll::Ready(Ok(n))
684 }
685 r => r,
686 }
687 }
688}
689impl<S> AsyncBufRead for ReadLimited<S>
690where
691 S: AsyncBufRead,
692{
693 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<&[u8]>> {
694 let self_ = unsafe { self.get_unchecked_mut() };
696
697 if self_.limit == 0 {
698 let err = (self_.on_limit)();
699 return Poll::Ready(Err(err));
700 }
701
702 unsafe { Pin::new_unchecked(&mut self_.source) }.poll_fill_buf(cx)
704 }
705
706 fn consume(self: Pin<&mut Self>, amt: usize) {
707 let self_ = unsafe { self.get_unchecked_mut() };
709 unsafe { Pin::new_unchecked(&mut self_.source) }.consume(amt);
711 self_.limit = self_.limit.saturating_sub(amt);
712 }
713}
714impl<S> Read for ReadLimited<S>
715where
716 S: Read,
717{
718 fn read(&mut self, mut buf: &mut [u8]) -> Result<usize> {
719 if self.limit == 0 {
720 let err = (self.on_limit)();
721 return Err(err);
722 }
723
724 if buf.len() > self.limit {
725 buf = &mut buf[..self.limit];
726 }
727
728 match self.source.read(buf) {
729 Ok(n) => {
730 self.limit = self.limit.saturating_sub(n);
731 Ok(n)
732 }
733 r => r,
734 }
735 }
736}
737impl<S> BufRead for ReadLimited<S>
738where
739 S: BufRead,
740{
741 fn fill_buf(&mut self) -> Result<&[u8]> {
742 if self.limit == 0 {
743 let err = (self.on_limit)();
744 return Err(err);
745 }
746
747 self.source.fill_buf()
748 }
749
750 fn consume(&mut self, amount: usize) {
751 self.source.consume(amount);
752 self.limit = self.limit.saturating_sub(amount);
753 }
754}
755
756enum ReadState {
757 Running,
758 Eof,
759 Err(CloneableError),
760}
761
762#[cfg(test)]
763mod tests {
764 use super::*;
765 use crate as task;
766 use zng_unit::TimeUnits;
767
768 #[test]
769 pub fn mc_buf_reader_parallel() {
770 let data = Data::new(60.kilobytes().0);
771
772 let mut expected = vec![0; data.len];
773 let _ = data.clone().blocking_read(&mut expected[..]);
774
775 let mut a = McBufReader::new(data);
776 let mut b = a.clone();
777 let mut c = a.clone();
778
779 let (a, b, c) = async_test(async move {
780 let a = task::run(async move {
781 let mut buf = vec![];
782 a.read_to_end(&mut buf).await.unwrap();
783 buf
784 });
785 let b = task::run(async move {
786 let mut buf: Vec<u8> = vec![];
787 b.read_to_end(&mut buf).await.unwrap();
788 buf
789 });
790 let c = task::run(async move {
791 let mut buf: Vec<u8> = vec![];
792 c.read_to_end(&mut buf).await.unwrap();
793 buf
794 });
795
796 task::all!(a, b, c).await
797 });
798
799 crate::assert_vec_eq!(expected, a);
800 crate::assert_vec_eq!(expected, b);
801 crate::assert_vec_eq!(expected, c);
802 }
803
804 #[test]
805 pub fn mc_buf_reader_single() {
806 let data = Data::new(60.kilobytes().0);
807
808 let mut expected = vec![0; data.len];
809 let _ = data.clone().blocking_read(&mut expected[..]);
810
811 let mut a = McBufReader::new(data);
812
813 let a = async_test(async move {
814 let a = task::run(async move {
815 let mut buf = vec![];
816 a.read_to_end(&mut buf).await.unwrap();
817 buf
818 });
819
820 a.await
821 });
822
823 crate::assert_vec_eq!(expected, a);
824 }
825
826 #[test]
827 pub fn mc_buf_reader_sequential() {
828 let data = Data::new(60.kilobytes().0);
829
830 let mut expected = vec![0; data.len];
831 let _ = data.clone().blocking_read(&mut expected[..]);
832
833 let mut clones = vec![McBufReader::new(data)];
834 for _ in 0..5 {
835 clones.push(clones[0].clone());
836 }
837
838 let r = async_test(async move {
839 let mut r = vec![];
840
841 for mut clone in clones {
842 let mut buf = vec![];
843 clone.read_to_end(&mut buf).await.unwrap();
844 r.push(buf);
845 }
846
847 r
848 });
849
850 for r in r {
851 crate::assert_vec_eq!(expected, r);
852 }
853 }
854
855 #[test]
856 pub fn mc_buf_reader_completed() {
857 let data = Data::new(60.kilobytes().0);
858 let mut buf = Vec::with_capacity(data.len);
859 let mut a = McBufReader::new(data);
860
861 let r = async_test(async move {
862 a.read_to_end(&mut buf).await.unwrap();
863
864 let mut b = a.clone();
865 buf.clear();
866
867 b.read_to_end(&mut buf).await.unwrap();
868 buf.len()
869 });
870
871 assert_eq!(0, r);
872 }
873
874 #[test]
875 pub fn mc_buf_reader_error() {
876 let mut data = Data::new(20.kilobytes().0);
877 data.set_error();
878
879 let mut expected = vec![0; data.len];
880 let _ = data.clone().blocking_read(&mut expected[..]);
881
882 let mut a = McBufReader::new(data);
883 let mut b = a.clone();
884
885 let (a, b) = async_test(async move {
886 let a = task::run(async move {
887 let mut buf = vec![];
888 a.read_to_end(&mut buf).await.unwrap_err()
889 });
890 let b = task::run(async move {
891 let mut buf: Vec<u8> = vec![];
892 b.read_to_end(&mut buf).await.unwrap_err()
893 });
894
895 task::all!(a, b).await
896 });
897
898 assert_eq!(ErrorKind::InvalidData, a.kind());
899 assert_eq!(ErrorKind::InvalidData, b.kind());
900 }
901
902 #[test]
903 pub fn mc_buf_reader_error_completed() {
904 let mut data = Data::new(20.kilobytes().0);
905 data.set_error();
906
907 let mut buf = Vec::with_capacity(data.len);
908 let mut a = McBufReader::new(data);
909
910 let (a, b) = async_test(async move {
911 let a_err = a.read_to_end(&mut buf).await.unwrap_err();
912
913 let mut b = a.clone();
914 buf.clear();
915
916 let b_err = b.read_to_end(&mut buf).await.unwrap_err();
917
918 (a_err, b_err)
919 });
920
921 assert_eq!(ErrorKind::InvalidData, a.kind());
922 assert_eq!(ErrorKind::InvalidData, b.kind());
923 }
924
925 #[test]
926 pub fn mc_buf_reader_parallel_with_delay1() {
927 let mut data = Data::new(60.kilobytes().0);
928 data.enable_pending();
929
930 let mut expected = vec![0; data.len];
931 let _ = data.clone().blocking_read(&mut expected[..]);
932
933 let mut a = McBufReader::new(data);
934 let mut b = a.clone();
935 let mut c = a.clone();
936
937 let (a, b, c) = async_test(async move {
938 let a = task::run(async move {
939 let mut buf = vec![];
940 a.read_to_end(&mut buf).await.unwrap();
941 buf
942 });
943 let b = task::run(async move {
944 let mut buf: Vec<u8> = vec![];
945 b.read_to_end(&mut buf).await.unwrap();
946 buf
947 });
948 let c = task::run(async move {
949 let mut buf: Vec<u8> = vec![];
950 c.read_to_end(&mut buf).await.unwrap();
951 buf
952 });
953
954 task::all!(a, b, c).await
955 });
956
957 crate::assert_vec_eq!(expected, a);
958 crate::assert_vec_eq!(expected, b);
959 crate::assert_vec_eq!(expected, c);
960 }
961
962 #[test]
963 pub fn mc_buf_reader_parallel_with_delay2() {
964 let mut data = Data::new(60.kilobytes().0);
965 data.enable_pending();
966
967 let mut expected = vec![0; data.len];
968 let _ = data.clone().blocking_read(&mut expected[..]);
969
970 let mut a = McBufReader::new(data);
971 let mut b = a.clone();
972 let mut c = a.clone();
973
974 let (a, b, c) = async_test(async move {
975 let a = task::run(async move {
976 let mut buf = vec![];
977 a.read_to_end(&mut buf).await.unwrap();
978 buf
979 });
980 let b = task::run(async move {
981 let mut buf: Vec<u8> = vec![];
982 task::deadline(5.ms()).await;
983 b.read_to_end(&mut buf).await.unwrap();
984 buf
985 });
986 let c = task::run(async move {
987 let mut buf: Vec<u8> = vec![];
988 c.read_to_end(&mut buf).await.unwrap();
989 buf
990 });
991
992 task::all!(a, b, c).await
993 });
994
995 crate::assert_vec_eq!(expected, a);
996 crate::assert_vec_eq!(expected, b);
997 crate::assert_vec_eq!(expected, c);
998 }
999
1000 #[derive(Clone)]
1001 struct Data {
1002 b: u8,
1003 len: usize,
1004 error: Option<CloneableError>,
1005 delay: Duration,
1006 pending: bool,
1007 }
1008 impl Data {
1009 pub fn new(len: usize) -> Self {
1010 Self {
1011 b: 0,
1012 len,
1013 error: None,
1014 delay: 0.ms(),
1015 pending: false,
1016 }
1017 }
1018 pub fn blocking_read(&mut self, buf: &mut [u8]) -> Result<usize> {
1019 let len = self.len;
1020 for b in buf.iter_mut().take(len) {
1021 *b = self.b;
1022 self.len -= 1;
1023 self.b = self.b.wrapping_add(1);
1024 }
1025
1026 if len == 0
1027 && let Some(e) = &self.error
1028 {
1029 return e.err();
1030 }
1031
1032 Ok(buf.len().min(len))
1033 }
1034 pub fn set_error(&mut self) {
1035 self.error = Some(CloneableError::new(&Error::new(ErrorKind::InvalidData, "test error")));
1036 }
1037
1038 pub fn enable_pending(&mut self) {
1039 self.delay = 3.ms();
1040 }
1041 }
1042 impl AsyncRead for Data {
1043 fn poll_read(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
1044 if self.delay > Duration::ZERO {
1045 self.pending = !self.pending;
1046 if self.pending {
1047 let waker = cx.waker().clone();
1048 let delay = self.delay;
1049 task::spawn(async move {
1050 task::deadline(delay).await;
1051 waker.wake();
1052 });
1053 return Poll::Pending;
1054 }
1055 }
1056
1057 let r = self.as_mut().blocking_read(buf);
1058 Poll::Ready(r)
1059 }
1060 }
1061
1062 #[track_caller]
1063 fn async_test<F>(test: F) -> F::Output
1064 where
1065 F: Future,
1066 {
1067 task::block_on(task::with_deadline(test, 5.secs())).unwrap()
1068 }
1069
1070 #[macro_export]
1072 macro_rules! assert_vec_eq {
1073 ($a:expr, $b: expr) => {
1074 match (&$a, &$b) {
1075 (ref a, ref b) => {
1076 let len_not_eq = a.len() != b.len();
1077 let mut data_not_eq = None;
1078 for (i, (a, b)) in a.iter().zip(b.iter()).enumerate() {
1079 if a != b {
1080 data_not_eq = Some(i);
1081 break;
1082 }
1083 }
1084
1085 if len_not_eq || data_not_eq.is_some() {
1086 use std::fmt::*;
1087
1088 let mut error = format!("`{}` != `{}`", stringify!($a), stringify!($b));
1089 if len_not_eq {
1090 let _ = write!(&mut error, "\n lengths not equal: {} != {}", a.len(), b.len());
1091 }
1092 if let Some(i) = data_not_eq {
1093 let _ = write!(&mut error, "\n data not equal at index {}: {} != {:?}", i, a[i], b[i]);
1094 }
1095 panic!("{error}")
1096 }
1097 }
1098 }
1099 };
1100 }
1101}