rama_net/stream/layer/tracker/
bytes.rs1use std::{
13 fmt, io,
14 pin::Pin,
15 sync::{
16 Arc,
17 atomic::{AtomicUsize, Ordering},
18 },
19 task::{Context, Poll},
20};
21
22use pin_project_lite::pin_project;
23use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
24
25pin_project! {
26 pub struct BytesRWTracker<S> {
36 read: Arc<AtomicUsize>,
37 written: Arc<AtomicUsize>,
38 #[pin]
39 stream: S,
40 }
41}
42
43impl<S: fmt::Debug> fmt::Debug for BytesRWTracker<S> {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 f.debug_struct("BytesRWTracker")
46 .field("read", &self.read)
47 .field("written", &self.written)
48 .field("stream", &self.stream)
49 .finish()
50 }
51}
52
53impl<S> BytesRWTracker<S> {
54 pub fn new(stream: S) -> Self {
60 Self {
61 read: Arc::new(AtomicUsize::new(0)),
62 written: Arc::new(AtomicUsize::new(0)),
63 stream,
64 }
65 }
66
67 pub fn read(&self) -> usize {
69 self.read.load(Ordering::Acquire)
70 }
71
72 pub fn written(&self) -> usize {
74 self.written.load(Ordering::Acquire)
75 }
76
77 pub fn handle(&self) -> BytesRWTrackerHandle {
81 BytesRWTrackerHandle {
82 read: self.read.clone(),
83 written: self.written.clone(),
84 }
85 }
86
87 pub fn into_inner(self) -> S {
97 self.stream
98 }
99}
100
101impl<S> AsyncRead for BytesRWTracker<S>
102where
103 S: AsyncRead,
104{
105 fn poll_read(
106 mut self: Pin<&mut Self>,
107 cx: &mut Context<'_>,
108 buf: &mut ReadBuf<'_>,
109 ) -> Poll<io::Result<()>> {
110 let this = self.as_mut().project();
111 let size = buf.filled().len();
112 let res: Poll<Result<(), io::Error>> = this.stream.poll_read(cx, buf);
113 if let Poll::Ready(Ok(_)) = res {
114 let new_size = buf.filled().len();
115 match new_size.cmp(&size) {
116 std::cmp::Ordering::Greater => {
117 let bytes_read = new_size - size;
118 this.read.fetch_add(bytes_read, Ordering::AcqRel);
119 }
120 std::cmp::Ordering::Less => {
121 tracing::error!(
122 "BytesRWTracker: poll_read returned Ok(()) with filled buffer smaller then before"
123 );
124 }
125 std::cmp::Ordering::Equal => {
126 tracing::trace!("BytesRWTracker: poll_read returned Ok(()) with nothing read");
127 }
128 }
129 }
130 res
131 }
132}
133
134impl<S> AsyncWrite for BytesRWTracker<S>
135where
136 S: AsyncWrite,
137{
138 fn poll_write(
139 mut self: Pin<&mut Self>,
140 cx: &mut Context<'_>,
141 buf: &[u8],
142 ) -> Poll<Result<usize, io::Error>> {
143 let this = self.as_mut().project();
144 let res: Poll<Result<usize, io::Error>> = this.stream.poll_write(cx, buf);
145 if let Poll::Ready(Ok(bytes_written)) = res {
146 this.written.fetch_add(bytes_written, Ordering::AcqRel);
147 }
148 res
149 }
150
151 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
152 self.project().stream.poll_flush(cx)
153 }
154
155 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
156 self.project().stream.poll_shutdown(cx)
157 }
158
159 fn poll_write_vectored(
160 mut self: Pin<&mut Self>,
161 cx: &mut Context<'_>,
162 bufs: &[io::IoSlice<'_>],
163 ) -> Poll<Result<usize, io::Error>> {
164 let this = self.as_mut().project();
165 let res: Poll<Result<usize, io::Error>> = this.stream.poll_write_vectored(cx, bufs);
166 if let Poll::Ready(Ok(bytes_written)) = res {
167 this.written.fetch_add(bytes_written, Ordering::AcqRel);
168 }
169 res
170 }
171
172 fn is_write_vectored(&self) -> bool {
173 self.stream.is_write_vectored()
174 }
175}
176
177#[derive(Debug, Clone)]
181pub struct BytesRWTrackerHandle {
182 read: Arc<AtomicUsize>,
183 written: Arc<AtomicUsize>,
184}
185
186impl BytesRWTrackerHandle {
187 pub fn read(&self) -> usize {
189 self.read.load(Ordering::Acquire)
190 }
191
192 pub fn written(&self) -> usize {
194 self.written.load(Ordering::Acquire)
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 use tokio::io::{AsyncReadExt, AsyncWriteExt};
203 use tokio_test::io::Builder;
204
205 #[tokio::test]
206 async fn test_read_tracker() {
207 let stream = Builder::new()
208 .read(b"foo")
209 .read(b"bar")
210 .read(b"baz")
211 .build();
212
213 let mut tracker = BytesRWTracker::new(stream);
214 let mut buf = [0u8; 3];
215
216 assert_eq!(tracker.read(), 0);
217 assert_eq!(tracker.written(), 0);
218 tracker.read_exact(&mut buf).await.unwrap();
219 assert_eq!(tracker.read(), 3);
220 assert_eq!(tracker.written(), 0);
221 tracker.read_exact(&mut buf).await.unwrap();
222 assert_eq!(tracker.read(), 6);
223 assert_eq!(tracker.written(), 0);
224 tracker.read_exact(&mut buf).await.unwrap();
225 assert_eq!(tracker.read(), 9);
226 assert_eq!(tracker.written(), 0);
227 }
228
229 #[tokio::test]
230 async fn test_written_tracker() {
231 let stream = Builder::new()
232 .write(b"foo")
233 .write(b"bar")
234 .write(b"baz")
235 .build();
236
237 let mut tracker = BytesRWTracker::new(stream);
238
239 assert_eq!(tracker.read(), 0);
240 assert_eq!(tracker.written(), 0);
241 tracker.write_all(b"foo").await.unwrap();
242 assert_eq!(tracker.read(), 0);
243 assert_eq!(tracker.written(), 3);
244 tracker.write_all(b"bar").await.unwrap();
245 assert_eq!(tracker.read(), 0);
246 assert_eq!(tracker.written(), 6);
247 tracker.write_all(b"baz").await.unwrap();
248 assert_eq!(tracker.read(), 0);
249 assert_eq!(tracker.written(), 9);
250 }
251
252 #[tokio::test]
253 async fn test_rw_tracker() {
254 let stream = Builder::new()
255 .read(b"foo")
256 .write(b"foo")
257 .read(b"bar")
258 .write(b"bar")
259 .read(b"baz")
260 .write(b"baz")
261 .build();
262
263 let mut tracker = BytesRWTracker::new(stream);
264 let mut buf = [0u8; 3];
265
266 assert_eq!(tracker.read(), 0);
267 assert_eq!(tracker.written(), 0);
268 tracker.read_exact(&mut buf).await.unwrap();
269 assert_eq!(tracker.read(), 3);
270 assert_eq!(tracker.written(), 0);
271 tracker.write_all(b"foo").await.unwrap();
272 assert_eq!(tracker.read(), 3);
273 assert_eq!(tracker.written(), 3);
274 tracker.read_exact(&mut buf).await.unwrap();
275 assert_eq!(tracker.read(), 6);
276 assert_eq!(tracker.written(), 3);
277 tracker.write_all(b"bar").await.unwrap();
278 assert_eq!(tracker.read(), 6);
279 assert_eq!(tracker.written(), 6);
280 tracker.read_exact(&mut buf).await.unwrap();
281 assert_eq!(tracker.read(), 9);
282 assert_eq!(tracker.written(), 6);
283 tracker.write_all(b"baz").await.unwrap();
284 assert_eq!(tracker.read(), 9);
285 assert_eq!(tracker.written(), 9);
286 }
287
288 #[tokio::test]
289 async fn test_rw_handle_tracker() {
290 let stream = Builder::new()
291 .read(b"foo")
292 .write(b"foo")
293 .read(b"bar")
294 .write(b"bar")
295 .read(b"baz")
296 .write(b"baz")
297 .build();
298
299 let tracker = BytesRWTracker::new(stream);
300 let handle = tracker.handle();
301
302 assert_eq!(handle.read(), 0);
303 assert_eq!(handle.written(), 0);
304
305 let (action_tx, mut action_rx) = tokio::sync::mpsc::channel(1);
306 let (check_tx, mut check_rx) = tokio::sync::broadcast::channel(1);
307 let check_rx_2 = check_tx.subscribe();
308
309 let task_1 = tokio::spawn(async move {
310 let mut tracker = tracker;
311 let mut buf = [0u8; 3];
312
313 action_rx.recv().await;
314 tracker.read_exact(&mut buf).await.unwrap();
315 check_tx.send(()).unwrap();
316
317 action_rx.recv().await;
318 tracker.write_all(b"foo").await.unwrap();
319 check_tx.send(()).unwrap();
320
321 action_rx.recv().await;
322 tracker.read_exact(&mut buf).await.unwrap();
323 check_tx.send(()).unwrap();
324
325 action_rx.recv().await;
326 tracker.write_all(b"bar").await.unwrap();
327 check_tx.send(()).unwrap();
328
329 action_rx.recv().await;
330 tracker.read_exact(&mut buf).await.unwrap();
331 check_tx.send(()).unwrap();
332
333 action_rx.recv().await;
334 tracker.write_all(b"baz").await.unwrap();
335 check_tx.send(()).unwrap();
336 });
337
338 let task_2 = {
339 let handle = handle.clone();
340 let mut check_rx = check_rx_2;
341 tokio::spawn(async move {
342 check_rx.recv().await.unwrap();
343
344 assert_eq!(handle.read(), 3);
345 assert_eq!(handle.written(), 0);
346
347 check_rx.recv().await.unwrap();
348
349 assert_eq!(handle.read(), 3);
350 assert_eq!(handle.written(), 3);
351
352 check_rx.recv().await.unwrap();
353
354 assert_eq!(handle.read(), 6);
355 assert_eq!(handle.written(), 3);
356
357 check_rx.recv().await.unwrap();
358
359 assert_eq!(handle.read(), 6);
360 assert_eq!(handle.written(), 6);
361
362 check_rx.recv().await.unwrap();
363
364 assert_eq!(handle.read(), 9);
365 assert_eq!(handle.written(), 6);
366
367 check_rx.recv().await.unwrap();
368
369 assert_eq!(handle.read(), 9);
370 assert_eq!(handle.written(), 9)
371 })
372 };
373
374 assert_eq!(handle.read(), 0);
375 assert_eq!(handle.written(), 0);
376
377 action_tx.send(()).await.unwrap();
378 check_rx.recv().await.unwrap();
379
380 assert_eq!(handle.read(), 3);
381 assert_eq!(handle.written(), 0);
382
383 action_tx.send(()).await.unwrap();
384 check_rx.recv().await.unwrap();
385
386 assert_eq!(handle.read(), 3);
387 assert_eq!(handle.written(), 3);
388
389 action_tx.send(()).await.unwrap();
390 check_rx.recv().await.unwrap();
391
392 assert_eq!(handle.read(), 6);
393 assert_eq!(handle.written(), 3);
394
395 action_tx.send(()).await.unwrap();
396 check_rx.recv().await.unwrap();
397
398 assert_eq!(handle.read(), 6);
399 assert_eq!(handle.written(), 6);
400
401 action_tx.send(()).await.unwrap();
402 check_rx.recv().await.unwrap();
403
404 assert_eq!(handle.read(), 9);
405 assert_eq!(handle.written(), 6);
406
407 action_tx.send(()).await.unwrap();
408 check_rx.recv().await.unwrap();
409
410 assert_eq!(handle.read(), 9);
411 assert_eq!(handle.written(), 9);
412
413 let (t1, t2) = futures_lite::future::zip(task_1, task_2).await;
414 t1.unwrap();
415 t2.unwrap();
416 }
417}