rama_net/stream/layer/tracker/
bytes.rs

1//! Provides [`BytesRWTracker`] which wraps a [`AsyncRead`] and/or [`AsyncWrite`]
2//! in order to track the number of bytes read and/or written.
3//!
4//! Use [`BytesRWTracker::handle`] to get a [`BytesRWTrackerHandle`], a requirement
5//! to get the number of bytes read and/or written even though the [`BytesRWTracker`]
6//! is consumed by a protocol consumer, which is for example the case when you wish
7//! to track the bytes read and/or written for a Tcp stream that is owned by a Tls stream.
8//!
9//! [`AsyncRead`]: crate::stream::AsyncRead
10//! [`AsyncWrite`]: crate::stream::AsyncWrite
11
12use std::{
13    fmt, io,
14    pin::Pin,
15    sync::{
16        Arc,
17        atomic::{AtomicUsize, Ordering},
18    },
19    task::{Context, Poll},
20};
21
22use pin_project_lite::pin_project;
23use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
24
25pin_project! {
26    /// A wrapper around a [`AsyncRead`] and/or [`AsyncWrite`] that tracks the number
27    /// of bytes read and/or written.
28    ///
29    /// Use [`BytesRWTracker::handle`] to get a [`BytesRWTrackerHandle`] in order
30    /// to get the number of bytes read and/or written even though the [`BytesRWTracker`]
31    /// is consumed by a protocol consumer.
32    ///
33    /// [`AsyncRead`]: crate::stream::AsyncRead
34    /// [`AsyncWrite`]: crate::stream::AsyncWrite
35    pub struct BytesRWTracker<S> {
36        read: Arc<AtomicUsize>,
37        written: Arc<AtomicUsize>,
38        #[pin]
39        stream: S,
40    }
41}
42
43impl<S: fmt::Debug> fmt::Debug for BytesRWTracker<S> {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        f.debug_struct("BytesRWTracker")
46            .field("read", &self.read)
47            .field("written", &self.written)
48            .field("stream", &self.stream)
49            .finish()
50    }
51}
52
53impl<S> BytesRWTracker<S> {
54    /// Create a new [`BytesRWTracker`] that wraps the
55    /// given [`AsyncRead`] and/or [`AsyncWrite`].
56    ///
57    /// [`AsyncRead`]: crate::stream::AsyncRead
58    /// [`AsyncWrite`]: crate::stream::AsyncWrite
59    pub fn new(stream: S) -> Self {
60        Self {
61            read: Arc::new(AtomicUsize::new(0)),
62            written: Arc::new(AtomicUsize::new(0)),
63            stream,
64        }
65    }
66
67    /// Get the number of bytes read (so far).
68    pub fn read(&self) -> usize {
69        self.read.load(Ordering::Acquire)
70    }
71
72    /// Get the number of bytes written (so far).
73    pub fn written(&self) -> usize {
74        self.written.load(Ordering::Acquire)
75    }
76
77    /// Get a [`BytesRWTrackerHandle`] that can be used to get the number of bytes
78    /// read and/or written even though the tracker is consumed by a protocol
79    /// consumer in a later stage.
80    pub fn handle(&self) -> BytesRWTrackerHandle {
81        BytesRWTrackerHandle {
82            read: self.read.clone(),
83            written: self.written.clone(),
84        }
85    }
86
87    /// Get the inner [`AsyncRead`] and/or [`AsyncWrite`] stream.
88    /// Dropping the tracking info and capabilities for this stream.
89    ///
90    /// Any previously obtained [`BytesRWTrackerHandle`] will no longer
91    /// be updated but will still report the number of bytes read and/or
92    /// written up to the point where this method was called.
93    ///
94    /// [`AsyncRead`]: crate::stream::AsyncRead
95    /// [`AsyncWrite`]: crate::stream::AsyncWrite
96    pub fn into_inner(self) -> S {
97        self.stream
98    }
99}
100
101impl<S> AsyncRead for BytesRWTracker<S>
102where
103    S: AsyncRead,
104{
105    fn poll_read(
106        mut self: Pin<&mut Self>,
107        cx: &mut Context<'_>,
108        buf: &mut ReadBuf<'_>,
109    ) -> Poll<io::Result<()>> {
110        let this = self.as_mut().project();
111        let size = buf.filled().len();
112        let res: Poll<Result<(), io::Error>> = this.stream.poll_read(cx, buf);
113        if let Poll::Ready(Ok(_)) = res {
114            let new_size = buf.filled().len();
115            match new_size.cmp(&size) {
116                std::cmp::Ordering::Greater => {
117                    let bytes_read = new_size - size;
118                    this.read.fetch_add(bytes_read, Ordering::AcqRel);
119                }
120                std::cmp::Ordering::Less => {
121                    tracing::error!(
122                        "BytesRWTracker: poll_read returned Ok(()) with filled buffer smaller then before"
123                    );
124                }
125                std::cmp::Ordering::Equal => {
126                    tracing::trace!("BytesRWTracker: poll_read returned Ok(()) with nothing read");
127                }
128            }
129        }
130        res
131    }
132}
133
134impl<S> AsyncWrite for BytesRWTracker<S>
135where
136    S: AsyncWrite,
137{
138    fn poll_write(
139        mut self: Pin<&mut Self>,
140        cx: &mut Context<'_>,
141        buf: &[u8],
142    ) -> Poll<Result<usize, io::Error>> {
143        let this = self.as_mut().project();
144        let res: Poll<Result<usize, io::Error>> = this.stream.poll_write(cx, buf);
145        if let Poll::Ready(Ok(bytes_written)) = res {
146            this.written.fetch_add(bytes_written, Ordering::AcqRel);
147        }
148        res
149    }
150
151    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
152        self.project().stream.poll_flush(cx)
153    }
154
155    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
156        self.project().stream.poll_shutdown(cx)
157    }
158
159    fn poll_write_vectored(
160        mut self: Pin<&mut Self>,
161        cx: &mut Context<'_>,
162        bufs: &[io::IoSlice<'_>],
163    ) -> Poll<Result<usize, io::Error>> {
164        let this = self.as_mut().project();
165        let res: Poll<Result<usize, io::Error>> = this.stream.poll_write_vectored(cx, bufs);
166        if let Poll::Ready(Ok(bytes_written)) = res {
167            this.written.fetch_add(bytes_written, Ordering::AcqRel);
168        }
169        res
170    }
171
172    fn is_write_vectored(&self) -> bool {
173        self.stream.is_write_vectored()
174    }
175}
176
177/// A handle to a tracker that can be used to get the number of bytes
178/// read and/or written even though the tracker is consumed by a protocol
179/// consumer.
180#[derive(Debug, Clone)]
181pub struct BytesRWTrackerHandle {
182    read: Arc<AtomicUsize>,
183    written: Arc<AtomicUsize>,
184}
185
186impl BytesRWTrackerHandle {
187    /// Get the number of bytes read (so far).
188    pub fn read(&self) -> usize {
189        self.read.load(Ordering::Acquire)
190    }
191
192    /// Get the number of bytes written (so far).
193    pub fn written(&self) -> usize {
194        self.written.load(Ordering::Acquire)
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    use tokio::io::{AsyncReadExt, AsyncWriteExt};
203    use tokio_test::io::Builder;
204
205    #[tokio::test]
206    async fn test_read_tracker() {
207        let stream = Builder::new()
208            .read(b"foo")
209            .read(b"bar")
210            .read(b"baz")
211            .build();
212
213        let mut tracker = BytesRWTracker::new(stream);
214        let mut buf = [0u8; 3];
215
216        assert_eq!(tracker.read(), 0);
217        assert_eq!(tracker.written(), 0);
218        tracker.read_exact(&mut buf).await.unwrap();
219        assert_eq!(tracker.read(), 3);
220        assert_eq!(tracker.written(), 0);
221        tracker.read_exact(&mut buf).await.unwrap();
222        assert_eq!(tracker.read(), 6);
223        assert_eq!(tracker.written(), 0);
224        tracker.read_exact(&mut buf).await.unwrap();
225        assert_eq!(tracker.read(), 9);
226        assert_eq!(tracker.written(), 0);
227    }
228
229    #[tokio::test]
230    async fn test_written_tracker() {
231        let stream = Builder::new()
232            .write(b"foo")
233            .write(b"bar")
234            .write(b"baz")
235            .build();
236
237        let mut tracker = BytesRWTracker::new(stream);
238
239        assert_eq!(tracker.read(), 0);
240        assert_eq!(tracker.written(), 0);
241        tracker.write_all(b"foo").await.unwrap();
242        assert_eq!(tracker.read(), 0);
243        assert_eq!(tracker.written(), 3);
244        tracker.write_all(b"bar").await.unwrap();
245        assert_eq!(tracker.read(), 0);
246        assert_eq!(tracker.written(), 6);
247        tracker.write_all(b"baz").await.unwrap();
248        assert_eq!(tracker.read(), 0);
249        assert_eq!(tracker.written(), 9);
250    }
251
252    #[tokio::test]
253    async fn test_rw_tracker() {
254        let stream = Builder::new()
255            .read(b"foo")
256            .write(b"foo")
257            .read(b"bar")
258            .write(b"bar")
259            .read(b"baz")
260            .write(b"baz")
261            .build();
262
263        let mut tracker = BytesRWTracker::new(stream);
264        let mut buf = [0u8; 3];
265
266        assert_eq!(tracker.read(), 0);
267        assert_eq!(tracker.written(), 0);
268        tracker.read_exact(&mut buf).await.unwrap();
269        assert_eq!(tracker.read(), 3);
270        assert_eq!(tracker.written(), 0);
271        tracker.write_all(b"foo").await.unwrap();
272        assert_eq!(tracker.read(), 3);
273        assert_eq!(tracker.written(), 3);
274        tracker.read_exact(&mut buf).await.unwrap();
275        assert_eq!(tracker.read(), 6);
276        assert_eq!(tracker.written(), 3);
277        tracker.write_all(b"bar").await.unwrap();
278        assert_eq!(tracker.read(), 6);
279        assert_eq!(tracker.written(), 6);
280        tracker.read_exact(&mut buf).await.unwrap();
281        assert_eq!(tracker.read(), 9);
282        assert_eq!(tracker.written(), 6);
283        tracker.write_all(b"baz").await.unwrap();
284        assert_eq!(tracker.read(), 9);
285        assert_eq!(tracker.written(), 9);
286    }
287
288    #[tokio::test]
289    async fn test_rw_handle_tracker() {
290        let stream = Builder::new()
291            .read(b"foo")
292            .write(b"foo")
293            .read(b"bar")
294            .write(b"bar")
295            .read(b"baz")
296            .write(b"baz")
297            .build();
298
299        let tracker = BytesRWTracker::new(stream);
300        let handle = tracker.handle();
301
302        assert_eq!(handle.read(), 0);
303        assert_eq!(handle.written(), 0);
304
305        let (action_tx, mut action_rx) = tokio::sync::mpsc::channel(1);
306        let (check_tx, mut check_rx) = tokio::sync::broadcast::channel(1);
307        let check_rx_2 = check_tx.subscribe();
308
309        let task_1 = tokio::spawn(async move {
310            let mut tracker = tracker;
311            let mut buf = [0u8; 3];
312
313            action_rx.recv().await;
314            tracker.read_exact(&mut buf).await.unwrap();
315            check_tx.send(()).unwrap();
316
317            action_rx.recv().await;
318            tracker.write_all(b"foo").await.unwrap();
319            check_tx.send(()).unwrap();
320
321            action_rx.recv().await;
322            tracker.read_exact(&mut buf).await.unwrap();
323            check_tx.send(()).unwrap();
324
325            action_rx.recv().await;
326            tracker.write_all(b"bar").await.unwrap();
327            check_tx.send(()).unwrap();
328
329            action_rx.recv().await;
330            tracker.read_exact(&mut buf).await.unwrap();
331            check_tx.send(()).unwrap();
332
333            action_rx.recv().await;
334            tracker.write_all(b"baz").await.unwrap();
335            check_tx.send(()).unwrap();
336        });
337
338        let task_2 = {
339            let handle = handle.clone();
340            let mut check_rx = check_rx_2;
341            tokio::spawn(async move {
342                check_rx.recv().await.unwrap();
343
344                assert_eq!(handle.read(), 3);
345                assert_eq!(handle.written(), 0);
346
347                check_rx.recv().await.unwrap();
348
349                assert_eq!(handle.read(), 3);
350                assert_eq!(handle.written(), 3);
351
352                check_rx.recv().await.unwrap();
353
354                assert_eq!(handle.read(), 6);
355                assert_eq!(handle.written(), 3);
356
357                check_rx.recv().await.unwrap();
358
359                assert_eq!(handle.read(), 6);
360                assert_eq!(handle.written(), 6);
361
362                check_rx.recv().await.unwrap();
363
364                assert_eq!(handle.read(), 9);
365                assert_eq!(handle.written(), 6);
366
367                check_rx.recv().await.unwrap();
368
369                assert_eq!(handle.read(), 9);
370                assert_eq!(handle.written(), 9)
371            })
372        };
373
374        assert_eq!(handle.read(), 0);
375        assert_eq!(handle.written(), 0);
376
377        action_tx.send(()).await.unwrap();
378        check_rx.recv().await.unwrap();
379
380        assert_eq!(handle.read(), 3);
381        assert_eq!(handle.written(), 0);
382
383        action_tx.send(()).await.unwrap();
384        check_rx.recv().await.unwrap();
385
386        assert_eq!(handle.read(), 3);
387        assert_eq!(handle.written(), 3);
388
389        action_tx.send(()).await.unwrap();
390        check_rx.recv().await.unwrap();
391
392        assert_eq!(handle.read(), 6);
393        assert_eq!(handle.written(), 3);
394
395        action_tx.send(()).await.unwrap();
396        check_rx.recv().await.unwrap();
397
398        assert_eq!(handle.read(), 6);
399        assert_eq!(handle.written(), 6);
400
401        action_tx.send(()).await.unwrap();
402        check_rx.recv().await.unwrap();
403
404        assert_eq!(handle.read(), 9);
405        assert_eq!(handle.written(), 6);
406
407        action_tx.send(()).await.unwrap();
408        check_rx.recv().await.unwrap();
409
410        assert_eq!(handle.read(), 9);
411        assert_eq!(handle.written(), 9);
412
413        let (t1, t2) = futures_lite::future::zip(task_1, task_2).await;
414        t1.unwrap();
415        t2.unwrap();
416    }
417}