1use tokio::sync::mpsc::{channel, Receiver, Sender};
11
12const DEFAULT_CAPACITY: usize = 16;
14
15#[derive(Clone)]
19pub struct StreamWriter {
20 tx: Sender<Vec<u8>>,
21}
22
23pub struct StreamHandle {
27 rx: Receiver<Vec<u8>>,
28}
29
30pub fn stream() -> (StreamWriter, StreamHandle) {
32 stream_with_capacity(DEFAULT_CAPACITY)
33}
34
35pub fn stream_with_capacity(capacity: usize) -> (StreamWriter, StreamHandle) {
37 let (tx, rx) = channel(capacity.max(1));
38 (StreamWriter { tx }, StreamHandle { rx })
39}
40
41impl StreamWriter {
42 pub async fn write(&self, chunk: Vec<u8>) -> Result<(), Vec<u8>> {
45 self.tx.send(chunk).await.map_err(|err| err.0)
46 }
47}
48
49impl StreamHandle {
50 pub async fn read(&mut self) -> Option<Vec<u8>> {
53 self.rx.recv().await
54 }
55}
56
57impl std::fmt::Debug for StreamWriter {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("StreamWriter").finish_non_exhaustive()
60 }
61}
62
63impl std::fmt::Debug for StreamHandle {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 f.debug_struct("StreamHandle").finish_non_exhaustive()
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72
73 #[tokio::test]
74 async fn writer_applies_backpressure_until_the_reader_drains() {
75 let (w, mut r) = stream_with_capacity(1);
76 w.write(b"1".to_vec()).await.unwrap(); let blocked =
80 tokio::time::timeout(std::time::Duration::from_millis(20), w.write(b"2".to_vec()))
81 .await;
82 assert!(
83 blocked.is_err(),
84 "write must block while the buffer is full"
85 );
86
87 assert_eq!(r.read().await, Some(b"1".to_vec())); w.write(b"2".to_vec()).await.unwrap();
89 assert_eq!(r.read().await, Some(b"2".to_vec()));
90 }
91
92 #[tokio::test]
93 async fn read_returns_none_after_the_writer_drops() {
94 let (w, mut r) = stream();
95 w.write(b"x".to_vec()).await.unwrap();
96 drop(w);
97 assert_eq!(r.read().await, Some(b"x".to_vec()));
98 assert_eq!(r.read().await, None); }
100
101 #[tokio::test]
102 async fn write_fails_once_the_reader_is_gone() {
103 let (w, r) = stream();
104 drop(r);
105 assert_eq!(w.write(b"x".to_vec()).await, Err(b"x".to_vec()));
106 }
107
108 #[tokio::test]
109 async fn a_cloned_writer_fans_into_the_same_stream() {
110 use std::collections::HashSet;
111 let (w, mut r) = stream();
112 let w2 = w.clone();
113 w.write(b"1".to_vec()).await.unwrap();
114 w2.write(b"2".to_vec()).await.unwrap();
115 let got: HashSet<Vec<u8>> = [r.read().await.unwrap(), r.read().await.unwrap()]
116 .into_iter()
117 .collect();
118 assert_eq!(got, HashSet::from([b"1".to_vec(), b"2".to_vec()]));
119 drop(w);
121 drop(w2);
122 assert_eq!(r.read().await, None);
123 }
124
125 #[test]
126 fn handles_format_for_debug() {
127 let (w, r) = stream();
128 assert!(format!("{w:?}").contains("StreamWriter"));
129 assert!(format!("{r:?}").contains("StreamHandle"));
130 }
131}