distant_net/common/transport/
inmemory.rs1use std::io;
2use std::sync::{Mutex, MutexGuard};
3
4use async_trait::async_trait;
5use tokio::sync::mpsc::error::{TryRecvError, TrySendError};
6use tokio::sync::mpsc::{self};
7
8use super::{Interest, Ready, Reconnectable, Transport};
9
10#[derive(Debug)]
12pub struct InmemoryTransport {
13 tx: mpsc::Sender<Vec<u8>>,
14 rx: Mutex<mpsc::Receiver<Vec<u8>>>,
15
16 buf: Mutex<Option<Vec<u8>>>,
18}
19
20impl InmemoryTransport {
21 pub fn new(tx: mpsc::Sender<Vec<u8>>, rx: mpsc::Receiver<Vec<u8>>) -> Self {
27 Self {
28 tx,
29 rx: Mutex::new(rx),
30 buf: Mutex::new(None),
31 }
32 }
33
34 pub fn make(buffer: usize) -> (mpsc::Sender<Vec<u8>>, mpsc::Receiver<Vec<u8>>, Self) {
41 let (incoming_tx, incoming_rx) = mpsc::channel(buffer);
42 let (outgoing_tx, outgoing_rx) = mpsc::channel(buffer);
43
44 (
45 incoming_tx,
46 outgoing_rx,
47 Self::new(outgoing_tx, incoming_rx),
48 )
49 }
50
51 pub fn pair(buffer: usize) -> (Self, Self) {
54 let (tx, rx, transport) = Self::make(buffer);
55 (transport, Self::new(tx, rx))
56 }
57
58 pub fn link(&mut self, other: &mut InmemoryTransport, buffer: usize) {
65 let (incoming_tx, incoming_rx) = mpsc::channel(buffer);
66 let (outgoing_tx, outgoing_rx) = mpsc::channel(buffer);
67
68 self.buf = Mutex::new(None);
69 self.tx = outgoing_tx;
70 self.rx = Mutex::new(incoming_rx);
71
72 other.buf = Mutex::new(None);
73 other.tx = incoming_tx;
74 other.rx = Mutex::new(outgoing_rx);
75 }
76
77 fn is_rx_closed(&self) -> bool {
87 match self.rx.lock().unwrap().try_recv() {
88 Ok(mut data) => {
89 let mut buf_lock = self.buf.lock().unwrap();
90
91 let data = match buf_lock.take() {
92 Some(mut existing) => {
93 existing.append(&mut data);
94 existing
95 }
96 None => data,
97 };
98
99 *buf_lock = Some(data);
100
101 false
102 }
103 Err(TryRecvError::Empty) => false,
104 Err(TryRecvError::Disconnected) => true,
105 }
106 }
107}
108
109#[async_trait]
110impl Reconnectable for InmemoryTransport {
111 async fn reconnect(&mut self) -> io::Result<()> {
117 if self.tx.is_closed() || self.is_rx_closed() {
118 Err(io::Error::from(io::ErrorKind::ConnectionRefused))
119 } else {
120 Ok(())
121 }
122 }
123}
124
125#[async_trait]
126impl Transport for InmemoryTransport {
127 fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
128 let mut buf_lock = self.buf.lock().unwrap();
131
132 if let Some(data) = buf_lock.take() {
134 return Ok(copy_and_store(buf_lock, data, buf));
135 }
136
137 match self.rx.lock().unwrap().try_recv() {
138 Ok(data) => Ok(copy_and_store(buf_lock, data, buf)),
139 Err(TryRecvError::Empty) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
140 Err(TryRecvError::Disconnected) => Ok(0),
141 }
142 }
143
144 fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
145 match self.tx.try_send(buf.to_vec()) {
146 Ok(()) => Ok(buf.len()),
147 Err(TrySendError::Full(_)) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
148 Err(TrySendError::Closed(_)) => Ok(0),
149 }
150 }
151
152 async fn ready(&self, interest: Interest) -> io::Result<Ready> {
153 let mut status = Ready::EMPTY;
154
155 if interest.is_readable() {
156 status |= if self.is_rx_closed() && self.buf.lock().unwrap().is_none() {
161 Ready::READ_CLOSED
162 } else {
163 Ready::READABLE
164 };
165 }
166
167 if interest.is_writable() {
168 status |= if self.tx.is_closed() {
169 Ready::WRITE_CLOSED
170 } else {
171 Ready::WRITABLE
172 };
173 }
174
175 Ok(status)
176 }
177}
178
179fn copy_and_store(
182 mut buf_lock: MutexGuard<Option<Vec<u8>>>,
183 mut data: Vec<u8>,
184 out: &mut [u8],
185) -> usize {
186 if data.len() > out.len() {
190 let n = out.len();
191 out.copy_from_slice(&data[..n]);
192 *buf_lock = Some(data.split_off(n));
193 n
194 } else {
195 let n = data.len();
196 out[..n].copy_from_slice(&data);
197 n
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use test_log::test;
204
205 use super::*;
206 use crate::common::TransportExt;
207
208 #[test]
209 fn is_rx_closed_should_properly_reflect_if_internal_rx_channel_is_closed() {
210 let (write_tx, _write_rx) = mpsc::channel(1);
211 let (read_tx, read_rx) = mpsc::channel(1);
212
213 let transport = InmemoryTransport::new(write_tx, read_rx);
214
215 assert!(!transport.is_rx_closed());
217
218 read_tx.try_send(b"some bytes".to_vec()).unwrap();
219
220 assert!(!transport.is_rx_closed());
222 assert_eq!(
223 transport.buf.lock().unwrap().as_deref().unwrap(),
224 b"some bytes"
225 );
226
227 read_tx.try_send(b"more".to_vec()).unwrap();
229 drop(read_tx);
230
231 assert!(!transport.is_rx_closed());
233 assert_eq!(
234 transport.buf.lock().unwrap().as_deref().unwrap(),
235 b"some bytesmore"
236 );
237
238 assert!(transport.is_rx_closed());
240 assert_eq!(
241 transport.buf.lock().unwrap().as_deref().unwrap(),
242 b"some bytesmore"
243 );
244 }
245
246 #[test]
247 fn try_read_should_succeed_if_able_to_read_entire_data_through_channel() {
248 let (write_tx, _write_rx) = mpsc::channel(1);
249 let (read_tx, read_rx) = mpsc::channel(1);
250
251 let transport = InmemoryTransport::new(write_tx, read_rx);
252
253 read_tx.try_send(b"some bytes".to_vec()).unwrap();
255
256 let mut buf = [0; 10];
257 assert_eq!(transport.try_read(&mut buf).unwrap(), 10);
258 assert_eq!(&buf[..10], b"some bytes");
259 }
260
261 #[test]
262 fn try_read_should_succeed_if_reading_cached_data_from_previous_read() {
263 let (write_tx, _write_rx) = mpsc::channel(1);
264 let (read_tx, read_rx) = mpsc::channel(1);
265
266 let transport = InmemoryTransport::new(write_tx, read_rx);
267
268 read_tx.try_send(b"some bytes".to_vec()).unwrap();
270
271 let mut buf = [0; 5];
272 assert_eq!(transport.try_read(&mut buf).unwrap(), 5);
273 assert_eq!(&buf[..5], b"some ");
274
275 read_tx.try_send(b"more".to_vec()).unwrap();
277
278 let mut buf = [0; 2];
279 assert_eq!(transport.try_read(&mut buf).unwrap(), 2);
280 assert_eq!(&buf[..2], b"by");
281
282 let mut buf = [0; 5];
284 assert_eq!(transport.try_read(&mut buf).unwrap(), 3);
285 assert_eq!(&buf[..3], b"tes");
286
287 let mut buf = [0; 5];
288 assert_eq!(transport.try_read(&mut buf).unwrap(), 4);
289 assert_eq!(&buf[..4], b"more");
290 }
291
292 #[test]
293 fn try_read_should_fail_with_would_block_if_channel_is_empty() {
294 let (write_tx, _write_rx) = mpsc::channel(1);
295 let (_read_tx, read_rx) = mpsc::channel(1);
296
297 let transport = InmemoryTransport::new(write_tx, read_rx);
298
299 assert_eq!(
300 transport.try_read(&mut [0; 5]).unwrap_err().kind(),
301 io::ErrorKind::WouldBlock
302 );
303 }
304
305 #[test]
306 fn try_read_should_succeed_with_zero_bytes_read_if_channel_closed() {
307 let (write_tx, _write_rx) = mpsc::channel(1);
308 let (read_tx, read_rx) = mpsc::channel(1);
309
310 drop(read_tx);
312
313 let transport = InmemoryTransport::new(write_tx, read_rx);
314 assert_eq!(transport.try_read(&mut [0; 5]).unwrap(), 0);
315 }
316
317 #[test]
318 fn try_write_should_succeed_if_able_to_send_data_through_channel() {
319 let (write_tx, _write_rx) = mpsc::channel(1);
320 let (_read_tx, read_rx) = mpsc::channel(1);
321
322 let transport = InmemoryTransport::new(write_tx, read_rx);
323
324 let value = b"some bytes";
325 assert_eq!(transport.try_write(value).unwrap(), value.len());
326 }
327
328 #[test]
329 fn try_write_should_fail_with_would_block_if_channel_capacity_has_been_reached() {
330 let (write_tx, _write_rx) = mpsc::channel(1);
331 let (_read_tx, read_rx) = mpsc::channel(1);
332
333 let transport = InmemoryTransport::new(write_tx, read_rx);
334
335 transport
337 .try_write(b"some bytes")
338 .expect("Failed to fill channel");
339
340 assert_eq!(
341 transport.try_write(b"some bytes").unwrap_err().kind(),
342 io::ErrorKind::WouldBlock
343 );
344 }
345
346 #[test]
347 fn try_write_should_succeed_with_zero_bytes_written_if_channel_closed() {
348 let (write_tx, write_rx) = mpsc::channel(1);
349 let (_read_tx, read_rx) = mpsc::channel(1);
350
351 drop(write_rx);
353
354 let transport = InmemoryTransport::new(write_tx, read_rx);
355 assert_eq!(transport.try_write(b"some bytes").unwrap(), 0);
356 }
357
358 #[test(tokio::test)]
359 async fn reconnect_should_fail_if_read_channel_closed() {
360 let (write_tx, _write_rx) = mpsc::channel(1);
361 let (_, read_rx) = mpsc::channel(1);
362 let mut transport = InmemoryTransport::new(write_tx, read_rx);
363
364 assert_eq!(
365 transport.reconnect().await.unwrap_err().kind(),
366 io::ErrorKind::ConnectionRefused
367 );
368 }
369
370 #[test(tokio::test)]
371 async fn reconnect_should_fail_if_write_channel_closed() {
372 let (write_tx, _) = mpsc::channel(1);
373 let (_read_tx, read_rx) = mpsc::channel(1);
374 let mut transport = InmemoryTransport::new(write_tx, read_rx);
375
376 assert_eq!(
377 transport.reconnect().await.unwrap_err().kind(),
378 io::ErrorKind::ConnectionRefused
379 );
380 }
381
382 #[test(tokio::test)]
383 async fn reconnect_should_succeed_if_both_channels_open() {
384 let (write_tx, _write_rx) = mpsc::channel(1);
385 let (_read_tx, read_rx) = mpsc::channel(1);
386 let mut transport = InmemoryTransport::new(write_tx, read_rx);
387
388 transport.reconnect().await.unwrap();
389 }
390
391 #[test(tokio::test)]
392 async fn ready_should_report_read_closed_if_channel_closed_and_internal_buf_empty() {
393 let (write_tx, _write_rx) = mpsc::channel(1);
394 let (read_tx, read_rx) = mpsc::channel(1);
395
396 drop(read_tx);
398
399 let transport = InmemoryTransport::new(write_tx, read_rx);
400 let ready = transport.ready(Interest::READABLE).await.unwrap();
401 assert!(ready.is_readable());
402 assert!(ready.is_read_closed());
403 }
404
405 #[test(tokio::test)]
406 async fn ready_should_report_readable_if_channel_not_closed() {
407 let (write_tx, _write_rx) = mpsc::channel(1);
408 let (_read_tx, read_rx) = mpsc::channel(1);
409
410 let transport = InmemoryTransport::new(write_tx, read_rx);
411 let ready = transport.ready(Interest::READABLE).await.unwrap();
412 assert!(ready.is_readable());
413 assert!(!ready.is_read_closed());
414 }
415
416 #[test(tokio::test)]
417 async fn ready_should_report_readable_if_internal_buf_not_empty() {
418 let (write_tx, _write_rx) = mpsc::channel(1);
419 let (read_tx, read_rx) = mpsc::channel(1);
420
421 drop(read_tx);
423
424 let transport = InmemoryTransport::new(write_tx, read_rx);
425
426 *transport.buf.lock().unwrap() = Some(vec![1]);
428
429 let ready = transport.ready(Interest::READABLE).await.unwrap();
430 assert!(ready.is_readable());
431 assert!(!ready.is_read_closed());
432 }
433
434 #[test(tokio::test)]
435 async fn ready_should_report_writable_if_channel_not_closed() {
436 let (write_tx, _write_rx) = mpsc::channel(1);
437 let (_read_tx, read_rx) = mpsc::channel(1);
438
439 let transport = InmemoryTransport::new(write_tx, read_rx);
440 let ready = transport.ready(Interest::WRITABLE).await.unwrap();
441 assert!(ready.is_writable());
442 assert!(!ready.is_write_closed());
443 }
444
445 #[test(tokio::test)]
446 async fn ready_should_report_write_closed_if_channel_closed() {
447 let (write_tx, write_rx) = mpsc::channel(1);
448 let (_read_tx, read_rx) = mpsc::channel(1);
449
450 drop(write_rx);
452
453 let transport = InmemoryTransport::new(write_tx, read_rx);
454 let ready = transport.ready(Interest::WRITABLE).await.unwrap();
455 assert!(ready.is_writable());
456 assert!(ready.is_write_closed());
457 }
458
459 #[test(tokio::test)]
460 async fn make_should_return_sender_that_sends_data_to_transport() {
461 let (tx, _, transport) = InmemoryTransport::make(3);
462
463 tx.send(b"test msg 1".to_vec()).await.unwrap();
464 tx.send(b"test msg 2".to_vec()).await.unwrap();
465 tx.send(b"test msg 3".to_vec()).await.unwrap();
466
467 let mut buf = [0; 256];
469 let len = transport.try_read(&mut buf).unwrap();
470 assert_eq!(&buf[..len], b"test msg 1");
471
472 let len = transport.try_read(&mut buf).unwrap();
474 assert_eq!(&buf[..len], b"test msg 2");
475
476 drop(tx);
480
481 let len = transport.try_read(&mut buf).unwrap();
482 assert_eq!(&buf[..len], b"test msg 3");
483
484 let len = transport.try_read(&mut buf).unwrap();
485 assert_eq!(len, 0, "Unexpectedly got more data");
486 }
487
488 #[test(tokio::test)]
489 async fn make_should_return_receiver_that_receives_data_from_transport() {
490 let (_, mut rx, transport) = InmemoryTransport::make(3);
491
492 transport.write_all(b"test msg 1").await.unwrap();
493 transport.write_all(b"test msg 2").await.unwrap();
494 transport.write_all(b"test msg 3").await.unwrap();
495
496 assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec()));
498
499 assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec()));
501
502 drop(transport);
506
507 assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec()));
508
509 assert_eq!(rx.recv().await, None, "Unexpectedly got more data");
510 }
511}