async_prost/
reader.rs

1use byteorder::{ByteOrder, NetworkEndian};
2use std::{
3    io,
4    marker::PhantomData,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use bytes::{Buf, BytesMut};
10use futures_core::{ready, Stream};
11use prost::Message;
12use tokio::io::{AsyncRead, ReadBuf};
13
14use crate::{AsyncDestination, AsyncFrameDestination, Framed};
15
16const BUFFER_SIZE: usize = 8192;
17const LEN_SIZE: usize = 4;
18
19enum FillResult {
20    Filled,
21    Eof,
22}
23
24/// A wrapper around an async reader that produces an asynchronous stream of prost-decoded values
25#[derive(Debug)]
26pub struct AsyncProstReader<R, T, D> {
27    reader: R,
28    pub(crate) buffer: BytesMut,
29    into: PhantomData<T>,
30    dest: PhantomData<D>,
31}
32impl<R, T, D> Unpin for AsyncProstReader<R, T, D> where R: Unpin {}
33
34impl<R, T, D> AsyncProstReader<R, T, D> {
35    /// create a new reader
36    pub fn new(reader: R) -> Self {
37        Self {
38            reader,
39            buffer: BytesMut::with_capacity(BUFFER_SIZE),
40            into: PhantomData,
41            dest: PhantomData,
42        }
43    }
44
45    /// gets a reference to the underlying reader
46    pub fn get_ref(&self) -> &R {
47        &self.reader
48    }
49
50    /// gets a mutable reference to the underlying reader
51    pub fn get_mut(&mut self) -> &mut R {
52        &mut self.reader
53    }
54
55    /// returns a reference to the internally buffered data
56    pub fn buffer(&self) -> &[u8] {
57        &self.buffer[..]
58    }
59
60    /// unwrap the `AsyncProstReader`, returning the underlying reader
61    pub fn into_inner(self) -> R {
62        self.reader
63    }
64}
65
66impl<R, T, D> Default for AsyncProstReader<R, T, D>
67where
68    R: Default,
69{
70    fn default() -> Self {
71        Self::from(R::default())
72    }
73}
74
75impl<R, T, D> From<R> for AsyncProstReader<R, T, D> {
76    fn from(reader: R) -> Self {
77        Self::new(reader)
78    }
79}
80
81impl<R, T> Stream for AsyncProstReader<R, T, AsyncDestination>
82where
83    T: Message + Default,
84    R: AsyncRead + Unpin,
85{
86    type Item = Result<T, io::Error>;
87
88    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
89        // FIXME: what 5 means here?
90        if let FillResult::Eof = ready!(self.as_mut().fill(cx, 5))? {
91            return Poll::Ready(None);
92        }
93
94        let message_size = NetworkEndian::read_u32(&self.buffer[..LEN_SIZE]) as usize;
95
96        // since self.buffer.len() >= 4, we know that we can't get a clean EOF here
97        ready!(self.as_mut().fill(cx, message_size + LEN_SIZE))?;
98
99        self.buffer.advance(LEN_SIZE);
100        let message =
101            Message::decode(&self.buffer[..message_size]).map_err(prost::DecodeError::from)?;
102        self.buffer.advance(message_size);
103        Poll::Ready(Some(Ok(message)))
104    }
105}
106
107impl<R, T> Stream for AsyncProstReader<R, T, AsyncFrameDestination>
108where
109    R: AsyncRead + Unpin,
110    T: Framed + Default,
111{
112    type Item = Result<T, io::Error>;
113
114    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
115        // FIXME: what 5 means here?
116        if let FillResult::Eof = ready!(self.as_mut().fill(cx, LEN_SIZE + 1))? {
117            return Poll::Ready(None);
118        }
119
120        let size = NetworkEndian::read_u32(&self.buffer[..LEN_SIZE]) as usize;
121        let header_size = size >> 24;
122        let body_size = 0x00ffffff & size;
123        let message_size = header_size + body_size;
124
125        // since self.buffer.len() >= 4, we know that we can't get a clean EOF here
126        ready!(self.as_mut().fill(cx, message_size + LEN_SIZE))?;
127
128        self.buffer.advance(LEN_SIZE);
129        let message = T::decode(&self.buffer[..message_size], header_size)?;
130
131        self.buffer.advance(message_size);
132        Poll::Ready(Some(Ok(message)))
133    }
134}
135
136impl<R, T, D> AsyncProstReader<R, T, D>
137where
138    R: AsyncRead + Unpin,
139{
140    fn fill(
141        mut self: Pin<&mut Self>,
142        cx: &mut Context,
143        target_buffer_size: usize,
144    ) -> Poll<Result<FillResult, io::Error>> {
145        if self.buffer.len() >= target_buffer_size {
146            // we already ave the bytes we need!
147            return Poll::Ready(Ok(FillResult::Filled));
148        }
149
150        // make sure we can fit all the data we're about to read
151        if self.buffer.capacity() < target_buffer_size {
152            let missing = target_buffer_size - self.buffer.capacity();
153            self.buffer.reserve(missing);
154        }
155
156        let had = self.buffer.len();
157        // this is the bit we'll be reading into
158        let mut rest = self.buffer.split_off(had);
159        // this is safe because we're not extending beyond the reserved capacity
160        // and we're never reading unwritten bytes
161        let max = rest.capacity();
162        unsafe { rest.set_len(max) };
163
164        while self.buffer.len() < target_buffer_size {
165            let mut buf = ReadBuf::new(&mut rest[..]);
166            ready!(Pin::new(&mut self.reader).poll_read(cx, &mut buf))?;
167            let n = buf.filled().len();
168            if n == 0 {
169                if self.buffer.is_empty() {
170                    return Poll::Ready(Ok(FillResult::Eof));
171                } else {
172                    return Poll::Ready(Err(io::Error::from(io::ErrorKind::BrokenPipe)));
173                }
174            }
175
176            // adopt the new bytes
177            let read = rest.split_to(n);
178            self.buffer.unsplit(read);
179        }
180
181        Poll::Ready(Ok(FillResult::Filled))
182    }
183}