iroh_blobs/util/
io.rs

1//! Utilities for working with tokio io
2
3use std::{io, pin::Pin, task::Poll};
4
5use iroh_io::AsyncStreamReader;
6use tokio::io::AsyncWrite;
7
8/// A reader that tracks the number of bytes read
9#[derive(Debug)]
10pub struct TrackingReader<R> {
11    inner: R,
12    read: u64,
13}
14
15impl<R> TrackingReader<R> {
16    /// Wrap a reader in a tracking reader
17    pub fn new(inner: R) -> Self {
18        Self { inner, read: 0 }
19    }
20
21    /// Get the number of bytes read
22    #[allow(dead_code)]
23    pub fn bytes_read(&self) -> u64 {
24        self.read
25    }
26
27    /// Get the inner reader
28    pub fn into_parts(self) -> (R, u64) {
29        (self.inner, self.read)
30    }
31}
32
33impl<R> AsyncStreamReader for TrackingReader<R>
34where
35    R: AsyncStreamReader,
36{
37    async fn read_bytes(&mut self, len: usize) -> io::Result<bytes::Bytes> {
38        let bytes = self.inner.read_bytes(len).await?;
39        self.read = self.read.saturating_add(bytes.len() as u64);
40        Ok(bytes)
41    }
42
43    async fn read<const L: usize>(&mut self) -> io::Result<[u8; L]> {
44        let res = self.inner.read::<L>().await?;
45        self.read = self.read.saturating_add(L as u64);
46        Ok(res)
47    }
48}
49
50/// A writer that tracks the number of bytes written
51#[derive(Debug)]
52pub struct TrackingWriter<W> {
53    inner: W,
54    written: u64,
55}
56
57impl<W> TrackingWriter<W> {
58    /// Wrap a writer in a tracking writer
59    pub fn new(inner: W) -> Self {
60        Self { inner, written: 0 }
61    }
62
63    /// Get the number of bytes written
64    #[allow(dead_code)]
65    pub fn bytes_written(&self) -> u64 {
66        self.written
67    }
68
69    /// Get the inner writer
70    pub fn into_parts(self) -> (W, u64) {
71        (self.inner, self.written)
72    }
73}
74
75impl<W: AsyncWrite + Unpin> AsyncWrite for TrackingWriter<W> {
76    fn poll_write(
77        mut self: Pin<&mut Self>,
78        cx: &mut std::task::Context<'_>,
79        buf: &[u8],
80    ) -> Poll<io::Result<usize>> {
81        let this = &mut *self;
82        let res = Pin::new(&mut this.inner).poll_write(cx, buf);
83        if let Poll::Ready(Ok(size)) = res {
84            this.written = this.written.saturating_add(size as u64);
85        }
86        res
87    }
88
89    fn poll_flush(
90        mut self: Pin<&mut Self>,
91        cx: &mut std::task::Context<'_>,
92    ) -> Poll<io::Result<()>> {
93        Pin::new(&mut self.inner).poll_flush(cx)
94    }
95
96    fn poll_shutdown(
97        mut self: Pin<&mut Self>,
98        cx: &mut std::task::Context<'_>,
99    ) -> Poll<io::Result<()>> {
100        Pin::new(&mut self.inner).poll_shutdown(cx)
101    }
102}