connection_utils/mocks/
channel_mock.rs

1use std::{pin::Pin, task::{Context, Poll}, io, ops::RangeInclusive, fmt};
2
3use futures::{Future, ready};
4use cs_utils::{random_number, random_str, futures::wait_random, traits::Random};
5use tokio::{io::{duplex, AsyncRead, AsyncWrite, ReadBuf, DuplexStream}, sync::watch};
6
7use crate::Channel;
8
9pub struct ChannelMock<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static = DuplexStream> {
10    id: u16,
11    label: String,
12    channel: Pin<Box<TAsyncDuplex>>,
13    options: ChannelMockOptions,
14    read_delay_future: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
15    write_delay_future: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
16    on_close: watch::Receiver<bool>,
17    on_close_sender: watch::Sender<bool>,
18}
19
20impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> ChannelMock<TAsyncDuplex> {
21    pub fn new(
22        channel: Box<TAsyncDuplex>,
23        options: ChannelMockOptions,
24    ) -> Box<dyn Channel> {
25        let (on_close_sender, on_close) = watch::channel(false);
26
27        return Box::new(
28            ChannelMock {
29                id: options.id,
30                label: options.label.clone(),
31                channel: Pin::new(channel),
32                options,
33                read_delay_future: None,
34                write_delay_future: None,
35                on_close,
36                on_close_sender,
37            },
38        );
39    }
40}
41
42#[derive(Debug, Clone, PartialEq)]
43pub struct ChannelMockOptions {
44    id: u16,
45    label: String,
46    latency_range: RangeInclusive<u64>,
47    buffer_size: u32,
48}
49
50impl ChannelMockOptions {
51    pub fn with_id(
52        self,
53        id: u16,
54    ) -> ChannelMockOptions {
55        return ChannelMockOptions {
56            id,
57            ..self
58        };
59    }
60
61    pub fn with_label(
62        self,
63        label: impl AsRef<str> + ToString,
64    ) -> ChannelMockOptions {
65        return ChannelMockOptions {
66            label: label.to_string(),
67            ..self
68        };
69    }
70
71    pub fn with_latency(
72        self,
73        latency_range: RangeInclusive<u64>,
74    ) -> ChannelMockOptions {
75        return ChannelMockOptions {
76            latency_range,
77            ..self
78        };
79    }
80
81    pub fn with_buffer_size(
82        self,
83        buffer_size: u32,
84    ) -> ChannelMockOptions {
85        return ChannelMockOptions {
86            buffer_size,
87            ..self
88        };
89    }
90}
91
92impl Random for ChannelMockOptions {
93    fn random() -> Self {
94        let min = random_number(0..5);
95        let max = random_number(5..=50);
96
97        return ChannelMockOptions::default()
98            .with_latency(min..=max);
99    }
100}
101
102impl Default for ChannelMockOptions {
103    fn default() -> ChannelMockOptions {
104        return ChannelMockOptions {
105            id: random_number(0..=u16::MAX),
106            label: format!("channel-mock-{}", random_str(8)),
107            latency_range: (0..=0),
108            buffer_size: 4_096,
109        };
110    }
111}
112
113pub fn channel_mock_pair(
114    options1: ChannelMockOptions,
115    options2: ChannelMockOptions,
116) -> (Box<dyn Channel>, Box<dyn Channel>) {
117    let (channel1, channel2) = duplex(options1.buffer_size as usize);
118
119    return (
120        ChannelMock::new(Box::new(channel1), options1),
121        ChannelMock::new(Box::new(channel2), options2),
122    );
123}
124
125impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> Channel for ChannelMock<TAsyncDuplex> {
126    fn id(&self) -> u16 {
127        return self.id;
128    }
129
130    fn label(&self) ->  &String {
131        return &self.label;
132    }
133
134    fn is_closed(&self) -> bool {
135        return *self.on_close.borrow();
136    }
137
138    fn on_close(&self) -> watch::Receiver<bool> {
139        return watch::Receiver::clone(&self.on_close);
140    }
141
142    fn buffer_size(&self) -> u32 {
143        return self.options.buffer_size;
144    }
145}
146
147impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> fmt::Debug for ChannelMock<TAsyncDuplex> {
148    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149        return self.debug("ChannelMock", f);
150    }
151}
152
153impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncRead for ChannelMock<TAsyncDuplex> {
154    fn poll_read(
155        mut self: Pin<&mut Self>,
156        cx: &mut Context<'_>,
157        buf: &mut ReadBuf<'_>,
158    ) -> Poll<io::Result<()>> {
159        // if delay future present, wait until it completes
160        if let Some(read_delay_future) = self.read_delay_future.as_mut() {
161            ready!(read_delay_future.as_mut().poll(cx));
162
163            self.read_delay_future.take();
164        }
165
166        // otherwise run the read future to completion
167        let result = ready!(self.channel.as_mut().poll_read(cx, buf));
168
169        self.read_delay_future = Some(Box::pin(wait_random(self.options.latency_range.clone())));
170
171        return Poll::Ready(result);
172    }
173}
174
175impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncWrite for ChannelMock<TAsyncDuplex> {
176    fn poll_write(
177        mut self: Pin<&mut Self>,
178        cx: &mut Context<'_>,
179        buf: &[u8],
180    ) -> Poll<io::Result<usize>> {
181        // if delay future present, wait until it completes
182        if let Some(write_delay_future) = self.write_delay_future.as_mut() {
183            ready!(write_delay_future.as_mut().poll(cx));
184
185            self.write_delay_future.take();
186        }
187
188        let result = ready!(self.channel.as_mut().poll_write(cx, buf));
189
190        self.write_delay_future = Some(Box::pin(wait_random(self.options.latency_range.clone())));
191
192        return Poll::Ready(result);
193    }
194
195    fn poll_flush(
196        mut self: Pin<&mut Self>,
197        cx: &mut Context<'_>,
198    ) -> Poll<io::Result<()>> {
199        return self.channel.as_mut()
200            .poll_flush(cx);
201    }
202
203    fn poll_shutdown(
204        mut self: Pin<&mut Self>,
205        cx: &mut Context<'_>,
206    ) -> Poll<io::Result<()>> {
207        // if delay future present, wait until it completes
208        if let Some(read_delay_future) = self.read_delay_future.as_mut() {
209            ready!(read_delay_future.as_mut().poll(cx));
210
211            self.read_delay_future.take();
212        }
213
214        let result = ready!(self.channel.as_mut().poll_shutdown(cx));
215
216        let _err = self.on_close_sender.send(true);
217
218        return Poll::Ready(result);
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use rstest::rstest;
225    
226    use cs_utils::{traits::Random, random_number};
227    
228    use crate::utils::create_framed_stream;
229    use crate::test::{TestStreamMessage, test_async_stream, test_framed_stream, TestOptions};
230
231    use super::channel_mock_pair;
232
233    #[rstest]
234    #[case(128)]
235    #[case(256)]
236    #[case(512)]
237    #[case(1_024)]
238    #[case(2_048)]
239    #[case(4_096)]
240    #[case(8_192)]
241    #[case(16_384)]
242    #[case(32_768)]
243    #[case(65_536)]
244    #[tokio::test]
245    async fn transfers_binary_data(
246        #[case] test_data_len: usize,
247    ) {
248        let (channel1, channel2) = channel_mock_pair(Random::random(), Random::random());
249
250        test_async_stream(
251            channel1,
252            channel2,
253            TestOptions::random()
254                .with_data_len(test_data_len),
255        ).await;
256    }
257
258    #[rstest]
259    #[case(random_number(6..=8))]
260    #[case(random_number(12..=16))]
261    #[case(random_number(25..=32))]
262    #[case(random_number(53..=64))]
263    #[case(random_number(100..=128))]
264    #[case(random_number(200..=256))]
265    #[tokio::test]
266    async fn transfers_stream_data(
267        #[case] items_count: usize,
268    ) {
269        let (channel1, channel2) = channel_mock_pair(Random::random(), Random::random());
270
271        let channel1 = create_framed_stream::<TestStreamMessage, _>(channel1);
272        let channel2 = create_framed_stream::<TestStreamMessage, _>(channel2);
273
274        test_framed_stream(
275            channel1,
276            channel2,
277            TestOptions::random()
278                .with_data_len(items_count),
279        ).await;
280    }
281}