1use iroh_io::AsyncStreamReader;
4use std::{io, pin::Pin, task::Poll};
5use tokio::io::AsyncWrite;
6
7#[derive(Debug)]
9pub struct TrackingReader<R> {
10 inner: R,
11 read: u64,
12}
13
14impl<R> TrackingReader<R> {
15 pub fn new(inner: R) -> Self {
17 Self { inner, read: 0 }
18 }
19
20 #[allow(dead_code)]
22 pub fn bytes_read(&self) -> u64 {
23 self.read
24 }
25
26 pub fn into_parts(self) -> (R, u64) {
28 (self.inner, self.read)
29 }
30}
31
32impl<R> AsyncStreamReader for TrackingReader<R>
33where
34 R: AsyncStreamReader,
35{
36 async fn read_bytes(&mut self, len: usize) -> io::Result<bytes::Bytes> {
37 let bytes = self.inner.read_bytes(len).await?;
38 self.read = self.read.saturating_add(bytes.len() as u64);
39 Ok(bytes)
40 }
41
42 async fn read<const L: usize>(&mut self) -> io::Result<[u8; L]> {
43 let res = self.inner.read::<L>().await?;
44 self.read = self.read.saturating_add(L as u64);
45 Ok(res)
46 }
47}
48
49#[derive(Debug)]
51pub struct TrackingWriter<W> {
52 inner: W,
53 written: u64,
54}
55
56impl<W> TrackingWriter<W> {
57 pub fn new(inner: W) -> Self {
59 Self { inner, written: 0 }
60 }
61
62 #[allow(dead_code)]
64 pub fn bytes_written(&self) -> u64 {
65 self.written
66 }
67
68 pub fn into_parts(self) -> (W, u64) {
70 (self.inner, self.written)
71 }
72}
73
74impl<W: AsyncWrite + Unpin> AsyncWrite for TrackingWriter<W> {
75 fn poll_write(
76 mut self: Pin<&mut Self>,
77 cx: &mut std::task::Context<'_>,
78 buf: &[u8],
79 ) -> Poll<io::Result<usize>> {
80 let this = &mut *self;
81 let res = Pin::new(&mut this.inner).poll_write(cx, buf);
82 if let Poll::Ready(Ok(size)) = res {
83 this.written = this.written.saturating_add(size as u64);
84 }
85 res
86 }
87
88 fn poll_flush(
89 mut self: Pin<&mut Self>,
90 cx: &mut std::task::Context<'_>,
91 ) -> Poll<io::Result<()>> {
92 Pin::new(&mut self.inner).poll_flush(cx)
93 }
94
95 fn poll_shutdown(
96 mut self: Pin<&mut Self>,
97 cx: &mut std::task::Context<'_>,
98 ) -> Poll<io::Result<()>> {
99 Pin::new(&mut self.inner).poll_shutdown(cx)
100 }
101}