nym_test_utils/mocks/
async_read_write.rs1use crate::mocks::shared::InnerWrapper;
5use futures::ready;
6use std::fmt::{Display, Formatter};
7use std::io;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU8, Ordering};
11use std::task::{Context, Poll};
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tracing::trace;
14
15const INIT_ID: &str = "initialiser";
16const RECV_ID: &str = "recipient";
17
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum Side {
20 Initialiser,
21 Recipient,
22}
23
24impl Display for Side {
25 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
26 match self {
27 Side::Initialiser => INIT_ID.fmt(f),
28 Side::Recipient => RECV_ID.fmt(f),
29 }
30 }
31}
32
33pub fn mock_io_streams() -> (MockIOStream, MockIOStream) {
36 let ch1 = MockIOStream::default();
37 let ch2 = ch1.make_connection();
38
39 (ch1, ch2)
40}
41
42pub struct MockIOStream {
43 id: Arc<AtomicU8>,
45
46 side: Side,
48
49 tx: InnerWrapper<Vec<u8>>,
51
52 rx: InnerWrapper<Vec<u8>>,
54}
55
56impl Default for MockIOStream {
57 fn default() -> Self {
58 MockIOStream {
59 id: Arc::new(AtomicU8::new(0)),
60 side: Side::Initialiser,
61 tx: Default::default(),
62 rx: Default::default(),
63 }
64 }
65}
66
67impl MockIOStream {
68 #[allow(clippy::panic)]
69 fn make_connection(&self) -> Self {
70 if self.side != Side::Initialiser {
71 panic!("attempted to make invalid connection")
72 }
73 MockIOStream {
74 id: self.id.clone(),
75 side: Side::Recipient,
76 tx: self.rx.cloned_buffer(),
77 rx: self.tx.cloned_buffer(),
78 }
79 }
80
81 pub fn set_id(&self, id: u8) {
82 self.id.store(id, Ordering::Relaxed)
83 }
84
85 pub fn try_get_remote_handle(&self) -> Self {
88 self.make_connection()
89 }
90
91 #[allow(clippy::unwrap_used)]
93 pub fn unchecked_tx_data(&self) -> Vec<u8> {
94 self.tx.buffer.try_lock().unwrap().content.clone()
95 }
96
97 #[allow(clippy::unwrap_used)]
99 pub fn unchecked_rx_data(&self) -> Vec<u8> {
100 self.rx.buffer.try_lock().unwrap().content.clone()
101 }
102
103 fn log_read(&self, bytes: usize) {
104 let id = self.id.load(Ordering::Relaxed);
105 if id == 0 {
106 trace!("[{}] read {bytes} bytes from mock stream", self.side)
107 } else {
108 trace!("[{}-{id}] read {bytes} bytes from mock stream", self.side)
109 }
110 }
111
112 fn log_write(&self, bytes: usize) {
113 let id = self.id.load(Ordering::Relaxed);
114
115 if id == 0 {
116 trace!("[{}] wrote {bytes} bytes to mock stream", self.side)
117 } else {
118 trace!("[{}-{id}] wrote {bytes} bytes to mock stream", self.side)
119 }
120 }
121}
122
123impl AsyncRead for MockIOStream {
124 fn poll_read(
125 mut self: Pin<&mut Self>,
126 cx: &mut Context<'_>,
127 buf: &mut ReadBuf<'_>,
128 ) -> Poll<io::Result<()>> {
129 ready!(Pin::new(&mut self.rx).poll_guard_ready(cx));
130
131 let unfilled = buf.remaining();
132
133 #[allow(clippy::unwrap_used)]
135 let guard = self.rx.guard().unwrap();
136
137 let data = guard.take_at_most(unfilled);
138 if data.is_empty() {
139 guard.waker = Some(cx.waker().clone());
141
142 self.rx.transition_to_idle();
144 return Poll::Pending;
145 }
146
147 self.log_read(data.len());
148 self.rx.transition_to_idle();
153
154 buf.put_slice(&data);
155 Poll::Ready(Ok(()))
156 }
157}
158
159impl AsyncWrite for MockIOStream {
160 fn poll_write(
161 mut self: Pin<&mut Self>,
162 cx: &mut Context<'_>,
163 buf: &[u8],
164 ) -> Poll<io::Result<usize>> {
165 ready!(Pin::new(&mut self.tx).poll_guard_ready(cx));
167
168 #[allow(clippy::unwrap_used)]
170 let guard = self.tx.guard().unwrap();
171
172 let len = buf.len();
173 guard.content.extend_from_slice(buf);
174
175 self.log_write(buf.len());
185
186 Poll::Ready(Ok(len))
187 }
188
189 fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
190 let Some(guard) = self.tx.guard() else {
191 return Poll::Ready(Err(io::Error::other(
192 "invalid lock state to send/flush messages",
193 )));
194 };
195
196 if let Some(waker) = guard.waker.take() {
197 waker.wake();
199 }
200
201 self.tx.transition_to_idle();
203
204 Poll::Ready(Ok(()))
205 }
206
207 fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208 self.tx.transition_to_idle();
210
211 Poll::Ready(Ok(()))
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use tokio::io::{AsyncReadExt, AsyncWriteExt};
219
220 #[tokio::test]
221 async fn basic() {
222 let (mut stream1, mut stream2) = mock_io_streams();
223 stream1.write_all(&[1, 2, 3, 4, 5]).await.unwrap();
224 stream1.flush().await.unwrap();
225
226 let mut buf = [0u8; 5];
227 let read = stream2.read(&mut buf).await.unwrap();
228 assert_eq!(read, 5);
229 assert_eq!(&buf[0..5], &[1, 2, 3, 4, 5]);
230
231 let mut buf = [0u8; 5];
232 stream2.write_all(&[6, 7, 8, 9, 10]).await.unwrap();
233 stream2.flush().await.unwrap();
234
235 let read = stream1.read(&mut buf).await.unwrap();
236 assert_eq!(read, 5);
237 assert_eq!(&buf[0..5], &[6, 7, 8, 9, 10]);
238 }
239}