rpc_it/
transport.rs

1use std::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5
6pub use bytes::Bytes;
7pub use bytes::{Buf, BytesMut};
8use futures_util::{AsyncWrite, Stream};
9
10/* --------------------------------------------- -- --------------------------------------------- */
11
12pub struct FrameReader<'a> {
13    inner: &'a mut BytesMut,
14    read_offset: usize,
15}
16
17impl AsRef<[u8]> for FrameReader<'_> {
18    fn as_ref(&self) -> &[u8] {
19        self.chunk()
20    }
21}
22
23impl<'a> Buf for FrameReader<'a> {
24    fn remaining(&self) -> usize {
25        self.inner.len() - self.read_offset
26    }
27
28    fn chunk(&self) -> &[u8] {
29        &self.inner[self.read_offset..]
30    }
31
32    fn advance(&mut self, cnt: usize) {
33        self.read_offset += cnt;
34        assert!(self.read_offset <= self.inner.len());
35    }
36}
37
38impl<'a> FrameReader<'a> {
39    pub fn new(inner: &'a mut BytesMut) -> Self {
40        Self { inner, read_offset: 0 }
41    }
42
43    pub fn as_slice(&self) -> &[u8] {
44        self.chunk()
45    }
46
47    pub fn take(&mut self) -> BytesMut {
48        let read_offset = std::mem::take(&mut self.read_offset);
49        if self.inner.capacity() > self.inner.len() * 2 {
50            // NOTE: In this case, assumes that the buffer is actively reused.
51            // - In this case, if the consumer wants to retrieve `Vec<u8>` from output BytesMut,
52            //   it may deeply clone the underlying buffer since the buffer ownership is currently
53            //   shared.
54            self.inner.split_off(read_offset)
55        } else {
56            // Buffer maybe automatically expanded over write operation, so we assume that the
57            // buffer won't be reused. In this case, we can just take the whole buffer, and take
58            // the ownership of the buffer to minimize copy.
59            std::mem::take(&mut self.inner)
60        }
61    }
62
63    pub fn advanced(&self) -> usize {
64        self.read_offset
65    }
66
67    pub fn advance(&mut self, cnt: usize) {
68        <Self as Buf>::advance(self, cnt);
69    }
70
71    pub fn is_empty(&self) -> bool {
72        self.read_offset == self.inner.len()
73    }
74}
75
76/* --------------------------------------------- -- --------------------------------------------- */
77
78pub trait AsyncFrameWrite: Send + 'static {
79    /// Called before writing a frame. This can be used to deal with writing cancellation.
80    fn begin_write_frame(self: Pin<&mut Self>, len: usize) -> std::io::Result<()> {
81        let _ = (len,);
82        Ok(())
83    }
84
85    /// Write a frame to the underlying transport. It can be called multiple times to write a single
86    /// frame. In this case, the input buffer should be advanced accordingly.
87    fn poll_write(
88        self: Pin<&mut Self>,
89        cx: &mut Context<'_>,
90        buf: &mut FrameReader,
91    ) -> Poll<std::io::Result<()>>;
92
93    /// Flush the underlying transport.
94    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
95        let _ = (cx,);
96        Poll::Ready(Ok(()))
97    }
98
99    /// Close the underlying transport.
100    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
101        let _ = (cx,);
102        Poll::Ready(Ok(()))
103    }
104}
105
106/// Futures adaptor for [`AsyncWriteFrame`]
107impl<T> AsyncFrameWrite for T
108where
109    T: AsyncWrite + Send + 'static,
110{
111    fn poll_write(
112        self: Pin<&mut Self>,
113        cx: &mut Context<'_>,
114        buf: &mut FrameReader,
115    ) -> Poll<std::io::Result<()>> {
116        match self.poll_write(cx, buf.as_ref())? {
117            Poll::Ready(x) => {
118                buf.advance(x);
119                Poll::Ready(Ok(()))
120            }
121            Poll::Pending => Poll::Pending,
122        }
123    }
124
125    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
126        self.poll_flush(cx)
127    }
128
129    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
130        self.poll_close(cx)
131    }
132}
133
134/* --------------------------------------------- -- --------------------------------------------- */
135
136pub trait AsyncFrameRead: Send + Sync + 'static {
137    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<Bytes>>;
138}
139
140impl<T: Stream<Item = std::io::Result<Bytes>> + Sync + Send + 'static> AsyncFrameRead for T {
141    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<Bytes>> {
142        self.poll_next(cx).map(|x| x.unwrap_or_else(|| Err(std::io::ErrorKind::BrokenPipe.into())))
143    }
144}