1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
//! Random assortment of utility methods.

use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncReadExt as _};

/// Extension trait for [`AsyncRead`].
#[async_trait]
pub trait AsyncReadExt: AsyncRead + Unpin {
    /// Reads the exact number of bytes, like `read_exact`, but returns `None` if it gets EOF at
    /// the start of the read. In other words, this is the "all or nothing" version of `read`.
    async fn try_read_exact<const N: usize>(&mut self) -> std::io::Result<Option<[u8; N]>> {
        let mut buf = [0_u8; N];
        let mut count = 0_usize;
        while count < N {
            let read_bytes = self.read(&mut buf[count..]).await?;
            if read_bytes == 0 {
                if count == 0 {
                    return Ok(None);
                } else {
                    return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof));
                }
            }
            count += read_bytes;
        }
        Ok(Some(buf))
    }
}

impl<R: ?Sized + AsyncRead + Unpin> AsyncReadExt for R {}

#[cfg(test)]
mod test {
    use super::AsyncReadExt;

    #[tokio::test]
    async fn try_read_exact_success() {
        let bytes = b"test";
        let read_bytes = (&mut &bytes[..]).try_read_exact::<4>().await.unwrap();
        assert_eq!(Some(bytes), read_bytes.as_ref());
    }

    #[tokio::test]
    async fn try_read_exact_long_success() {
        let bytes = b"testing long string";
        let mut slice = &bytes[..];
        assert_eq!(
            Some(b"test"),
            (&mut slice).try_read_exact::<4>().await.unwrap().as_ref()
        );
        assert_eq!(
            Some(b"ing "),
            (&mut slice).try_read_exact::<4>().await.unwrap().as_ref()
        );
    }

    #[tokio::test]
    async fn try_read_exact_none() {
        let bytes = b"";
        let read_bytes = (&mut &bytes[..]).try_read_exact::<4>().await.unwrap();
        assert_eq!(None, read_bytes);
    }

    #[tokio::test]
    async fn try_read_exact_unexpected_eof() {
        let bytes = b"tt";
        let read_bytes = (&mut &bytes[..]).try_read_exact::<4>().await;
        assert_eq!(
            read_bytes.unwrap_err().kind(),
            std::io::ErrorKind::UnexpectedEof
        );
    }
}