distant_net/common/transport/
inmemory.rs

1use std::io;
2use std::sync::{Mutex, MutexGuard};
3
4use async_trait::async_trait;
5use tokio::sync::mpsc::error::{TryRecvError, TrySendError};
6use tokio::sync::mpsc::{self};
7
8use super::{Interest, Ready, Reconnectable, Transport};
9
10/// Represents a [`Transport`] comprised of two inmemory channels
11#[derive(Debug)]
12pub struct InmemoryTransport {
13    tx: mpsc::Sender<Vec<u8>>,
14    rx: Mutex<mpsc::Receiver<Vec<u8>>>,
15
16    /// Internal storage used when we get more data from a `try_read` than can be returned
17    buf: Mutex<Option<Vec<u8>>>,
18}
19
20impl InmemoryTransport {
21    /// Creates a new transport where `tx` is used to send data out of the transport during
22    /// [`try_write`] and `rx` is used to receive data into the transport during [`try_read`].
23    ///
24    /// [`try_read`]: Transport::try_read
25    /// [`try_write`]: Transport::try_write
26    pub fn new(tx: mpsc::Sender<Vec<u8>>, rx: mpsc::Receiver<Vec<u8>>) -> Self {
27        Self {
28            tx,
29            rx: Mutex::new(rx),
30            buf: Mutex::new(None),
31        }
32    }
33
34    /// Returns (incoming_tx, outgoing_rx, transport) where `incoming_tx` is used to send data to
35    /// the transport where it will be consumed during [`try_read`] and `outgoing_rx` is used to
36    /// receive data from the transport when it is written using [`try_write`].
37    ///
38    /// [`try_read`]: Transport::try_read
39    /// [`try_write`]: Transport::try_write
40    pub fn make(buffer: usize) -> (mpsc::Sender<Vec<u8>>, mpsc::Receiver<Vec<u8>>, Self) {
41        let (incoming_tx, incoming_rx) = mpsc::channel(buffer);
42        let (outgoing_tx, outgoing_rx) = mpsc::channel(buffer);
43
44        (
45            incoming_tx,
46            outgoing_rx,
47            Self::new(outgoing_tx, incoming_rx),
48        )
49    }
50
51    /// Returns pair of transports that are connected such that one sends to the other and
52    /// vice versa
53    pub fn pair(buffer: usize) -> (Self, Self) {
54        let (tx, rx, transport) = Self::make(buffer);
55        (transport, Self::new(tx, rx))
56    }
57
58    /// Links two independent [`InmemoryTransport`] together by dropping their internal channels
59    /// and generating new ones of `buffer` capacity to connect these transports.
60    ///
61    /// ### Note
62    ///
63    /// This will drop any pre-existing data in the internal storage to avoid corruption.
64    pub fn link(&mut self, other: &mut InmemoryTransport, buffer: usize) {
65        let (incoming_tx, incoming_rx) = mpsc::channel(buffer);
66        let (outgoing_tx, outgoing_rx) = mpsc::channel(buffer);
67
68        self.buf = Mutex::new(None);
69        self.tx = outgoing_tx;
70        self.rx = Mutex::new(incoming_rx);
71
72        other.buf = Mutex::new(None);
73        other.tx = incoming_tx;
74        other.rx = Mutex::new(outgoing_rx);
75    }
76
77    /// Returns true if the read channel is closed, meaning it will no longer receive more data.
78    /// This does not factor in data remaining in the internal buffer, meaning that this may return
79    /// true while the transport still has data remaining in the internal buffer.
80    ///
81    /// NOTE: Because there is no `is_closed` on the receiver, we have to actually try to
82    ///       read from the receiver to see if it is disconnected, adding any received data
83    ///       to our internal buffer if it is not disconnected and has data available
84    ///
85    /// Track https://github.com/tokio-rs/tokio/issues/4638 for future `is_closed` on rx
86    fn is_rx_closed(&self) -> bool {
87        match self.rx.lock().unwrap().try_recv() {
88            Ok(mut data) => {
89                let mut buf_lock = self.buf.lock().unwrap();
90
91                let data = match buf_lock.take() {
92                    Some(mut existing) => {
93                        existing.append(&mut data);
94                        existing
95                    }
96                    None => data,
97                };
98
99                *buf_lock = Some(data);
100
101                false
102            }
103            Err(TryRecvError::Empty) => false,
104            Err(TryRecvError::Disconnected) => true,
105        }
106    }
107}
108
109#[async_trait]
110impl Reconnectable for InmemoryTransport {
111    /// Once the underlying channels have closed, there is no way for this transport to
112    /// re-establish those channels; therefore, reconnecting will fail with
113    /// [`ErrorKind::ConnectionRefused`] if either underlying channel has closed.
114    ///
115    /// [`ErrorKind::ConnectionRefused`]: io::ErrorKind::ConnectionRefused
116    async fn reconnect(&mut self) -> io::Result<()> {
117        if self.tx.is_closed() || self.is_rx_closed() {
118            Err(io::Error::from(io::ErrorKind::ConnectionRefused))
119        } else {
120            Ok(())
121        }
122    }
123}
124
125#[async_trait]
126impl Transport for InmemoryTransport {
127    fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
128        // Lock our internal storage to ensure that nothing else mutates it for the lifetime of
129        // this call as we want to make sure that data is read and stored in order
130        let mut buf_lock = self.buf.lock().unwrap();
131
132        // Check if we have data in our internal buffer, and if so feed it into the outgoing buf
133        if let Some(data) = buf_lock.take() {
134            return Ok(copy_and_store(buf_lock, data, buf));
135        }
136
137        match self.rx.lock().unwrap().try_recv() {
138            Ok(data) => Ok(copy_and_store(buf_lock, data, buf)),
139            Err(TryRecvError::Empty) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
140            Err(TryRecvError::Disconnected) => Ok(0),
141        }
142    }
143
144    fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
145        match self.tx.try_send(buf.to_vec()) {
146            Ok(()) => Ok(buf.len()),
147            Err(TrySendError::Full(_)) => Err(io::Error::from(io::ErrorKind::WouldBlock)),
148            Err(TrySendError::Closed(_)) => Ok(0),
149        }
150    }
151
152    async fn ready(&self, interest: Interest) -> io::Result<Ready> {
153        let mut status = Ready::EMPTY;
154
155        if interest.is_readable() {
156            // TODO: Replace `self.is_rx_closed()` with `self.rx.is_closed()` once the tokio issue
157            //       is resolved that adds `is_closed` to the `mpsc::Receiver`
158            //
159            // See https://github.com/tokio-rs/tokio/issues/4638
160            status |= if self.is_rx_closed() && self.buf.lock().unwrap().is_none() {
161                Ready::READ_CLOSED
162            } else {
163                Ready::READABLE
164            };
165        }
166
167        if interest.is_writable() {
168            status |= if self.tx.is_closed() {
169                Ready::WRITE_CLOSED
170            } else {
171                Ready::WRITABLE
172            };
173        }
174
175        Ok(status)
176    }
177}
178
179/// Copies `data` into `out`, storing any overflow from `data` into the storage pointed to by the
180/// mutex `buf_lock`
181fn copy_and_store(
182    mut buf_lock: MutexGuard<Option<Vec<u8>>>,
183    mut data: Vec<u8>,
184    out: &mut [u8],
185) -> usize {
186    // NOTE: We can get data that is larger than the destination buf; so,
187    //       we store as much as we can and queue up the rest in our temporary
188    //       storage for future retrievals
189    if data.len() > out.len() {
190        let n = out.len();
191        out.copy_from_slice(&data[..n]);
192        *buf_lock = Some(data.split_off(n));
193        n
194    } else {
195        let n = data.len();
196        out[..n].copy_from_slice(&data);
197        n
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use test_log::test;
204
205    use super::*;
206    use crate::common::TransportExt;
207
208    #[test]
209    fn is_rx_closed_should_properly_reflect_if_internal_rx_channel_is_closed() {
210        let (write_tx, _write_rx) = mpsc::channel(1);
211        let (read_tx, read_rx) = mpsc::channel(1);
212
213        let transport = InmemoryTransport::new(write_tx, read_rx);
214
215        // Not closed when the channel is empty
216        assert!(!transport.is_rx_closed());
217
218        read_tx.try_send(b"some bytes".to_vec()).unwrap();
219
220        // Not closed when the channel has data (will queue up data)
221        assert!(!transport.is_rx_closed());
222        assert_eq!(
223            transport.buf.lock().unwrap().as_deref().unwrap(),
224            b"some bytes"
225        );
226
227        // Queue up one more set of bytes and then close the channel
228        read_tx.try_send(b"more".to_vec()).unwrap();
229        drop(read_tx);
230
231        // Not closed when channel has closed but has something remaining in the queue
232        assert!(!transport.is_rx_closed());
233        assert_eq!(
234            transport.buf.lock().unwrap().as_deref().unwrap(),
235            b"some bytesmore"
236        );
237
238        // Closed once there is nothing left in the channel and it has closed
239        assert!(transport.is_rx_closed());
240        assert_eq!(
241            transport.buf.lock().unwrap().as_deref().unwrap(),
242            b"some bytesmore"
243        );
244    }
245
246    #[test]
247    fn try_read_should_succeed_if_able_to_read_entire_data_through_channel() {
248        let (write_tx, _write_rx) = mpsc::channel(1);
249        let (read_tx, read_rx) = mpsc::channel(1);
250
251        let transport = InmemoryTransport::new(write_tx, read_rx);
252
253        // Queue up some data to be read
254        read_tx.try_send(b"some bytes".to_vec()).unwrap();
255
256        let mut buf = [0; 10];
257        assert_eq!(transport.try_read(&mut buf).unwrap(), 10);
258        assert_eq!(&buf[..10], b"some bytes");
259    }
260
261    #[test]
262    fn try_read_should_succeed_if_reading_cached_data_from_previous_read() {
263        let (write_tx, _write_rx) = mpsc::channel(1);
264        let (read_tx, read_rx) = mpsc::channel(1);
265
266        let transport = InmemoryTransport::new(write_tx, read_rx);
267
268        // Queue up some data to be read
269        read_tx.try_send(b"some bytes".to_vec()).unwrap();
270
271        let mut buf = [0; 5];
272        assert_eq!(transport.try_read(&mut buf).unwrap(), 5);
273        assert_eq!(&buf[..5], b"some ");
274
275        // Queue up some new data to be read (previous data already consumed)
276        read_tx.try_send(b"more".to_vec()).unwrap();
277
278        let mut buf = [0; 2];
279        assert_eq!(transport.try_read(&mut buf).unwrap(), 2);
280        assert_eq!(&buf[..2], b"by");
281
282        // Inmemory still separates buffered bytes from next channel recv()
283        let mut buf = [0; 5];
284        assert_eq!(transport.try_read(&mut buf).unwrap(), 3);
285        assert_eq!(&buf[..3], b"tes");
286
287        let mut buf = [0; 5];
288        assert_eq!(transport.try_read(&mut buf).unwrap(), 4);
289        assert_eq!(&buf[..4], b"more");
290    }
291
292    #[test]
293    fn try_read_should_fail_with_would_block_if_channel_is_empty() {
294        let (write_tx, _write_rx) = mpsc::channel(1);
295        let (_read_tx, read_rx) = mpsc::channel(1);
296
297        let transport = InmemoryTransport::new(write_tx, read_rx);
298
299        assert_eq!(
300            transport.try_read(&mut [0; 5]).unwrap_err().kind(),
301            io::ErrorKind::WouldBlock
302        );
303    }
304
305    #[test]
306    fn try_read_should_succeed_with_zero_bytes_read_if_channel_closed() {
307        let (write_tx, _write_rx) = mpsc::channel(1);
308        let (read_tx, read_rx) = mpsc::channel(1);
309
310        // Drop to close the read channel
311        drop(read_tx);
312
313        let transport = InmemoryTransport::new(write_tx, read_rx);
314        assert_eq!(transport.try_read(&mut [0; 5]).unwrap(), 0);
315    }
316
317    #[test]
318    fn try_write_should_succeed_if_able_to_send_data_through_channel() {
319        let (write_tx, _write_rx) = mpsc::channel(1);
320        let (_read_tx, read_rx) = mpsc::channel(1);
321
322        let transport = InmemoryTransport::new(write_tx, read_rx);
323
324        let value = b"some bytes";
325        assert_eq!(transport.try_write(value).unwrap(), value.len());
326    }
327
328    #[test]
329    fn try_write_should_fail_with_would_block_if_channel_capacity_has_been_reached() {
330        let (write_tx, _write_rx) = mpsc::channel(1);
331        let (_read_tx, read_rx) = mpsc::channel(1);
332
333        let transport = InmemoryTransport::new(write_tx, read_rx);
334
335        // Fill up the channel
336        transport
337            .try_write(b"some bytes")
338            .expect("Failed to fill channel");
339
340        assert_eq!(
341            transport.try_write(b"some bytes").unwrap_err().kind(),
342            io::ErrorKind::WouldBlock
343        );
344    }
345
346    #[test]
347    fn try_write_should_succeed_with_zero_bytes_written_if_channel_closed() {
348        let (write_tx, write_rx) = mpsc::channel(1);
349        let (_read_tx, read_rx) = mpsc::channel(1);
350
351        // Drop to close the write channel
352        drop(write_rx);
353
354        let transport = InmemoryTransport::new(write_tx, read_rx);
355        assert_eq!(transport.try_write(b"some bytes").unwrap(), 0);
356    }
357
358    #[test(tokio::test)]
359    async fn reconnect_should_fail_if_read_channel_closed() {
360        let (write_tx, _write_rx) = mpsc::channel(1);
361        let (_, read_rx) = mpsc::channel(1);
362        let mut transport = InmemoryTransport::new(write_tx, read_rx);
363
364        assert_eq!(
365            transport.reconnect().await.unwrap_err().kind(),
366            io::ErrorKind::ConnectionRefused
367        );
368    }
369
370    #[test(tokio::test)]
371    async fn reconnect_should_fail_if_write_channel_closed() {
372        let (write_tx, _) = mpsc::channel(1);
373        let (_read_tx, read_rx) = mpsc::channel(1);
374        let mut transport = InmemoryTransport::new(write_tx, read_rx);
375
376        assert_eq!(
377            transport.reconnect().await.unwrap_err().kind(),
378            io::ErrorKind::ConnectionRefused
379        );
380    }
381
382    #[test(tokio::test)]
383    async fn reconnect_should_succeed_if_both_channels_open() {
384        let (write_tx, _write_rx) = mpsc::channel(1);
385        let (_read_tx, read_rx) = mpsc::channel(1);
386        let mut transport = InmemoryTransport::new(write_tx, read_rx);
387
388        transport.reconnect().await.unwrap();
389    }
390
391    #[test(tokio::test)]
392    async fn ready_should_report_read_closed_if_channel_closed_and_internal_buf_empty() {
393        let (write_tx, _write_rx) = mpsc::channel(1);
394        let (read_tx, read_rx) = mpsc::channel(1);
395
396        // Drop to close the read channel
397        drop(read_tx);
398
399        let transport = InmemoryTransport::new(write_tx, read_rx);
400        let ready = transport.ready(Interest::READABLE).await.unwrap();
401        assert!(ready.is_readable());
402        assert!(ready.is_read_closed());
403    }
404
405    #[test(tokio::test)]
406    async fn ready_should_report_readable_if_channel_not_closed() {
407        let (write_tx, _write_rx) = mpsc::channel(1);
408        let (_read_tx, read_rx) = mpsc::channel(1);
409
410        let transport = InmemoryTransport::new(write_tx, read_rx);
411        let ready = transport.ready(Interest::READABLE).await.unwrap();
412        assert!(ready.is_readable());
413        assert!(!ready.is_read_closed());
414    }
415
416    #[test(tokio::test)]
417    async fn ready_should_report_readable_if_internal_buf_not_empty() {
418        let (write_tx, _write_rx) = mpsc::channel(1);
419        let (read_tx, read_rx) = mpsc::channel(1);
420
421        // Drop to close the read channel
422        drop(read_tx);
423
424        let transport = InmemoryTransport::new(write_tx, read_rx);
425
426        // Assign some data to our buffer to ensure that we test this condition
427        *transport.buf.lock().unwrap() = Some(vec![1]);
428
429        let ready = transport.ready(Interest::READABLE).await.unwrap();
430        assert!(ready.is_readable());
431        assert!(!ready.is_read_closed());
432    }
433
434    #[test(tokio::test)]
435    async fn ready_should_report_writable_if_channel_not_closed() {
436        let (write_tx, _write_rx) = mpsc::channel(1);
437        let (_read_tx, read_rx) = mpsc::channel(1);
438
439        let transport = InmemoryTransport::new(write_tx, read_rx);
440        let ready = transport.ready(Interest::WRITABLE).await.unwrap();
441        assert!(ready.is_writable());
442        assert!(!ready.is_write_closed());
443    }
444
445    #[test(tokio::test)]
446    async fn ready_should_report_write_closed_if_channel_closed() {
447        let (write_tx, write_rx) = mpsc::channel(1);
448        let (_read_tx, read_rx) = mpsc::channel(1);
449
450        // Drop to close the write channel
451        drop(write_rx);
452
453        let transport = InmemoryTransport::new(write_tx, read_rx);
454        let ready = transport.ready(Interest::WRITABLE).await.unwrap();
455        assert!(ready.is_writable());
456        assert!(ready.is_write_closed());
457    }
458
459    #[test(tokio::test)]
460    async fn make_should_return_sender_that_sends_data_to_transport() {
461        let (tx, _, transport) = InmemoryTransport::make(3);
462
463        tx.send(b"test msg 1".to_vec()).await.unwrap();
464        tx.send(b"test msg 2".to_vec()).await.unwrap();
465        tx.send(b"test msg 3".to_vec()).await.unwrap();
466
467        // Should get data matching a singular message
468        let mut buf = [0; 256];
469        let len = transport.try_read(&mut buf).unwrap();
470        assert_eq!(&buf[..len], b"test msg 1");
471
472        // Next call would get the second message
473        let len = transport.try_read(&mut buf).unwrap();
474        assert_eq!(&buf[..len], b"test msg 2");
475
476        // When the last of the senders is dropped, we should still get
477        // the rest of the data that was sent first before getting
478        // an indicator that there is no more data
479        drop(tx);
480
481        let len = transport.try_read(&mut buf).unwrap();
482        assert_eq!(&buf[..len], b"test msg 3");
483
484        let len = transport.try_read(&mut buf).unwrap();
485        assert_eq!(len, 0, "Unexpectedly got more data");
486    }
487
488    #[test(tokio::test)]
489    async fn make_should_return_receiver_that_receives_data_from_transport() {
490        let (_, mut rx, transport) = InmemoryTransport::make(3);
491
492        transport.write_all(b"test msg 1").await.unwrap();
493        transport.write_all(b"test msg 2").await.unwrap();
494        transport.write_all(b"test msg 3").await.unwrap();
495
496        // Should get data matching a singular message
497        assert_eq!(rx.recv().await, Some(b"test msg 1".to_vec()));
498
499        // Next call would get the second message
500        assert_eq!(rx.recv().await, Some(b"test msg 2".to_vec()));
501
502        // When the transport is dropped, we should still get
503        // the rest of the data that was sent first before getting
504        // an indicator that there is no more data
505        drop(transport);
506
507        assert_eq!(rx.recv().await, Some(b"test msg 3".to_vec()));
508
509        assert_eq!(rx.recv().await, None, "Unexpectedly got more data");
510    }
511}