1use iroh_io::AsyncStreamReader;
4use std::{io, pin::Pin, task::Poll};
5use tokio::io::AsyncWrite;
6
7#[derive(Debug)]
9pub struct TrackingReader<R> {
10    inner: R,
11    read: u64,
12}
13
14impl<R> TrackingReader<R> {
15    pub fn new(inner: R) -> Self {
17        Self { inner, read: 0 }
18    }
19
20    #[allow(dead_code)]
22    pub fn bytes_read(&self) -> u64 {
23        self.read
24    }
25
26    pub fn into_parts(self) -> (R, u64) {
28        (self.inner, self.read)
29    }
30}
31
32impl<R> AsyncStreamReader for TrackingReader<R>
33where
34    R: AsyncStreamReader,
35{
36    async fn read_bytes(&mut self, len: usize) -> io::Result<bytes::Bytes> {
37        let bytes = self.inner.read_bytes(len).await?;
38        self.read = self.read.saturating_add(bytes.len() as u64);
39        Ok(bytes)
40    }
41
42    async fn read<const L: usize>(&mut self) -> io::Result<[u8; L]> {
43        let res = self.inner.read::<L>().await?;
44        self.read = self.read.saturating_add(L as u64);
45        Ok(res)
46    }
47}
48
49#[derive(Debug)]
51pub struct TrackingWriter<W> {
52    inner: W,
53    written: u64,
54}
55
56impl<W> TrackingWriter<W> {
57    pub fn new(inner: W) -> Self {
59        Self { inner, written: 0 }
60    }
61
62    #[allow(dead_code)]
64    pub fn bytes_written(&self) -> u64 {
65        self.written
66    }
67
68    pub fn into_parts(self) -> (W, u64) {
70        (self.inner, self.written)
71    }
72}
73
74impl<W: AsyncWrite + Unpin> AsyncWrite for TrackingWriter<W> {
75    fn poll_write(
76        mut self: Pin<&mut Self>,
77        cx: &mut std::task::Context<'_>,
78        buf: &[u8],
79    ) -> Poll<io::Result<usize>> {
80        let this = &mut *self;
81        let res = Pin::new(&mut this.inner).poll_write(cx, buf);
82        if let Poll::Ready(Ok(size)) = res {
83            this.written = this.written.saturating_add(size as u64);
84        }
85        res
86    }
87
88    fn poll_flush(
89        mut self: Pin<&mut Self>,
90        cx: &mut std::task::Context<'_>,
91    ) -> Poll<io::Result<()>> {
92        Pin::new(&mut self.inner).poll_flush(cx)
93    }
94
95    fn poll_shutdown(
96        mut self: Pin<&mut Self>,
97        cx: &mut std::task::Context<'_>,
98    ) -> Poll<io::Result<()>> {
99        Pin::new(&mut self.inner).poll_shutdown(cx)
100    }
101}