fractal_storage_client/stream/
count.rs

1use bytes::Bytes;
2use futures::task::Context;
3use futures::task::Poll;
4use futures::Stream;
5use std::error::Error as StdError;
6use std::pin::Pin;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9
10/// Atomic counter that is safe to be shared between threads, as it uses atomic
11/// add and load operations.
12#[derive(Clone, Debug)]
13pub struct BytesCount {
14    bytes: Arc<AtomicUsize>,
15}
16
17impl BytesCount {
18    /// Creates new with initial value
19    pub fn new(value: usize) -> Self {
20        BytesCount {
21            bytes: Arc::new(AtomicUsize::new(value)),
22        }
23    }
24
25    /// Adds a value to the counter
26    pub fn add(&self, value: usize) {
27        self.bytes.fetch_add(value, Ordering::SeqCst);
28    }
29
30    /// Fetches the current value
31    pub fn get(&self) -> usize {
32        self.bytes.load(Ordering::SeqCst)
33    }
34}
35
36/// Stream adaptor that has the ability to measure the amount of bytes that
37/// pass through it.
38pub struct CountBytesStream<E: StdError> {
39    /// Underlying stream
40    stream: Pin<Box<dyn Stream<Item = Result<Bytes, E>> + Send + Sync>>,
41    /// Count of bytes that passed through so far
42    count: BytesCount,
43}
44
45impl<E: StdError> CountBytesStream<E> {
46    /// Create new stream from an underlying stream
47    pub fn new(stream: Pin<Box<dyn Stream<Item = Result<Bytes, E>> + Send + Sync>>) -> Self {
48        CountBytesStream {
49            stream,
50            count: BytesCount::new(0),
51        }
52    }
53
54    /// Return a clone of the BytesCount instance that can be used to fetch the number of bytes
55    /// at a later point.
56    pub fn bytes_count(&self) -> BytesCount {
57        self.count.clone()
58    }
59}
60
61impl<E: StdError> Stream for CountBytesStream<E> {
62    type Item = Result<Bytes, E>;
63
64    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
65        let result = Pin::new(&mut self.stream).poll_next(cx);
66        match &result {
67            Poll::Ready(Some(Ok(bytes))) => {
68                self.count.add(bytes.len());
69            }
70            _ => (),
71        }
72        result
73    }
74}
75
76#[cfg(test)]
77#[tokio::test]
78async fn can_measure_bytes() {
79    use futures::StreamExt;
80    let data = Bytes::copy_from_slice(b"hello");
81    let stream = futures::stream::iter(vec![Ok(data.clone())]);
82    let mut stream = CountBytesStream::<std::io::Error>::new(Box::pin(stream));
83    assert_eq!(stream.bytes_count().get(), 0);
84
85    let result = stream.next().await.unwrap();
86    assert_eq!(result.unwrap(), data);
87    assert_eq!(stream.bytes_count().get(), 5);
88
89    let result = stream.next().await;
90    assert!(result.is_none());
91    assert_eq!(stream.bytes_count().get(), 5);
92}
93
94#[cfg(test)]
95#[tokio::test]
96async fn can_measure_bytes_multiple() {
97    use futures::StreamExt;
98    let data1 = Bytes::copy_from_slice(b"hello");
99    let data2 = Bytes::copy_from_slice(b"world!");
100    let stream = futures::stream::iter(vec![Ok(data1.clone()), Ok(data2.clone())]);
101    let mut stream = CountBytesStream::<std::io::Error>::new(Box::pin(stream));
102    let count = stream.bytes_count();
103    assert_eq!(count.get(), 0);
104
105    let result = stream.next().await.unwrap();
106    assert_eq!(result.unwrap(), data1);
107    assert_eq!(count.get(), 5);
108
109    let result = stream.next().await.unwrap();
110    assert_eq!(result.unwrap(), data2);
111    assert_eq!(count.get(), 11);
112
113    let result = stream.next().await;
114    assert!(result.is_none());
115    assert_eq!(count.get(), 11);
116}