1use crate::{BufMut, Error, IoBufs, Sink as SinkTrait, Stream as StreamTrait};
4use bytes::{Bytes, BytesMut};
5use commonware_utils::{channel::oneshot, sync::Mutex};
6use std::sync::Arc;
7
8const DEFAULT_READ_BUFFER_SIZE: usize = 64 * 1024;
10
11pub struct Channel {
13 buffer: BytesMut,
15
16 waiter: Option<(usize, oneshot::Sender<Bytes>)>,
20
21 read_buffer_size: usize,
23
24 sink_alive: bool,
26
27 stream_alive: bool,
29}
30
31impl Channel {
32 pub fn init() -> (Sink, Stream) {
34 Self::init_with_read_buffer_size(DEFAULT_READ_BUFFER_SIZE)
35 }
36
37 pub fn init_with_read_buffer_size(read_buffer_size: usize) -> (Sink, Stream) {
39 let channel = Arc::new(Mutex::new(Self {
40 buffer: BytesMut::new(),
41 waiter: None,
42 read_buffer_size,
43 sink_alive: true,
44 stream_alive: true,
45 }));
46 (
47 Sink {
48 channel: channel.clone(),
49 },
50 Stream {
51 channel,
52 buffer: BytesMut::new(),
53 },
54 )
55 }
56}
57
58pub struct Sink {
60 channel: Arc<Mutex<Channel>>,
61}
62
63impl SinkTrait for Sink {
64 async fn send(&mut self, bufs: impl Into<IoBufs> + Send) -> Result<(), Error> {
65 let (os_send, data) = {
66 let mut channel = self.channel.lock();
67
68 if !channel.stream_alive {
70 return Err(Error::Closed);
71 }
72
73 channel.buffer.put(bufs.into());
74
75 if channel
79 .waiter
80 .as_ref()
81 .is_some_and(|(requested, _)| *requested <= channel.buffer.len())
82 {
83 let (requested, os_send) = channel.waiter.take().unwrap();
85 let send_amount = channel
86 .buffer
87 .len()
88 .min(requested.max(channel.read_buffer_size));
89 let data = channel.buffer.split_to(send_amount).freeze();
90 (os_send, data)
91 } else {
92 return Ok(());
93 }
94 };
95
96 os_send.send(data).map_err(|_| Error::SendFailed)?;
98 Ok(())
99 }
100}
101
102impl Drop for Sink {
103 fn drop(&mut self) {
104 let mut channel = self.channel.lock();
105 channel.sink_alive = false;
106
107 channel.waiter.take();
109 }
110}
111
112pub struct Stream {
114 channel: Arc<Mutex<Channel>>,
115 buffer: BytesMut,
117}
118
119impl StreamTrait for Stream {
120 async fn recv(&mut self, len: usize) -> Result<IoBufs, Error> {
121 let os_recv = {
122 let mut channel = self.channel.lock();
123
124 if !channel.buffer.is_empty() {
126 let target = len.max(channel.read_buffer_size);
127 let pull_amount = channel
128 .buffer
129 .len()
130 .min(target.saturating_sub(self.buffer.len()));
131 if pull_amount > 0 {
132 let data = channel.buffer.split_to(pull_amount);
133 self.buffer.extend_from_slice(&data);
134 }
135 }
136
137 if self.buffer.len() >= len {
139 return Ok(IoBufs::from(self.buffer.split_to(len).freeze()));
140 }
141
142 if !channel.sink_alive {
144 return Err(Error::Closed);
145 }
146
147 let remaining = len - self.buffer.len();
149 assert!(channel.waiter.is_none());
150 let (os_send, os_recv) = oneshot::channel();
151 channel.waiter = Some((remaining, os_send));
152 os_recv
153 };
154
155 let data = os_recv.await.map_err(|_| Error::Closed)?;
157 self.buffer.extend_from_slice(&data);
158
159 assert!(self.buffer.len() >= len);
160 Ok(IoBufs::from(self.buffer.split_to(len).freeze()))
161 }
162
163 fn peek(&self, max_len: usize) -> &[u8] {
164 let len = max_len.min(self.buffer.len());
165 &self.buffer[..len]
166 }
167}
168
169impl Drop for Stream {
170 fn drop(&mut self) {
171 let mut channel = self.channel.lock();
172 channel.stream_alive = false;
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use crate::{deterministic, Clock, Runner, Spawner};
180 use commonware_macros::select;
181 use std::{thread::sleep, time::Duration};
182
183 #[test]
184 fn test_send_recv() {
185 let (mut sink, mut stream) = Channel::init();
186 let data = b"hello world";
187
188 let executor = deterministic::Runner::default();
189 executor.start(|_| async move {
190 sink.send(data.as_slice()).await.unwrap();
191 let received = stream.recv(data.len()).await.unwrap();
192 assert_eq!(received.coalesce(), data);
193 });
194 }
195
196 #[test]
197 fn test_send_recv_partial_multiple() {
198 let (mut sink, mut stream) = Channel::init();
199 let data = b"hello";
200 let data2 = b" world";
201
202 let executor = deterministic::Runner::default();
203 executor.start(|_| async move {
204 sink.send(data.as_slice()).await.unwrap();
205 sink.send(data2.as_slice()).await.unwrap();
206 let received = stream.recv(5).await.unwrap();
207 assert_eq!(received.coalesce(), b"hello");
208 let received = stream.recv(5).await.unwrap();
209 assert_eq!(received.coalesce(), b" worl");
210 let received = stream.recv(1).await.unwrap();
211 assert_eq!(received.coalesce(), b"d");
212 });
213 }
214
215 #[test]
216 fn test_send_recv_async() {
217 let (mut sink, mut stream) = Channel::init();
218 let data = b"hello world";
219
220 let executor = deterministic::Runner::default();
221 executor.start(|_| async move {
222 let (received, _) = futures::try_join!(stream.recv(data.len()), async {
223 sleep(Duration::from_millis(50));
224 sink.send(data.as_slice()).await
225 })
226 .unwrap();
227 assert_eq!(received.coalesce(), data);
228 });
229 }
230
231 #[test]
232 fn test_recv_error_sink_dropped_while_waiting() {
233 let (sink, mut stream) = Channel::init();
234
235 let executor = deterministic::Runner::default();
236 executor.start(|context| async move {
237 futures::join!(
238 async {
239 let result = stream.recv(5).await;
240 assert!(matches!(result, Err(Error::Closed)));
241 },
242 async {
243 context.sleep(Duration::from_millis(50)).await;
245 drop(sink);
246 }
247 );
248 });
249 }
250
251 #[test]
252 fn test_recv_error_sink_dropped_before_recv() {
253 let (sink, mut stream) = Channel::init();
254 drop(sink); let executor = deterministic::Runner::default();
257 executor.start(|_| async move {
258 let result = stream.recv(5).await;
259 assert!(matches!(result, Err(Error::Closed)));
260 });
261 }
262
263 #[test]
264 fn test_send_error_stream_dropped() {
265 let (mut sink, mut stream) = Channel::init();
266
267 let executor = deterministic::Runner::default();
268 executor.start(|context| async move {
269 assert!(sink.send(b"7 bytes".as_slice()).await.is_ok());
271
272 let handle = context.clone().spawn(|_| async move {
274 let _ = stream.recv(5).await;
275 let _ = stream.recv(5).await;
276 });
277
278 context.sleep(Duration::from_millis(50)).await;
280
281 handle.abort();
283 assert!(matches!(handle.await, Err(Error::Closed)));
284
285 let result = sink.send(b"hello world".as_slice()).await;
287 assert!(matches!(result, Err(Error::Closed)));
288 });
289 }
290
291 #[test]
292 fn test_send_error_stream_dropped_before_send() {
293 let (mut sink, stream) = Channel::init();
294 drop(stream); let executor = deterministic::Runner::default();
297 executor.start(|_| async move {
298 let result = sink.send(b"hello world".as_slice()).await;
299 assert!(matches!(result, Err(Error::Closed)));
300 });
301 }
302
303 #[test]
304 fn test_recv_timeout() {
305 let (_sink, mut stream) = Channel::init();
306
307 let executor = deterministic::Runner::default();
310 executor.start(|context| async move {
311 select! {
312 v = stream.recv(5) => {
313 panic!("unexpected value: {v:?}");
314 },
315 _ = context.sleep(Duration::from_millis(100)) => "timeout",
316 };
317 });
318 }
319
320 #[test]
321 fn test_peek_empty() {
322 let (_sink, stream) = Channel::init();
323
324 assert!(stream.peek(10).is_empty());
326 }
327
328 #[test]
329 fn test_peek_after_partial_recv() {
330 let (mut sink, mut stream) = Channel::init();
331
332 let executor = deterministic::Runner::default();
333 executor.start(|_| async move {
334 sink.send(b"hello world".as_slice()).await.unwrap();
336
337 let received = stream.recv(5).await.unwrap();
339 assert_eq!(received.coalesce(), b"hello");
340
341 assert_eq!(stream.peek(100), b" world");
343
344 assert_eq!(stream.peek(3), b" wo");
346
347 assert_eq!(stream.peek(100), b" world");
349
350 let received = stream.recv(6).await.unwrap();
352 assert_eq!(received.coalesce(), b" world");
353
354 assert!(stream.peek(100).is_empty());
356 });
357 }
358
359 #[test]
360 fn test_peek_after_recv_wakeup() {
361 let (mut sink, mut stream) = Channel::init_with_read_buffer_size(64);
362
363 let executor = deterministic::Runner::default();
364 executor.start(|context| async move {
365 let (tx, rx) = oneshot::channel();
367 let recv_handle = context.clone().spawn(|_| async move {
368 let data = stream.recv(3).await.unwrap();
369 tx.send(stream).ok();
370 data
371 });
372
373 context.sleep(Duration::from_millis(10)).await;
375
376 sink.send(b"ABCDEFGHIJ".as_slice()).await.unwrap();
378
379 let received = recv_handle.await.unwrap();
381 assert_eq!(received.coalesce(), b"ABC");
382
383 let stream = rx.await.unwrap();
385 assert_eq!(stream.peek(100), b"DEFGHIJ");
386 });
387 }
388
389 #[test]
390 fn test_peek_multiple_sends() {
391 let (mut sink, mut stream) = Channel::init();
392
393 let executor = deterministic::Runner::default();
394 executor.start(|_| async move {
395 sink.send(b"aaa".as_slice()).await.unwrap();
397 sink.send(b"bbb".as_slice()).await.unwrap();
398 sink.send(b"ccc".as_slice()).await.unwrap();
399
400 let received = stream.recv(4).await.unwrap();
402 assert_eq!(received.coalesce(), b"aaab");
403
404 assert_eq!(stream.peek(100), b"bbccc");
406 });
407 }
408
409 #[test]
410 fn test_read_buffer_size_limit() {
411 let (mut sink, mut stream) = Channel::init_with_read_buffer_size(10);
413
414 let executor = deterministic::Runner::default();
415 executor.start(|_| async move {
416 sink.send(b"0123456789ABCDEF".as_slice()).await.unwrap();
418
419 let received = stream.recv(2).await.unwrap();
421 assert_eq!(received.coalesce(), b"01");
422
423 assert_eq!(stream.peek(100), b"23456789");
425
426 let received = stream.recv(8).await.unwrap();
429 assert_eq!(received.coalesce(), b"23456789");
430
431 let received = stream.recv(2).await.unwrap();
433 assert_eq!(received.coalesce(), b"AB");
434
435 assert_eq!(stream.peek(100), b"CDEF");
436 });
437 }
438
439 #[test]
440 fn test_recv_before_send() {
441 let (mut sink, mut stream) = Channel::init_with_read_buffer_size(10);
443
444 let executor = deterministic::Runner::default();
445 executor.start(|context| async move {
446 let recv_handle = context
448 .clone()
449 .spawn(|_| async move { stream.recv(3).await.unwrap() });
450
451 context.sleep(Duration::from_millis(10)).await;
453
454 sink.send(b"ABCDEFGHIJKLMNOP".as_slice()).await.unwrap();
456
457 let received = recv_handle.await.unwrap();
459 assert_eq!(received.coalesce(), b"ABC");
460 });
461 }
462}