connection_utils/connections/
transport_connection.rs

1use std::{sync::Arc, collections::HashMap, pin::Pin, fmt};
2
3use cs_utils::random_number;
4use cs_trace::{Tracer, child};
5use tokio_util::codec::Framed;
6use anyhow::{Result, anyhow, bail};
7use serde::{Serialize, Deserialize};
8use futures::{StreamExt, stream::{SplitStream, SplitSink}, SinkExt};
9use tokio::{sync::{mpsc::{Sender, self}, Mutex, Notify, watch, RwLock}, io::{split, duplex, WriteHalf, AsyncReadExt, ReadHalf, AsyncWriteExt}};
10
11use crate::{Channel, create_framed_stream, TransportChannel, codecs::GenericCodec};
12
13type TChannels = Arc<Mutex<HashMap<u16, Arc<Mutex<(WriteHalf<Box<dyn Channel>>, watch::Receiver<bool>)>>>>>;
14type TChannelId = u16;
15
16#[derive(Serialize, Deserialize, Debug)]
17pub enum ControlMessage {
18    Data(TChannelId, Vec<u8>),
19    OpenChannel(TChannelId, String, u32, bool),
20    Close(TChannelId),
21    Error(TChannelId, String),
22}
23
24impl fmt::Display for ControlMessage {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        match self {
27            ControlMessage::Data(id, data) => {
28                return f.debug_tuple("ControlMessage::Data")
29                    .field(id)
30                    .field(&data.len())
31                    .finish();
32            },
33            ControlMessage::OpenChannel(id, label, buffer_size, is_response) => {
34                return f.debug_tuple("ControlMessage::OpenChannel")
35                    .field(id)
36                    .field(label)
37                    .field(buffer_size)
38                    .field(is_response)
39                    .finish();
40            },
41            ControlMessage::Close(id) => {
42                return f.debug_tuple("ControlMessage::Close")
43                    .field(id)
44                    .finish();
45            },
46            ControlMessage::Error(id, message) => {
47                return f.debug_tuple("ControlMessage::Error")
48                    .field(id)
49                    .field(message)
50                    .finish();
51            },
52        };
53    }
54}
55
56/// Send error to the remote side so the remote side can close the channel.
57async fn send_error(
58    trace: &Box<dyn Tracer>,
59    id: TChannelId,
60    message: String,
61    message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
62) {
63    trace.error(
64        &format!("Channel {id} read error: {:?}", message),
65    );
66
67    let result = message_sender.lock().await
68        .send(ControlMessage::Error(id, message)).await;
69
70    if let Err(error) = result {
71        trace.error(
72            &format!("Failed to send channel error to the remote side: {:?}", error),
73        );
74    }
75}
76
77/// Read from a local binary channel and send the read data to the remote side
78/// over the control channel stream.
79async fn forward_channel_data(
80    trace: Box<dyn Tracer>,
81    id: u16,
82    mut reader: ReadHalf<Box<dyn Channel>>,
83    message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
84    on_close: watch::Receiver<bool>,
85    channels: TChannels,
86    buffer_size: u32,
87) {
88    let mut buf = vec![];
89    buf.resize(buffer_size as usize, 0);
90
91    loop {
92        let is_closed = *on_close.borrow();
93
94        let bytes_read = match reader.read(buf.as_mut_slice()).await {
95            Ok(number) => number,
96            Err(error) => {
97                send_error(&trace, id, format!("{error}"), message_sender).await;
98                return;
99            },
100        };
101
102        let data = (&buf[..bytes_read]).to_vec();
103        let result = {
104            message_sender
105                .lock().await
106                .send(ControlMessage::Data(id, data)).await
107        };
108
109        if let Err(error) = result {
110            send_error(&trace, id, format!("{}", error), message_sender).await;
111            return;
112        };
113
114        if bytes_read == 0 {
115            trace.warn(
116                &format!("got EOF, sending channel close message"),
117            );
118
119            let close_message_result = {
120                message_sender.lock().await
121                    .send(ControlMessage::Close(id)).await
122            };
123
124            if let Err(error) = close_channel(id, channels).await {
125                trace.error(
126                    &format!("failed to close local channel: {:?}", error),
127                );
128            };
129
130            trace.info(
131                &format!("channel is closed by EOF"),
132            );
133
134            if let Err(error) = close_message_result {
135                send_error(&trace, id, format!("{}", error), message_sender).await;
136                return;
137            };
138
139            return;
140        }
141
142        if is_closed {
143            trace.info("channel is closed by notification");
144
145            return;
146        }
147    }
148}
149
150/// Send channel data from the control stream to the local binary channel.
151async fn send_channel_data(
152    trace: Box<dyn Tracer>,
153    id: u16,
154    data: Vec<u8>,
155    channels: TChannels,
156    message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
157) {
158    let channel = {
159        let lock = channels.lock().await;
160
161        let channel = match lock.get(&id) {
162            Some(writer) => writer,
163            None => {
164                send_error(&trace, id, format!("No channel with ID {:?} found.", id), message_sender).await;
165
166                return;
167            },
168        };
169
170        Arc::clone(channel)
171    };
172
173    let (writer, on_close) = &mut *channel.lock().await;
174
175    let is_closed = *on_close.borrow();
176
177    if data.len() == 0 && is_closed {
178        trace.warn(
179            &format!("channel {id} already closed, skip writing"),
180        );
181
182        return;
183    }
184
185    if let Err(error) = writer.write_all(&data[..]).await {
186        send_error(&trace, id, format!("{}", error), Arc::clone(&message_sender)).await;
187    }
188}
189
190/// Close a local binary channel and remove its reference from the channels map.
191async fn close_channel(
192    id: u16,
193    channels: TChannels,
194) -> Result<()> {
195    let mut lock = channels.lock().await;
196
197    let channel = {
198        let channel = match lock.remove(&id) {
199            Some(writer) => writer,
200            None => bail!("No channel found with ID {}.", id),
201        };
202
203        channel
204    };
205
206    let (writer, on_close) = &mut *channel.lock().await;
207
208    if *on_close.borrow() {
209        // already closed
210        return Ok(());
211    }
212
213    writer.shutdown().await?;
214
215    return Ok(());
216}
217
218/// Add a local binary channel and start data forwarding job.
219async fn add_local_channel(
220    trace: Box<dyn Tracer>,
221    id: u16,
222    label: String,
223    buffer_size: u32,
224    channels: TChannels,
225    message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
226) -> Result<Box<dyn Channel>> {
227    let (duplex1, duplex2) = duplex(buffer_size as usize);
228    let (channel1, channel2) = TransportChannel::new_pair(
229        id,
230        label.clone(),
231        (Box::new(duplex1), Box::new(duplex2)),
232        buffer_size,
233    );
234
235    let on_close1 = channel1.on_close();
236    let on_close2 = channel1.on_close();
237
238    let (reader, writer) = split(channel1);
239
240    let trace2 = &trace;
241    let trace2 = child!(trace2, "forward-channel-data");
242
243    tokio::spawn(forward_channel_data(
244        trace2,
245        id,
246        reader,
247        Arc::clone(&message_sender),
248        on_close1,
249        Arc::clone(&channels),
250        buffer_size,
251    ));
252
253    channels
254        .lock().await
255        .insert(id, Arc::new(Mutex::new((writer, on_close2))));
256
257    trace.info(
258        &format!("local channel opened: {}, {}", id, label),
259    );
260
261    return Ok(channel2);
262}
263
264/// Open a transport channel either from a local `channel()` request or a remote one.
265async fn open_channel(
266    trace: Box<dyn Tracer>,
267    id: u16,
268    label: String,
269    buffer_size: u32,
270    is_response: bool,
271    channels: TChannels,
272    open_channel_requests: Arc<RwLock<HashMap<u16, Arc<Notify>>>>,
273    message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
274    on_remote_channel: mpsc::Sender<Box<dyn Channel>>,
275) -> Result<()> {
276    if is_response {
277        let read_lock = open_channel_requests.read().await;
278        let notify = match read_lock.get(&id) {
279            None => bail!("No open channel notifier found."),
280            Some(notify) => notify,
281        };
282
283        notify.notify_waiters();
284
285        return Ok(());
286    }
287
288    let trace1 = &trace;
289    let trace1 = child!(trace1, "add-local-channel");
290
291    trace.trace("sending open channel response");
292
293    let channel = add_local_channel(
294        trace1,
295        id,
296        label.clone(),
297        buffer_size,
298        channels,
299        Arc::clone(&message_sender),
300    ).await?;
301
302    {
303        message_sender
304            .lock().await
305            .send(ControlMessage::OpenChannel(id, label.clone(), buffer_size, true)).await?;
306    }
307
308    trace.trace("sent");
309
310    on_remote_channel
311        .send(channel).await
312        .map_err(|error| {
313            return anyhow!("{}", error);
314        })?;
315
316
317
318    return Ok(());
319}
320
321/// Receive the control stream messages and invoke an appropriate handler for a message.
322async fn handle_control_messages(
323    trace: Box<dyn Tracer>,
324    mut stream_source: SplitStream<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>>,
325    channels: TChannels,
326    open_channel_requests: Arc<RwLock<HashMap<u16, Arc<Notify>>>>,
327    message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
328    on_remote_channel: mpsc::Sender<Box<dyn Channel>>,
329) -> Result<()> {
330    while let Some(maybe_message) = stream_source.next().await {
331        let message = maybe_message?;
332
333        trace.warn(
334            &format!("got control message: {}", message),
335        );
336
337        match message {
338            ControlMessage::Data(id, data) => {
339                let trace = &trace;
340                let trace = child!(trace, "send-channel-data");
341
342                tokio::spawn(send_channel_data(
343                    trace,
344                    id,
345                    data,
346                    Arc::clone(&channels),
347                    Arc::clone(&message_sender),
348                ));
349            },
350            ControlMessage::OpenChannel(id, label, buffer_size, is_response) => {
351                let trace = &trace;
352                let trace = child!(trace, "open-channel");
353
354                open_channel(
355                    trace,
356                    id,
357                    label,
358                    buffer_size,
359                    is_response,
360                    Arc::clone(&channels),
361                    Arc::clone(&open_channel_requests),
362                    Arc::clone(&message_sender),
363                    Sender::clone(&on_remote_channel),
364                ).await?;
365            },
366            ControlMessage::Close(id) => {
367                tokio::spawn(close_channel(id, Arc::clone(&channels)));
368            },
369            ControlMessage::Error(id, message) => {
370                trace.error(
371                    &format!("remote channel {id} error: {:?}", message),
372                );
373
374                tokio::spawn(close_channel(id, Arc::clone(&channels)));
375            },
376        };
377    }
378
379    return Ok(());
380}
381
382pub struct TransportConnection {
383    trace: Box<dyn Tracer>,
384    message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
385    open_channel_requests: Arc<RwLock<HashMap<u16, Arc<Notify>>>>,
386    channels: TChannels,
387    on_remote_channel: Option<mpsc::Receiver<Box<dyn Channel>>>,
388}
389
390impl TransportConnection {
391    pub fn new(
392        trace: &Box<dyn Tracer>,
393        channel: Box<dyn Channel>,
394    ) -> Box<TransportConnection> {
395        let trace = child!(trace, "transport-channel");
396
397        let stream = create_framed_stream(channel);
398
399        let (channel_sink, channel_source) = stream.split();
400
401        let message_sender = Arc::new(Mutex::new(channel_sink));
402
403        let open_channel_requests = Arc::new(RwLock::new(HashMap::new()));
404        let channels = Arc::new(Mutex::new(HashMap::new()));
405        
406        let (on_remote_channel_sender, on_remote_channel) = mpsc::channel(25);
407
408        let trace2 = &trace;
409        let trace2 = child!(trace2, "control-messages-handler");
410        
411        tokio::spawn(handle_control_messages(
412            trace2,
413            channel_source,
414            Arc::clone(&channels),
415            Arc::clone(&open_channel_requests),
416            Arc::clone(&message_sender),
417            on_remote_channel_sender,
418        ));
419
420        return Box::new(TransportConnection {
421            trace,
422            message_sender,
423            open_channel_requests,
424            channels,
425            on_remote_channel: Some(on_remote_channel),
426        });
427    }
428
429    pub fn on_remote_channel(&mut self) -> Result<mpsc::Receiver<Box<dyn Channel>>> {
430        match self.on_remote_channel.take() {
431            Some(on_remote_channel) => return Ok(on_remote_channel),
432            None => bail!("No on_remote_channel found."),
433        };
434    }
435
436    pub fn off_remote_channel(
437        &mut self,
438        on_channel: mpsc::Receiver<Box<dyn Channel>>,
439    ) -> Result<()> {
440        if let Some(_) = self.on_remote_channel {
441            bail!("on_remote_channel already set.");
442        }
443
444        self.on_remote_channel.replace(on_channel);
445        return Ok(());
446    }
447
448    pub async fn channel(
449        &mut self,
450        label: impl AsRef<str> + ToString,
451        buffer_size: u32,
452    ) -> Result<Box<dyn Channel>> {
453        let id = random_number(0..=u16::MAX);
454        let label = label.to_string();
455
456        self.trace.trace(
457            &format!("creating channel, ID: {}, label: {}", id, label),
458        );
459
460        let notify = Arc::new(Notify::new());
461
462        {
463            self.open_channel_requests
464                .write().await
465                .insert(id, Arc::clone(&notify));
466        }
467
468        self.trace.trace(
469            &format!("sending open channel request"),
470        );
471
472        {
473            self.message_sender
474                .lock().await
475                .send(ControlMessage::OpenChannel(id, label.clone(), buffer_size, false)).await?;
476        }
477
478        self.trace.trace(
479            &format!("open channel request sent"),
480        );
481
482        notify.notified().await;
483
484        self.trace.trace(
485            &format!("got open channel response"),
486        );
487
488        let trace2 = &self.trace;
489        let trace2 = child!(trace2, "add-local-channel");
490
491        let channel = add_local_channel(
492            trace2,
493            id,
494            label,
495            buffer_size,
496            Arc::clone(&self.channels),
497            Arc::clone(&self.message_sender),
498        ).await?;
499
500        self.trace.trace(
501            &format!("channel created: {}, {}", channel.id(), channel.label()),
502        );
503
504        return Ok(channel);
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use std::{collections::HashMap, sync::Arc};
511
512    use rstest::rstest;
513    use futures::StreamExt;
514    use cs_trace::create_trace;
515    use cs_utils::{random_str, random_str_rg, random_number, traits::Random, futures::wait_random};
516    use tokio::{sync::{Mutex, mpsc, watch, RwLock, Notify}, io::{split, AsyncWriteExt, AsyncReadExt}};
517    
518    use crate::test::TestOptions;
519    use crate::create_framed_stream;
520    use crate::{TransportChannel, TransportConnection};
521    use crate::{connections::transport_connection::{send_channel_data, forward_channel_data, close_channel, add_local_channel, open_channel, ControlMessage}, mocks::{ChannelMockOptions, channel_mock_pair}};
522
523    mod send_channel_data {
524        use super::*;
525        
526        #[rstest]
527        #[case(128)]
528        #[case(256)]
529        #[case(512)]
530        #[case(1_024)]
531        #[case(2_048)]
532        #[case(4_096)]
533        #[tokio::test]
534        async fn sends_data_to_channel(
535            #[case] data_len: usize,
536        ) {
537            let trace = create_trace!("test");
538
539            let buffer_size = 4_096;
540            let id = random_number(0..=u16::MAX);
541            let data = random_str(data_len).as_bytes().to_vec();
542
543            let options1 = ChannelMockOptions::random()
544                .with_buffer_size(buffer_size);
545            let options2 = ChannelMockOptions::random()
546                .with_buffer_size(buffer_size);
547
548            let (channel1, mut channel2) = channel_mock_pair(options1, options2);
549
550            let on_close = channel1.on_close();
551
552            let (_reader, writer) = split(channel1);
553
554            let mut channels = HashMap::new();
555            channels.insert(id, Arc::new(Mutex::new((writer, on_close))));
556
557            let channels = Arc::new(Mutex::new(channels));
558
559            let data_to_send = data.clone();
560            let data_to_receive = data.clone();
561
562            let options1 = ChannelMockOptions::random()
563                .with_buffer_size(buffer_size);
564            let options2 = ChannelMockOptions::random()
565                .with_buffer_size(buffer_size);
566
567            let (stream1, _stream2) = TransportChannel::new_pair(
568                id,
569                "control-stream",
570                channel_mock_pair(options1, options2),
571                buffer_size,
572            );
573
574            let (stream1_tx, _stream1_rx) = create_framed_stream(stream1).split();
575            let control_channel_sender = Arc::new(Mutex::new(stream1_tx));
576
577            tokio::join!(
578                Box::pin(async move {
579                    wait_random(1..=5).await;
580                    
581                    send_channel_data(
582                        trace,
583                        id,
584                        data_to_send,
585                        channels,
586                        control_channel_sender,
587                    ).await;
588                }),
589                Box::pin(async move {
590                    wait_random(1..=5).await;
591
592                    let mut buf = [0; 4_096];
593
594                    let bytes_read = channel2
595                        .read(&mut buf).await
596                        .unwrap();
597
598                    let received_data = &buf[..bytes_read];
599
600                    assert_eq!(
601                        received_data,
602                        &data_to_receive[..],
603                        "Must receive correct data.",
604                    );
605                }),
606            );
607        }
608
609        #[rstest]
610        #[tokio::test]
611        async fn does_not_send_if_already_closed() {
612            let trace = create_trace!("test");
613
614            let buffer_size = 4_096;
615            let id = random_number(0..=u16::MAX);
616
617            let options1 = ChannelMockOptions::random()
618                .with_buffer_size(buffer_size);
619            let options2 = ChannelMockOptions::random()
620                .with_buffer_size(buffer_size);
621
622            let (mut channel1, mut channel2) = channel_mock_pair(options1, options2);
623
624            let on_close = channel1.on_close();
625
626            channel1
627                .shutdown().await
628                .unwrap();
629
630            let (_reader, writer) = split(channel1);
631
632            let mut channels = HashMap::new();
633            channels.insert(id, Arc::new(Mutex::new((writer, on_close))));
634
635            let channels = Arc::new(Mutex::new(channels));
636
637            let options1 = ChannelMockOptions::random()
638                .with_buffer_size(buffer_size);
639            let options2 = ChannelMockOptions::random()
640                .with_buffer_size(buffer_size);
641
642            let (stream1, _stream2) = TransportChannel::new_pair(
643                id,
644                "control-stream",
645                channel_mock_pair(options1, options2),
646                buffer_size,
647            );
648
649            let (stream1_tx, _stream1_rx) = create_framed_stream(stream1).split();
650            let control_channel_sender = Arc::new(Mutex::new(stream1_tx));
651
652            tokio::join!(
653                Box::pin(async move {
654                    wait_random(1..=5).await;
655                    
656                    send_channel_data(
657                        trace,
658                        id,
659                        vec![],
660                        channels,
661                        control_channel_sender,
662                    ).await;
663                }),
664                Box::pin(async move {
665                    wait_random(1..=5).await;
666
667                    let mut buf = [0; 4_096];
668
669                    let bytes_read = channel2
670                        .read(&mut buf).await
671                        .unwrap();
672
673                    assert_eq!(
674                        bytes_read,
675                        0,
676                        "Must 0 bytes.",
677                    );
678                }),
679            );
680        }
681
682        #[rstest]
683        #[case(128)]
684        #[case(256)]
685        #[case(512)]
686        #[case(1_024)]
687        #[case(2_048)]
688        #[case(4_096)]
689        #[tokio::test]
690        async fn fails_if_no_channel_found(
691            #[case] data_len: usize,
692        ) {
693            let trace = create_trace!("test");
694
695            let buffer_size = 4_096;
696            let id = random_number(0..=u16::MAX);
697            let data = random_str(data_len).as_bytes().to_vec();
698
699            let options1 = ChannelMockOptions::random()
700                .with_buffer_size(buffer_size);
701            let options2 = ChannelMockOptions::random()
702                .with_buffer_size(buffer_size);
703
704            let (channel1, _channel2) = channel_mock_pair(options1, options2);
705
706            let on_close = channel1.on_close();
707
708            let (_reader, writer) = split(channel1);
709
710            let mut channels = HashMap::new();
711
712            let another_id = {
713                let mut another_id = random_number(0..=u16::MAX);
714
715                while another_id == id {
716                    another_id = random_number(0..=u16::MAX);
717                }
718
719                another_id
720            };
721
722            channels.insert(another_id, Arc::new(Mutex::new((writer, on_close))));
723
724            let channels = Arc::new(Mutex::new(channels));
725
726            let options1 = ChannelMockOptions::random()
727                .with_buffer_size(buffer_size);
728            let options2 = ChannelMockOptions::random()
729                .with_buffer_size(buffer_size);
730
731            let (stream1, stream2) = TransportChannel::new_pair(
732                id,
733                "control-stream",
734                channel_mock_pair(options1, options2),
735                buffer_size,
736            );
737
738            let (stream1_tx, _stream1_rx) = create_framed_stream(stream1).split();
739            let (_stream2_tx, mut stream2_rx) = create_framed_stream(stream2).split();
740
741            tokio::try_join!(
742                tokio::spawn(send_channel_data(
743                    trace,
744                    id,
745                    data,
746                    channels,
747                    Arc::new(Mutex::new(stream1_tx),
748                ))),
749                tokio::spawn(async move {
750                    let message = stream2_rx.next().await.unwrap().unwrap();
751
752                    match message {
753                        ControlMessage::Error(received_id, error_message) => {
754                            assert_eq!(
755                                received_id,
756                                id,
757                                "Must receive error with correct id.",
758                            );
759
760                            assert!(
761                                error_message.len() > 3,
762                                "Received error message must be not empty.",
763                            );
764                        },
765                        unexpected @ _ => panic!("Unexpected message: {:?}.", unexpected),
766                    };
767                }),
768            ).unwrap();
769        }
770    }
771
772    mod handle_channel_reads {
773        use crate::TransportChannel;
774
775        use super::*;
776        
777        #[rstest]
778        #[case(512)]
779        #[case(1_024)]
780        #[case(2_048)]
781        #[case(4_096)]
782        #[case(8_192)]
783        #[case(16_384)]
784        #[tokio::test]
785        async fn reads_from_a_local_channel(
786            #[case] data_len: usize,
787        ) {
788            let trace = cs_trace::create_trace!("test");
789
790            let buffer_size: u32 = 4_096;
791
792            let id = random_number(0..=u16::MAX);
793            let data = random_str(data_len)
794                .as_bytes().to_vec();
795
796            let options1 = ChannelMockOptions::random()
797                .with_buffer_size(buffer_size);
798            let options2 = ChannelMockOptions::random()
799                .with_buffer_size(buffer_size);
800
801            let (channel1, mut channel2) = TransportChannel::new_pair(
802                id,
803                "transport-channel",
804                channel_mock_pair(options1, options2),
805                buffer_size,
806            );
807
808            let on_close = channel1.on_close();
809
810            let (reader, _writer) = split(channel1);
811
812            let options1 = ChannelMockOptions::random()
813                .with_buffer_size(buffer_size);
814            let options2 = ChannelMockOptions::random()
815                .with_buffer_size(buffer_size);
816
817            let (stream1, stream2) = TransportChannel::new_pair(
818                id,
819                "control-stream",
820                channel_mock_pair(options1, options2),
821                buffer_size,
822            );
823
824            let stream1 = create_framed_stream(stream1);
825            let stream2 = create_framed_stream(stream2);
826
827            let (stream1_tx, _stream1_rx) = stream1.split();
828            let (_stream2_tx, mut control_channel_receiver) = stream2.split();
829
830            let control_channel_sender = Arc::new(Mutex::new(stream1_tx));
831
832            let data_to_send = data.clone();
833            let data_to_receive = data.clone();
834
835            let channels = Arc::new(Mutex::new(HashMap::new()));
836
837            let channels1 = Arc::clone(&channels);
838            let channels2 = Arc::clone(&channels);
839
840            tokio::join!(
841                Box::pin(async move {
842                    wait_random(1..=5).await;
843                    
844                    forward_channel_data(
845                        trace,
846                        id,
847                        reader,
848                        control_channel_sender,
849                        on_close,
850                        channels1,
851                        buffer_size,
852                    ).await;
853                }),
854                Box::pin(async move {
855                    wait_random(1..=5).await;
856
857                    let mut total_written = 0;
858                    while total_written < data_to_send.len() {
859                        let written = channel2
860                            .write(&data_to_send[total_written..]).await
861                            .unwrap();
862
863                        total_written += written;
864                    }
865
866                    assert!(
867                        !channels2.lock().await.contains_key(&id),
868                        "Channel must be deleted.",
869                    );
870                }),
871                Box::pin(async move {
872                    wait_random(1..=5).await;
873
874                    let mut received_data = vec![];
875
876                    while let Some(maybe_message) = control_channel_receiver.next().await {
877                        let message = maybe_message.unwrap();
878                        let (received_id, data) = match message {
879                            ControlMessage::Data(id, data) => (id, data),
880                            ControlMessage::Close(received_id) => {
881                                assert_eq!(
882                                    received_id,
883                                    id,
884                                    "Message must have correct channel ID.",
885                                );
886
887                                break;
888                            },
889                            other @ _ => panic!("Unexpected message: {:?}", other),
890                        };
891
892                        assert_eq!(
893                            received_id,
894                            id,
895                            "Message must have correct channel ID.",
896                        );
897
898                        received_data.extend_from_slice(&data[..]);
899                    }
900
901                    assert_eq!(
902                        received_data,
903                        data_to_receive,
904                        "Must receive correct data.",
905                    );
906                }),
907            );
908        }
909    }
910    
911    mod close_channel {
912        use crate::TransportChannel;
913
914        use super::*;
915        
916        #[rstest]
917        #[case(())]
918        #[case(())]
919        #[case(())]
920        #[case(())]
921        #[case(())]
922        #[case(())]
923        #[tokio::test]
924        async fn shutsdown_a_channel_and_removes_reference(
925            #[case] _case_num: (),
926        ) {
927            let buffer_size = 4_096;
928
929            let id = random_number(0..=u16::MAX);
930            let options1 = ChannelMockOptions::random()
931                .with_buffer_size(buffer_size);
932            let options2 = ChannelMockOptions::random()
933                .with_buffer_size(buffer_size);
934
935            let (channel1, _channel2) = TransportChannel::new_pair(
936                id,
937                "transport-channel",
938                channel_mock_pair(options1, options2),
939                buffer_size,
940            );
941
942            let on_close = channel1.on_close();
943
944            let (_reader, writer) = split(channel1);
945
946            let mut channels = HashMap::new();
947            channels.insert(id, Arc::new(Mutex::new((writer, watch::Receiver::clone(&on_close)))));
948
949            let channels = Arc::new(Mutex::new(channels));
950
951            wait_random(1..=5).await;
952
953            close_channel(
954                id,
955                Arc::clone(&channels),
956            ).await.unwrap();
957
958            {
959                assert!(
960                    !(channels.lock().await.contains_key(&id)),
961                    "Must remove channel reference from the map.",
962                );
963            }
964
965            assert!(
966                *on_close.borrow(),
967                "Must close the channel.",
968            );
969        }
970
971        #[rstest]
972        #[case(())]
973        #[case(())]
974        #[case(())]
975        #[case(())]
976        #[case(())]
977        #[case(())]
978        #[tokio::test]
979        async fn does_not_fails_if_channel_allready_closed(
980            #[case] _case_num: (),
981        ) {
982            let buffer_size = 4_096;
983
984            let id = random_number(0..=u16::MAX);
985            let options1 = ChannelMockOptions::random()
986                .with_buffer_size(buffer_size);
987            let options2 = ChannelMockOptions::random()
988                .with_buffer_size(buffer_size);
989
990            let (channel1, _channel2) = TransportChannel::new_pair(
991                id,
992                "transport-channel",
993                channel_mock_pair(options1, options2),
994                buffer_size,
995            );
996
997            let on_close = channel1.on_close();
998
999            let (_reader, mut writer) = split(channel1);
1000
1001            let mut channels = HashMap::new();
1002
1003            writer.shutdown().await
1004                .unwrap();
1005
1006            assert!(
1007                *on_close.borrow(),
1008                "Must close the channel.",
1009            );
1010
1011            channels.insert(id, Arc::new(Mutex::new((writer, watch::Receiver::clone(&on_close)))));
1012
1013            assert!(
1014                channels.contains_key(&id),
1015                "Must contain channel before test",
1016            );
1017
1018            let channels = Arc::new(Mutex::new(channels));
1019
1020            wait_random(1..=5).await;
1021
1022            close_channel(
1023                id,
1024                Arc::clone(&channels),
1025            ).await.unwrap();
1026
1027            {
1028                assert!(
1029                    !(channels.lock().await.contains_key(&id)),
1030                    "Must remove channel reference from the map.",
1031                );
1032            }
1033
1034            assert!(
1035                *on_close.borrow(),
1036                "Must close the channel.",
1037            );
1038        }
1039
1040        #[rstest]
1041        #[case(())]
1042        #[case(())]
1043        #[case(())]
1044        #[case(())]
1045        #[case(())]
1046        #[case(())]
1047        #[tokio::test]
1048        #[should_panic]
1049        async fn fails_if_no_channel_found(
1050            #[case] _case_num: (),
1051        ) {
1052            let id = random_number(0..=u16::MAX);
1053    
1054            let channels = HashMap::new();
1055            let channels = Arc::new(Mutex::new(channels));
1056
1057            wait_random(1..=5).await;
1058
1059            close_channel(
1060                id,
1061                Arc::clone(&channels),
1062            ).await.unwrap();
1063        }
1064    }
1065
1066    mod add_local_channel {
1067        use cs_trace::create_trace;
1068
1069        use crate::create_framed_stream;
1070
1071        use super::*;
1072        
1073        #[tokio::test]
1074        async fn adds_channel_to_channels_map() {
1075            let trace = create_trace!("test");
1076
1077            let buffer_size: u32 = 4_096;
1078            let id = random_number(0..=u16::MAX);
1079            let label = random_str_rg(8..=16);
1080
1081            let channels = HashMap::new();
1082            let channels = Arc::new(Mutex::new(channels));
1083
1084            {
1085                assert!(
1086                    !(channels.lock().await).contains_key(&id),
1087                    "Must not contain channel before the test.",
1088                );
1089            }
1090
1091            let options1 = ChannelMockOptions::random()
1092                .with_buffer_size(buffer_size);
1093            let options2 = ChannelMockOptions::random()
1094                .with_buffer_size(buffer_size);
1095
1096            let (stream1, _stream2) = TransportChannel::new_pair(
1097                id,
1098                "control-stream",
1099                channel_mock_pair(options1, options2),
1100                buffer_size,
1101            );
1102
1103            let stream1 = create_framed_stream(stream1);
1104
1105            let (stream1_tx, _stream1_rx) = stream1.split();
1106
1107            let control_sender = Arc::new(Mutex::new(stream1_tx));
1108
1109            add_local_channel(
1110                trace,
1111                id,
1112                label,
1113                buffer_size,
1114                Arc::clone(&channels),
1115                control_sender,
1116            ).await.unwrap();
1117
1118            {
1119                assert!(
1120                    (channels.lock().await).contains_key(&id),
1121                    "Must add channel to the map.",
1122                );
1123            }
1124        }
1125    }
1126
1127    mod open_channel {
1128        use cs_trace::create_trace;
1129
1130        use crate::create_framed_stream;
1131
1132        use super::*;
1133        
1134        #[tokio::test]
1135        async fn notifies_pending_channel_open_requests() {
1136            let trace = create_trace!("test");
1137
1138            let buffer_size: u32 = 4_096;
1139            let id = random_number(0..=u16::MAX);
1140            let label = random_str_rg(8..=16);
1141
1142            let is_response = true;
1143
1144            let channels = Arc::new(Mutex::new(HashMap::new()));
1145            let open_channel_requests = Arc::new(RwLock::new(HashMap::new()));
1146
1147            // let (control_sender, _control_receiver) = mpsc::channel(buffer_size as usize);
1148            let (on_remote_channel, _on_remote_channel_receiver) = mpsc::channel(buffer_size as usize);
1149
1150            let channel_open_notification = Arc::new(Notify::new());
1151
1152            {
1153                open_channel_requests.write().await
1154                    .insert(id, Arc::clone(&channel_open_notification));
1155            }
1156
1157            let channels1 = Arc::clone(&channels);
1158
1159            let options1 = ChannelMockOptions::random()
1160                .with_buffer_size(buffer_size);
1161            let options2 = ChannelMockOptions::random()
1162                .with_buffer_size(buffer_size);
1163
1164            let (stream1, _stream2) = TransportChannel::new_pair(
1165                id,
1166                "control-stream",
1167                channel_mock_pair(options1, options2),
1168                buffer_size,
1169            );
1170
1171            let stream1 = create_framed_stream(stream1);
1172            // let stream2 = create_framed_stream(stream2);
1173
1174            let (stream1_tx, _stream1_rx) = stream1.split();
1175            // let (_stream2_tx, mut control_channel_receiver) = stream2.split();
1176
1177            let control_sender = Arc::new(Mutex::new(stream1_tx));
1178
1179            tokio::join!(
1180                Box::pin(async move {
1181                    open_channel(
1182                        trace,
1183                        id,
1184                        label,
1185                        buffer_size,
1186                        is_response,
1187                        channels1,
1188                        Arc::clone(&open_channel_requests),
1189                        control_sender,
1190                        on_remote_channel,
1191                    ).await.unwrap();
1192                }),
1193                Box::pin(channel_open_notification.notified()),
1194            );
1195
1196            assert!(
1197                !(channels.lock().await.contains_key(&id)),
1198                "Must not add channel into the map.",
1199            );
1200        }
1201
1202        #[tokio::test]
1203        async fn fails_if_no_channel_notification_found() {
1204            let trace = create_trace!("test");
1205
1206            let buffer_size: u32 = 4_096;
1207            let id = random_number(0..=u16::MAX);
1208            let label = random_str_rg(8..=16);
1209
1210            let is_response = true;
1211
1212            let channels = Arc::new(Mutex::new(HashMap::new()));
1213            let open_channel_requests = Arc::new(RwLock::new(HashMap::new()));
1214
1215            // let (control_sender, _control_receiver) = mpsc::channel(buffer_size as usize);
1216            let (on_remote_channel, _on_remote_channel_receiver) = mpsc::channel(buffer_size as usize);
1217
1218            let options1 = ChannelMockOptions::random()
1219                .with_buffer_size(buffer_size);
1220            let options2 = ChannelMockOptions::random()
1221                .with_buffer_size(buffer_size);
1222
1223            let (stream1, _stream2) = TransportChannel::new_pair(
1224                id,
1225                "control-stream",
1226                channel_mock_pair(options1, options2),
1227                buffer_size,
1228            );
1229
1230            let stream1 = create_framed_stream(stream1);
1231            let (stream1_tx, _stream1_rx) = stream1.split();
1232
1233            let control_sender = Arc::new(Mutex::new(stream1_tx));
1234
1235            let result = open_channel(
1236                trace,
1237                id,
1238                label,
1239                buffer_size,
1240                is_response,
1241                Arc::clone(&channels),
1242                Arc::clone(&open_channel_requests),
1243                control_sender,
1244                on_remote_channel,
1245            ).await;
1246
1247            assert!(
1248                result.is_err(),
1249                "Must fail if no channel notification present.",
1250            );
1251
1252            assert!(
1253                !(channels.lock().await.contains_key(&id)),
1254                "Must not add channel into the map.",
1255            );
1256        }
1257
1258        #[tokio::test]
1259        async fn responds_to_channel_open_request() {
1260            let trace = create_trace!("test");
1261
1262            let buffer_size: u32 = 4_096;
1263            let id = random_number(0..=u16::MAX);
1264            let label = random_str_rg(8..=16);
1265
1266            let is_response = false;
1267
1268            let channels = Arc::new(Mutex::new(HashMap::new()));
1269            let open_channel_requests = Arc::new(RwLock::new(HashMap::new()));
1270
1271            let (on_remote_channel, mut on_remote_channel_receiver) = mpsc::channel(buffer_size as usize);
1272
1273            let options1 = ChannelMockOptions::random()
1274                .with_buffer_size(buffer_size);
1275            let options2 = ChannelMockOptions::random()
1276                .with_buffer_size(buffer_size);
1277
1278            let (stream1, stream2) = TransportChannel::new_pair(
1279                id,
1280                "control-stream",
1281                channel_mock_pair(options1, options2),
1282                buffer_size,
1283            );
1284
1285            let stream1 = create_framed_stream(stream1);
1286            let stream2 = create_framed_stream(stream2);
1287
1288            let (stream1_tx, _stream1_rx) = stream1.split();
1289            let (_stream2_tx, mut control_receiver) = stream2.split();
1290
1291            let control_sender = Arc::new(Mutex::new(stream1_tx));
1292
1293            let channels1 = Arc::clone(&channels);
1294            let channels2 = Arc::clone(&channels);
1295
1296            let label1 = label.clone();
1297            let label2 = label.clone();
1298
1299            tokio::join!(
1300                Box::pin(async move {
1301                    wait_random(1..=5).await;
1302
1303                    open_channel(
1304                        trace,
1305                        id,
1306                        label1,
1307                        buffer_size,
1308                        is_response,
1309                        channels1,
1310                        Arc::clone(&open_channel_requests),
1311                        control_sender,
1312                        on_remote_channel,
1313                    ).await.unwrap();
1314                }),
1315                Box::pin(async move {
1316                    let message = control_receiver.next().await.expect("Stream closed.").unwrap();
1317                    match message {
1318                        ControlMessage::OpenChannel(recv_id, recv_label, recv_buffer_size, recv_is_response) => {
1319                            assert_eq!(
1320                                recv_id,
1321                                id,
1322                                "Must receive correct channel ID.",
1323                            );
1324
1325                            assert_eq!(
1326                                recv_label,
1327                                label2,
1328                                "Must receive correct channel label.",
1329                            );
1330
1331                            assert_eq!(
1332                                recv_buffer_size,
1333                                buffer_size,
1334                                "Must receive correct channel buffer_size.",
1335                            );
1336
1337                            assert!(
1338                                recv_is_response,
1339                                "Must send a response.",
1340                            );
1341                        },
1342                        unexpected @ _ => panic!("Got unexpected control message: {:?}", unexpected),  
1343                    };
1344                }),
1345            );
1346
1347            let _channel = on_remote_channel_receiver
1348                .recv().await
1349                .expect("Must send `on_remote_channel` notification.");
1350
1351            assert!(
1352                (channels2.lock().await.contains_key(&id)),
1353                "Must add channel into the map.",
1354            );
1355        }
1356    }
1357
1358    mod data_transfer {
1359        use futures::future;
1360        use cs_trace::{create_trace, child};
1361
1362        use super::*;
1363        use crate::{test::test_stream, Channel};
1364
1365        /// Open a channel pair (local, remote) on a connection.
1366        async fn open_channel(
1367            mut local_connection: Box<TransportConnection>,
1368            mut remote_connection: Box<TransportConnection>,
1369            buffer_size: u32,
1370        ) -> [(Box<TransportConnection>, Box<dyn Channel>); 2] {
1371            let (local, remote) = tokio::join!(
1372                Box::pin(async move {
1373                    let local_channel = local_connection.channel("local-channel1", buffer_size).await
1374                        .expect("Cannot create a channel.");
1375
1376                    return (local_connection, local_channel);
1377                }),
1378                Box::pin(async move {
1379                    let mut on_remote_channel = remote_connection
1380                        .on_remote_channel().unwrap();
1381
1382                    let remote_channel = on_remote_channel
1383                        .recv().await
1384                        .expect("Cannot receive a remote channel.");
1385
1386                    remote_connection.off_remote_channel(on_remote_channel)
1387                        .expect("Cannot set remote channel listener.");
1388
1389                    return (remote_connection, remote_channel);
1390                }),
1391            );
1392
1393            return [local, remote];
1394        }
1395
1396        #[rstest]
1397        #[case(512)]
1398        #[case(1_024)]
1399        #[case(2_048)]
1400        #[case(4_096)]
1401        #[case(8_192)]
1402        #[case(16_384)]
1403        #[tokio::test]
1404        async fn transfers_data(
1405            #[case] data_len: usize,
1406        ) {
1407            let trace = create_trace!("test");
1408
1409            let buffer_size: u32 = 2_048;
1410
1411            let (channel1, channel2) = TransportChannel::new_pair(
1412                random_number(0..=u16::MAX),
1413                "transport-channels",
1414                channel_mock_pair(ChannelMockOptions::random(), ChannelMockOptions::random()),
1415                buffer_size,
1416            );
1417
1418            let trace1 = &trace;
1419            let trace1 = child!(trace1, "local");
1420            
1421            let trace2 = &trace;
1422            let trace2 = child!(trace2, "remote");
1423
1424            let local_connection = TransportConnection::new(&trace1, channel1);
1425            let remote_connection = TransportConnection::new(&trace2, channel2);
1426
1427            let [
1428                (_local_connection, local_channel),
1429                (_remote_connection, remote_channel),
1430            ] = open_channel(
1431                local_connection,
1432                remote_connection,
1433                buffer_size,
1434            ).await;
1435
1436            test_stream(
1437                local_channel,
1438                remote_channel,
1439                TestOptions::random()
1440                    .with_data_len(data_len),
1441            ).await;
1442
1443        }
1444
1445        #[rstest]
1446        #[case(512)]
1447        #[case(1_024)]
1448        #[case(2_048)]
1449        #[case(4_096)]
1450        #[case(8_192)]
1451        #[case(16_384)]
1452        #[tokio::test]
1453        async fn transfers_data_in_parallel(
1454            #[case] data_len: usize,
1455        ) {
1456            let trace = create_trace!("test");
1457
1458            let buffer_size: u32 = 2_048;
1459
1460            let (channel1, channel2) = TransportChannel::new_pair(
1461                random_number(0..=u16::MAX),
1462                "transport-channels",
1463                channel_mock_pair(ChannelMockOptions::random(), ChannelMockOptions::random()),
1464                buffer_size,
1465            );
1466
1467            let trace1 = &trace;
1468            let trace1 = child!(trace1, "local");
1469            
1470            let trace2 = &trace;
1471            let trace2 = child!(trace2, "remote");
1472
1473            let mut local_connection = TransportConnection::new(&trace1, channel1);
1474            let mut remote_connection = TransportConnection::new(&trace2, channel2);
1475
1476            let mut tasks = vec![];
1477
1478            for _ in 0..random_number(5..=10) {
1479                let [
1480                    (local_connection1, local_channel),
1481                    (remote_connection1, remote_channel),
1482                ] = open_channel(
1483                    local_connection,
1484                    remote_connection,
1485                    buffer_size,
1486                ).await;
1487
1488                local_connection = local_connection1;
1489                remote_connection = remote_connection1;
1490
1491                tasks.push(
1492                    tokio::spawn(test_stream(
1493                        local_channel,
1494                        remote_channel,
1495                        TestOptions::random()
1496                            .with_data_len(data_len),
1497                    )),
1498                );
1499
1500                wait_random(0..=50).await;
1501            }
1502
1503            future::try_join_all(tasks).await
1504                .unwrap();
1505        }
1506    }
1507}