r_extcap/controls/asynchronous/
util.rs

1//! Random assortment of utility methods.
2
3use async_trait::async_trait;
4use tokio::io::{AsyncRead, AsyncReadExt as _};
5
6/// Extension trait for [`AsyncRead`].
7#[async_trait]
8pub trait AsyncReadExt: AsyncRead + Unpin {
9    /// Reads the exact number of bytes, like `read_exact`, but returns `None` if it gets EOF at
10    /// the start of the read. In other words, this is the "all or nothing" version of `read`.
11    async fn try_read_exact<const N: usize>(&mut self) -> std::io::Result<Option<[u8; N]>> {
12        let mut buf = [0_u8; N];
13        let mut count = 0_usize;
14        while count < N {
15            let read_bytes = self.read(&mut buf[count..]).await?;
16            if read_bytes == 0 {
17                if count == 0 {
18                    return Ok(None);
19                } else {
20                    return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof));
21                }
22            }
23            count += read_bytes;
24        }
25        Ok(Some(buf))
26    }
27}
28
29impl<R: ?Sized + AsyncRead + Unpin> AsyncReadExt for R {}
30
31#[cfg(test)]
32mod test {
33    use super::AsyncReadExt;
34
35    #[tokio::test]
36    async fn try_read_exact_success() {
37        let bytes = b"test";
38        let read_bytes = (&mut &bytes[..]).try_read_exact::<4>().await.unwrap();
39        assert_eq!(Some(bytes), read_bytes.as_ref());
40    }
41
42    #[tokio::test]
43    async fn try_read_exact_long_success() {
44        let bytes = b"testing long string";
45        let mut slice = &bytes[..];
46        assert_eq!(
47            Some(b"test"),
48            (&mut slice).try_read_exact::<4>().await.unwrap().as_ref()
49        );
50        assert_eq!(
51            Some(b"ing "),
52            (&mut slice).try_read_exact::<4>().await.unwrap().as_ref()
53        );
54    }
55
56    #[tokio::test]
57    async fn try_read_exact_none() {
58        let bytes = b"";
59        let read_bytes = (&mut &bytes[..]).try_read_exact::<4>().await.unwrap();
60        assert_eq!(None, read_bytes);
61    }
62
63    #[tokio::test]
64    async fn try_read_exact_unexpected_eof() {
65        let bytes = b"tt";
66        let read_bytes = (&mut &bytes[..]).try_read_exact::<4>().await;
67        assert_eq!(
68            read_bytes.unwrap_err().kind(),
69            std::io::ErrorKind::UnexpectedEof
70        );
71    }
72}