1use crate::{BufMut, Error, IoBufs};
4use bytes::{Bytes, BytesMut};
5use commonware_utils::{
6 channel::{fallible::OneshotExt, oneshot},
7 sync::Mutex,
8};
9use std::sync::Arc;
10
11const DEFAULT_BUFFER_SIZE: usize = 64 * 1024;
14
15pub struct Channel {
17 buffer: BytesMut,
19
20 waiter: Option<(usize, oneshot::Sender<Bytes>)>,
24
25 buffer_size: usize,
28
29 drain_waiter: Option<oneshot::Sender<()>>,
32
33 sink_alive: bool,
35
36 stream_alive: bool,
38}
39
40impl Channel {
41 pub fn init() -> (Sink, Stream) {
43 Self::init_with_buffer_size(DEFAULT_BUFFER_SIZE)
44 }
45
46 pub fn init_with_buffer_size(buffer_size: usize) -> (Sink, Stream) {
48 let channel = Arc::new(Mutex::new(Self {
49 buffer: BytesMut::new(),
50 waiter: None,
51 buffer_size,
52 drain_waiter: None,
53 sink_alive: true,
54 stream_alive: true,
55 }));
56 (
57 Sink {
58 channel: channel.clone(),
59 state: SinkState::Open,
60 },
61 Stream {
62 channel,
63 buffer: BytesMut::new(),
64 poisoned: false,
65 },
66 )
67 }
68
69 fn restore_front(&mut self, data: Bytes) {
71 if data.is_empty() {
72 return;
73 }
74
75 let mut restored = BytesMut::with_capacity(data.len() + self.buffer.len());
76 restored.extend_from_slice(&data);
77 restored.extend_from_slice(&self.buffer);
78 self.buffer = restored;
79 }
80
81 fn close_sink(&mut self) {
83 self.sink_alive = false;
84
85 self.waiter.take();
87 }
88}
89
90struct RecvWaiterGuard {
91 channel: Arc<Mutex<Channel>>,
92 active: bool,
93}
94
95impl RecvWaiterGuard {
96 const fn new(channel: Arc<Mutex<Channel>>) -> Self {
97 Self {
98 channel,
99 active: true,
100 }
101 }
102
103 const fn disarm(&mut self) {
104 self.active = false;
105 }
106}
107
108impl Drop for RecvWaiterGuard {
109 fn drop(&mut self) {
110 if !self.active {
111 return;
112 }
113
114 self.channel.lock().waiter.take();
115 }
116}
117
118pub struct Sink {
120 channel: Arc<Mutex<Channel>>,
121 state: SinkState,
122}
123
124enum SinkState {
126 Open,
128 Sending,
130 Closed,
132}
133
134impl Sink {
135 fn close(&mut self) {
136 if matches!(self.state, SinkState::Closed) {
137 return;
138 }
139 self.channel.lock().close_sink();
140 self.state = SinkState::Closed;
141 }
142}
143
144impl crate::Sink for Sink {
145 async fn send(&mut self, bufs: impl Into<IoBufs> + Send) -> Result<(), Error> {
146 match self.state {
147 SinkState::Open => {}
148 SinkState::Sending => {
149 self.close();
150 return Err(Error::Closed);
151 }
152 SinkState::Closed => return Err(Error::Closed),
153 }
154
155 let drain_recv = {
156 let mut channel = self.channel.lock();
157
158 if !channel.stream_alive {
160 channel.close_sink();
161 self.state = SinkState::Closed;
162 return Err(Error::SendFailed);
163 }
164
165 channel.buffer.put(bufs.into());
166
167 if channel
170 .waiter
171 .as_ref()
172 .is_some_and(|(requested, _)| *requested <= channel.buffer.len())
173 {
174 let (requested, os_send) = channel.waiter.take().unwrap();
176 let send_amount = channel.buffer.len().min(requested.max(channel.buffer_size));
177 let data = channel.buffer.split_to(send_amount).freeze();
178
179 if let Err(data) = os_send.send(data) {
182 channel.restore_front(data);
183 if !channel.stream_alive {
184 channel.close_sink();
185 self.state = SinkState::Closed;
186 return Err(Error::SendFailed);
187 }
188 }
189 }
190
191 if channel.buffer.len() > channel.buffer_size {
194 assert!(channel.drain_waiter.is_none());
195 let (os_send, os_recv) = oneshot::channel();
196 channel.drain_waiter = Some(os_send);
197 os_recv
198 } else {
199 return Ok(());
200 }
201 };
202
203 self.state = SinkState::Sending;
206
207 match drain_recv.await {
209 Ok(()) => {
210 self.state = SinkState::Open;
211 Ok(())
212 }
213 Err(_) => {
214 self.close();
215 Err(Error::SendFailed)
216 }
217 }
218 }
219}
220
221impl Drop for Sink {
222 fn drop(&mut self) {
223 self.close();
224 }
225}
226
227pub struct Stream {
229 channel: Arc<Mutex<Channel>>,
230 buffer: BytesMut,
232 poisoned: bool,
233}
234
235impl crate::Stream for Stream {
236 async fn recv(&mut self, len: usize) -> Result<IoBufs, Error> {
237 if self.poisoned {
238 return Err(Error::Closed);
239 }
240
241 let os_recv = {
242 let mut channel = self.channel.lock();
243
244 let target = len.max(channel.buffer_size);
246 let pull_amount = channel
247 .buffer
248 .len()
249 .min(target.saturating_sub(self.buffer.len()));
250 if pull_amount > 0 {
251 let data = channel.buffer.split_to(pull_amount);
252 self.buffer.extend_from_slice(&data);
253
254 if channel.buffer.len() <= channel.buffer_size {
256 if let Some(sender) = channel.drain_waiter.take() {
257 sender.send_lossy(());
258 }
259 }
260 }
261
262 if self.buffer.len() >= len {
264 return Ok(IoBufs::from(self.buffer.split_to(len).freeze()));
265 }
266
267 if !channel.sink_alive {
269 self.poisoned = true;
270 return Err(Error::RecvFailed);
271 }
272
273 let remaining = len - self.buffer.len();
275 assert!(channel.waiter.is_none());
276 let (os_send, os_recv) = oneshot::channel();
277 channel.waiter = Some((remaining, os_send));
278 os_recv
279 };
280
281 let mut waiter_guard = RecvWaiterGuard::new(self.channel.clone());
282
283 self.poisoned = true;
285
286 let data = match os_recv.await {
288 Ok(data) => {
289 waiter_guard.disarm();
290 self.poisoned = false;
291 data
292 }
293 Err(_) => {
294 waiter_guard.disarm();
295 return Err(Error::RecvFailed);
296 }
297 };
298 self.buffer.extend_from_slice(&data);
299
300 assert!(self.buffer.len() >= len);
301 Ok(IoBufs::from(self.buffer.split_to(len).freeze()))
302 }
303
304 fn peek(&self, max_len: usize) -> &[u8] {
305 let len = max_len.min(self.buffer.len());
306 &self.buffer[..len]
307 }
308}
309
310impl Drop for Stream {
311 fn drop(&mut self) {
312 let mut channel = self.channel.lock();
313 channel.stream_alive = false;
314
315 channel.drain_waiter.take();
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::{deterministic, Clock, Runner, Sink, Spawner, Stream, Supervisor as _};
324 use commonware_macros::select;
325 use std::{thread::sleep, time::Duration};
326
327 #[test]
328 fn test_send_recv() {
329 let (mut sink, mut stream) = Channel::init();
330 let data = b"hello world";
331
332 let executor = deterministic::Runner::default();
333 executor.start(|_| async move {
334 sink.send(data.as_slice()).await.unwrap();
335 let received = stream.recv(data.len()).await.unwrap();
336 assert_eq!(received.coalesce(), data);
337 });
338 }
339
340 #[test]
341 fn test_send_recv_partial_multiple() {
342 let (mut sink, mut stream) = Channel::init();
343 let data = b"hello";
344 let data2 = b" world";
345
346 let executor = deterministic::Runner::default();
347 executor.start(|_| async move {
348 sink.send(data.as_slice()).await.unwrap();
349 sink.send(data2.as_slice()).await.unwrap();
350 let received = stream.recv(5).await.unwrap();
351 assert_eq!(received.coalesce(), b"hello");
352 let received = stream.recv(5).await.unwrap();
353 assert_eq!(received.coalesce(), b" worl");
354 let received = stream.recv(1).await.unwrap();
355 assert_eq!(received.coalesce(), b"d");
356 });
357 }
358
359 #[test]
360 fn test_send_recv_async() {
361 let (mut sink, mut stream) = Channel::init();
362 let data = b"hello world";
363
364 let executor = deterministic::Runner::default();
365 executor.start(|_| async move {
366 let (received, _) = futures::try_join!(stream.recv(data.len()), async {
367 sleep(Duration::from_millis(50));
368 sink.send(data.as_slice()).await
369 })
370 .unwrap();
371 assert_eq!(received.coalesce(), data);
372 });
373 }
374
375 #[test]
376 fn test_recv_error_sink_dropped_while_waiting() {
377 let (sink, mut stream) = Channel::init();
378
379 let executor = deterministic::Runner::default();
380 executor.start(|context| async move {
381 futures::join!(
382 async {
383 let result = stream.recv(5).await;
384 assert!(matches!(result, Err(Error::RecvFailed)));
385 let result = stream.recv(5).await;
386 assert!(matches!(result, Err(Error::Closed)));
387 },
388 async {
389 context.sleep(Duration::from_millis(50)).await;
391 drop(sink);
392 }
393 );
394 });
395 }
396
397 #[test]
398 fn test_recv_error_sink_dropped_before_recv() {
399 let (sink, mut stream) = Channel::init();
400 drop(sink); let executor = deterministic::Runner::default();
403 executor.start(|_| async move {
404 let result = stream.recv(5).await;
405 assert!(matches!(result, Err(Error::RecvFailed)));
406 let result = stream.recv(5).await;
407 assert!(matches!(result, Err(Error::Closed)));
408 });
409 }
410
411 #[test]
412 fn test_send_error_stream_dropped() {
413 let (mut sink, mut stream) = Channel::init();
414
415 let executor = deterministic::Runner::default();
416 executor.start(|context| async move {
417 assert!(sink.send(b"7 bytes".as_slice()).await.is_ok());
419
420 let handle = context.child("recv").spawn(|_| async move {
422 let _ = stream.recv(5).await;
423 let _ = stream.recv(5).await;
424 });
425
426 context.sleep(Duration::from_millis(50)).await;
428
429 handle.abort();
431 assert!(matches!(handle.await, Err(Error::Closed)));
432
433 let result = sink.send(b"hello world".as_slice()).await;
435 assert!(matches!(result, Err(Error::SendFailed)));
436 let result = sink.send(b"hello world".as_slice()).await;
437 assert!(matches!(result, Err(Error::Closed)));
438 });
439 }
440
441 #[test]
442 fn test_send_error_stream_dropped_before_send() {
443 let (mut sink, stream) = Channel::init();
444 drop(stream); let executor = deterministic::Runner::default();
447 executor.start(|_| async move {
448 let result = sink.send(b"hello world".as_slice()).await;
449 assert!(matches!(result, Err(Error::SendFailed)));
450 let result = sink.send(b"hello world".as_slice()).await;
451 assert!(matches!(result, Err(Error::Closed)));
452 });
453 }
454
455 #[test]
456 fn test_recv_timeout() {
457 let (_sink, mut stream) = Channel::init();
458
459 let executor = deterministic::Runner::default();
462 executor.start(|context| async move {
463 select! {
464 v = stream.recv(5) => {
465 panic!("unexpected value: {v:?}");
466 },
467 _ = context.sleep(Duration::from_millis(100)) => "timeout",
468 };
469 });
470 }
471
472 #[test]
473 fn test_peek_empty() {
474 let (_sink, stream) = Channel::init();
475
476 assert!(stream.peek(10).is_empty());
478 }
479
480 #[test]
481 fn test_peek_after_partial_recv() {
482 let (mut sink, mut stream) = Channel::init();
483
484 let executor = deterministic::Runner::default();
485 executor.start(|_| async move {
486 sink.send(b"hello world".as_slice()).await.unwrap();
488
489 let received = stream.recv(5).await.unwrap();
491 assert_eq!(received.coalesce(), b"hello");
492
493 assert_eq!(stream.peek(100), b" world");
495
496 assert_eq!(stream.peek(3), b" wo");
498
499 assert_eq!(stream.peek(100), b" world");
501
502 let received = stream.recv(6).await.unwrap();
504 assert_eq!(received.coalesce(), b" world");
505
506 assert!(stream.peek(100).is_empty());
508 });
509 }
510
511 #[test]
512 fn test_peek_after_recv_wakeup() {
513 let (mut sink, mut stream) = Channel::init_with_buffer_size(64);
514
515 let executor = deterministic::Runner::default();
516 executor.start(|context| async move {
517 let (tx, rx) = oneshot::channel();
519 let recv_handle = context.child("recv").spawn(|_| async move {
520 let data = stream.recv(3).await.unwrap();
521 tx.send(stream).ok();
522 data
523 });
524
525 context.sleep(Duration::from_millis(10)).await;
527
528 sink.send(b"ABCDEFGHIJ".as_slice()).await.unwrap();
530
531 let received = recv_handle.await.unwrap();
533 assert_eq!(received.coalesce(), b"ABC");
534
535 let stream = rx.await.unwrap();
537 assert_eq!(stream.peek(100), b"DEFGHIJ");
538 });
539 }
540
541 #[test]
542 fn test_peek_multiple_sends() {
543 let (mut sink, mut stream) = Channel::init();
544
545 let executor = deterministic::Runner::default();
546 executor.start(|_| async move {
547 sink.send(b"aaa".as_slice()).await.unwrap();
549 sink.send(b"bbb".as_slice()).await.unwrap();
550 sink.send(b"ccc".as_slice()).await.unwrap();
551
552 let received = stream.recv(4).await.unwrap();
554 assert_eq!(received.coalesce(), b"aaab");
555
556 assert_eq!(stream.peek(100), b"bbccc");
558 });
559 }
560
561 #[test]
562 fn test_buffer_size_limit() {
563 let (mut sink, mut stream) = Channel::init_with_buffer_size(10);
565
566 let executor = deterministic::Runner::default();
567 executor.start(|context| async move {
568 let send_handle = context.child("sender").spawn(|_| async move {
571 sink.send(b"0123456789ABCDEF".as_slice()).await.unwrap();
572 sink
573 });
574
575 let received = stream.recv(2).await.unwrap();
577 assert_eq!(received.coalesce(), b"01");
578
579 assert_eq!(stream.peek(100), b"23456789");
581
582 let received = stream.recv(8).await.unwrap();
585 assert_eq!(received.coalesce(), b"23456789");
586
587 let received = stream.recv(2).await.unwrap();
589 assert_eq!(received.coalesce(), b"AB");
590
591 assert_eq!(stream.peek(100), b"CDEF");
592
593 send_handle.await.unwrap();
595 });
596 }
597
598 #[test]
599 fn test_recv_before_send() {
600 let (mut sink, mut stream) = Channel::init_with_buffer_size(10);
602
603 let executor = deterministic::Runner::default();
604 executor.start(|context| async move {
605 let recv_handle = context
607 .child("recv")
608 .spawn(|_| async move { stream.recv(3).await.unwrap() });
609
610 context.sleep(Duration::from_millis(10)).await;
612
613 sink.send(b"ABCDEFGHIJKLMNOP".as_slice()).await.unwrap();
615
616 let received = recv_handle.await.unwrap();
618 assert_eq!(received.coalesce(), b"ABC");
619 });
620 }
621}