1use std::{pin::Pin, task::{Context, Poll, Waker}, io, fmt};
2
3use futures::ready;
4use tokio::{io::{AsyncRead, AsyncWrite, ReadBuf}, sync::watch};
5
6use crate::Channel;
7
8pub struct TransportChannel<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> {
9 id: u16,
10 label: String,
11 channel: Pin<Box<TAsyncDuplex>>,
12 is_closed: bool,
13 is_read_closed: bool,
14 is_shutdown_requested: bool,
15 read_waker: Option<Waker>,
16 self_closed: watch::Receiver<bool>,
17 remote_closed: watch::Receiver<bool>,
18 local_closed: watch::Sender<bool>,
19 buffer_size: u32,
20}
21
22impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> TransportChannel<TAsyncDuplex> {
23 pub fn new_pair(
24 id: u16,
25 label: impl AsRef<str> + ToString,
26 channels: (Box<TAsyncDuplex>, Box<TAsyncDuplex>),
27 buffer_size: u32,
28 ) -> (Box<dyn Channel>, Box<dyn Channel>) {
29 let (channel1, channel2) = channels;
30
31 let (local_closed1, remote_closed1) = watch::channel(false);
32 let (local_closed2, remote_closed2) = watch::channel(false);
33
34 let label = label.to_string();
35 let label1 = format!("{label}-1");
36 let label2 = format!("{label}-2");
37
38 let self_closed1 = remote_closed1.clone();
39 let self_closed2 = remote_closed2.clone();
40
41 let channel1 = Box::new(
42 TransportChannel {
43 id,
44 label: label1,
45 channel: Pin::new(channel1),
46 is_closed: false,
47 is_read_closed: false,
48 is_shutdown_requested: false,
49 read_waker: None,
50 self_closed: self_closed1,
51 remote_closed: remote_closed2,
52 local_closed: local_closed1,
53 buffer_size,
54 },
55 );
56
57 let channel2 = Box::new(
58 TransportChannel {
59 id,
60 label: label2,
61 channel: Pin::new(channel2),
62 is_closed: false,
63 is_read_closed: false,
64 is_shutdown_requested: false,
65 read_waker: None,
66 self_closed: self_closed2,
67 remote_closed: remote_closed1,
68 local_closed: local_closed2,
69 buffer_size
70 },
71 );
72
73 return (channel1, channel2)
74 }
75
76 fn is_remote_closed(&self) -> bool {
77 return *self.remote_closed.borrow();
78 }
79}
80
81impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> Channel for TransportChannel<TAsyncDuplex> {
82 fn id(&self) -> u16 {
83 return self.id;
84 }
85
86 fn label(&self) -> &String {
87 return &self.label;
88 }
89
90 fn is_closed(&self) -> bool {
91 return self.is_closed;
92 }
93
94 fn on_close(&self) -> watch::Receiver<bool> {
95 return self.self_closed.clone();
96 }
97
98 fn buffer_size(&self) -> u32 {
99 return self.buffer_size;
100 }
101}
102
103impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> AsyncRead for TransportChannel<TAsyncDuplex> {
104 fn poll_read(
105 mut self: Pin<&mut Self>,
106 cx: &mut Context<'_>,
107 buf: &mut ReadBuf<'_>,
108 ) -> Poll<io::Result<()>> {
109 if self.is_shutdown_requested && !self.is_closed {
110 let result = ready!(self.as_mut().poll_shutdown(cx));
111
112 self.is_read_closed = true;
113
114 return Poll::Ready(result);
115 }
116
117 if self.is_closed && self.is_read_closed {
119 return Poll::Ready(Ok(()));
120 }
121
122 let filled_before = buf.filled().len();
123
124 let result = self.channel.as_mut().poll_read(cx, buf);
126
127 let bytes_read = buf.filled().len() - filled_before;
128
129 if self.is_closed && !self.is_read_closed {
131 self.is_read_closed = true;
132
133 return Poll::Ready(Ok(()));
134 }
135
136 if result.is_pending() {
138 self.read_waker.replace(cx.waker().clone());
141 } else {
142 self.read_waker.take();
144
145 if self.is_remote_closed() {
146 self.is_shutdown_requested = true;
147
148 if bytes_read == 0 {
150 return self.poll_shutdown(cx);
151 }
152 }
153 }
154
155 return result;
156 }
157}
158
159impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> AsyncWrite for TransportChannel<TAsyncDuplex> {
160 fn poll_write(
161 mut self: Pin<&mut Self>,
162 cx: &mut Context<'_>,
163 buf: &[u8],
164 ) -> Poll<io::Result<usize>> {
165 if self.is_remote_closed() {
166 return Poll::Ready(Ok(0));
167 }
168
169 let result = self.channel.as_mut()
170 .poll_write(cx, buf);
171
172 return result;
173 }
174
175 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
176 return self.channel.as_mut()
177 .poll_flush(cx);
178 }
179
180 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
181 if self.is_closed {
182 return Poll::Ready(Ok(()));
183 }
184
185 let result = ready!(self.channel.as_mut().poll_shutdown(cx));
187
188 self.is_closed = true;
189
190 let _res = self.local_closed.send(true);
192
193 if let Some(waker) = self.read_waker.take() {
196 waker.wake();
197 }
198
199
200 return Poll::Ready(result);
201 }
202}
203
204impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> fmt::Debug for TransportChannel<TAsyncDuplex> {
205 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 return self.debug("TransportChannel", f);
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use rstest::rstest;
213 use futures::{SinkExt, StreamExt};
214 use tokio::io::{AsyncReadExt, AsyncWriteExt};
215 use cs_utils::{traits::Random, futures::wait_random, test::random_vec, random_number, random_str_rg};
216
217 use super::TransportChannel;
218 use crate::create_framed_stream;
219 use crate::mocks::{channel_mock_pair, ChannelMockOptions};
220 use crate::test::{test_framed_stream, test_async_stream, TestOptions, TestStreamMessage};
221
222 #[rstest]
223 #[case(128)]
224 #[case(256)]
225 #[case(512)]
226 #[case(1_024)]
227 #[case(2_048)]
228 #[case(4_096)]
229 #[case(8_192)]
230 #[case(16_384)]
231 #[case(32_768)]
232 #[tokio::test]
233 async fn transfers_binary_data(
234 #[case] test_data_size: usize,
235 ) {
236 let (channel1, channel2) = channel_mock_pair(
237 ChannelMockOptions::random(),
238 ChannelMockOptions::random(),
239 );
240
241 let (channel1, channel2) = TransportChannel::new_pair(
242 1,
243 "in-memory-channel-1",
244 (Box::new(channel1), Box::new(channel2)),
245 4_096,
246 );
247
248 test_async_stream(
249 channel1,
250 channel2,
251 TestOptions::random()
252 .with_data_len(test_data_size),
253 ).await;
254 }
255
256 #[rstest]
257 #[case(random_number(6..=8))]
258 #[case(random_number(12..=16))]
259 #[case(random_number(25..=32))]
260 #[case(random_number(53..=64))]
261 #[case(random_number(100..=128))]
262 #[case(random_number(200..=256))]
263 #[tokio::test]
264 async fn transfers_stream_data(
265 #[case] items_count: usize,
266 ) {
267 let (channel1, channel2) = channel_mock_pair(
268 ChannelMockOptions::random(),
269 ChannelMockOptions::random(),
270 );
271
272 let (channel1, channel2) = TransportChannel::new_pair(
273 1,
274 "in-memory-channel-1",
275 (Box::new(channel1), Box::new(channel2)),
276 4_096,
277 );
278
279 let channel1 = create_framed_stream::<TestStreamMessage, _>(channel1);
280 let channel2 = create_framed_stream::<TestStreamMessage, _>(channel2);
281
282 test_framed_stream(
283 channel1,
284 channel2,
285 TestOptions::random()
286 .with_data_len(items_count),
287 ).await;
288 }
289
290 #[rstest]
291 #[case(128)]
292 #[case(256)]
293 #[case(512)]
294 #[case(1_024)]
295 #[case(2_048)]
296 #[tokio::test]
297 async fn reads_to_end_if_self_shutdown(
298 #[case] test_data_size: usize,
299 ) {
300
301 let (channel1, channel2) = channel_mock_pair(
302 ChannelMockOptions::random(),
303 ChannelMockOptions::random(),
304 );
305
306 let (channel1, channel2) = TransportChannel::new_pair(
307 1,
308 "in-memory-channel-1",
309 (Box::new(channel1), Box::new(channel2)),
310 4_096,
311 );
312
313 let (channel1, mut channel2) = test_async_stream(
314 channel1,
315 channel2,
316 TestOptions::random()
317 .with_data_len(test_data_size),
318 ).await;
319
320 wait_random(25..=50).await;
321
322 let test_data = random_str_rg(8..=32);
323
324 channel2.write(test_data.as_bytes()).await.unwrap();
325
326 let (mut source, mut sink) = tokio::io::split(channel1);
327
328 tokio::join!(
329 Box::pin(async move {
330 wait_random(0..=5).await;
331
332 let mut buf = vec![];
333
334 let bytes_read = source.read_to_end(&mut buf).await
335 .expect("Cannot read to end.");
336
337 assert_eq!(
338 bytes_read,
339 test_data.len(),
340 "Closed channel must read {} bytes.",
341 test_data.len(),
342 );
343 }),
344 Box::pin(async move {
345 wait_random(0..=5).await;
346
347 sink.shutdown().await.unwrap();
348 }),
349 );
350
351 assert!(!channel2.is_closed(), "Channel2 must not be closed.");
352 }
353
354 #[rstest]
355 #[case(128)]
356 #[case(256)]
357 #[case(512)]
358 #[case(1_024)]
359 #[case(2_048)]
360 #[tokio::test]
361 async fn reads_if_self_shutdown(
362 #[case] test_data_size: usize,
363 ) {
364
365 let (channel1, channel2) = channel_mock_pair(
366 ChannelMockOptions::random(),
367 ChannelMockOptions::random(),
368 );
369
370 let (channel1, channel2) = TransportChannel::new_pair(
371 1,
372 "in-memory-channel-1",
373 (Box::new(channel1), Box::new(channel2)),
374 4_096,
375 );
376
377 let (channel1, mut channel2) = test_async_stream(
378 channel1,
379 channel2,
380 TestOptions::random()
381 .with_data_len(test_data_size),
382 ).await;
383
384 wait_random(25..=50).await;
385
386 let test_data = random_str_rg(8..=32);
387
388 channel2.write(test_data.as_bytes()).await.unwrap();
389
390 let (mut source, mut sink) = tokio::io::split(channel1);
391
392 tokio::join!(
393 Box::pin(async move {
394 wait_random(0..=5).await;
395
396 let mut buf = [0; 1024];
397
398 let bytes_read = source.read(&mut buf).await
399 .expect("Cannot read to end.");
400
401 assert_eq!(
402 bytes_read,
403 test_data.len(),
404 "Closed channel must read {} bytes.",
405 test_data.len(),
406 );
407 }),
408 Box::pin(async move {
409 wait_random(0..=5).await;
410
411 sink.shutdown().await.unwrap();
412 }),
413 );
414
415 assert!(!channel2.is_closed(), "Channel2 must not be closed.");
416 }
417
418 #[rstest]
419 #[case(random_number(6..=8))]
420 #[case(random_number(12..=16))]
421 #[case(random_number(25..=32))]
422 #[case(random_number(53..=64))]
423 #[case(random_number(100..=128))]
424 #[case(random_number(200..=256))]
425 #[tokio::test]
426 async fn closes_stream_if_self_is_closed(
427 #[case] items_count: u32,
428 ) {
429 let (channel1, channel2) = channel_mock_pair(
430 ChannelMockOptions::random(),
431 ChannelMockOptions::random(),
432 );
433
434 let (channel1, channel2) = TransportChannel::new_pair(
435 1,
436 "in-memory-channel-1",
437 (Box::new(channel1), Box::new(channel2)),
438 4_096,
439 );
440
441 let channel1 = create_framed_stream::<TestStreamMessage, _>(channel1);
442 let channel2 = create_framed_stream::<TestStreamMessage, _>(channel2);
443
444 let (channel1, mut channel2) = test_framed_stream(
445 channel1,
446 channel2,
447 TestOptions::random()
448 .with_data_len(10),
449 ).await;
450
451 let (mut sink, mut source) = channel1.split();
452
453 let test_messages = random_vec::<TestStreamMessage>(items_count);
454 let messages_to_send = test_messages.clone();
455 let mut received_messages = vec![];
456
457 tokio::join!(
458 Box::pin(async move {
459 while let Some(message) = source.next().await {
460 received_messages.push(message);
461 }
462 }),
463 Box::pin(async move {
464 for message in messages_to_send {
465 channel2.send(message).await.unwrap();
466 }
467
468 sink.close().await.unwrap();
469 }),
470 );
471 }
472
473 #[rstest]
474 #[case(random_number(6..=8))]
475 #[case(random_number(12..=16))]
476 #[case(random_number(25..=32))]
477 #[case(random_number(53..=64))]
478 #[case(random_number(100..=128))]
479 #[case(random_number(200..=256))]
480 #[tokio::test]
481 async fn closes_stream_if_remote_counterpart_is_closed(
482 #[case] items_count: u32,
483 ) {
484 let (channel1, channel2) = channel_mock_pair(
485 ChannelMockOptions::random(),
486 ChannelMockOptions::random(),
487 );
488
489 let (channel1, channel2) = TransportChannel::new_pair(
490 1,
491 "in-memory-channel-1",
492 (Box::new(channel1), Box::new(channel2)),
493 4_096,
494 );
495
496 let channel1 = create_framed_stream::<TestStreamMessage, _>(channel1);
497 let channel2 = create_framed_stream::<TestStreamMessage, _>(channel2);
498
499 let (mut channel1, mut channel2) = test_framed_stream(
500 channel1,
501 channel2,
502 TestOptions::random()
503 .with_data_len(10),
504 ).await;
505
506 let test_messages = random_vec::<TestStreamMessage>(items_count);
507 let messages_to_send = test_messages.clone();
508 let mut received_messages = vec![];
509
510 tokio::join!(
511 Box::pin(async move {
512 while let Some(message) = channel1.next().await {
513 received_messages.push(message);
514 }
515
516 assert!(channel1.get_ref().is_closed(), "Channel must be closed.");
517 }),
518 Box::pin(async move {
519 for message in messages_to_send {
520 channel2.send(message).await.unwrap();
521 }
522
523 channel2.close().await.unwrap();
524 }),
525 );
526 }
527
528 #[rstest]
529 #[case(128)]
530 #[case(256)]
531 #[case(512)]
532 #[case(1_024)]
533 #[case(2_048)]
534 #[tokio::test]
535 async fn reads_to_end_if_remote_counterpart_is_closed(
536 #[case] test_data_size: usize,
537 ) {
538
539 let (channel1, channel2) = channel_mock_pair(
540 ChannelMockOptions::random(),
541 ChannelMockOptions::random(),
542 );
543
544 let (channel1, channel2) = TransportChannel::new_pair(
545 1,
546 "in-memory-channel-1",
547 (Box::new(channel1), Box::new(channel2)),
548 4_096,
549 );
550
551 let (mut channel1, mut channel2) = test_async_stream(
552 channel1,
553 channel2,
554 TestOptions::random()
555 .with_data_len(test_data_size),
556 ).await;
557
558 let test_data = random_str_rg(8..=32);
559
560 channel2.write(test_data.as_bytes()).await.unwrap();
561
562 tokio::join!(
563 Box::pin(async move {
564 wait_random(0..=5).await;
565
566 let mut buf = vec![];
567
568 let bytes_read = channel1.read_to_end(&mut buf).await
569 .expect("Cannot read to end.");
570
571 assert_eq!(
572 bytes_read,
573 test_data.len(),
574 "Closed channel must read {} bytes.",
575 test_data.len(),
576 );
577
578 assert!(
579 channel1.is_closed(),
580 "Channel must be closed after remote counterpart is closed.",
581 );
582 }),
583 Box::pin(async move {
584 wait_random(0..=5).await;
585
586 channel2.shutdown().await.unwrap();
587 }),
588 );
589 }
590
591 #[rstest]
592 #[case(128)]
593 #[case(256)]
594 #[case(512)]
595 #[case(1_024)]
596 #[case(2_048)]
597 #[tokio::test]
598 async fn reads_if_remote_counterpart_is_closed(
599 #[case] test_data_size: usize,
600 ) {
601
602 let (channel1, channel2) = channel_mock_pair(
603 ChannelMockOptions::random(),
604 ChannelMockOptions::random(),
605 );
606
607 let (channel1, channel2) = TransportChannel::new_pair(
608 1,
609 "in-memory-channel-1",
610 (Box::new(channel1), Box::new(channel2)),
611 4_096,
612 );
613
614 let (mut channel1, mut channel2) = test_async_stream(
615 channel1,
616 channel2,
617 TestOptions::random()
618 .with_data_len(test_data_size),
619 ).await;
620
621 let test_data = random_str_rg(8..=32);
622
623 channel2.write(test_data.as_bytes()).await.unwrap();
624
625 channel2.shutdown().await.unwrap();
626
627 assert!(
628 channel2.is_closed(),
629 "Channel2 must be closed.",
630 );
631
632 wait_random(3..=5).await;
633
634 let mut buf = [0; 1024];
635
636 let bytes_read = channel1.read(&mut buf).await
637 .expect("Cannot read to end.");
638
639 assert_eq!(
640 bytes_read,
641 test_data.len(),
642 "Closed channel must read {} bytes.",
643 test_data.len(),
644 );
645
646 let bytes_read = channel1.read(&mut buf).await
647 .expect("Cannot read to end.");
648
649 assert_eq!(
650 bytes_read,
651 0,
652 "Closed channel must read 0 bytes.",
653 );
654
655 assert!(
656 channel1.is_closed(),
657 "Channel must be closed after remote counterpart is closed.",
658 );
659 }
660
661 #[rstest]
662 #[case(128)]
663 #[case(256)]
664 #[case(512)]
665 #[case(1_024)]
666 #[case(2_048)]
667 #[tokio::test]
668 async fn fails_to_write_if_remote_counterpart_is_closed(
669 #[case] test_data_size: usize,
670 ) {
671
672 let (channel1, channel2) = channel_mock_pair(
673 ChannelMockOptions::random(),
674 ChannelMockOptions::random(),
675 );
676
677 let (channel1, channel2) = TransportChannel::new_pair(
678 1,
679 "in-memory-channel-1",
680 (Box::new(channel1), Box::new(channel2)),
681 4_096,
682 );
683
684 let (mut channel1, mut channel2) = test_async_stream(
685 channel1,
686 channel2,
687 TestOptions::random()
688 .with_data_len(test_data_size),
689 ).await;
690
691 channel2.shutdown().await.unwrap();
692
693 assert!(
694 channel2.write(b"anything").await.is_err(),
695 "Must fail to write to closed channel.",
696 );
697
698 assert!(
699 channel2.is_closed(),
700 "Channel2 must be closed.",
701 );
702
703 wait_random(3..=5).await;
704
705 let test_data = random_str_rg(24..=32);
706 let bytes_written = channel1.write(test_data.as_bytes()).await
707 .expect("Cannot write to channel.");
708
709 assert_eq!(
710 bytes_written,
711 0,
712 "Must write 0 bytes if remote channel is closed.",
713 );
714 }
715
716 #[rstest]
717 #[case(128)]
718 #[case(256)]
719 #[case(512)]
720 #[case(1_024)]
721 #[case(2_048)]
722 #[tokio::test]
723 async fn fails_to_write_if_self_is_closed(
724 #[case] test_data_size: usize,
725 ) {
726
727 let (channel1, channel2) = channel_mock_pair(
728 ChannelMockOptions::random(),
729 ChannelMockOptions::random(),
730 );
731
732 let (channel1, channel2) = TransportChannel::new_pair(
733 1,
734 "in-memory-channel-1",
735 (Box::new(channel1), Box::new(channel2)),
736 4_096,
737 );
738
739 let (channel1, mut channel2) = test_async_stream(
740 channel1,
741 channel2,
742 TestOptions::random()
743 .with_data_len(test_data_size),
744 ).await;
745
746 let (mut source, mut sink) = tokio::io::split(channel1);
747
748 let test_data = random_str_rg(24..=32);
749
750 channel2.write(test_data.as_bytes()).await
751 .expect("Cannot write data.");
752
753 sink.shutdown().await.unwrap();
754
755 assert!(
756 sink.write(b"something").await.is_err(),
757 "Must fail to write to closed channel.",
758 );
759
760 let mut buf = vec![];
761 let bytes_received = source.read_to_end(&mut buf).await
762 .expect("Cannot read data.");
763
764 assert_eq!(
765 bytes_received,
766 test_data.len(),
767 "Must be able to read to end if channel is closed.",
768 );
769
770 let channel1 = source.unsplit(sink);
771
772 assert!(
773 channel1.is_closed(),
774 "Channel must be closed.",
775 );
776 }
777}