multiplexed_connection/
multiplexed_connection.rs

1use std::pin::Pin;
2
3use connection_utils::Disconnected;
4use cs_trace::{Tracer, create_trace};
5use tokio::io::{AsyncRead, AsyncWrite};
6
7mod rpc;
8
9pub mod disconnected;
10pub mod connected;
11
12pub struct MultiplexedConnection<TAsyncDuplex: AsyncRead + AsyncWrite + Send + 'static> {
13    trace: Box<dyn Tracer>,
14    stream: Option<Pin<Box<TAsyncDuplex>>>,
15}
16
17impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + 'static> MultiplexedConnection<TAsyncDuplex> {
18    pub fn new(
19        stream: TAsyncDuplex,
20    ) -> Box<dyn Disconnected> {
21        let trace = create_trace!("rpc-connection");
22
23        return Box::new(
24            MultiplexedConnection {
25                trace,
26                stream: Some(Box::pin(stream)),
27            }
28        );
29    }
30}
31
32#[cfg(test)]
33mod tests {
34    pub use cs_trace::{create_trace_listener, TraceListenerOptions, Trace, SubscriberInitExt, create_trace, child};
35    use cs_utils::futures::wait;
36    use connection_utils::test::test_async_stream;
37    use cs_utils::random_str;
38    use rstest::rstest;
39    use tokio::io::duplex;
40    use tokio::try_join;
41    use cs_utils::random_str_rg;
42
43    use super::MultiplexedConnection;
44
45    #[rstest]
46    #[case::size_8_32(8, 32)]
47    #[case::size_128_512(128, 512)]
48    #[case::size_2048_4096(2048, 4096)]
49    #[case::size_4096_8192(4096, 8192)]
50    #[case::size_8192_16384(8192, 16384)]
51    #[tokio::test]
52    async fn sends_data_from_local_channel(
53        #[case] str_min_size: usize,
54        #[case] str_max_size: usize,
55    ) {
56        use cs_utils::random_str_rg;
57
58        let (duplex1, duplex2) = duplex(4096);
59
60        let channel_label = format!("channel-label-{}", random_str(4));
61        let channel_label1 = channel_label.clone();
62        let channel_label2 = channel_label.clone();
63
64        let (channel1, channel2) = try_join!(
65            tokio::spawn(async move {
66                let mut connection1 = MultiplexedConnection::new(duplex1)
67                    .connect().await
68                    .expect("Error while listening.");
69
70                wait(50).await;
71                
72                let channel = connection1
73                    .channel(channel_label1.clone()).await
74                    .unwrap();
75                
76                assert_eq!(
77                    channel.label(),
78                    &channel_label1,
79                    "Channel labels must match.",
80                );
81
82                channel
83            }),
84            tokio::spawn(async move {
85                let mut connection2 = MultiplexedConnection::new(duplex2)
86                    .listen().await
87                    .expect("Error while listening.");
88
89                let mut on_remote_channel = connection2.on_remote_channel().unwrap();
90
91                let channel = {
92                    loop {
93                        if let Ok(channel) = on_remote_channel.try_recv() {
94                            break channel;
95                        }
96
97                        wait(50).await;
98                    }
99                };
100
101                assert_eq!(
102                    channel.label(),
103                    &channel_label2,
104                    "Channel labels must match.",
105                );
106
107                channel
108            }),
109        ).unwrap();
110
111        let test_data = vec![
112            random_str_rg(str_min_size..=str_max_size),
113            random_str_rg(str_min_size..=str_max_size),
114            random_str_rg(str_min_size..=str_max_size),
115            random_str_rg(str_min_size..=str_max_size),
116            random_str_rg(str_min_size..=str_max_size),
117            random_str_rg(str_min_size..=str_max_size),
118            random_str_rg(str_min_size..=str_max_size),
119        ].join("");
120
121        test_async_stream(
122            channel1,
123            channel2,
124            test_data,
125        ).await;
126    }
127
128    #[rstest]
129    #[case::size_8_32(8, 32)]
130    #[case::size_128_512(128, 512)]
131    #[case::size_2048_4096(2048, 4096)]
132    #[case::size_4096_8192(4096, 8192)]
133    #[case::size_8192_16384(8192, 16384)]
134    #[tokio::test]
135    async fn sends_data_from_remote_channel(
136        #[case] str_min_size: usize,
137        #[case] str_max_size: usize,
138    ) {
139        let (duplex1, duplex2) = duplex(4096);
140
141        let channel_label = format!("channel-label-{}", random_str(4));
142        let channel_label1 = channel_label.clone();
143        let channel_label2 = channel_label.clone();
144
145        let (channel1, channel2) = try_join!(
146            tokio::spawn(async move {
147                let mut connection1 = MultiplexedConnection::new(duplex1)
148                    .connect().await
149                    .expect("Error while listening");
150
151                wait(50).await;
152                
153                let channel = connection1
154                    .channel(channel_label1.clone()).await
155                    .unwrap();
156
157                assert_eq!(
158                    channel.label(),
159                    &channel_label1,
160                    "Channel labels must match.",
161                );
162
163                channel
164            }),
165            tokio::spawn(async move {
166                let mut connection2 = MultiplexedConnection::new(duplex2)
167                    .listen().await
168                    .expect("Error while listening.");
169
170                let mut on_remote_channel = connection2.on_remote_channel().unwrap();
171
172                let channel = {
173                    loop {
174                        if let Ok(channel) = on_remote_channel.try_recv() {
175                            break channel;
176                        }
177
178                        wait(50).await;
179                    }
180                };
181
182                assert_eq!(
183                    channel.label(),
184                    &channel_label2,
185                    "Channel labels must match.",
186                );
187
188                channel
189            }),
190        ).unwrap();
191
192        let test_data = vec![
193            random_str_rg(str_min_size..=str_max_size),
194            random_str_rg(str_min_size..=str_max_size),
195            random_str_rg(str_min_size..=str_max_size),
196            random_str_rg(str_min_size..=str_max_size),
197            random_str_rg(str_min_size..=str_max_size),
198            random_str_rg(str_min_size..=str_max_size),
199            random_str_rg(str_min_size..=str_max_size),
200        ].join("");
201
202        test_async_stream(
203            channel1,
204            channel2,
205            test_data,
206        ).await;
207    }
208}