1use byteorder::{ByteOrder, NetworkEndian};
2use std::{
3 io,
4 marker::PhantomData,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use bytes::{Buf, BytesMut};
10use futures_core::{ready, Stream};
11use prost::Message;
12use tokio::io::{AsyncRead, ReadBuf};
13
14use crate::{AsyncDestination, AsyncFrameDestination, Framed};
15
16const BUFFER_SIZE: usize = 8192;
17const LEN_SIZE: usize = 4;
18
19enum FillResult {
20 Filled,
21 Eof,
22}
23
24#[derive(Debug)]
26pub struct AsyncProstReader<R, T, D> {
27 reader: R,
28 pub(crate) buffer: BytesMut,
29 into: PhantomData<T>,
30 dest: PhantomData<D>,
31}
32impl<R, T, D> Unpin for AsyncProstReader<R, T, D> where R: Unpin {}
33
34impl<R, T, D> AsyncProstReader<R, T, D> {
35 pub fn new(reader: R) -> Self {
37 Self {
38 reader,
39 buffer: BytesMut::with_capacity(BUFFER_SIZE),
40 into: PhantomData,
41 dest: PhantomData,
42 }
43 }
44
45 pub fn get_ref(&self) -> &R {
47 &self.reader
48 }
49
50 pub fn get_mut(&mut self) -> &mut R {
52 &mut self.reader
53 }
54
55 pub fn buffer(&self) -> &[u8] {
57 &self.buffer[..]
58 }
59
60 pub fn into_inner(self) -> R {
62 self.reader
63 }
64}
65
66impl<R, T, D> Default for AsyncProstReader<R, T, D>
67where
68 R: Default,
69{
70 fn default() -> Self {
71 Self::from(R::default())
72 }
73}
74
75impl<R, T, D> From<R> for AsyncProstReader<R, T, D> {
76 fn from(reader: R) -> Self {
77 Self::new(reader)
78 }
79}
80
81impl<R, T> Stream for AsyncProstReader<R, T, AsyncDestination>
82where
83 T: Message + Default,
84 R: AsyncRead + Unpin,
85{
86 type Item = Result<T, io::Error>;
87
88 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
89 if let FillResult::Eof = ready!(self.as_mut().fill(cx, 5))? {
91 return Poll::Ready(None);
92 }
93
94 let message_size = NetworkEndian::read_u32(&self.buffer[..LEN_SIZE]) as usize;
95
96 ready!(self.as_mut().fill(cx, message_size + LEN_SIZE))?;
98
99 self.buffer.advance(LEN_SIZE);
100 let message =
101 Message::decode(&self.buffer[..message_size]).map_err(prost::DecodeError::from)?;
102 self.buffer.advance(message_size);
103 Poll::Ready(Some(Ok(message)))
104 }
105}
106
107impl<R, T> Stream for AsyncProstReader<R, T, AsyncFrameDestination>
108where
109 R: AsyncRead + Unpin,
110 T: Framed + Default,
111{
112 type Item = Result<T, io::Error>;
113
114 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
115 if let FillResult::Eof = ready!(self.as_mut().fill(cx, LEN_SIZE + 1))? {
117 return Poll::Ready(None);
118 }
119
120 let size = NetworkEndian::read_u32(&self.buffer[..LEN_SIZE]) as usize;
121 let header_size = size >> 24;
122 let body_size = 0x00ffffff & size;
123 let message_size = header_size + body_size;
124
125 ready!(self.as_mut().fill(cx, message_size + LEN_SIZE))?;
127
128 self.buffer.advance(LEN_SIZE);
129 let message = T::decode(&self.buffer[..message_size], header_size)?;
130
131 self.buffer.advance(message_size);
132 Poll::Ready(Some(Ok(message)))
133 }
134}
135
136impl<R, T, D> AsyncProstReader<R, T, D>
137where
138 R: AsyncRead + Unpin,
139{
140 fn fill(
141 mut self: Pin<&mut Self>,
142 cx: &mut Context,
143 target_buffer_size: usize,
144 ) -> Poll<Result<FillResult, io::Error>> {
145 if self.buffer.len() >= target_buffer_size {
146 return Poll::Ready(Ok(FillResult::Filled));
148 }
149
150 if self.buffer.capacity() < target_buffer_size {
152 let missing = target_buffer_size - self.buffer.capacity();
153 self.buffer.reserve(missing);
154 }
155
156 let had = self.buffer.len();
157 let mut rest = self.buffer.split_off(had);
159 let max = rest.capacity();
162 unsafe { rest.set_len(max) };
163
164 while self.buffer.len() < target_buffer_size {
165 let mut buf = ReadBuf::new(&mut rest[..]);
166 ready!(Pin::new(&mut self.reader).poll_read(cx, &mut buf))?;
167 let n = buf.filled().len();
168 if n == 0 {
169 if self.buffer.is_empty() {
170 return Poll::Ready(Ok(FillResult::Eof));
171 } else {
172 return Poll::Ready(Err(io::Error::from(io::ErrorKind::BrokenPipe)));
173 }
174 }
175
176 let read = rest.split_to(n);
178 self.buffer.unsplit(read);
179 }
180
181 Poll::Ready(Ok(FillResult::Filled))
182 }
183}