1use crate::{BufMut, Error, IoBufs, Sink as SinkTrait, Stream as StreamTrait};
4use bytes::{Bytes, BytesMut};
5use commonware_utils::channel::oneshot;
6use std::sync::{Arc, Mutex};
7
8const DEFAULT_READ_BUFFER_SIZE: usize = 64 * 1024;
10
11pub struct Channel {
13 buffer: BytesMut,
15
16 waiter: Option<(usize, oneshot::Sender<Bytes>)>,
20
21 read_buffer_size: usize,
23
24 sink_alive: bool,
26
27 stream_alive: bool,
29}
30
31impl Channel {
32 pub fn init() -> (Sink, Stream) {
34 Self::init_with_read_buffer_size(DEFAULT_READ_BUFFER_SIZE)
35 }
36
37 pub fn init_with_read_buffer_size(read_buffer_size: usize) -> (Sink, Stream) {
39 let channel = Arc::new(Mutex::new(Self {
40 buffer: BytesMut::new(),
41 waiter: None,
42 read_buffer_size,
43 sink_alive: true,
44 stream_alive: true,
45 }));
46 (
47 Sink {
48 channel: channel.clone(),
49 },
50 Stream {
51 channel,
52 buffer: BytesMut::new(),
53 },
54 )
55 }
56}
57
58pub struct Sink {
60 channel: Arc<Mutex<Channel>>,
61}
62
63impl SinkTrait for Sink {
64 async fn send(&mut self, buf: impl Into<IoBufs> + Send) -> Result<(), Error> {
65 let (os_send, data) = {
66 let mut channel = self.channel.lock().unwrap();
67
68 if !channel.stream_alive {
70 return Err(Error::Closed);
71 }
72
73 channel.buffer.put(buf.into());
74
75 if channel
79 .waiter
80 .as_ref()
81 .is_some_and(|(requested, _)| *requested <= channel.buffer.len())
82 {
83 let (requested, os_send) = channel.waiter.take().unwrap();
85 let send_amount = channel
86 .buffer
87 .len()
88 .min(requested.max(channel.read_buffer_size));
89 let data = channel.buffer.split_to(send_amount).freeze();
90 (os_send, data)
91 } else {
92 return Ok(());
93 }
94 };
95
96 os_send.send(data).map_err(|_| Error::SendFailed)?;
98 Ok(())
99 }
100}
101
102impl Drop for Sink {
103 fn drop(&mut self) {
104 let mut channel = self.channel.lock().unwrap();
105 channel.sink_alive = false;
106
107 channel.waiter.take();
109 }
110}
111
112pub struct Stream {
114 channel: Arc<Mutex<Channel>>,
115 buffer: BytesMut,
117}
118
119impl StreamTrait for Stream {
120 async fn recv(&mut self, len: u64) -> Result<IoBufs, Error> {
121 let len = len as usize;
122
123 let os_recv = {
124 let mut channel = self.channel.lock().unwrap();
125
126 if !channel.buffer.is_empty() {
128 let target = len.max(channel.read_buffer_size);
129 let pull_amount = channel
130 .buffer
131 .len()
132 .min(target.saturating_sub(self.buffer.len()));
133 if pull_amount > 0 {
134 let data = channel.buffer.split_to(pull_amount);
135 self.buffer.extend_from_slice(&data);
136 }
137 }
138
139 if self.buffer.len() >= len {
141 return Ok(IoBufs::from(self.buffer.split_to(len).freeze()));
142 }
143
144 if !channel.sink_alive {
146 return Err(Error::Closed);
147 }
148
149 let remaining = len - self.buffer.len();
151 assert!(channel.waiter.is_none());
152 let (os_send, os_recv) = oneshot::channel();
153 channel.waiter = Some((remaining, os_send));
154 os_recv
155 };
156
157 let data = os_recv.await.map_err(|_| Error::Closed)?;
159 self.buffer.extend_from_slice(&data);
160
161 assert!(self.buffer.len() >= len);
162 Ok(IoBufs::from(self.buffer.split_to(len).freeze()))
163 }
164
165 fn peek(&self, max_len: u64) -> &[u8] {
166 let len = (max_len as usize).min(self.buffer.len());
167 &self.buffer[..len]
168 }
169}
170
171impl Drop for Stream {
172 fn drop(&mut self) {
173 let mut channel = self.channel.lock().unwrap();
174 channel.stream_alive = false;
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use crate::{deterministic, Clock, Runner, Spawner};
182 use commonware_macros::select;
183 use std::{thread::sleep, time::Duration};
184
185 #[test]
186 fn test_send_recv() {
187 let (mut sink, mut stream) = Channel::init();
188 let data = b"hello world";
189
190 let executor = deterministic::Runner::default();
191 executor.start(|_| async move {
192 sink.send(data.as_slice()).await.unwrap();
193 let received = stream.recv(data.len() as u64).await.unwrap();
194 assert_eq!(received.coalesce(), data);
195 });
196 }
197
198 #[test]
199 fn test_send_recv_partial_multiple() {
200 let (mut sink, mut stream) = Channel::init();
201 let data = b"hello";
202 let data2 = b" world";
203
204 let executor = deterministic::Runner::default();
205 executor.start(|_| async move {
206 sink.send(data.as_slice()).await.unwrap();
207 sink.send(data2.as_slice()).await.unwrap();
208 let received = stream.recv(5).await.unwrap();
209 assert_eq!(received.coalesce(), b"hello");
210 let received = stream.recv(5).await.unwrap();
211 assert_eq!(received.coalesce(), b" worl");
212 let received = stream.recv(1).await.unwrap();
213 assert_eq!(received.coalesce(), b"d");
214 });
215 }
216
217 #[test]
218 fn test_send_recv_async() {
219 let (mut sink, mut stream) = Channel::init();
220 let data = b"hello world";
221
222 let executor = deterministic::Runner::default();
223 executor.start(|_| async move {
224 let (received, _) = futures::try_join!(stream.recv(data.len() as u64), async {
225 sleep(Duration::from_millis(50));
226 sink.send(data.as_slice()).await
227 })
228 .unwrap();
229 assert_eq!(received.coalesce(), data);
230 });
231 }
232
233 #[test]
234 fn test_recv_error_sink_dropped_while_waiting() {
235 let (sink, mut stream) = Channel::init();
236
237 let executor = deterministic::Runner::default();
238 executor.start(|context| async move {
239 futures::join!(
240 async {
241 let result = stream.recv(5).await;
242 assert!(matches!(result, Err(Error::Closed)));
243 },
244 async {
245 context.sleep(Duration::from_millis(50)).await;
247 drop(sink);
248 }
249 );
250 });
251 }
252
253 #[test]
254 fn test_recv_error_sink_dropped_before_recv() {
255 let (sink, mut stream) = Channel::init();
256 drop(sink); let executor = deterministic::Runner::default();
259 executor.start(|_| async move {
260 let result = stream.recv(5).await;
261 assert!(matches!(result, Err(Error::Closed)));
262 });
263 }
264
265 #[test]
266 fn test_send_error_stream_dropped() {
267 let (mut sink, mut stream) = Channel::init();
268
269 let executor = deterministic::Runner::default();
270 executor.start(|context| async move {
271 assert!(sink.send(b"7 bytes".as_slice()).await.is_ok());
273
274 let handle = context.clone().spawn(|_| async move {
276 let _ = stream.recv(5).await;
277 let _ = stream.recv(5).await;
278 });
279
280 context.sleep(Duration::from_millis(50)).await;
282
283 handle.abort();
285 assert!(matches!(handle.await, Err(Error::Closed)));
286
287 let result = sink.send(b"hello world".as_slice()).await;
289 assert!(matches!(result, Err(Error::Closed)));
290 });
291 }
292
293 #[test]
294 fn test_send_error_stream_dropped_before_send() {
295 let (mut sink, stream) = Channel::init();
296 drop(stream); let executor = deterministic::Runner::default();
299 executor.start(|_| async move {
300 let result = sink.send(b"hello world".as_slice()).await;
301 assert!(matches!(result, Err(Error::Closed)));
302 });
303 }
304
305 #[test]
306 fn test_recv_timeout() {
307 let (_sink, mut stream) = Channel::init();
308
309 let executor = deterministic::Runner::default();
312 executor.start(|context| async move {
313 select! {
314 v = stream.recv(5) => {
315 panic!("unexpected value: {v:?}");
316 },
317 _ = context.sleep(Duration::from_millis(100)) => "timeout",
318 };
319 });
320 }
321
322 #[test]
323 fn test_peek_empty() {
324 let (_sink, stream) = Channel::init();
325
326 assert!(stream.peek(10).is_empty());
328 }
329
330 #[test]
331 fn test_peek_after_partial_recv() {
332 let (mut sink, mut stream) = Channel::init();
333
334 let executor = deterministic::Runner::default();
335 executor.start(|_| async move {
336 sink.send(b"hello world".as_slice()).await.unwrap();
338
339 let received = stream.recv(5).await.unwrap();
341 assert_eq!(received.coalesce(), b"hello");
342
343 assert_eq!(stream.peek(100), b" world");
345
346 assert_eq!(stream.peek(3), b" wo");
348
349 assert_eq!(stream.peek(100), b" world");
351
352 let received = stream.recv(6).await.unwrap();
354 assert_eq!(received.coalesce(), b" world");
355
356 assert!(stream.peek(100).is_empty());
358 });
359 }
360
361 #[test]
362 fn test_peek_after_recv_wakeup() {
363 let (mut sink, mut stream) = Channel::init_with_read_buffer_size(64);
364
365 let executor = deterministic::Runner::default();
366 executor.start(|context| async move {
367 let (tx, rx) = oneshot::channel();
369 let recv_handle = context.clone().spawn(|_| async move {
370 let data = stream.recv(3).await.unwrap();
371 tx.send(stream).ok();
372 data
373 });
374
375 context.sleep(Duration::from_millis(10)).await;
377
378 sink.send(b"ABCDEFGHIJ".as_slice()).await.unwrap();
380
381 let received = recv_handle.await.unwrap();
383 assert_eq!(received.coalesce(), b"ABC");
384
385 let stream = rx.await.unwrap();
387 assert_eq!(stream.peek(100), b"DEFGHIJ");
388 });
389 }
390
391 #[test]
392 fn test_peek_multiple_sends() {
393 let (mut sink, mut stream) = Channel::init();
394
395 let executor = deterministic::Runner::default();
396 executor.start(|_| async move {
397 sink.send(b"aaa".as_slice()).await.unwrap();
399 sink.send(b"bbb".as_slice()).await.unwrap();
400 sink.send(b"ccc".as_slice()).await.unwrap();
401
402 let received = stream.recv(4).await.unwrap();
404 assert_eq!(received.coalesce(), b"aaab");
405
406 assert_eq!(stream.peek(100), b"bbccc");
408 });
409 }
410
411 #[test]
412 fn test_read_buffer_size_limit() {
413 let (mut sink, mut stream) = Channel::init_with_read_buffer_size(10);
415
416 let executor = deterministic::Runner::default();
417 executor.start(|_| async move {
418 sink.send(b"0123456789ABCDEF".as_slice()).await.unwrap();
420
421 let received = stream.recv(2).await.unwrap();
423 assert_eq!(received.coalesce(), b"01");
424
425 assert_eq!(stream.peek(100), b"23456789");
427
428 let received = stream.recv(8).await.unwrap();
431 assert_eq!(received.coalesce(), b"23456789");
432
433 let received = stream.recv(2).await.unwrap();
435 assert_eq!(received.coalesce(), b"AB");
436
437 assert_eq!(stream.peek(100), b"CDEF");
438 });
439 }
440
441 #[test]
442 fn test_recv_before_send() {
443 let (mut sink, mut stream) = Channel::init_with_read_buffer_size(10);
445
446 let executor = deterministic::Runner::default();
447 executor.start(|context| async move {
448 let recv_handle = context
450 .clone()
451 .spawn(|_| async move { stream.recv(3).await.unwrap() });
452
453 context.sleep(Duration::from_millis(10)).await;
455
456 sink.send(b"ABCDEFGHIJKLMNOP".as_slice()).await.unwrap();
458
459 let received = recv_handle.await.unwrap();
461 assert_eq!(received.coalesce(), b"ABC");
462 });
463 }
464}