commonware_runtime/
mocks.rs1use crate::{Error, Sink as SinkTrait, Stream as StreamTrait};
4use bytes::Bytes;
5use commonware_utils::StableBuf;
6use futures::channel::oneshot;
7use std::{
8 collections::VecDeque,
9 sync::{Arc, Mutex},
10};
11
12pub struct Channel {
14 buffer: VecDeque<u8>,
16
17 waiter: Option<(usize, oneshot::Sender<Bytes>)>,
21
22 sink_alive: bool,
24
25 stream_alive: bool,
27}
28
29impl Channel {
30 pub fn init() -> (Sink, Stream) {
32 let channel = Arc::new(Mutex::new(Channel {
33 buffer: VecDeque::new(),
34 waiter: None,
35 sink_alive: true,
36 stream_alive: true,
37 }));
38 (
39 Sink {
40 channel: channel.clone(),
41 },
42 Stream { channel },
43 )
44 }
45}
46
47pub struct Sink {
49 channel: Arc<Mutex<Channel>>,
50}
51
52impl SinkTrait for Sink {
53 async fn send(&mut self, msg: impl Into<StableBuf> + Send) -> Result<(), Error> {
54 let msg = msg.into();
55 let (os_send, data) = {
56 let mut channel = self.channel.lock().unwrap();
57
58 if !channel.stream_alive {
60 return Err(Error::Closed);
61 }
62
63 channel.buffer.extend(msg.as_ref());
65
66 if channel
70 .waiter
71 .as_ref()
72 .is_some_and(|(requested, _)| *requested <= channel.buffer.len())
73 {
74 let (requested, os_send) = channel.waiter.take().unwrap();
75 let data: Vec<u8> = channel.buffer.drain(0..requested).collect();
76 (os_send, Bytes::from(data))
77 } else {
78 return Ok(());
79 }
80 };
81
82 os_send.send(data).map_err(|_| Error::SendFailed)?;
84 Ok(())
85 }
86}
87
88impl Drop for Sink {
89 fn drop(&mut self) {
90 let mut channel = self.channel.lock().unwrap();
91 channel.sink_alive = false;
92
93 channel.waiter.take();
95 }
96}
97
98pub struct Stream {
100 channel: Arc<Mutex<Channel>>,
101}
102
103impl StreamTrait for Stream {
104 async fn recv(&mut self, buf: impl Into<StableBuf> + Send) -> Result<StableBuf, Error> {
105 let mut buf = buf.into();
106 let os_recv = {
107 let mut channel = self.channel.lock().unwrap();
108
109 if channel.buffer.len() >= buf.len() {
112 let b: Vec<u8> = channel.buffer.drain(0..buf.len()).collect();
113 buf.put_slice(&b);
114 return Ok(buf);
115 }
116
117 if !channel.sink_alive {
120 return Err(Error::Closed);
121 }
122
123 assert!(channel.waiter.is_none());
125 let (os_send, os_recv) = oneshot::channel();
126 channel.waiter = Some((buf.len(), os_send));
127 os_recv
128 };
129
130 let data = os_recv.await.map_err(|_| Error::Closed)?;
133 assert_eq!(data.len(), buf.len());
134 buf.put_slice(&data);
135 Ok(buf)
136 }
137}
138
139impl Drop for Stream {
140 fn drop(&mut self) {
141 let mut channel = self.channel.lock().unwrap();
142 channel.stream_alive = false;
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use crate::{deterministic, Clock, Runner, Spawner};
150 use commonware_macros::select;
151 use std::{thread::sleep, time::Duration};
152
153 #[test]
154 fn test_send_recv() {
155 let (mut sink, mut stream) = Channel::init();
156 let data = b"hello world".to_vec();
157
158 let executor = deterministic::Runner::default();
159 executor.start(|_| async move {
160 sink.send(data.clone()).await.unwrap();
161 let buf = stream.recv(vec![0; data.len()]).await.unwrap();
162 assert_eq!(buf.as_ref(), data);
163 });
164 }
165
166 #[test]
167 fn test_send_recv_partial_multiple() {
168 let (mut sink, mut stream) = Channel::init();
169 let data = b"hello".to_vec();
170 let data2 = b" world".to_vec();
171
172 let executor = deterministic::Runner::default();
173 executor.start(|_| async move {
174 sink.send(data).await.unwrap();
175 sink.send(data2).await.unwrap();
176 let buf = stream.recv(vec![0; 5]).await.unwrap();
177 assert_eq!(buf.as_ref(), b"hello");
178 let buf = stream.recv(buf).await.unwrap();
179 assert_eq!(buf.as_ref(), b" worl");
180 let buf = stream.recv(vec![0; 1]).await.unwrap();
181 assert_eq!(buf.as_ref(), b"d");
182 });
183 }
184
185 #[test]
186 fn test_send_recv_async() {
187 let (mut sink, mut stream) = Channel::init();
188 let data = b"hello world";
189
190 let executor = deterministic::Runner::default();
191 executor.start(|_| async move {
192 let (buf, _) = futures::try_join!(stream.recv(vec![0; data.len()]), async {
193 sleep(Duration::from_millis(50));
194 sink.send(data.to_vec()).await
195 })
196 .unwrap();
197 assert_eq!(buf.as_ref(), data);
198 });
199 }
200
201 #[test]
202 fn test_recv_error_sink_dropped_while_waiting() {
203 let (sink, mut stream) = Channel::init();
204
205 let executor = deterministic::Runner::default();
206 executor.start(|context| async move {
207 futures::join!(
208 async {
209 let result = stream.recv(vec![0; 5]).await;
210 assert!(matches!(result, Err(Error::Closed)));
211 },
212 async {
213 context.sleep(Duration::from_millis(50)).await;
215 drop(sink);
216 }
217 );
218 });
219 }
220
221 #[test]
222 fn test_recv_error_sink_dropped_before_recv() {
223 let (sink, mut stream) = Channel::init();
224 drop(sink); let executor = deterministic::Runner::default();
227 executor.start(|_| async move {
228 let result = stream.recv(vec![0; 5]).await;
229 assert!(matches!(result, Err(Error::Closed)));
230 });
231 }
232
233 #[test]
234 fn test_send_error_stream_dropped() {
235 let (mut sink, mut stream) = Channel::init();
236
237 let executor = deterministic::Runner::default();
238 executor.start(|context| async move {
239 assert!(sink.send(b"7 bytes".to_vec()).await.is_ok());
241
242 let handle = context.clone().spawn(|_| async move {
244 let _ = stream.recv(vec![0; 5]).await;
245 let _ = stream.recv(vec![0; 5]).await;
246 });
247
248 context.sleep(Duration::from_millis(50)).await;
250
251 handle.abort();
253 assert!(matches!(handle.await, Err(Error::Closed)));
254
255 let result = sink.send(b"hello world".to_vec()).await;
257 assert!(matches!(result, Err(Error::Closed)));
258 });
259 }
260
261 #[test]
262 fn test_send_error_stream_dropped_before_send() {
263 let (mut sink, stream) = Channel::init();
264 drop(stream); let executor = deterministic::Runner::default();
267 executor.start(|_| async move {
268 let result = sink.send(b"hello world".to_vec()).await;
269 assert!(matches!(result, Err(Error::Closed)));
270 });
271 }
272
273 #[test]
274 fn test_recv_timeout() {
275 let (_sink, mut stream) = Channel::init();
276
277 let executor = deterministic::Runner::default();
280 executor.start(|context| async move {
281 select! {
282 v = stream.recv(vec![0;5]) => {
283 panic!("unexpected value: {v:?}");
284 },
285 _ = context.sleep(Duration::from_millis(100)) => {
286 "timeout"
287 },
288 };
289 });
290 }
291}