commonware_runtime/
mocks.rs1use crate::{Error, Sink as SinkTrait, Stream as StreamTrait};
4use bytes::{Buf, BufMut, Bytes, BytesMut};
5use futures::channel::oneshot;
6use std::sync::{Arc, Mutex};
7
8pub struct Channel {
10 buffer: BytesMut,
12
13 waiter: Option<(usize, oneshot::Sender<Bytes>)>,
17
18 sink_alive: bool,
20
21 stream_alive: bool,
23}
24
25impl Channel {
26 pub fn init() -> (Sink, Stream) {
28 let channel = Arc::new(Mutex::new(Self {
29 buffer: BytesMut::new(),
30 waiter: None,
31 sink_alive: true,
32 stream_alive: true,
33 }));
34 (
35 Sink {
36 channel: channel.clone(),
37 },
38 Stream { channel },
39 )
40 }
41}
42
43pub struct Sink {
45 channel: Arc<Mutex<Channel>>,
46}
47
48impl SinkTrait for Sink {
49 async fn send(&mut self, buf: impl Buf + Send) -> Result<(), Error> {
50 let (os_send, data) = {
51 let mut channel = self.channel.lock().unwrap();
52
53 if !channel.stream_alive {
55 return Err(Error::Closed);
56 }
57
58 channel.buffer.put(buf);
59
60 if channel
64 .waiter
65 .as_ref()
66 .is_some_and(|(requested, _)| *requested <= channel.buffer.len())
67 {
68 let (requested, os_send) = channel.waiter.take().unwrap();
69 let data = channel.buffer.copy_to_bytes(requested);
70 (os_send, data)
71 } else {
72 return Ok(());
73 }
74 };
75
76 os_send.send(data).map_err(|_| Error::SendFailed)?;
78 Ok(())
79 }
80}
81
82impl Drop for Sink {
83 fn drop(&mut self) {
84 let mut channel = self.channel.lock().unwrap();
85 channel.sink_alive = false;
86
87 channel.waiter.take();
89 }
90}
91
92pub struct Stream {
94 channel: Arc<Mutex<Channel>>,
95}
96
97impl StreamTrait for Stream {
98 async fn recv(&mut self, mut buf: impl BufMut + Send) -> Result<(), Error> {
99 let os_recv = {
100 let mut channel = self.channel.lock().unwrap();
101
102 if channel.buffer.len() >= buf.remaining_mut() {
105 let b = channel.buffer.copy_to_bytes(buf.remaining_mut());
106 buf.put_slice(&b);
107 return Ok(());
108 }
109
110 if !channel.sink_alive {
113 return Err(Error::Closed);
114 }
115
116 assert!(channel.waiter.is_none());
118 let (os_send, os_recv) = oneshot::channel();
119 channel.waiter = Some((buf.remaining_mut(), os_send));
120 os_recv
121 };
122
123 let data = os_recv.await.map_err(|_| Error::Closed)?;
126 assert_eq!(data.len(), buf.remaining_mut());
127 buf.put_slice(&data);
128 Ok(())
129 }
130}
131
132impl Drop for Stream {
133 fn drop(&mut self) {
134 let mut channel = self.channel.lock().unwrap();
135 channel.stream_alive = false;
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use crate::{deterministic, Clock, Runner, Spawner};
143 use commonware_macros::select;
144 use std::{thread::sleep, time::Duration};
145
146 #[test]
147 fn test_send_recv() {
148 let (mut sink, mut stream) = Channel::init();
149 let data = b"hello world";
150
151 let executor = deterministic::Runner::default();
152 executor.start(|_| async move {
153 sink.send(data.as_slice()).await.unwrap();
154 let mut buf = vec![0u8; data.len()];
155 stream.recv(&mut buf[..]).await.unwrap();
156 assert_eq!(&buf[..], data);
157 });
158 }
159
160 #[test]
161 fn test_send_recv_partial_multiple() {
162 let (mut sink, mut stream) = Channel::init();
163 let data = b"hello";
164 let data2 = b" world";
165
166 let executor = deterministic::Runner::default();
167 executor.start(|_| async move {
168 sink.send(data.as_slice()).await.unwrap();
169 sink.send(data2.as_slice()).await.unwrap();
170 let mut buf = [0u8; 5];
171 stream.recv(&mut buf[..]).await.unwrap();
172 assert_eq!(&buf[..], b"hello");
173 stream.recv(&mut buf[..]).await.unwrap();
174 assert_eq!(&buf[..], b" worl");
175 let mut buf = [0u8; 1];
176 stream.recv(&mut buf[..]).await.unwrap();
177 assert_eq!(&buf[..], b"d");
178 });
179 }
180
181 #[test]
182 fn test_send_recv_async() {
183 let (mut sink, mut stream) = Channel::init();
184 let data = b"hello world";
185
186 let executor = deterministic::Runner::default();
187 executor.start(|_| async move {
188 let mut buf = vec![0; data.len()];
189 let (_, _) = futures::try_join!(stream.recv(&mut buf[..]), async {
190 sleep(Duration::from_millis(50));
191 sink.send(data.as_slice()).await
192 })
193 .unwrap();
194 assert_eq!(&buf[..], data);
195 });
196 }
197
198 #[test]
199 fn test_recv_error_sink_dropped_while_waiting() {
200 let (sink, mut stream) = Channel::init();
201
202 let executor = deterministic::Runner::default();
203 executor.start(|context| async move {
204 futures::join!(
205 async {
206 let mut buf = [0u8; 5];
207 let result = stream.recv(&mut buf[..]).await;
208 assert!(matches!(result, Err(Error::Closed)));
209 },
210 async {
211 context.sleep(Duration::from_millis(50)).await;
213 drop(sink);
214 }
215 );
216 });
217 }
218
219 #[test]
220 fn test_recv_error_sink_dropped_before_recv() {
221 let (sink, mut stream) = Channel::init();
222 drop(sink); let executor = deterministic::Runner::default();
225 executor.start(|_| async move {
226 let mut buf = [0u8; 5];
227 let result = stream.recv(&mut buf[..]).await;
228 assert!(matches!(result, Err(Error::Closed)));
229 });
230 }
231
232 #[test]
233 fn test_send_error_stream_dropped() {
234 let (mut sink, mut stream) = Channel::init();
235
236 let executor = deterministic::Runner::default();
237 executor.start(|context| async move {
238 assert!(sink.send(b"7 bytes".as_slice()).await.is_ok());
240
241 let handle = context.clone().spawn(|_| async move {
243 let mut buf = [0u8; 5];
244 let _ = stream.recv(&mut buf[..]).await;
245 let _ = stream.recv(&mut buf[..]).await;
246 });
247
248 context.sleep(Duration::from_millis(50)).await;
250
251 handle.abort();
253 assert!(matches!(handle.await, Err(Error::Closed)));
254
255 let result = sink.send(b"hello world".as_slice()).await;
257 assert!(matches!(result, Err(Error::Closed)));
258 });
259 }
260
261 #[test]
262 fn test_send_error_stream_dropped_before_send() {
263 let (mut sink, stream) = Channel::init();
264 drop(stream); let executor = deterministic::Runner::default();
267 executor.start(|_| async move {
268 let result = sink.send(b"hello world".as_slice()).await;
269 assert!(matches!(result, Err(Error::Closed)));
270 });
271 }
272
273 #[test]
274 fn test_recv_timeout() {
275 let (_sink, mut stream) = Channel::init();
276
277 let executor = deterministic::Runner::default();
280 executor.start(|context| async move {
281 let mut buf = [0u8; 5];
282 select! {
283 v = stream.recv(&mut buf[..]) => {
284 panic!("unexpected value: {v:?}");
285 },
286 _ = context.sleep(Duration::from_millis(100)) => {
287 "timeout"
288 },
289 };
290 });
291 }
292}