fractal_storage_client/stream/
count.rs1use bytes::Bytes;
2use futures::task::Context;
3use futures::task::Poll;
4use futures::Stream;
5use std::error::Error as StdError;
6use std::pin::Pin;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9
10#[derive(Clone, Debug)]
13pub struct BytesCount {
14 bytes: Arc<AtomicUsize>,
15}
16
17impl BytesCount {
18 pub fn new(value: usize) -> Self {
20 BytesCount {
21 bytes: Arc::new(AtomicUsize::new(value)),
22 }
23 }
24
25 pub fn add(&self, value: usize) {
27 self.bytes.fetch_add(value, Ordering::SeqCst);
28 }
29
30 pub fn get(&self) -> usize {
32 self.bytes.load(Ordering::SeqCst)
33 }
34}
35
36pub struct CountBytesStream<E: StdError> {
39 stream: Pin<Box<dyn Stream<Item = Result<Bytes, E>> + Send + Sync>>,
41 count: BytesCount,
43}
44
45impl<E: StdError> CountBytesStream<E> {
46 pub fn new(stream: Pin<Box<dyn Stream<Item = Result<Bytes, E>> + Send + Sync>>) -> Self {
48 CountBytesStream {
49 stream,
50 count: BytesCount::new(0),
51 }
52 }
53
54 pub fn bytes_count(&self) -> BytesCount {
57 self.count.clone()
58 }
59}
60
61impl<E: StdError> Stream for CountBytesStream<E> {
62 type Item = Result<Bytes, E>;
63
64 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
65 let result = Pin::new(&mut self.stream).poll_next(cx);
66 match &result {
67 Poll::Ready(Some(Ok(bytes))) => {
68 self.count.add(bytes.len());
69 }
70 _ => (),
71 }
72 result
73 }
74}
75
76#[cfg(test)]
77#[tokio::test]
78async fn can_measure_bytes() {
79 use futures::StreamExt;
80 let data = Bytes::copy_from_slice(b"hello");
81 let stream = futures::stream::iter(vec![Ok(data.clone())]);
82 let mut stream = CountBytesStream::<std::io::Error>::new(Box::pin(stream));
83 assert_eq!(stream.bytes_count().get(), 0);
84
85 let result = stream.next().await.unwrap();
86 assert_eq!(result.unwrap(), data);
87 assert_eq!(stream.bytes_count().get(), 5);
88
89 let result = stream.next().await;
90 assert!(result.is_none());
91 assert_eq!(stream.bytes_count().get(), 5);
92}
93
94#[cfg(test)]
95#[tokio::test]
96async fn can_measure_bytes_multiple() {
97 use futures::StreamExt;
98 let data1 = Bytes::copy_from_slice(b"hello");
99 let data2 = Bytes::copy_from_slice(b"world!");
100 let stream = futures::stream::iter(vec![Ok(data1.clone()), Ok(data2.clone())]);
101 let mut stream = CountBytesStream::<std::io::Error>::new(Box::pin(stream));
102 let count = stream.bytes_count();
103 assert_eq!(count.get(), 0);
104
105 let result = stream.next().await.unwrap();
106 assert_eq!(result.unwrap(), data1);
107 assert_eq!(count.get(), 5);
108
109 let result = stream.next().await.unwrap();
110 assert_eq!(result.unwrap(), data2);
111 assert_eq!(count.get(), 11);
112
113 let result = stream.next().await;
114 assert!(result.is_none());
115 assert_eq!(count.get(), 11);
116}