1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
use std::{
    pin::Pin,
    task::{Context, Poll},
};

pub use bytes::Bytes;
pub use bytes::{Buf, BytesMut};
use futures_util::{AsyncWrite, Stream};

/* --------------------------------------------- -- --------------------------------------------- */

pub struct FrameReader<'a> {
    inner: &'a mut BytesMut,
    read_offset: usize,
}

impl AsRef<[u8]> for FrameReader<'_> {
    fn as_ref(&self) -> &[u8] {
        self.chunk()
    }
}

impl<'a> Buf for FrameReader<'a> {
    fn remaining(&self) -> usize {
        self.inner.len() - self.read_offset
    }

    fn chunk(&self) -> &[u8] {
        &self.inner[self.read_offset..]
    }

    fn advance(&mut self, cnt: usize) {
        self.read_offset += cnt;
        assert!(self.read_offset <= self.inner.len());
    }
}

impl<'a> FrameReader<'a> {
    pub fn new(inner: &'a mut BytesMut) -> Self {
        Self { inner, read_offset: 0 }
    }

    pub fn as_slice(&self) -> &[u8] {
        self.chunk()
    }

    pub fn take(&mut self) -> BytesMut {
        let read_offset = std::mem::take(&mut self.read_offset);
        if self.inner.capacity() > self.inner.len() * 2 {
            // NOTE: In this case, assumes that the buffer is actively reused.
            // - In this case, if the consumer wants to retrieve `Vec<u8>` from output BytesMut,
            //   it may deeply clone the underlying buffer since the buffer ownership is currently
            //   shared.
            self.inner.split_off(read_offset)
        } else {
            // Buffer maybe automatically expanded over write operation, so we assume that the
            // buffer won't be reused. In this case, we can just take the whole buffer, and take
            // the ownership of the buffer to minimize copy.
            std::mem::take(&mut self.inner)
        }
    }

    pub fn advanced(&self) -> usize {
        self.read_offset
    }

    pub fn advance(&mut self, cnt: usize) {
        <Self as Buf>::advance(self, cnt);
    }

    pub fn is_empty(&self) -> bool {
        self.read_offset == self.inner.len()
    }
}

/* --------------------------------------------- -- --------------------------------------------- */

pub trait AsyncFrameWrite: Send + 'static {
    /// Called before writing a frame. This can be used to deal with writing cancellation.
    fn begin_write_frame(self: Pin<&mut Self>, len: usize) -> std::io::Result<()> {
        let _ = (len,);
        Ok(())
    }

    /// Write a frame to the underlying transport. It can be called multiple times to write a single
    /// frame. In this case, the input buffer should be advanced accordingly.
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut FrameReader,
    ) -> Poll<std::io::Result<()>>;

    /// Flush the underlying transport.
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        let _ = (cx,);
        Poll::Ready(Ok(()))
    }

    /// Close the underlying transport.
    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        let _ = (cx,);
        Poll::Ready(Ok(()))
    }
}

/// Futures adaptor for [`AsyncWriteFrame`]
impl<T> AsyncFrameWrite for T
where
    T: AsyncWrite + Send + 'static,
{
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut FrameReader,
    ) -> Poll<std::io::Result<()>> {
        match self.poll_write(cx, buf.as_ref())? {
            Poll::Ready(x) => {
                buf.advance(x);
                Poll::Ready(Ok(()))
            }
            Poll::Pending => Poll::Pending,
        }
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        self.poll_flush(cx)
    }

    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        self.poll_close(cx)
    }
}

/* --------------------------------------------- -- --------------------------------------------- */

pub trait AsyncFrameRead: Send + Sync + 'static {
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<Bytes>>;
}

impl<T: Stream<Item = std::io::Result<Bytes>> + Sync + Send + 'static> AsyncFrameRead for T {
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<Bytes>> {
        self.poll_next(cx).map(|x| x.unwrap_or_else(|| Err(std::io::ErrorKind::BrokenPipe.into())))
    }
}