1use std::{pin::Pin, task::{Context, Poll}, io, ops::RangeInclusive, fmt};
2
3use futures::{Future, ready};
4use cs_utils::{random_number, random_str, futures::wait_random, traits::Random};
5use tokio::{io::{duplex, AsyncRead, AsyncWrite, ReadBuf, DuplexStream}, sync::watch};
6
7use crate::Channel;
8
9pub struct ChannelMock<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static = DuplexStream> {
10 id: u16,
11 label: String,
12 channel: Pin<Box<TAsyncDuplex>>,
13 options: ChannelMockOptions,
14 read_delay_future: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
15 write_delay_future: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
16 on_close: watch::Receiver<bool>,
17 on_close_sender: watch::Sender<bool>,
18}
19
20impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> ChannelMock<TAsyncDuplex> {
21 pub fn new(
22 channel: Box<TAsyncDuplex>,
23 options: ChannelMockOptions,
24 ) -> Box<dyn Channel> {
25 let (on_close_sender, on_close) = watch::channel(false);
26
27 return Box::new(
28 ChannelMock {
29 id: options.id,
30 label: options.label.clone(),
31 channel: Pin::new(channel),
32 options,
33 read_delay_future: None,
34 write_delay_future: None,
35 on_close,
36 on_close_sender,
37 },
38 );
39 }
40}
41
42#[derive(Debug, Clone, PartialEq)]
43pub struct ChannelMockOptions {
44 id: u16,
45 label: String,
46 latency_range: RangeInclusive<u64>,
47 buffer_size: u32,
48}
49
50impl ChannelMockOptions {
51 pub fn with_id(
52 self,
53 id: u16,
54 ) -> ChannelMockOptions {
55 return ChannelMockOptions {
56 id,
57 ..self
58 };
59 }
60
61 pub fn with_label(
62 self,
63 label: impl AsRef<str> + ToString,
64 ) -> ChannelMockOptions {
65 return ChannelMockOptions {
66 label: label.to_string(),
67 ..self
68 };
69 }
70
71 pub fn with_latency(
72 self,
73 latency_range: RangeInclusive<u64>,
74 ) -> ChannelMockOptions {
75 return ChannelMockOptions {
76 latency_range,
77 ..self
78 };
79 }
80
81 pub fn with_buffer_size(
82 self,
83 buffer_size: u32,
84 ) -> ChannelMockOptions {
85 return ChannelMockOptions {
86 buffer_size,
87 ..self
88 };
89 }
90}
91
92impl Random for ChannelMockOptions {
93 fn random() -> Self {
94 let min = random_number(0..5);
95 let max = random_number(5..=50);
96
97 return ChannelMockOptions::default()
98 .with_latency(min..=max);
99 }
100}
101
102impl Default for ChannelMockOptions {
103 fn default() -> ChannelMockOptions {
104 return ChannelMockOptions {
105 id: random_number(0..=u16::MAX),
106 label: format!("channel-mock-{}", random_str(8)),
107 latency_range: (0..=0),
108 buffer_size: 4_096,
109 };
110 }
111}
112
113pub fn channel_mock_pair(
114 options1: ChannelMockOptions,
115 options2: ChannelMockOptions,
116) -> (Box<dyn Channel>, Box<dyn Channel>) {
117 let (channel1, channel2) = duplex(options1.buffer_size as usize);
118
119 return (
120 ChannelMock::new(Box::new(channel1), options1),
121 ChannelMock::new(Box::new(channel2), options2),
122 );
123}
124
125impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> Channel for ChannelMock<TAsyncDuplex> {
126 fn id(&self) -> u16 {
127 return self.id;
128 }
129
130 fn label(&self) -> &String {
131 return &self.label;
132 }
133
134 fn is_closed(&self) -> bool {
135 return *self.on_close.borrow();
136 }
137
138 fn on_close(&self) -> watch::Receiver<bool> {
139 return watch::Receiver::clone(&self.on_close);
140 }
141
142 fn buffer_size(&self) -> u32 {
143 return self.options.buffer_size;
144 }
145}
146
147impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> fmt::Debug for ChannelMock<TAsyncDuplex> {
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 return self.debug("ChannelMock", f);
150 }
151}
152
153impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncRead for ChannelMock<TAsyncDuplex> {
154 fn poll_read(
155 mut self: Pin<&mut Self>,
156 cx: &mut Context<'_>,
157 buf: &mut ReadBuf<'_>,
158 ) -> Poll<io::Result<()>> {
159 if let Some(read_delay_future) = self.read_delay_future.as_mut() {
161 ready!(read_delay_future.as_mut().poll(cx));
162
163 self.read_delay_future.take();
164 }
165
166 let result = ready!(self.channel.as_mut().poll_read(cx, buf));
168
169 self.read_delay_future = Some(Box::pin(wait_random(self.options.latency_range.clone())));
170
171 return Poll::Ready(result);
172 }
173}
174
175impl<TAsyncDuplex: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncWrite for ChannelMock<TAsyncDuplex> {
176 fn poll_write(
177 mut self: Pin<&mut Self>,
178 cx: &mut Context<'_>,
179 buf: &[u8],
180 ) -> Poll<io::Result<usize>> {
181 if let Some(write_delay_future) = self.write_delay_future.as_mut() {
183 ready!(write_delay_future.as_mut().poll(cx));
184
185 self.write_delay_future.take();
186 }
187
188 let result = ready!(self.channel.as_mut().poll_write(cx, buf));
189
190 self.write_delay_future = Some(Box::pin(wait_random(self.options.latency_range.clone())));
191
192 return Poll::Ready(result);
193 }
194
195 fn poll_flush(
196 mut self: Pin<&mut Self>,
197 cx: &mut Context<'_>,
198 ) -> Poll<io::Result<()>> {
199 return self.channel.as_mut()
200 .poll_flush(cx);
201 }
202
203 fn poll_shutdown(
204 mut self: Pin<&mut Self>,
205 cx: &mut Context<'_>,
206 ) -> Poll<io::Result<()>> {
207 if let Some(read_delay_future) = self.read_delay_future.as_mut() {
209 ready!(read_delay_future.as_mut().poll(cx));
210
211 self.read_delay_future.take();
212 }
213
214 let result = ready!(self.channel.as_mut().poll_shutdown(cx));
215
216 let _err = self.on_close_sender.send(true);
217
218 return Poll::Ready(result);
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use rstest::rstest;
225
226 use cs_utils::{traits::Random, random_number};
227
228 use crate::utils::create_framed_stream;
229 use crate::test::{TestStreamMessage, test_async_stream, test_framed_stream, TestOptions};
230
231 use super::channel_mock_pair;
232
233 #[rstest]
234 #[case(128)]
235 #[case(256)]
236 #[case(512)]
237 #[case(1_024)]
238 #[case(2_048)]
239 #[case(4_096)]
240 #[case(8_192)]
241 #[case(16_384)]
242 #[case(32_768)]
243 #[case(65_536)]
244 #[tokio::test]
245 async fn transfers_binary_data(
246 #[case] test_data_len: usize,
247 ) {
248 let (channel1, channel2) = channel_mock_pair(Random::random(), Random::random());
249
250 test_async_stream(
251 channel1,
252 channel2,
253 TestOptions::random()
254 .with_data_len(test_data_len),
255 ).await;
256 }
257
258 #[rstest]
259 #[case(random_number(6..=8))]
260 #[case(random_number(12..=16))]
261 #[case(random_number(25..=32))]
262 #[case(random_number(53..=64))]
263 #[case(random_number(100..=128))]
264 #[case(random_number(200..=256))]
265 #[tokio::test]
266 async fn transfers_stream_data(
267 #[case] items_count: usize,
268 ) {
269 let (channel1, channel2) = channel_mock_pair(Random::random(), Random::random());
270
271 let channel1 = create_framed_stream::<TestStreamMessage, _>(channel1);
272 let channel2 = create_framed_stream::<TestStreamMessage, _>(channel2);
273
274 test_framed_stream(
275 channel1,
276 channel2,
277 TestOptions::random()
278 .with_data_len(items_count),
279 ).await;
280 }
281}