connection_utils/channels/
transport_channel.rs

1use std::{pin::Pin, task::{Context, Poll, Waker}, io, fmt};
2
3use futures::ready;
4use tokio::{io::{AsyncRead, AsyncWrite, ReadBuf}, sync::watch};
5
6use crate::Channel;
7
8pub struct TransportChannel<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> {
9    id: u16,
10    label: String,
11    channel: Pin<Box<TAsyncDuplex>>,
12    is_closed: bool,
13    is_read_closed: bool,
14    is_shutdown_requested: bool,
15    read_waker: Option<Waker>,
16    self_closed: watch::Receiver<bool>,
17    remote_closed: watch::Receiver<bool>,
18    local_closed: watch::Sender<bool>,
19    buffer_size: u32,
20}
21
22impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> TransportChannel<TAsyncDuplex> {
23    pub fn new_pair(
24        id: u16,
25        label: impl AsRef<str> + ToString,
26        channels: (Box<TAsyncDuplex>, Box<TAsyncDuplex>),
27        buffer_size: u32,
28    ) -> (Box<dyn Channel>, Box<dyn Channel>) {
29        let (channel1, channel2) = channels;
30
31        let (local_closed1, remote_closed1) = watch::channel(false);
32        let (local_closed2, remote_closed2) = watch::channel(false);
33    
34        let label = label.to_string();
35        let label1 = format!("{label}-1");
36        let label2 = format!("{label}-2");
37
38        let self_closed1 = remote_closed1.clone();
39        let self_closed2 = remote_closed2.clone();
40    
41        let channel1 = Box::new(
42            TransportChannel {
43                id,
44                label: label1,
45                channel: Pin::new(channel1),
46                is_closed: false,
47                is_read_closed: false,
48                is_shutdown_requested: false,
49                read_waker: None,
50                self_closed: self_closed1,
51                remote_closed: remote_closed2,
52                local_closed: local_closed1,
53                buffer_size,
54            },
55        );
56    
57        let channel2 = Box::new(
58            TransportChannel {
59                id,
60                label: label2,
61                channel: Pin::new(channel2),
62                is_closed: false,
63                is_read_closed: false,
64                is_shutdown_requested: false,
65                read_waker: None,
66                self_closed: self_closed2,
67                remote_closed: remote_closed1,
68                local_closed: local_closed2,
69                buffer_size
70            },
71        );
72    
73        return (channel1, channel2)
74    }
75
76    fn is_remote_closed(&self) -> bool {
77        return *self.remote_closed.borrow();
78    }
79}
80
81impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> Channel for TransportChannel<TAsyncDuplex> {
82    fn id(&self) -> u16 {
83        return self.id;
84    }
85    
86    fn label(&self) ->  &String {
87        return &self.label;
88    }
89
90    fn is_closed(&self) ->  bool {
91        return self.is_closed;
92    }
93
94    fn on_close(&self) -> watch::Receiver<bool> {
95        return self.self_closed.clone();
96    }
97
98    fn buffer_size(&self) -> u32 {
99        return self.buffer_size;
100    }
101} 
102
103impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> AsyncRead for TransportChannel<TAsyncDuplex> {
104    fn poll_read(
105        mut self: Pin<&mut Self>,
106        cx: &mut Context<'_>,
107        buf: &mut ReadBuf<'_>,
108    ) -> Poll<io::Result<()>> {
109        if self.is_shutdown_requested && !self.is_closed {
110            let result = ready!(self.as_mut().poll_shutdown(cx));
111
112            self.is_read_closed = true;
113
114            return Poll::Ready(result);
115        }
116
117        // fully closed for reads, return EOF
118        if self.is_closed && self.is_read_closed {
119            return Poll::Ready(Ok(()));
120        }
121        
122        let filled_before = buf.filled().len();
123        
124        // poll underlying channel
125        let result = self.channel.as_mut().poll_read(cx, buf);
126
127        let bytes_read = buf.filled().len() - filled_before;
128
129        // allow for the last read after `shutdown`
130        if self.is_closed && !self.is_read_closed {
131            self.is_read_closed = true;
132
133            return Poll::Ready(Ok(()));
134        }
135
136        // save or remove read waker
137        if result.is_pending() {
138            // save read waker in case we need to shutdown the channel
139            // but someone called `read_to_end` before the channel was closed
140            self.read_waker.replace(cx.waker().clone());
141        } else {
142            // remove read waker
143            self.read_waker.take();
144
145            if self.is_remote_closed() {
146                self.is_shutdown_requested = true;
147
148                // if received EOF, shutdown immediatelly
149                if bytes_read == 0 {
150                    return self.poll_shutdown(cx);
151                }
152            }
153        }
154
155        return result;
156    }
157}
158
159impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> AsyncWrite for TransportChannel<TAsyncDuplex> {
160    fn poll_write(
161        mut self: Pin<&mut Self>,
162        cx: &mut Context<'_>,
163        buf: &[u8],
164    ) -> Poll<io::Result<usize>> {
165        if self.is_remote_closed() {
166            return Poll::Ready(Ok(0));
167        }
168
169        let result = self.channel.as_mut()
170            .poll_write(cx, buf);
171
172        return result;
173    }
174
175    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
176        return self.channel.as_mut()
177            .poll_flush(cx);
178    }
179
180    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
181        if self.is_closed {
182            return Poll::Ready(Ok(()));
183        }
184
185        // wait until shut
186        let result = ready!(self.channel.as_mut().poll_shutdown(cx));
187
188        self.is_closed = true;
189
190        // notify remote part about shutdown
191        let _res = self.local_closed.send(true);
192
193        // in some cases, if `read_to_end` was called and yielded a `Poll::Pending` result,
194        // we need wake the `poll_read` again, otherwise the `read_to_end` might never return
195        if let Some(waker) = self.read_waker.take() {
196            waker.wake();
197        }
198
199
200        return Poll::Ready(result);
201    }
202}
203
204impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + ?Sized + 'static> fmt::Debug for TransportChannel<TAsyncDuplex> {
205    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206        return self.debug("TransportChannel", f);
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use rstest::rstest;
213    use futures::{SinkExt, StreamExt};
214    use tokio::io::{AsyncReadExt, AsyncWriteExt};
215    use cs_utils::{traits::Random, futures::wait_random, test::random_vec, random_number, random_str_rg};
216
217    use super::TransportChannel;
218    use crate::create_framed_stream;
219    use crate::mocks::{channel_mock_pair, ChannelMockOptions};
220    use crate::test::{test_framed_stream, test_async_stream, TestOptions, TestStreamMessage};
221
222    #[rstest]
223    #[case(128)]
224    #[case(256)]
225    #[case(512)]
226    #[case(1_024)]
227    #[case(2_048)]
228    #[case(4_096)]
229    #[case(8_192)]
230    #[case(16_384)]
231    #[case(32_768)]
232    #[tokio::test]
233    async fn transfers_binary_data(
234        #[case] test_data_size: usize,
235    ) {
236        let (channel1, channel2) = channel_mock_pair(
237            ChannelMockOptions::random(),
238            ChannelMockOptions::random(),
239        );
240
241        let (channel1, channel2) = TransportChannel::new_pair(
242            1,
243            "in-memory-channel-1",
244            (Box::new(channel1), Box::new(channel2)),
245            4_096,
246        );
247
248        test_async_stream(
249            channel1,
250            channel2,
251            TestOptions::random()
252                .with_data_len(test_data_size),
253        ).await;
254    }
255
256    #[rstest]
257    #[case(random_number(6..=8))]
258    #[case(random_number(12..=16))]
259    #[case(random_number(25..=32))]
260    #[case(random_number(53..=64))]
261    #[case(random_number(100..=128))]
262    #[case(random_number(200..=256))]
263    #[tokio::test]
264    async fn transfers_stream_data(
265        #[case] items_count: usize,
266    ) {
267        let (channel1, channel2) = channel_mock_pair(
268            ChannelMockOptions::random(),
269            ChannelMockOptions::random(),
270        );
271
272        let (channel1, channel2) = TransportChannel::new_pair(
273            1,
274            "in-memory-channel-1",
275            (Box::new(channel1), Box::new(channel2)),
276            4_096,
277        );
278
279        let channel1 = create_framed_stream::<TestStreamMessage, _>(channel1);
280        let channel2 = create_framed_stream::<TestStreamMessage, _>(channel2);
281
282        test_framed_stream(
283            channel1,
284            channel2,
285            TestOptions::random()
286                .with_data_len(items_count),
287        ).await;
288    }
289
290    #[rstest]
291    #[case(128)]
292    #[case(256)]
293    #[case(512)]
294    #[case(1_024)]
295    #[case(2_048)]
296    #[tokio::test]
297    async fn reads_to_end_if_self_shutdown(
298        #[case] test_data_size: usize,
299    ) {
300
301        let (channel1, channel2) = channel_mock_pair(
302            ChannelMockOptions::random(),
303            ChannelMockOptions::random(),
304        );
305
306        let (channel1, channel2) = TransportChannel::new_pair(
307            1,
308            "in-memory-channel-1",
309            (Box::new(channel1), Box::new(channel2)),
310            4_096,
311        );
312
313        let (channel1, mut channel2) = test_async_stream(
314            channel1,
315            channel2,
316            TestOptions::random()
317                .with_data_len(test_data_size),
318        ).await;
319
320        wait_random(25..=50).await;
321
322        let test_data = random_str_rg(8..=32);
323
324        channel2.write(test_data.as_bytes()).await.unwrap();
325
326        let (mut source, mut sink) = tokio::io::split(channel1);
327
328        tokio::join!(
329            Box::pin(async move {
330                wait_random(0..=5).await;
331
332                let mut buf = vec![];
333
334                let bytes_read = source.read_to_end(&mut buf).await
335                    .expect("Cannot read to end.");
336
337                assert_eq!(
338                    bytes_read,
339                    test_data.len(),
340                    "Closed channel must read {} bytes.",
341                    test_data.len(),
342                );
343            }),
344            Box::pin(async move {
345                wait_random(0..=5).await;
346
347                sink.shutdown().await.unwrap();
348            }),
349        );
350
351        assert!(!channel2.is_closed(), "Channel2 must not be closed.");
352    }
353
354    #[rstest]
355    #[case(128)]
356    #[case(256)]
357    #[case(512)]
358    #[case(1_024)]
359    #[case(2_048)]
360    #[tokio::test]
361    async fn reads_if_self_shutdown(
362        #[case] test_data_size: usize,
363    ) {
364
365        let (channel1, channel2) = channel_mock_pair(
366            ChannelMockOptions::random(),
367            ChannelMockOptions::random(),
368        );
369
370        let (channel1, channel2) = TransportChannel::new_pair(
371            1,
372            "in-memory-channel-1",
373            (Box::new(channel1), Box::new(channel2)),
374            4_096,
375        );
376
377        let (channel1, mut channel2) = test_async_stream(
378            channel1,
379            channel2,
380            TestOptions::random()
381                .with_data_len(test_data_size),
382        ).await;
383
384        wait_random(25..=50).await;
385
386        let test_data = random_str_rg(8..=32);
387
388        channel2.write(test_data.as_bytes()).await.unwrap();
389
390        let (mut source, mut sink) = tokio::io::split(channel1);
391
392        tokio::join!(
393            Box::pin(async move {
394                wait_random(0..=5).await;
395
396                let mut buf = [0; 1024];
397
398                let bytes_read = source.read(&mut buf).await
399                    .expect("Cannot read to end.");
400
401                assert_eq!(
402                    bytes_read,
403                    test_data.len(),
404                    "Closed channel must read {} bytes.",
405                    test_data.len(),
406                );
407            }),
408            Box::pin(async move {
409                wait_random(0..=5).await;
410
411                sink.shutdown().await.unwrap();
412            }),
413        );
414
415        assert!(!channel2.is_closed(), "Channel2 must not be closed.");
416    }
417
418    #[rstest]
419    #[case(random_number(6..=8))]
420    #[case(random_number(12..=16))]
421    #[case(random_number(25..=32))]
422    #[case(random_number(53..=64))]
423    #[case(random_number(100..=128))]
424    #[case(random_number(200..=256))]
425    #[tokio::test]
426    async fn closes_stream_if_self_is_closed(
427        #[case] items_count: u32,
428    ) {
429        let (channel1, channel2) = channel_mock_pair(
430            ChannelMockOptions::random(),
431            ChannelMockOptions::random(),
432        );
433
434        let (channel1, channel2) = TransportChannel::new_pair(
435            1,
436            "in-memory-channel-1",
437            (Box::new(channel1), Box::new(channel2)),
438            4_096,
439        );
440
441        let channel1 = create_framed_stream::<TestStreamMessage, _>(channel1);
442        let channel2 = create_framed_stream::<TestStreamMessage, _>(channel2);
443
444        let (channel1, mut channel2) = test_framed_stream(
445            channel1,
446            channel2,
447            TestOptions::random()
448                .with_data_len(10),
449        ).await;
450
451        let (mut sink, mut source) = channel1.split();
452
453        let test_messages = random_vec::<TestStreamMessage>(items_count);
454        let messages_to_send = test_messages.clone();
455        let mut received_messages = vec![];
456
457        tokio::join!(
458            Box::pin(async move {
459                while let Some(message) = source.next().await {
460                    received_messages.push(message);
461                }
462            }),
463            Box::pin(async move {
464                for message in messages_to_send {
465                    channel2.send(message).await.unwrap();
466                }
467
468                sink.close().await.unwrap();
469            }),
470        );
471    }
472
473    #[rstest]
474    #[case(random_number(6..=8))]
475    #[case(random_number(12..=16))]
476    #[case(random_number(25..=32))]
477    #[case(random_number(53..=64))]
478    #[case(random_number(100..=128))]
479    #[case(random_number(200..=256))]
480    #[tokio::test]
481    async fn closes_stream_if_remote_counterpart_is_closed(
482        #[case] items_count: u32,
483    ) {
484        let (channel1, channel2) = channel_mock_pair(
485            ChannelMockOptions::random(),
486            ChannelMockOptions::random(),
487        );
488
489        let (channel1, channel2) = TransportChannel::new_pair(
490            1,
491            "in-memory-channel-1",
492            (Box::new(channel1), Box::new(channel2)),
493            4_096,
494        );
495
496        let channel1 = create_framed_stream::<TestStreamMessage, _>(channel1);
497        let channel2 = create_framed_stream::<TestStreamMessage, _>(channel2);
498
499        let (mut channel1, mut channel2) = test_framed_stream(
500            channel1,
501            channel2,
502            TestOptions::random()
503                .with_data_len(10),
504        ).await;
505
506        let test_messages = random_vec::<TestStreamMessage>(items_count);
507        let messages_to_send = test_messages.clone();
508        let mut received_messages = vec![];
509
510        tokio::join!(
511            Box::pin(async move {
512                while let Some(message) = channel1.next().await {
513                    received_messages.push(message);
514                }
515
516                assert!(channel1.get_ref().is_closed(), "Channel must be closed.");
517            }),
518            Box::pin(async move {
519                for message in messages_to_send {
520                    channel2.send(message).await.unwrap();
521                }
522
523                channel2.close().await.unwrap();
524            }),
525        );
526    }
527
528    #[rstest]
529    #[case(128)]
530    #[case(256)]
531    #[case(512)]
532    #[case(1_024)]
533    #[case(2_048)]
534    #[tokio::test]
535    async fn reads_to_end_if_remote_counterpart_is_closed(
536        #[case] test_data_size: usize,
537    ) {
538
539        let (channel1, channel2) = channel_mock_pair(
540            ChannelMockOptions::random(),
541            ChannelMockOptions::random(),
542        );
543
544        let (channel1, channel2) = TransportChannel::new_pair(
545            1,
546            "in-memory-channel-1",
547            (Box::new(channel1), Box::new(channel2)),
548            4_096,
549        );
550
551        let (mut channel1, mut channel2) = test_async_stream(
552            channel1,
553            channel2,
554            TestOptions::random()
555                .with_data_len(test_data_size),
556        ).await;
557
558        let test_data = random_str_rg(8..=32);
559
560        channel2.write(test_data.as_bytes()).await.unwrap();
561
562        tokio::join!(
563            Box::pin(async move {
564                wait_random(0..=5).await;
565
566                let mut buf = vec![];
567
568                let bytes_read = channel1.read_to_end(&mut buf).await
569                    .expect("Cannot read to end.");
570
571                assert_eq!(
572                    bytes_read,
573                    test_data.len(),
574                    "Closed channel must read {} bytes.",
575                    test_data.len(),
576                );
577
578                assert!(
579                    channel1.is_closed(),
580                    "Channel must be closed after remote counterpart is closed.",
581                );
582            }),
583            Box::pin(async move {
584                wait_random(0..=5).await;
585
586                channel2.shutdown().await.unwrap();
587            }),
588        );
589    }
590
591    #[rstest]
592    #[case(128)]
593    #[case(256)]
594    #[case(512)]
595    #[case(1_024)]
596    #[case(2_048)]
597    #[tokio::test]
598    async fn reads_if_remote_counterpart_is_closed(
599        #[case] test_data_size: usize,
600    ) {
601
602        let (channel1, channel2) = channel_mock_pair(
603            ChannelMockOptions::random(),
604            ChannelMockOptions::random(),
605        );
606
607        let (channel1, channel2) = TransportChannel::new_pair(
608            1,
609            "in-memory-channel-1",
610            (Box::new(channel1), Box::new(channel2)),
611            4_096,
612        );
613
614        let (mut channel1, mut channel2) = test_async_stream(
615            channel1,
616            channel2,
617            TestOptions::random()
618                .with_data_len(test_data_size),
619        ).await;
620
621        let test_data = random_str_rg(8..=32);
622
623        channel2.write(test_data.as_bytes()).await.unwrap();
624
625        channel2.shutdown().await.unwrap();
626
627        assert!(
628            channel2.is_closed(),
629            "Channel2 must be closed.",
630        );
631
632        wait_random(3..=5).await;
633
634        let mut buf = [0; 1024];
635
636        let bytes_read = channel1.read(&mut buf).await
637            .expect("Cannot read to end.");
638
639        assert_eq!(
640            bytes_read,
641            test_data.len(),
642            "Closed channel must read {} bytes.",
643            test_data.len(),
644        );
645
646        let bytes_read = channel1.read(&mut buf).await
647            .expect("Cannot read to end.");
648
649        assert_eq!(
650            bytes_read,
651            0,
652            "Closed channel must read 0 bytes.",
653        );
654
655        assert!(
656            channel1.is_closed(),
657            "Channel must be closed after remote counterpart is closed.",
658        );
659    }
660
661    #[rstest]
662    #[case(128)]
663    #[case(256)]
664    #[case(512)]
665    #[case(1_024)]
666    #[case(2_048)]
667    #[tokio::test]
668    async fn fails_to_write_if_remote_counterpart_is_closed(
669        #[case] test_data_size: usize,
670    ) {
671
672        let (channel1, channel2) = channel_mock_pair(
673            ChannelMockOptions::random(),
674            ChannelMockOptions::random(),
675        );
676
677        let (channel1, channel2) = TransportChannel::new_pair(
678            1,
679            "in-memory-channel-1",
680            (Box::new(channel1), Box::new(channel2)),
681            4_096,
682        );
683
684        let (mut channel1, mut channel2) = test_async_stream(
685            channel1,
686            channel2,
687            TestOptions::random()
688                .with_data_len(test_data_size),
689        ).await;
690
691        channel2.shutdown().await.unwrap();
692
693        assert!(
694            channel2.write(b"anything").await.is_err(),
695            "Must fail to write to closed channel.",
696        );
697
698        assert!(
699            channel2.is_closed(),
700            "Channel2 must be closed.",
701        );
702
703        wait_random(3..=5).await;
704
705        let test_data = random_str_rg(24..=32);
706        let bytes_written = channel1.write(test_data.as_bytes()).await
707            .expect("Cannot write to channel.");
708
709        assert_eq!(
710            bytes_written,
711            0,
712            "Must write 0 bytes if remote channel is closed.",
713        );
714    }
715
716    #[rstest]
717    #[case(128)]
718    #[case(256)]
719    #[case(512)]
720    #[case(1_024)]
721    #[case(2_048)]
722    #[tokio::test]
723    async fn fails_to_write_if_self_is_closed(
724        #[case] test_data_size: usize,
725    ) {
726
727        let (channel1, channel2) = channel_mock_pair(
728            ChannelMockOptions::random(),
729            ChannelMockOptions::random(),
730        );
731
732        let (channel1, channel2) = TransportChannel::new_pair(
733            1,
734            "in-memory-channel-1",
735            (Box::new(channel1), Box::new(channel2)),
736            4_096,
737        );
738
739        let (channel1, mut channel2) = test_async_stream(
740            channel1,
741            channel2,
742            TestOptions::random()
743                .with_data_len(test_data_size),
744        ).await;
745
746        let (mut source, mut sink) = tokio::io::split(channel1);
747
748        let test_data = random_str_rg(24..=32);
749
750        channel2.write(test_data.as_bytes()).await
751            .expect("Cannot write data.");
752
753        sink.shutdown().await.unwrap();
754
755        assert!(
756            sink.write(b"something").await.is_err(),
757            "Must fail to write to closed channel.",
758        );
759
760        let mut buf = vec![];
761        let bytes_received = source.read_to_end(&mut buf).await
762            .expect("Cannot read data.");
763
764        assert_eq!(
765            bytes_received,
766            test_data.len(),
767            "Must be able to read to end if channel is closed.",
768        );
769
770        let channel1 = source.unsplit(sink);
771
772        assert!(
773            channel1.is_closed(),
774            "Channel must be closed.",
775        );
776    }
777}