1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use std::{io::IoSlice, pin::Pin};

/* ---------------------------------------------------------------------------------------------- */
/*                                        TRAIT DEFINITIONS                                       */
/* ---------------------------------------------------------------------------------------------- */
///
/// A message writer trait, which is used to write message to the underlying transport
/// layer.
///
pub trait AsyncFrameWrite: Send + Sync + Unpin {
    /// Called before writing a message
    fn poll_start_write(
        self: Pin<&mut Self>,
        _cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<std::io::Result<()>> {
        std::task::Poll::Ready(Ok(()))
    }

    /// Called until flushing all message
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        bufs: &[IoSlice<'_>],
    ) -> std::task::Poll<std::io::Result<usize>>;

    /// Called after writing single message
    fn poll_end_write(
        self: Pin<&mut Self>,
        _cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<std::io::Result<()>> {
        std::task::Poll::Ready(Ok(()))
    }

    /// Flush the underlying transport layer.
    fn poll_flush(
        self: Pin<&mut Self>,
        _cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<std::io::Result<()>> {
        std::task::Poll::Ready(Ok(()))
    }

    /// Shutdown the underlying transport layer.
    fn poll_shutdown(
        self: Pin<&mut Self>,
        _cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<std::io::Result<()>> {
        std::task::Poll::Ready(Ok(()))
    }
}

///
/// A message reader trait. This is used to read message from the underlying transport layer.
///
pub trait AsyncFrameRead: Send + Sync + Unpin {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &mut [u8],
    ) -> std::task::Poll<std::io::Result<usize>>;
}

/* ---------------------------------------------------------------------------------------------- */
/*                                            UTILITIES                                           */
/* ---------------------------------------------------------------------------------------------- */
pub mod util {
    use std::{future::poll_fn, io::IoSlice, pin::Pin};

    use crate::{AsyncFrameRead, AsyncFrameWrite};

    pub async fn write_vectored_all(
        this: &mut dyn AsyncFrameWrite,
        mut bufs: &'_ mut [IoSlice<'_>],
    ) -> std::io::Result<usize> {
        let mut total_written = 0;

        poll_fn(|cx| Pin::new(&mut *this).poll_start_write(cx)).await?;

        while bufs.is_empty() == false {
            let mut n = poll_fn(|cx| Pin::new(&mut *this).poll_write(cx, bufs)).await?;
            total_written += n;

            // HACK: following logic should be replaced with IoSlice::advance when it is stabilized.
            let mut nremv = 0;
            for buf in bufs.iter() {
                if buf.len() <= n {
                    n -= buf.len();
                    nremv += 1;
                } else {
                    break;
                }
            }

            bufs = &mut bufs[nremv..];

            unsafe {
                if n > 0 {
                    let buf = &mut bufs[0];
                    let src = std::slice::from_raw_parts(buf.as_ptr().add(n), buf.len() - n);
                    *buf = IoSlice::new(src)
                }
            }
        }

        poll_fn(|cx| Pin::new(&mut *this).poll_end_write(cx)).await?;
        Ok(total_written)
    }

    pub async fn read_all(
        this: &mut dyn AsyncFrameRead,
        mut buf: &'_ mut [u8],
    ) -> std::io::Result<usize> {
        let mut total_read = 0;
        let until = buf.len();

        while total_read != until {
            let n = poll_fn(|cx| Pin::new(&mut *this).poll_read(cx, buf)).await?;

            if n == 0 {
                return Err(std::io::Error::new(
                    std::io::ErrorKind::UnexpectedEof,
                    "unexpected EOF",
                ));
            }

            total_read += n;
            buf = &mut buf[n..];
        }

        Ok(total_read)
    }
}