distant_net/common/
transport.rs

1use std::time::Duration;
2use std::{fmt, io};
3
4use async_trait::async_trait;
5
6mod framed;
7pub use framed::*;
8
9mod inmemory;
10pub use inmemory::*;
11
12mod tcp;
13pub use tcp::*;
14
15#[cfg(test)]
16mod test;
17
18#[cfg(test)]
19pub use test::*;
20
21#[cfg(unix)]
22mod unix;
23
24#[cfg(unix)]
25pub use unix::*;
26
27#[cfg(windows)]
28mod windows;
29
30pub use tokio::io::{Interest, Ready};
31#[cfg(windows)]
32pub use windows::*;
33
34/// Duration to wait after WouldBlock received during looping operations like `read_exact`.
35const SLEEP_DURATION: Duration = Duration::from_millis(1);
36
37/// Interface representing a connection that is reconnectable.
38#[async_trait]
39pub trait Reconnectable {
40    /// Attempts to reconnect an already-established connection.
41    async fn reconnect(&mut self) -> io::Result<()>;
42}
43
44/// Interface representing a transport of raw bytes into and out of the system.
45#[async_trait]
46pub trait Transport: Reconnectable + fmt::Debug + Send + Sync {
47    /// Tries to read data from the transport into the provided buffer, returning how many bytes
48    /// were read.
49    ///
50    /// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport
51    /// is not ready to read data.
52    ///
53    /// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
54    fn try_read(&self, buf: &mut [u8]) -> io::Result<usize>;
55
56    /// Try to write a buffer to the transport, returning how many bytes were written.
57    ///
58    /// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport
59    /// is not ready to write data.
60    ///
61    /// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
62    fn try_write(&self, buf: &[u8]) -> io::Result<usize>;
63
64    /// Waits for the transport to be ready based on the given interest, returning the ready
65    /// status.
66    async fn ready(&self, interest: Interest) -> io::Result<Ready>;
67}
68
69#[async_trait]
70impl Transport for Box<dyn Transport> {
71    fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
72        Transport::try_read(AsRef::as_ref(self), buf)
73    }
74
75    fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
76        Transport::try_write(AsRef::as_ref(self), buf)
77    }
78
79    async fn ready(&self, interest: Interest) -> io::Result<Ready> {
80        Transport::ready(AsRef::as_ref(self), interest).await
81    }
82}
83
84#[async_trait]
85impl Reconnectable for Box<dyn Transport> {
86    async fn reconnect(&mut self) -> io::Result<()> {
87        Reconnectable::reconnect(AsMut::as_mut(self)).await
88    }
89}
90
91#[async_trait]
92pub trait TransportExt {
93    /// Waits for the transport to be readable to follow up with `try_read`.
94    async fn readable(&self) -> io::Result<()>;
95
96    /// Waits for the transport to be writeable to follow up with `try_write`.
97    async fn writeable(&self) -> io::Result<()>;
98
99    /// Waits for the transport to be either readable or writeable.
100    async fn readable_or_writeable(&self) -> io::Result<()>;
101
102    /// Reads exactly `n` bytes where `n` is the length of `buf` by continuing to call [`try_read`]
103    /// until completed. Calls to [`readable`] are made to ensure the transport is ready. Returns
104    /// the total bytes read.
105    ///
106    /// [`try_read`]: Transport::try_read
107    /// [`readable`]: Transport::readable
108    async fn read_exact(&self, buf: &mut [u8]) -> io::Result<usize>;
109
110    /// Reads all bytes until EOF in this source, placing them into `buf`.
111    ///
112    /// All bytes read from this source will be appended to the specified buffer `buf`. This
113    /// function will continuously call [`try_read`] to append more data to `buf` until
114    /// [`try_read`] returns either [`Ok(0)`] or an error that is neither [`Interrupted`] or
115    /// [`WouldBlock`].
116    ///
117    /// If successful, this function will return the total number of bytes read.
118    ///
119    /// ### Errors
120    ///
121    /// If this function encounters an error of the kind [`Interrupted`] or [`WouldBlock`], then
122    /// the error is ignored and the operation will continue.
123    ///
124    /// If any other read error is encountered then this function immediately returns. Any bytes
125    /// which have already been read will be appended to `buf`.
126    ///
127    /// [`Ok(0)`]: Ok
128    /// [`try_read`]: Transport::try_read
129    /// [`readable`]: Transport::readable
130    async fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize>;
131
132    /// Reads all bytes until EOF in this source, placing them into `buf`.
133    ///
134    /// If successful, this function will return the total number of bytes read.
135    ///
136    /// ### Errors
137    ///
138    /// If the data in this stream is *not* valid UTF-8 then an error is returned and `buf` is
139    /// unchanged.
140    ///
141    /// See [`read_to_end`] for other error semantics.
142    ///
143    /// [`Ok(0)`]: Ok
144    /// [`try_read`]: Transport::try_read
145    /// [`readable`]: Transport::readable
146    /// [`read_to_end`]: TransportExt::read_to_end
147    async fn read_to_string(&self, buf: &mut String) -> io::Result<usize>;
148
149    /// Writes all of `buf` by continuing to call [`try_write`] until completed. Calls to
150    /// [`writeable`] are made to ensure the transport is ready.
151    ///
152    /// [`try_write`]: Transport::try_write
153    /// [`writable`]: Transport::writable
154    async fn write_all(&self, buf: &[u8]) -> io::Result<()>;
155}
156
157#[async_trait]
158impl<T: Transport> TransportExt for T {
159    async fn readable(&self) -> io::Result<()> {
160        self.ready(Interest::READABLE).await?;
161        Ok(())
162    }
163
164    async fn writeable(&self) -> io::Result<()> {
165        self.ready(Interest::WRITABLE).await?;
166        Ok(())
167    }
168
169    async fn readable_or_writeable(&self) -> io::Result<()> {
170        self.ready(Interest::READABLE | Interest::WRITABLE).await?;
171        Ok(())
172    }
173
174    async fn read_exact(&self, buf: &mut [u8]) -> io::Result<usize> {
175        let mut i = 0;
176
177        while i < buf.len() {
178            self.readable().await?;
179
180            match self.try_read(&mut buf[i..]) {
181                // If we get 0 bytes read, this usually means that the underlying reader
182                // has closed, so we will return an EOF error to reflect that
183                //
184                // NOTE: `try_read` can also return 0 if the buf len is zero, but because we check
185                //       that our index is < len, the situation where we call try_read with a buf
186                //       of len 0 will never happen
187                Ok(0) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
188
189                Ok(n) => i += n,
190
191                // Because we are using `try_read`, it can be possible for it to return
192                // WouldBlock; so, if we encounter that then we just wait for next readable
193                Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
194                    // NOTE: We sleep for a little bit before trying again to avoid pegging CPU
195                    tokio::time::sleep(SLEEP_DURATION).await
196                }
197
198                Err(x) => return Err(x),
199            }
200        }
201
202        Ok(i)
203    }
204
205    async fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
206        let mut i = 0;
207        let mut tmp = [0u8; 1024];
208
209        loop {
210            self.readable().await?;
211
212            match self.try_read(&mut tmp) {
213                Ok(0) => return Ok(i),
214                Ok(n) => {
215                    buf.extend_from_slice(&tmp[..n]);
216                    i += n;
217                }
218                Err(x)
219                    if x.kind() == io::ErrorKind::WouldBlock
220                        || x.kind() == io::ErrorKind::Interrupted =>
221                {
222                    // NOTE: We sleep for a little bit before trying again to avoid pegging CPU
223                    tokio::time::sleep(SLEEP_DURATION).await
224                }
225
226                Err(x) => return Err(x),
227            }
228        }
229    }
230
231    async fn read_to_string(&self, buf: &mut String) -> io::Result<usize> {
232        let mut tmp = Vec::new();
233        let n = self.read_to_end(&mut tmp).await?;
234        buf.push_str(
235            &String::from_utf8(tmp).map_err(|x| io::Error::new(io::ErrorKind::InvalidData, x))?,
236        );
237        Ok(n)
238    }
239
240    async fn write_all(&self, buf: &[u8]) -> io::Result<()> {
241        let mut i = 0;
242
243        while i < buf.len() {
244            self.writeable().await?;
245
246            match self.try_write(&buf[i..]) {
247                // If we get 0 bytes written, this usually means that the underlying writer
248                // has closed, so we will return a write zero error to reflect that
249                //
250                // NOTE: `try_write` can also return 0 if the buf len is zero, but because we check
251                //       that our index is < len, the situation where we call try_write with a buf
252                //       of len 0 will never happen
253                Ok(0) => return Err(io::Error::from(io::ErrorKind::WriteZero)),
254
255                Ok(n) => i += n,
256
257                // Because we are using `try_write`, it can be possible for it to return
258                // WouldBlock; so, if we encounter that then we just wait for next writeable
259                Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
260                    // NOTE: We sleep for a little bit before trying again to avoid pegging CPU
261                    tokio::time::sleep(SLEEP_DURATION).await
262                }
263
264                Err(x) => return Err(x),
265            }
266        }
267
268        Ok(())
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use test_log::test;
275
276    use super::*;
277
278    #[test(tokio::test)]
279    async fn read_exact_should_fail_if_try_read_encounters_error_other_than_would_block() {
280        let transport = TestTransport {
281            f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
282            f_ready: Box::new(|_| Ok(Ready::READABLE)),
283            ..Default::default()
284        };
285
286        let mut buf = [0; 1];
287        assert_eq!(
288            transport.read_exact(&mut buf).await.unwrap_err().kind(),
289            io::ErrorKind::NotConnected
290        );
291    }
292
293    #[test(tokio::test)]
294    async fn read_exact_should_fail_if_try_read_returns_0_before_necessary_bytes_read() {
295        let transport = TestTransport {
296            f_try_read: Box::new(|_| Ok(0)),
297            f_ready: Box::new(|_| Ok(Ready::READABLE)),
298            ..Default::default()
299        };
300
301        let mut buf = [0; 1];
302        assert_eq!(
303            transport.read_exact(&mut buf).await.unwrap_err().kind(),
304            io::ErrorKind::UnexpectedEof
305        );
306    }
307
308    #[test(tokio::test)]
309    async fn read_exact_should_continue_to_call_try_read_until_buffer_is_filled() {
310        let transport = TestTransport {
311            f_try_read: Box::new(|buf| {
312                static mut CNT: u8 = 0;
313                unsafe {
314                    buf[0] = b'a' + CNT;
315                    CNT += 1;
316                }
317                Ok(1)
318            }),
319            f_ready: Box::new(|_| Ok(Ready::READABLE)),
320            ..Default::default()
321        };
322
323        let mut buf = [0; 3];
324        assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 3);
325        assert_eq!(&buf, b"abc");
326    }
327
328    #[test(tokio::test)]
329    async fn read_exact_should_continue_to_call_try_read_while_it_returns_would_block() {
330        // Configure `try_read` to alternate between reading a byte and WouldBlock
331        let transport = TestTransport {
332            f_try_read: Box::new(|buf| {
333                static mut CNT: u8 = 0;
334                unsafe {
335                    buf[0] = b'a' + CNT;
336                    CNT += 1;
337                    if CNT % 2 == 1 {
338                        Ok(1)
339                    } else {
340                        Err(io::Error::from(io::ErrorKind::WouldBlock))
341                    }
342                }
343            }),
344            f_ready: Box::new(|_| Ok(Ready::READABLE)),
345            ..Default::default()
346        };
347
348        let mut buf = [0; 3];
349        assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 3);
350        assert_eq!(&buf, b"ace");
351    }
352
353    #[test(tokio::test)]
354    async fn read_exact_should_return_0_if_given_a_buffer_of_0_len() {
355        let transport = TestTransport {
356            f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
357            f_ready: Box::new(|_| Ok(Ready::READABLE)),
358            ..Default::default()
359        };
360
361        let mut buf = [0; 0];
362        assert_eq!(transport.read_exact(&mut buf).await.unwrap(), 0);
363    }
364
365    #[test(tokio::test)]
366    async fn read_to_end_should_fail_if_try_read_encounters_error_other_than_would_block_and_interrupt(
367    ) {
368        let transport = TestTransport {
369            f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
370            f_ready: Box::new(|_| Ok(Ready::READABLE)),
371            ..Default::default()
372        };
373
374        assert_eq!(
375            transport
376                .read_to_end(&mut Vec::new())
377                .await
378                .unwrap_err()
379                .kind(),
380            io::ErrorKind::NotConnected
381        );
382    }
383
384    #[test(tokio::test)]
385    async fn read_to_end_should_read_until_0_bytes_returned_from_try_read() {
386        let transport = TestTransport {
387            f_try_read: Box::new(|buf| {
388                static mut CNT: u8 = 0;
389                unsafe {
390                    if CNT == 0 {
391                        buf[..5].copy_from_slice(b"hello");
392                        CNT += 1;
393                        Ok(5)
394                    } else {
395                        Ok(0)
396                    }
397                }
398            }),
399            f_ready: Box::new(|_| Ok(Ready::READABLE)),
400            ..Default::default()
401        };
402
403        let mut buf = Vec::new();
404        assert_eq!(transport.read_to_end(&mut buf).await.unwrap(), 5);
405        assert_eq!(buf, b"hello");
406    }
407
408    #[test(tokio::test)]
409    async fn read_to_end_should_continue_reading_when_interrupt_or_would_block_encountered() {
410        let transport = TestTransport {
411            f_try_read: Box::new(|buf| {
412                static mut CNT: u8 = 0;
413                unsafe {
414                    CNT += 1;
415                    if CNT == 1 {
416                        buf[..6].copy_from_slice(b"hello ");
417                        Ok(6)
418                    } else if CNT == 2 {
419                        Err(io::Error::from(io::ErrorKind::WouldBlock))
420                    } else if CNT == 3 {
421                        buf[..5].copy_from_slice(b"world");
422                        Ok(5)
423                    } else if CNT == 4 {
424                        Err(io::Error::from(io::ErrorKind::Interrupted))
425                    } else if CNT == 5 {
426                        buf[..6].copy_from_slice(b", test");
427                        Ok(6)
428                    } else {
429                        Ok(0)
430                    }
431                }
432            }),
433            f_ready: Box::new(|_| Ok(Ready::READABLE)),
434            ..Default::default()
435        };
436
437        let mut buf = Vec::new();
438        assert_eq!(transport.read_to_end(&mut buf).await.unwrap(), 17);
439        assert_eq!(buf, b"hello world, test");
440    }
441
442    #[test(tokio::test)]
443    async fn read_to_string_should_fail_if_try_read_encounters_error_other_than_would_block_and_interrupt(
444    ) {
445        let transport = TestTransport {
446            f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
447            f_ready: Box::new(|_| Ok(Ready::READABLE)),
448            ..Default::default()
449        };
450
451        assert_eq!(
452            transport
453                .read_to_string(&mut String::new())
454                .await
455                .unwrap_err()
456                .kind(),
457            io::ErrorKind::NotConnected
458        );
459    }
460
461    #[test(tokio::test)]
462    async fn read_to_string_should_fail_if_non_utf8_characters_read() {
463        let transport = TestTransport {
464            f_try_read: Box::new(|buf| {
465                static mut CNT: u8 = 0;
466                unsafe {
467                    if CNT == 0 {
468                        buf[0] = 0;
469                        buf[1] = 159;
470                        buf[2] = 146;
471                        buf[3] = 150;
472                        CNT += 1;
473                        Ok(4)
474                    } else {
475                        Ok(0)
476                    }
477                }
478            }),
479            f_ready: Box::new(|_| Ok(Ready::READABLE)),
480            ..Default::default()
481        };
482
483        let mut buf = String::new();
484        assert_eq!(
485            transport.read_to_string(&mut buf).await.unwrap_err().kind(),
486            io::ErrorKind::InvalidData
487        );
488    }
489
490    #[test(tokio::test)]
491    async fn read_to_string_should_read_until_0_bytes_returned_from_try_read() {
492        let transport = TestTransport {
493            f_try_read: Box::new(|buf| {
494                static mut CNT: u8 = 0;
495                unsafe {
496                    if CNT == 0 {
497                        buf[..5].copy_from_slice(b"hello");
498                        CNT += 1;
499                        Ok(5)
500                    } else {
501                        Ok(0)
502                    }
503                }
504            }),
505            f_ready: Box::new(|_| Ok(Ready::READABLE)),
506            ..Default::default()
507        };
508
509        let mut buf = String::new();
510        assert_eq!(transport.read_to_string(&mut buf).await.unwrap(), 5);
511        assert_eq!(buf, "hello");
512    }
513
514    #[test(tokio::test)]
515    async fn read_to_string_should_continue_reading_when_interrupt_or_would_block_encountered() {
516        let transport = TestTransport {
517            f_try_read: Box::new(|buf| {
518                static mut CNT: u8 = 0;
519                unsafe {
520                    CNT += 1;
521                    if CNT == 1 {
522                        buf[..6].copy_from_slice(b"hello ");
523                        Ok(6)
524                    } else if CNT == 2 {
525                        Err(io::Error::from(io::ErrorKind::WouldBlock))
526                    } else if CNT == 3 {
527                        buf[..5].copy_from_slice(b"world");
528                        Ok(5)
529                    } else if CNT == 4 {
530                        Err(io::Error::from(io::ErrorKind::Interrupted))
531                    } else if CNT == 5 {
532                        buf[..6].copy_from_slice(b", test");
533                        Ok(6)
534                    } else {
535                        Ok(0)
536                    }
537                }
538            }),
539            f_ready: Box::new(|_| Ok(Ready::READABLE)),
540            ..Default::default()
541        };
542
543        let mut buf = String::new();
544        assert_eq!(transport.read_to_string(&mut buf).await.unwrap(), 17);
545        assert_eq!(buf, "hello world, test");
546    }
547
548    #[test(tokio::test)]
549    async fn write_all_should_fail_if_try_write_encounters_error_other_than_would_block() {
550        let transport = TestTransport {
551            f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
552            f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
553            ..Default::default()
554        };
555
556        assert_eq!(
557            transport.write_all(b"abc").await.unwrap_err().kind(),
558            io::ErrorKind::NotConnected
559        );
560    }
561
562    #[test(tokio::test)]
563    async fn write_all_should_fail_if_try_write_returns_0_before_all_bytes_written() {
564        let transport = TestTransport {
565            f_try_write: Box::new(|_| Ok(0)),
566            f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
567            ..Default::default()
568        };
569
570        assert_eq!(
571            transport.write_all(b"abc").await.unwrap_err().kind(),
572            io::ErrorKind::WriteZero
573        );
574    }
575
576    #[test(tokio::test)]
577    async fn write_all_should_continue_to_call_try_write_until_all_bytes_written() {
578        // Configure `try_write` to alternate between writing a byte and WouldBlock
579        let transport = TestTransport {
580            f_try_write: Box::new(|buf| {
581                static mut CNT: u8 = 0;
582                unsafe {
583                    assert_eq!(buf[0], b'a' + CNT);
584                    CNT += 1;
585                    Ok(1)
586                }
587            }),
588            f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
589            ..Default::default()
590        };
591
592        transport.write_all(b"abc").await.unwrap();
593    }
594
595    #[test(tokio::test)]
596    async fn write_all_should_continue_to_call_try_write_while_it_returns_would_block() {
597        // Configure `try_write` to alternate between writing a byte and WouldBlock
598        let transport = TestTransport {
599            f_try_write: Box::new(|buf| {
600                static mut CNT: u8 = 0;
601                unsafe {
602                    if CNT % 2 == 0 {
603                        assert_eq!(buf[0], b'a' + CNT);
604                        CNT += 1;
605                        Ok(1)
606                    } else {
607                        CNT += 1;
608                        Err(io::Error::from(io::ErrorKind::WouldBlock))
609                    }
610                }
611            }),
612            f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
613            ..Default::default()
614        };
615
616        transport.write_all(b"ace").await.unwrap();
617    }
618
619    #[test(tokio::test)]
620    async fn write_all_should_return_immediately_if_given_buffer_of_0_len() {
621        let transport = TestTransport {
622            f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
623            f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
624            ..Default::default()
625        };
626
627        // No error takes place as we never call try_write
628        let buf = [0; 0];
629        transport.write_all(&buf).await.unwrap();
630    }
631}