nym_test_utils/mocks/
async_read_write.rs

1// Copyright 2025 - Nym Technologies SA <contact@nymtech.net>
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::mocks::shared::InnerWrapper;
5use futures::ready;
6use std::fmt::{Display, Formatter};
7use std::io;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU8, Ordering};
11use std::task::{Context, Poll};
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tracing::trace;
14
15const INIT_ID: &str = "initialiser";
16const RECV_ID: &str = "recipient";
17
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum Side {
20    Initialiser,
21    Recipient,
22}
23
24impl Display for Side {
25    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
26        match self {
27            Side::Initialiser => INIT_ID.fmt(f),
28            Side::Recipient => RECV_ID.fmt(f),
29        }
30    }
31}
32
33// sending buffer of the first stream is the receiving buffer of the second stream
34// and vice versa
35pub fn mock_io_streams() -> (MockIOStream, MockIOStream) {
36    let ch1 = MockIOStream::default();
37    let ch2 = ch1.make_connection();
38
39    (ch1, ch2)
40}
41
42pub struct MockIOStream {
43    // identifier to use for logging purposes
44    id: Arc<AtomicU8>,
45
46    // side of the stream to use for logging purposes
47    side: Side,
48
49    // messages to send
50    tx: InnerWrapper<Vec<u8>>,
51
52    // messages to receive
53    rx: InnerWrapper<Vec<u8>>,
54}
55
56impl Default for MockIOStream {
57    fn default() -> Self {
58        MockIOStream {
59            id: Arc::new(AtomicU8::new(0)),
60            side: Side::Initialiser,
61            tx: Default::default(),
62            rx: Default::default(),
63        }
64    }
65}
66
67impl MockIOStream {
68    #[allow(clippy::panic)]
69    fn make_connection(&self) -> Self {
70        if self.side != Side::Initialiser {
71            panic!("attempted to make invalid connection")
72        }
73        MockIOStream {
74            id: self.id.clone(),
75            side: Side::Recipient,
76            tx: self.rx.cloned_buffer(),
77            rx: self.tx.cloned_buffer(),
78        }
79    }
80
81    pub fn set_id(&self, id: u8) {
82        self.id.store(id, Ordering::Relaxed)
83    }
84
85    // the prefix `try_` is due to the fact that if the mock is cloned at an invalid state,
86    // `assert!` will fail causing panic (which is fine in **test** code)
87    pub fn try_get_remote_handle(&self) -> Self {
88        self.make_connection()
89    }
90
91    // unwrap in test code is fine
92    #[allow(clippy::unwrap_used)]
93    pub fn unchecked_tx_data(&self) -> Vec<u8> {
94        self.tx.buffer.try_lock().unwrap().content.clone()
95    }
96
97    // unwrap in test code is fine
98    #[allow(clippy::unwrap_used)]
99    pub fn unchecked_rx_data(&self) -> Vec<u8> {
100        self.rx.buffer.try_lock().unwrap().content.clone()
101    }
102
103    fn log_read(&self, bytes: usize) {
104        let id = self.id.load(Ordering::Relaxed);
105        if id == 0 {
106            trace!("[{}] read {bytes} bytes from mock stream", self.side)
107        } else {
108            trace!("[{}-{id}] read {bytes} bytes from mock stream", self.side)
109        }
110    }
111
112    fn log_write(&self, bytes: usize) {
113        let id = self.id.load(Ordering::Relaxed);
114
115        if id == 0 {
116            trace!("[{}] wrote {bytes} bytes to mock stream", self.side)
117        } else {
118            trace!("[{}-{id}] wrote {bytes} bytes to mock stream", self.side)
119        }
120    }
121}
122
123impl AsyncRead for MockIOStream {
124    fn poll_read(
125        mut self: Pin<&mut Self>,
126        cx: &mut Context<'_>,
127        buf: &mut ReadBuf<'_>,
128    ) -> Poll<io::Result<()>> {
129        ready!(Pin::new(&mut self.rx).poll_guard_ready(cx));
130
131        let unfilled = buf.remaining();
132
133        // SAFETY: guard is ready
134        #[allow(clippy::unwrap_used)]
135        let guard = self.rx.guard().unwrap();
136
137        let data = guard.take_at_most(unfilled);
138        if data.is_empty() {
139            // nothing to retrieve - store the waiter so that the sender could trigger it
140            guard.waker = Some(cx.waker().clone());
141
142            // drop the guard so that the sender could actually put messages in
143            self.rx.transition_to_idle();
144            return Poll::Pending;
145        }
146
147        self.log_read(data.len());
148        // if let Some(waker) = guard.waker.take() {
149        //     waker.wake();
150        // }
151
152        self.rx.transition_to_idle();
153
154        buf.put_slice(&data);
155        Poll::Ready(Ok(()))
156    }
157}
158
159impl AsyncWrite for MockIOStream {
160    fn poll_write(
161        mut self: Pin<&mut Self>,
162        cx: &mut Context<'_>,
163        buf: &[u8],
164    ) -> Poll<io::Result<usize>> {
165        // wait until we transition to the locked state
166        ready!(Pin::new(&mut self.tx).poll_guard_ready(cx));
167
168        // SAFETY: guard is ready
169        #[allow(clippy::unwrap_used)]
170        let guard = self.tx.guard().unwrap();
171
172        let len = buf.len();
173        guard.content.extend_from_slice(buf);
174
175        // TODO: if we wanted the behaviour of always reading everything before writing anything extra
176        // if !guard.content.is_empty() {
177        //     // sanity check
178        //     assert!(guard.waker.is_none());
179        //     guard.waker = Some(cx.waker().clone());
180        //     self.tx.transition_to_idle();
181        //     return Poll::Pending;
182        // }
183
184        self.log_write(buf.len());
185
186        Poll::Ready(Ok(len))
187    }
188
189    fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
190        let Some(guard) = self.tx.guard() else {
191            return Poll::Ready(Err(io::Error::other(
192                "invalid lock state to send/flush messages",
193            )));
194        };
195
196        if let Some(waker) = guard.waker.take() {
197            // notify the receiver if it was waiting for messages
198            waker.wake();
199        }
200
201        // release the guard
202        self.tx.transition_to_idle();
203
204        Poll::Ready(Ok(()))
205    }
206
207    fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208        // make sure our guard is always dropped on close
209        self.tx.transition_to_idle();
210
211        Poll::Ready(Ok(()))
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use tokio::io::{AsyncReadExt, AsyncWriteExt};
219
220    #[tokio::test]
221    async fn basic() {
222        let (mut stream1, mut stream2) = mock_io_streams();
223        stream1.write_all(&[1, 2, 3, 4, 5]).await.unwrap();
224        stream1.flush().await.unwrap();
225
226        let mut buf = [0u8; 5];
227        let read = stream2.read(&mut buf).await.unwrap();
228        assert_eq!(read, 5);
229        assert_eq!(&buf[0..5], &[1, 2, 3, 4, 5]);
230
231        let mut buf = [0u8; 5];
232        stream2.write_all(&[6, 7, 8, 9, 10]).await.unwrap();
233        stream2.flush().await.unwrap();
234
235        let read = stream1.read(&mut buf).await.unwrap();
236        assert_eq!(read, 5);
237        assert_eq!(&buf[0..5], &[6, 7, 8, 9, 10]);
238    }
239}