1use std::{sync::Arc, collections::HashMap, pin::Pin, fmt};
2
3use cs_utils::random_number;
4use cs_trace::{Tracer, child};
5use tokio_util::codec::Framed;
6use anyhow::{Result, anyhow, bail};
7use serde::{Serialize, Deserialize};
8use futures::{StreamExt, stream::{SplitStream, SplitSink}, SinkExt};
9use tokio::{sync::{mpsc::{Sender, self}, Mutex, Notify, watch, RwLock}, io::{split, duplex, WriteHalf, AsyncReadExt, ReadHalf, AsyncWriteExt}};
10
11use crate::{Channel, create_framed_stream, TransportChannel, codecs::GenericCodec};
12
13type TChannels = Arc<Mutex<HashMap<u16, Arc<Mutex<(WriteHalf<Box<dyn Channel>>, watch::Receiver<bool>)>>>>>;
14type TChannelId = u16;
15
16#[derive(Serialize, Deserialize, Debug)]
17pub enum ControlMessage {
18 Data(TChannelId, Vec<u8>),
19 OpenChannel(TChannelId, String, u32, bool),
20 Close(TChannelId),
21 Error(TChannelId, String),
22}
23
24impl fmt::Display for ControlMessage {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 match self {
27 ControlMessage::Data(id, data) => {
28 return f.debug_tuple("ControlMessage::Data")
29 .field(id)
30 .field(&data.len())
31 .finish();
32 },
33 ControlMessage::OpenChannel(id, label, buffer_size, is_response) => {
34 return f.debug_tuple("ControlMessage::OpenChannel")
35 .field(id)
36 .field(label)
37 .field(buffer_size)
38 .field(is_response)
39 .finish();
40 },
41 ControlMessage::Close(id) => {
42 return f.debug_tuple("ControlMessage::Close")
43 .field(id)
44 .finish();
45 },
46 ControlMessage::Error(id, message) => {
47 return f.debug_tuple("ControlMessage::Error")
48 .field(id)
49 .field(message)
50 .finish();
51 },
52 };
53 }
54}
55
56async fn send_error(
58 trace: &Box<dyn Tracer>,
59 id: TChannelId,
60 message: String,
61 message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
62) {
63 trace.error(
64 &format!("Channel {id} read error: {:?}", message),
65 );
66
67 let result = message_sender.lock().await
68 .send(ControlMessage::Error(id, message)).await;
69
70 if let Err(error) = result {
71 trace.error(
72 &format!("Failed to send channel error to the remote side: {:?}", error),
73 );
74 }
75}
76
77async fn forward_channel_data(
80 trace: Box<dyn Tracer>,
81 id: u16,
82 mut reader: ReadHalf<Box<dyn Channel>>,
83 message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
84 on_close: watch::Receiver<bool>,
85 channels: TChannels,
86 buffer_size: u32,
87) {
88 let mut buf = vec![];
89 buf.resize(buffer_size as usize, 0);
90
91 loop {
92 let is_closed = *on_close.borrow();
93
94 let bytes_read = match reader.read(buf.as_mut_slice()).await {
95 Ok(number) => number,
96 Err(error) => {
97 send_error(&trace, id, format!("{error}"), message_sender).await;
98 return;
99 },
100 };
101
102 let data = (&buf[..bytes_read]).to_vec();
103 let result = {
104 message_sender
105 .lock().await
106 .send(ControlMessage::Data(id, data)).await
107 };
108
109 if let Err(error) = result {
110 send_error(&trace, id, format!("{}", error), message_sender).await;
111 return;
112 };
113
114 if bytes_read == 0 {
115 trace.warn(
116 &format!("got EOF, sending channel close message"),
117 );
118
119 let close_message_result = {
120 message_sender.lock().await
121 .send(ControlMessage::Close(id)).await
122 };
123
124 if let Err(error) = close_channel(id, channels).await {
125 trace.error(
126 &format!("failed to close local channel: {:?}", error),
127 );
128 };
129
130 trace.info(
131 &format!("channel is closed by EOF"),
132 );
133
134 if let Err(error) = close_message_result {
135 send_error(&trace, id, format!("{}", error), message_sender).await;
136 return;
137 };
138
139 return;
140 }
141
142 if is_closed {
143 trace.info("channel is closed by notification");
144
145 return;
146 }
147 }
148}
149
150async fn send_channel_data(
152 trace: Box<dyn Tracer>,
153 id: u16,
154 data: Vec<u8>,
155 channels: TChannels,
156 message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
157) {
158 let channel = {
159 let lock = channels.lock().await;
160
161 let channel = match lock.get(&id) {
162 Some(writer) => writer,
163 None => {
164 send_error(&trace, id, format!("No channel with ID {:?} found.", id), message_sender).await;
165
166 return;
167 },
168 };
169
170 Arc::clone(channel)
171 };
172
173 let (writer, on_close) = &mut *channel.lock().await;
174
175 let is_closed = *on_close.borrow();
176
177 if data.len() == 0 && is_closed {
178 trace.warn(
179 &format!("channel {id} already closed, skip writing"),
180 );
181
182 return;
183 }
184
185 if let Err(error) = writer.write_all(&data[..]).await {
186 send_error(&trace, id, format!("{}", error), Arc::clone(&message_sender)).await;
187 }
188}
189
190async fn close_channel(
192 id: u16,
193 channels: TChannels,
194) -> Result<()> {
195 let mut lock = channels.lock().await;
196
197 let channel = {
198 let channel = match lock.remove(&id) {
199 Some(writer) => writer,
200 None => bail!("No channel found with ID {}.", id),
201 };
202
203 channel
204 };
205
206 let (writer, on_close) = &mut *channel.lock().await;
207
208 if *on_close.borrow() {
209 return Ok(());
211 }
212
213 writer.shutdown().await?;
214
215 return Ok(());
216}
217
218async fn add_local_channel(
220 trace: Box<dyn Tracer>,
221 id: u16,
222 label: String,
223 buffer_size: u32,
224 channels: TChannels,
225 message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
226) -> Result<Box<dyn Channel>> {
227 let (duplex1, duplex2) = duplex(buffer_size as usize);
228 let (channel1, channel2) = TransportChannel::new_pair(
229 id,
230 label.clone(),
231 (Box::new(duplex1), Box::new(duplex2)),
232 buffer_size,
233 );
234
235 let on_close1 = channel1.on_close();
236 let on_close2 = channel1.on_close();
237
238 let (reader, writer) = split(channel1);
239
240 let trace2 = &trace;
241 let trace2 = child!(trace2, "forward-channel-data");
242
243 tokio::spawn(forward_channel_data(
244 trace2,
245 id,
246 reader,
247 Arc::clone(&message_sender),
248 on_close1,
249 Arc::clone(&channels),
250 buffer_size,
251 ));
252
253 channels
254 .lock().await
255 .insert(id, Arc::new(Mutex::new((writer, on_close2))));
256
257 trace.info(
258 &format!("local channel opened: {}, {}", id, label),
259 );
260
261 return Ok(channel2);
262}
263
264async fn open_channel(
266 trace: Box<dyn Tracer>,
267 id: u16,
268 label: String,
269 buffer_size: u32,
270 is_response: bool,
271 channels: TChannels,
272 open_channel_requests: Arc<RwLock<HashMap<u16, Arc<Notify>>>>,
273 message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
274 on_remote_channel: mpsc::Sender<Box<dyn Channel>>,
275) -> Result<()> {
276 if is_response {
277 let read_lock = open_channel_requests.read().await;
278 let notify = match read_lock.get(&id) {
279 None => bail!("No open channel notifier found."),
280 Some(notify) => notify,
281 };
282
283 notify.notify_waiters();
284
285 return Ok(());
286 }
287
288 let trace1 = &trace;
289 let trace1 = child!(trace1, "add-local-channel");
290
291 trace.trace("sending open channel response");
292
293 let channel = add_local_channel(
294 trace1,
295 id,
296 label.clone(),
297 buffer_size,
298 channels,
299 Arc::clone(&message_sender),
300 ).await?;
301
302 {
303 message_sender
304 .lock().await
305 .send(ControlMessage::OpenChannel(id, label.clone(), buffer_size, true)).await?;
306 }
307
308 trace.trace("sent");
309
310 on_remote_channel
311 .send(channel).await
312 .map_err(|error| {
313 return anyhow!("{}", error);
314 })?;
315
316
317
318 return Ok(());
319}
320
321async fn handle_control_messages(
323 trace: Box<dyn Tracer>,
324 mut stream_source: SplitStream<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>>,
325 channels: TChannels,
326 open_channel_requests: Arc<RwLock<HashMap<u16, Arc<Notify>>>>,
327 message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
328 on_remote_channel: mpsc::Sender<Box<dyn Channel>>,
329) -> Result<()> {
330 while let Some(maybe_message) = stream_source.next().await {
331 let message = maybe_message?;
332
333 trace.warn(
334 &format!("got control message: {}", message),
335 );
336
337 match message {
338 ControlMessage::Data(id, data) => {
339 let trace = &trace;
340 let trace = child!(trace, "send-channel-data");
341
342 tokio::spawn(send_channel_data(
343 trace,
344 id,
345 data,
346 Arc::clone(&channels),
347 Arc::clone(&message_sender),
348 ));
349 },
350 ControlMessage::OpenChannel(id, label, buffer_size, is_response) => {
351 let trace = &trace;
352 let trace = child!(trace, "open-channel");
353
354 open_channel(
355 trace,
356 id,
357 label,
358 buffer_size,
359 is_response,
360 Arc::clone(&channels),
361 Arc::clone(&open_channel_requests),
362 Arc::clone(&message_sender),
363 Sender::clone(&on_remote_channel),
364 ).await?;
365 },
366 ControlMessage::Close(id) => {
367 tokio::spawn(close_channel(id, Arc::clone(&channels)));
368 },
369 ControlMessage::Error(id, message) => {
370 trace.error(
371 &format!("remote channel {id} error: {:?}", message),
372 );
373
374 tokio::spawn(close_channel(id, Arc::clone(&channels)));
375 },
376 };
377 }
378
379 return Ok(());
380}
381
382pub struct TransportConnection {
383 trace: Box<dyn Tracer>,
384 message_sender: Arc<Mutex<SplitSink<Framed<Pin<Box<dyn Channel>>, GenericCodec<ControlMessage>>, ControlMessage>>>,
385 open_channel_requests: Arc<RwLock<HashMap<u16, Arc<Notify>>>>,
386 channels: TChannels,
387 on_remote_channel: Option<mpsc::Receiver<Box<dyn Channel>>>,
388}
389
390impl TransportConnection {
391 pub fn new(
392 trace: &Box<dyn Tracer>,
393 channel: Box<dyn Channel>,
394 ) -> Box<TransportConnection> {
395 let trace = child!(trace, "transport-channel");
396
397 let stream = create_framed_stream(channel);
398
399 let (channel_sink, channel_source) = stream.split();
400
401 let message_sender = Arc::new(Mutex::new(channel_sink));
402
403 let open_channel_requests = Arc::new(RwLock::new(HashMap::new()));
404 let channels = Arc::new(Mutex::new(HashMap::new()));
405
406 let (on_remote_channel_sender, on_remote_channel) = mpsc::channel(25);
407
408 let trace2 = &trace;
409 let trace2 = child!(trace2, "control-messages-handler");
410
411 tokio::spawn(handle_control_messages(
412 trace2,
413 channel_source,
414 Arc::clone(&channels),
415 Arc::clone(&open_channel_requests),
416 Arc::clone(&message_sender),
417 on_remote_channel_sender,
418 ));
419
420 return Box::new(TransportConnection {
421 trace,
422 message_sender,
423 open_channel_requests,
424 channels,
425 on_remote_channel: Some(on_remote_channel),
426 });
427 }
428
429 pub fn on_remote_channel(&mut self) -> Result<mpsc::Receiver<Box<dyn Channel>>> {
430 match self.on_remote_channel.take() {
431 Some(on_remote_channel) => return Ok(on_remote_channel),
432 None => bail!("No on_remote_channel found."),
433 };
434 }
435
436 pub fn off_remote_channel(
437 &mut self,
438 on_channel: mpsc::Receiver<Box<dyn Channel>>,
439 ) -> Result<()> {
440 if let Some(_) = self.on_remote_channel {
441 bail!("on_remote_channel already set.");
442 }
443
444 self.on_remote_channel.replace(on_channel);
445 return Ok(());
446 }
447
448 pub async fn channel(
449 &mut self,
450 label: impl AsRef<str> + ToString,
451 buffer_size: u32,
452 ) -> Result<Box<dyn Channel>> {
453 let id = random_number(0..=u16::MAX);
454 let label = label.to_string();
455
456 self.trace.trace(
457 &format!("creating channel, ID: {}, label: {}", id, label),
458 );
459
460 let notify = Arc::new(Notify::new());
461
462 {
463 self.open_channel_requests
464 .write().await
465 .insert(id, Arc::clone(¬ify));
466 }
467
468 self.trace.trace(
469 &format!("sending open channel request"),
470 );
471
472 {
473 self.message_sender
474 .lock().await
475 .send(ControlMessage::OpenChannel(id, label.clone(), buffer_size, false)).await?;
476 }
477
478 self.trace.trace(
479 &format!("open channel request sent"),
480 );
481
482 notify.notified().await;
483
484 self.trace.trace(
485 &format!("got open channel response"),
486 );
487
488 let trace2 = &self.trace;
489 let trace2 = child!(trace2, "add-local-channel");
490
491 let channel = add_local_channel(
492 trace2,
493 id,
494 label,
495 buffer_size,
496 Arc::clone(&self.channels),
497 Arc::clone(&self.message_sender),
498 ).await?;
499
500 self.trace.trace(
501 &format!("channel created: {}, {}", channel.id(), channel.label()),
502 );
503
504 return Ok(channel);
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use std::{collections::HashMap, sync::Arc};
511
512 use rstest::rstest;
513 use futures::StreamExt;
514 use cs_trace::create_trace;
515 use cs_utils::{random_str, random_str_rg, random_number, traits::Random, futures::wait_random};
516 use tokio::{sync::{Mutex, mpsc, watch, RwLock, Notify}, io::{split, AsyncWriteExt, AsyncReadExt}};
517
518 use crate::test::TestOptions;
519 use crate::create_framed_stream;
520 use crate::{TransportChannel, TransportConnection};
521 use crate::{connections::transport_connection::{send_channel_data, forward_channel_data, close_channel, add_local_channel, open_channel, ControlMessage}, mocks::{ChannelMockOptions, channel_mock_pair}};
522
523 mod send_channel_data {
524 use super::*;
525
526 #[rstest]
527 #[case(128)]
528 #[case(256)]
529 #[case(512)]
530 #[case(1_024)]
531 #[case(2_048)]
532 #[case(4_096)]
533 #[tokio::test]
534 async fn sends_data_to_channel(
535 #[case] data_len: usize,
536 ) {
537 let trace = create_trace!("test");
538
539 let buffer_size = 4_096;
540 let id = random_number(0..=u16::MAX);
541 let data = random_str(data_len).as_bytes().to_vec();
542
543 let options1 = ChannelMockOptions::random()
544 .with_buffer_size(buffer_size);
545 let options2 = ChannelMockOptions::random()
546 .with_buffer_size(buffer_size);
547
548 let (channel1, mut channel2) = channel_mock_pair(options1, options2);
549
550 let on_close = channel1.on_close();
551
552 let (_reader, writer) = split(channel1);
553
554 let mut channels = HashMap::new();
555 channels.insert(id, Arc::new(Mutex::new((writer, on_close))));
556
557 let channels = Arc::new(Mutex::new(channels));
558
559 let data_to_send = data.clone();
560 let data_to_receive = data.clone();
561
562 let options1 = ChannelMockOptions::random()
563 .with_buffer_size(buffer_size);
564 let options2 = ChannelMockOptions::random()
565 .with_buffer_size(buffer_size);
566
567 let (stream1, _stream2) = TransportChannel::new_pair(
568 id,
569 "control-stream",
570 channel_mock_pair(options1, options2),
571 buffer_size,
572 );
573
574 let (stream1_tx, _stream1_rx) = create_framed_stream(stream1).split();
575 let control_channel_sender = Arc::new(Mutex::new(stream1_tx));
576
577 tokio::join!(
578 Box::pin(async move {
579 wait_random(1..=5).await;
580
581 send_channel_data(
582 trace,
583 id,
584 data_to_send,
585 channels,
586 control_channel_sender,
587 ).await;
588 }),
589 Box::pin(async move {
590 wait_random(1..=5).await;
591
592 let mut buf = [0; 4_096];
593
594 let bytes_read = channel2
595 .read(&mut buf).await
596 .unwrap();
597
598 let received_data = &buf[..bytes_read];
599
600 assert_eq!(
601 received_data,
602 &data_to_receive[..],
603 "Must receive correct data.",
604 );
605 }),
606 );
607 }
608
609 #[rstest]
610 #[tokio::test]
611 async fn does_not_send_if_already_closed() {
612 let trace = create_trace!("test");
613
614 let buffer_size = 4_096;
615 let id = random_number(0..=u16::MAX);
616
617 let options1 = ChannelMockOptions::random()
618 .with_buffer_size(buffer_size);
619 let options2 = ChannelMockOptions::random()
620 .with_buffer_size(buffer_size);
621
622 let (mut channel1, mut channel2) = channel_mock_pair(options1, options2);
623
624 let on_close = channel1.on_close();
625
626 channel1
627 .shutdown().await
628 .unwrap();
629
630 let (_reader, writer) = split(channel1);
631
632 let mut channels = HashMap::new();
633 channels.insert(id, Arc::new(Mutex::new((writer, on_close))));
634
635 let channels = Arc::new(Mutex::new(channels));
636
637 let options1 = ChannelMockOptions::random()
638 .with_buffer_size(buffer_size);
639 let options2 = ChannelMockOptions::random()
640 .with_buffer_size(buffer_size);
641
642 let (stream1, _stream2) = TransportChannel::new_pair(
643 id,
644 "control-stream",
645 channel_mock_pair(options1, options2),
646 buffer_size,
647 );
648
649 let (stream1_tx, _stream1_rx) = create_framed_stream(stream1).split();
650 let control_channel_sender = Arc::new(Mutex::new(stream1_tx));
651
652 tokio::join!(
653 Box::pin(async move {
654 wait_random(1..=5).await;
655
656 send_channel_data(
657 trace,
658 id,
659 vec![],
660 channels,
661 control_channel_sender,
662 ).await;
663 }),
664 Box::pin(async move {
665 wait_random(1..=5).await;
666
667 let mut buf = [0; 4_096];
668
669 let bytes_read = channel2
670 .read(&mut buf).await
671 .unwrap();
672
673 assert_eq!(
674 bytes_read,
675 0,
676 "Must 0 bytes.",
677 );
678 }),
679 );
680 }
681
682 #[rstest]
683 #[case(128)]
684 #[case(256)]
685 #[case(512)]
686 #[case(1_024)]
687 #[case(2_048)]
688 #[case(4_096)]
689 #[tokio::test]
690 async fn fails_if_no_channel_found(
691 #[case] data_len: usize,
692 ) {
693 let trace = create_trace!("test");
694
695 let buffer_size = 4_096;
696 let id = random_number(0..=u16::MAX);
697 let data = random_str(data_len).as_bytes().to_vec();
698
699 let options1 = ChannelMockOptions::random()
700 .with_buffer_size(buffer_size);
701 let options2 = ChannelMockOptions::random()
702 .with_buffer_size(buffer_size);
703
704 let (channel1, _channel2) = channel_mock_pair(options1, options2);
705
706 let on_close = channel1.on_close();
707
708 let (_reader, writer) = split(channel1);
709
710 let mut channels = HashMap::new();
711
712 let another_id = {
713 let mut another_id = random_number(0..=u16::MAX);
714
715 while another_id == id {
716 another_id = random_number(0..=u16::MAX);
717 }
718
719 another_id
720 };
721
722 channels.insert(another_id, Arc::new(Mutex::new((writer, on_close))));
723
724 let channels = Arc::new(Mutex::new(channels));
725
726 let options1 = ChannelMockOptions::random()
727 .with_buffer_size(buffer_size);
728 let options2 = ChannelMockOptions::random()
729 .with_buffer_size(buffer_size);
730
731 let (stream1, stream2) = TransportChannel::new_pair(
732 id,
733 "control-stream",
734 channel_mock_pair(options1, options2),
735 buffer_size,
736 );
737
738 let (stream1_tx, _stream1_rx) = create_framed_stream(stream1).split();
739 let (_stream2_tx, mut stream2_rx) = create_framed_stream(stream2).split();
740
741 tokio::try_join!(
742 tokio::spawn(send_channel_data(
743 trace,
744 id,
745 data,
746 channels,
747 Arc::new(Mutex::new(stream1_tx),
748 ))),
749 tokio::spawn(async move {
750 let message = stream2_rx.next().await.unwrap().unwrap();
751
752 match message {
753 ControlMessage::Error(received_id, error_message) => {
754 assert_eq!(
755 received_id,
756 id,
757 "Must receive error with correct id.",
758 );
759
760 assert!(
761 error_message.len() > 3,
762 "Received error message must be not empty.",
763 );
764 },
765 unexpected @ _ => panic!("Unexpected message: {:?}.", unexpected),
766 };
767 }),
768 ).unwrap();
769 }
770 }
771
772 mod handle_channel_reads {
773 use crate::TransportChannel;
774
775 use super::*;
776
777 #[rstest]
778 #[case(512)]
779 #[case(1_024)]
780 #[case(2_048)]
781 #[case(4_096)]
782 #[case(8_192)]
783 #[case(16_384)]
784 #[tokio::test]
785 async fn reads_from_a_local_channel(
786 #[case] data_len: usize,
787 ) {
788 let trace = cs_trace::create_trace!("test");
789
790 let buffer_size: u32 = 4_096;
791
792 let id = random_number(0..=u16::MAX);
793 let data = random_str(data_len)
794 .as_bytes().to_vec();
795
796 let options1 = ChannelMockOptions::random()
797 .with_buffer_size(buffer_size);
798 let options2 = ChannelMockOptions::random()
799 .with_buffer_size(buffer_size);
800
801 let (channel1, mut channel2) = TransportChannel::new_pair(
802 id,
803 "transport-channel",
804 channel_mock_pair(options1, options2),
805 buffer_size,
806 );
807
808 let on_close = channel1.on_close();
809
810 let (reader, _writer) = split(channel1);
811
812 let options1 = ChannelMockOptions::random()
813 .with_buffer_size(buffer_size);
814 let options2 = ChannelMockOptions::random()
815 .with_buffer_size(buffer_size);
816
817 let (stream1, stream2) = TransportChannel::new_pair(
818 id,
819 "control-stream",
820 channel_mock_pair(options1, options2),
821 buffer_size,
822 );
823
824 let stream1 = create_framed_stream(stream1);
825 let stream2 = create_framed_stream(stream2);
826
827 let (stream1_tx, _stream1_rx) = stream1.split();
828 let (_stream2_tx, mut control_channel_receiver) = stream2.split();
829
830 let control_channel_sender = Arc::new(Mutex::new(stream1_tx));
831
832 let data_to_send = data.clone();
833 let data_to_receive = data.clone();
834
835 let channels = Arc::new(Mutex::new(HashMap::new()));
836
837 let channels1 = Arc::clone(&channels);
838 let channels2 = Arc::clone(&channels);
839
840 tokio::join!(
841 Box::pin(async move {
842 wait_random(1..=5).await;
843
844 forward_channel_data(
845 trace,
846 id,
847 reader,
848 control_channel_sender,
849 on_close,
850 channels1,
851 buffer_size,
852 ).await;
853 }),
854 Box::pin(async move {
855 wait_random(1..=5).await;
856
857 let mut total_written = 0;
858 while total_written < data_to_send.len() {
859 let written = channel2
860 .write(&data_to_send[total_written..]).await
861 .unwrap();
862
863 total_written += written;
864 }
865
866 assert!(
867 !channels2.lock().await.contains_key(&id),
868 "Channel must be deleted.",
869 );
870 }),
871 Box::pin(async move {
872 wait_random(1..=5).await;
873
874 let mut received_data = vec![];
875
876 while let Some(maybe_message) = control_channel_receiver.next().await {
877 let message = maybe_message.unwrap();
878 let (received_id, data) = match message {
879 ControlMessage::Data(id, data) => (id, data),
880 ControlMessage::Close(received_id) => {
881 assert_eq!(
882 received_id,
883 id,
884 "Message must have correct channel ID.",
885 );
886
887 break;
888 },
889 other @ _ => panic!("Unexpected message: {:?}", other),
890 };
891
892 assert_eq!(
893 received_id,
894 id,
895 "Message must have correct channel ID.",
896 );
897
898 received_data.extend_from_slice(&data[..]);
899 }
900
901 assert_eq!(
902 received_data,
903 data_to_receive,
904 "Must receive correct data.",
905 );
906 }),
907 );
908 }
909 }
910
911 mod close_channel {
912 use crate::TransportChannel;
913
914 use super::*;
915
916 #[rstest]
917 #[case(())]
918 #[case(())]
919 #[case(())]
920 #[case(())]
921 #[case(())]
922 #[case(())]
923 #[tokio::test]
924 async fn shutsdown_a_channel_and_removes_reference(
925 #[case] _case_num: (),
926 ) {
927 let buffer_size = 4_096;
928
929 let id = random_number(0..=u16::MAX);
930 let options1 = ChannelMockOptions::random()
931 .with_buffer_size(buffer_size);
932 let options2 = ChannelMockOptions::random()
933 .with_buffer_size(buffer_size);
934
935 let (channel1, _channel2) = TransportChannel::new_pair(
936 id,
937 "transport-channel",
938 channel_mock_pair(options1, options2),
939 buffer_size,
940 );
941
942 let on_close = channel1.on_close();
943
944 let (_reader, writer) = split(channel1);
945
946 let mut channels = HashMap::new();
947 channels.insert(id, Arc::new(Mutex::new((writer, watch::Receiver::clone(&on_close)))));
948
949 let channels = Arc::new(Mutex::new(channels));
950
951 wait_random(1..=5).await;
952
953 close_channel(
954 id,
955 Arc::clone(&channels),
956 ).await.unwrap();
957
958 {
959 assert!(
960 !(channels.lock().await.contains_key(&id)),
961 "Must remove channel reference from the map.",
962 );
963 }
964
965 assert!(
966 *on_close.borrow(),
967 "Must close the channel.",
968 );
969 }
970
971 #[rstest]
972 #[case(())]
973 #[case(())]
974 #[case(())]
975 #[case(())]
976 #[case(())]
977 #[case(())]
978 #[tokio::test]
979 async fn does_not_fails_if_channel_allready_closed(
980 #[case] _case_num: (),
981 ) {
982 let buffer_size = 4_096;
983
984 let id = random_number(0..=u16::MAX);
985 let options1 = ChannelMockOptions::random()
986 .with_buffer_size(buffer_size);
987 let options2 = ChannelMockOptions::random()
988 .with_buffer_size(buffer_size);
989
990 let (channel1, _channel2) = TransportChannel::new_pair(
991 id,
992 "transport-channel",
993 channel_mock_pair(options1, options2),
994 buffer_size,
995 );
996
997 let on_close = channel1.on_close();
998
999 let (_reader, mut writer) = split(channel1);
1000
1001 let mut channels = HashMap::new();
1002
1003 writer.shutdown().await
1004 .unwrap();
1005
1006 assert!(
1007 *on_close.borrow(),
1008 "Must close the channel.",
1009 );
1010
1011 channels.insert(id, Arc::new(Mutex::new((writer, watch::Receiver::clone(&on_close)))));
1012
1013 assert!(
1014 channels.contains_key(&id),
1015 "Must contain channel before test",
1016 );
1017
1018 let channels = Arc::new(Mutex::new(channels));
1019
1020 wait_random(1..=5).await;
1021
1022 close_channel(
1023 id,
1024 Arc::clone(&channels),
1025 ).await.unwrap();
1026
1027 {
1028 assert!(
1029 !(channels.lock().await.contains_key(&id)),
1030 "Must remove channel reference from the map.",
1031 );
1032 }
1033
1034 assert!(
1035 *on_close.borrow(),
1036 "Must close the channel.",
1037 );
1038 }
1039
1040 #[rstest]
1041 #[case(())]
1042 #[case(())]
1043 #[case(())]
1044 #[case(())]
1045 #[case(())]
1046 #[case(())]
1047 #[tokio::test]
1048 #[should_panic]
1049 async fn fails_if_no_channel_found(
1050 #[case] _case_num: (),
1051 ) {
1052 let id = random_number(0..=u16::MAX);
1053
1054 let channels = HashMap::new();
1055 let channels = Arc::new(Mutex::new(channels));
1056
1057 wait_random(1..=5).await;
1058
1059 close_channel(
1060 id,
1061 Arc::clone(&channels),
1062 ).await.unwrap();
1063 }
1064 }
1065
1066 mod add_local_channel {
1067 use cs_trace::create_trace;
1068
1069 use crate::create_framed_stream;
1070
1071 use super::*;
1072
1073 #[tokio::test]
1074 async fn adds_channel_to_channels_map() {
1075 let trace = create_trace!("test");
1076
1077 let buffer_size: u32 = 4_096;
1078 let id = random_number(0..=u16::MAX);
1079 let label = random_str_rg(8..=16);
1080
1081 let channels = HashMap::new();
1082 let channels = Arc::new(Mutex::new(channels));
1083
1084 {
1085 assert!(
1086 !(channels.lock().await).contains_key(&id),
1087 "Must not contain channel before the test.",
1088 );
1089 }
1090
1091 let options1 = ChannelMockOptions::random()
1092 .with_buffer_size(buffer_size);
1093 let options2 = ChannelMockOptions::random()
1094 .with_buffer_size(buffer_size);
1095
1096 let (stream1, _stream2) = TransportChannel::new_pair(
1097 id,
1098 "control-stream",
1099 channel_mock_pair(options1, options2),
1100 buffer_size,
1101 );
1102
1103 let stream1 = create_framed_stream(stream1);
1104
1105 let (stream1_tx, _stream1_rx) = stream1.split();
1106
1107 let control_sender = Arc::new(Mutex::new(stream1_tx));
1108
1109 add_local_channel(
1110 trace,
1111 id,
1112 label,
1113 buffer_size,
1114 Arc::clone(&channels),
1115 control_sender,
1116 ).await.unwrap();
1117
1118 {
1119 assert!(
1120 (channels.lock().await).contains_key(&id),
1121 "Must add channel to the map.",
1122 );
1123 }
1124 }
1125 }
1126
1127 mod open_channel {
1128 use cs_trace::create_trace;
1129
1130 use crate::create_framed_stream;
1131
1132 use super::*;
1133
1134 #[tokio::test]
1135 async fn notifies_pending_channel_open_requests() {
1136 let trace = create_trace!("test");
1137
1138 let buffer_size: u32 = 4_096;
1139 let id = random_number(0..=u16::MAX);
1140 let label = random_str_rg(8..=16);
1141
1142 let is_response = true;
1143
1144 let channels = Arc::new(Mutex::new(HashMap::new()));
1145 let open_channel_requests = Arc::new(RwLock::new(HashMap::new()));
1146
1147 let (on_remote_channel, _on_remote_channel_receiver) = mpsc::channel(buffer_size as usize);
1149
1150 let channel_open_notification = Arc::new(Notify::new());
1151
1152 {
1153 open_channel_requests.write().await
1154 .insert(id, Arc::clone(&channel_open_notification));
1155 }
1156
1157 let channels1 = Arc::clone(&channels);
1158
1159 let options1 = ChannelMockOptions::random()
1160 .with_buffer_size(buffer_size);
1161 let options2 = ChannelMockOptions::random()
1162 .with_buffer_size(buffer_size);
1163
1164 let (stream1, _stream2) = TransportChannel::new_pair(
1165 id,
1166 "control-stream",
1167 channel_mock_pair(options1, options2),
1168 buffer_size,
1169 );
1170
1171 let stream1 = create_framed_stream(stream1);
1172 let (stream1_tx, _stream1_rx) = stream1.split();
1175 let control_sender = Arc::new(Mutex::new(stream1_tx));
1178
1179 tokio::join!(
1180 Box::pin(async move {
1181 open_channel(
1182 trace,
1183 id,
1184 label,
1185 buffer_size,
1186 is_response,
1187 channels1,
1188 Arc::clone(&open_channel_requests),
1189 control_sender,
1190 on_remote_channel,
1191 ).await.unwrap();
1192 }),
1193 Box::pin(channel_open_notification.notified()),
1194 );
1195
1196 assert!(
1197 !(channels.lock().await.contains_key(&id)),
1198 "Must not add channel into the map.",
1199 );
1200 }
1201
1202 #[tokio::test]
1203 async fn fails_if_no_channel_notification_found() {
1204 let trace = create_trace!("test");
1205
1206 let buffer_size: u32 = 4_096;
1207 let id = random_number(0..=u16::MAX);
1208 let label = random_str_rg(8..=16);
1209
1210 let is_response = true;
1211
1212 let channels = Arc::new(Mutex::new(HashMap::new()));
1213 let open_channel_requests = Arc::new(RwLock::new(HashMap::new()));
1214
1215 let (on_remote_channel, _on_remote_channel_receiver) = mpsc::channel(buffer_size as usize);
1217
1218 let options1 = ChannelMockOptions::random()
1219 .with_buffer_size(buffer_size);
1220 let options2 = ChannelMockOptions::random()
1221 .with_buffer_size(buffer_size);
1222
1223 let (stream1, _stream2) = TransportChannel::new_pair(
1224 id,
1225 "control-stream",
1226 channel_mock_pair(options1, options2),
1227 buffer_size,
1228 );
1229
1230 let stream1 = create_framed_stream(stream1);
1231 let (stream1_tx, _stream1_rx) = stream1.split();
1232
1233 let control_sender = Arc::new(Mutex::new(stream1_tx));
1234
1235 let result = open_channel(
1236 trace,
1237 id,
1238 label,
1239 buffer_size,
1240 is_response,
1241 Arc::clone(&channels),
1242 Arc::clone(&open_channel_requests),
1243 control_sender,
1244 on_remote_channel,
1245 ).await;
1246
1247 assert!(
1248 result.is_err(),
1249 "Must fail if no channel notification present.",
1250 );
1251
1252 assert!(
1253 !(channels.lock().await.contains_key(&id)),
1254 "Must not add channel into the map.",
1255 );
1256 }
1257
1258 #[tokio::test]
1259 async fn responds_to_channel_open_request() {
1260 let trace = create_trace!("test");
1261
1262 let buffer_size: u32 = 4_096;
1263 let id = random_number(0..=u16::MAX);
1264 let label = random_str_rg(8..=16);
1265
1266 let is_response = false;
1267
1268 let channels = Arc::new(Mutex::new(HashMap::new()));
1269 let open_channel_requests = Arc::new(RwLock::new(HashMap::new()));
1270
1271 let (on_remote_channel, mut on_remote_channel_receiver) = mpsc::channel(buffer_size as usize);
1272
1273 let options1 = ChannelMockOptions::random()
1274 .with_buffer_size(buffer_size);
1275 let options2 = ChannelMockOptions::random()
1276 .with_buffer_size(buffer_size);
1277
1278 let (stream1, stream2) = TransportChannel::new_pair(
1279 id,
1280 "control-stream",
1281 channel_mock_pair(options1, options2),
1282 buffer_size,
1283 );
1284
1285 let stream1 = create_framed_stream(stream1);
1286 let stream2 = create_framed_stream(stream2);
1287
1288 let (stream1_tx, _stream1_rx) = stream1.split();
1289 let (_stream2_tx, mut control_receiver) = stream2.split();
1290
1291 let control_sender = Arc::new(Mutex::new(stream1_tx));
1292
1293 let channels1 = Arc::clone(&channels);
1294 let channels2 = Arc::clone(&channels);
1295
1296 let label1 = label.clone();
1297 let label2 = label.clone();
1298
1299 tokio::join!(
1300 Box::pin(async move {
1301 wait_random(1..=5).await;
1302
1303 open_channel(
1304 trace,
1305 id,
1306 label1,
1307 buffer_size,
1308 is_response,
1309 channels1,
1310 Arc::clone(&open_channel_requests),
1311 control_sender,
1312 on_remote_channel,
1313 ).await.unwrap();
1314 }),
1315 Box::pin(async move {
1316 let message = control_receiver.next().await.expect("Stream closed.").unwrap();
1317 match message {
1318 ControlMessage::OpenChannel(recv_id, recv_label, recv_buffer_size, recv_is_response) => {
1319 assert_eq!(
1320 recv_id,
1321 id,
1322 "Must receive correct channel ID.",
1323 );
1324
1325 assert_eq!(
1326 recv_label,
1327 label2,
1328 "Must receive correct channel label.",
1329 );
1330
1331 assert_eq!(
1332 recv_buffer_size,
1333 buffer_size,
1334 "Must receive correct channel buffer_size.",
1335 );
1336
1337 assert!(
1338 recv_is_response,
1339 "Must send a response.",
1340 );
1341 },
1342 unexpected @ _ => panic!("Got unexpected control message: {:?}", unexpected),
1343 };
1344 }),
1345 );
1346
1347 let _channel = on_remote_channel_receiver
1348 .recv().await
1349 .expect("Must send `on_remote_channel` notification.");
1350
1351 assert!(
1352 (channels2.lock().await.contains_key(&id)),
1353 "Must add channel into the map.",
1354 );
1355 }
1356 }
1357
1358 mod data_transfer {
1359 use futures::future;
1360 use cs_trace::{create_trace, child};
1361
1362 use super::*;
1363 use crate::{test::test_stream, Channel};
1364
1365 async fn open_channel(
1367 mut local_connection: Box<TransportConnection>,
1368 mut remote_connection: Box<TransportConnection>,
1369 buffer_size: u32,
1370 ) -> [(Box<TransportConnection>, Box<dyn Channel>); 2] {
1371 let (local, remote) = tokio::join!(
1372 Box::pin(async move {
1373 let local_channel = local_connection.channel("local-channel1", buffer_size).await
1374 .expect("Cannot create a channel.");
1375
1376 return (local_connection, local_channel);
1377 }),
1378 Box::pin(async move {
1379 let mut on_remote_channel = remote_connection
1380 .on_remote_channel().unwrap();
1381
1382 let remote_channel = on_remote_channel
1383 .recv().await
1384 .expect("Cannot receive a remote channel.");
1385
1386 remote_connection.off_remote_channel(on_remote_channel)
1387 .expect("Cannot set remote channel listener.");
1388
1389 return (remote_connection, remote_channel);
1390 }),
1391 );
1392
1393 return [local, remote];
1394 }
1395
1396 #[rstest]
1397 #[case(512)]
1398 #[case(1_024)]
1399 #[case(2_048)]
1400 #[case(4_096)]
1401 #[case(8_192)]
1402 #[case(16_384)]
1403 #[tokio::test]
1404 async fn transfers_data(
1405 #[case] data_len: usize,
1406 ) {
1407 let trace = create_trace!("test");
1408
1409 let buffer_size: u32 = 2_048;
1410
1411 let (channel1, channel2) = TransportChannel::new_pair(
1412 random_number(0..=u16::MAX),
1413 "transport-channels",
1414 channel_mock_pair(ChannelMockOptions::random(), ChannelMockOptions::random()),
1415 buffer_size,
1416 );
1417
1418 let trace1 = &trace;
1419 let trace1 = child!(trace1, "local");
1420
1421 let trace2 = &trace;
1422 let trace2 = child!(trace2, "remote");
1423
1424 let local_connection = TransportConnection::new(&trace1, channel1);
1425 let remote_connection = TransportConnection::new(&trace2, channel2);
1426
1427 let [
1428 (_local_connection, local_channel),
1429 (_remote_connection, remote_channel),
1430 ] = open_channel(
1431 local_connection,
1432 remote_connection,
1433 buffer_size,
1434 ).await;
1435
1436 test_stream(
1437 local_channel,
1438 remote_channel,
1439 TestOptions::random()
1440 .with_data_len(data_len),
1441 ).await;
1442
1443 }
1444
1445 #[rstest]
1446 #[case(512)]
1447 #[case(1_024)]
1448 #[case(2_048)]
1449 #[case(4_096)]
1450 #[case(8_192)]
1451 #[case(16_384)]
1452 #[tokio::test]
1453 async fn transfers_data_in_parallel(
1454 #[case] data_len: usize,
1455 ) {
1456 let trace = create_trace!("test");
1457
1458 let buffer_size: u32 = 2_048;
1459
1460 let (channel1, channel2) = TransportChannel::new_pair(
1461 random_number(0..=u16::MAX),
1462 "transport-channels",
1463 channel_mock_pair(ChannelMockOptions::random(), ChannelMockOptions::random()),
1464 buffer_size,
1465 );
1466
1467 let trace1 = &trace;
1468 let trace1 = child!(trace1, "local");
1469
1470 let trace2 = &trace;
1471 let trace2 = child!(trace2, "remote");
1472
1473 let mut local_connection = TransportConnection::new(&trace1, channel1);
1474 let mut remote_connection = TransportConnection::new(&trace2, channel2);
1475
1476 let mut tasks = vec![];
1477
1478 for _ in 0..random_number(5..=10) {
1479 let [
1480 (local_connection1, local_channel),
1481 (remote_connection1, remote_channel),
1482 ] = open_channel(
1483 local_connection,
1484 remote_connection,
1485 buffer_size,
1486 ).await;
1487
1488 local_connection = local_connection1;
1489 remote_connection = remote_connection1;
1490
1491 tasks.push(
1492 tokio::spawn(test_stream(
1493 local_channel,
1494 remote_channel,
1495 TestOptions::random()
1496 .with_data_len(data_len),
1497 )),
1498 );
1499
1500 wait_random(0..=50).await;
1501 }
1502
1503 future::try_join_all(tasks).await
1504 .unwrap();
1505 }
1506 }
1507}