1use std::{pin::Pin, task::Poll};
19
20use crate::logger::tracing::error;
21use bytes::{Buf, BufMut, Bytes, BytesMut};
22use futures_util::{future, ready, Stream};
23use http_body::Body;
24
25use super::compression::{decompress, CompressionEncoding};
26use crate::{
27 invocation::Metadata,
28 triple::codec::{DecodeBuf, Decoder},
29};
30
31type BoxBody = http_body::combinators::UnsyncBoxBody<Bytes, crate::status::Status>;
32
33pub struct Decoding<T> {
34 state: State,
35 body: BoxBody,
36 decoder: Box<dyn Decoder<Item = T, Error = crate::status::Status> + Send + 'static>,
37 buf: BytesMut,
38 trailers: Option<Metadata>,
39 compress: Option<CompressionEncoding>,
40 decompress_buf: BytesMut,
41 decode_as_grpc: bool,
42}
43
44#[derive(PartialEq)]
45enum State {
46 ReadHeader,
47 ReadHttpBody,
48 ReadBody { len: usize, is_compressed: bool },
49 Error,
50}
51
52impl<T> Decoding<T> {
53 pub fn new<B>(
54 body: B,
55 decoder: Box<dyn Decoder<Item = T, Error = crate::status::Status> + Send + 'static>,
56 compress: Option<CompressionEncoding>,
57 decode_as_grpc: bool,
58 ) -> Self
59 where
60 B: Body + Send + 'static,
61 B::Error: Into<crate::Error>,
62 {
63 Self {
65 state: State::ReadHeader,
66 body: body
67 .map_data(|mut buf| buf.copy_to_bytes(buf.remaining()))
68 .map_err(|_err| {
69 crate::status::Status::new(
70 crate::status::Code::Internal,
71 "internal decode err".to_string(),
72 )
73 })
74 .boxed_unsync(),
75 decoder,
76 buf: BytesMut::with_capacity(super::consts::BUFFER_SIZE),
77 trailers: None,
78 compress,
79 decompress_buf: BytesMut::new(),
80 decode_as_grpc,
81 }
82 }
83
84 pub async fn message(&mut self) -> Result<Option<T>, crate::status::Status> {
85 match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await {
86 Some(Ok(res)) => Ok(Some(res)),
87 Some(Err(err)) => Err(err),
88 None => Ok(None),
89 }
90 }
91
92 pub async fn trailer(&mut self) -> Result<Option<Metadata>, crate::status::Status> {
93 if let Some(t) = self.trailers.take() {
94 return Ok(Some(t));
95 }
96 let trailer = future::poll_fn(|cx| Pin::new(&mut self.body).poll_trailers(cx)).await;
99 trailer.map(|data| data.map(Metadata::from_headers))
100 }
101
102 pub fn decode_http(&mut self) -> Result<Option<T>, crate::status::Status> {
103 if self.state == State::ReadHeader {
104 self.state = State::ReadHttpBody;
105 return Ok(None);
106 }
107 if let State::ReadHttpBody = self.state {
108 if self.buf.is_empty() {
109 return Ok(None);
110 }
111 match self.compress {
112 None => self.decompress_buf = self.buf.clone(),
113 Some(compress) => {
114 let len = self.buf.len();
115 if let Err(err) =
116 decompress(compress, &mut self.buf, &mut self.decompress_buf, len)
117 {
118 return Err(crate::status::Status::new(
119 crate::status::Code::Internal,
120 err.to_string(),
121 ));
122 }
123 }
124 }
125 let len = self.decompress_buf.len();
126 let decoding_result = self
127 .decoder
128 .decode(&mut DecodeBuf::new(&mut self.decompress_buf, len));
129
130 return match decoding_result {
131 Ok(Some(r)) => {
132 self.state = State::ReadHeader;
133 Ok(Some(r))
134 }
135 Ok(None) => Ok(None),
136 Err(err) => Err(err),
137 };
138 }
139 Ok(None)
140 }
141
142 pub fn decode_grpc(&mut self) -> Result<Option<T>, crate::status::Status> {
143 if self.state == State::ReadHeader {
144 if self.buf.remaining() < super::consts::HEADER_SIZE {
146 return Ok(None);
147 }
148
149 let is_compressed = match self.buf.get_u8() {
150 0 => false,
151 1 => {
152 if self.compress.is_some() {
153 true
154 } else {
155 return Err(crate::status::Status::new(
156 crate::status::Code::Internal,
157 "set compression flag, but no grpc-encoding specified".to_string(),
158 ));
159 }
160 }
161 v => {
162 return Err(crate::status::Status::new(
163 crate::status::Code::Internal,
164 format!(
165 "receive message with compression flag{}, flag should be 0 or 1",
166 v
167 ),
168 ))
169 }
170 };
171 let len = self.buf.get_u32() as usize;
172 self.buf.reserve(len as usize);
173
174 self.state = State::ReadBody { len, is_compressed }
175 }
176
177 if let State::ReadBody { len, is_compressed } = self.state {
178 if self.buf.remaining() < len || self.buf.len() < len {
179 return Ok(None);
180 }
181
182 let decoding_result = if is_compressed {
183 self.decompress_buf.clear();
184 if let Err(err) = decompress(
185 self.compress.unwrap(),
186 &mut self.buf,
187 &mut self.decompress_buf,
188 len,
189 ) {
190 return Err(crate::status::Status::new(
191 crate::status::Code::Internal,
192 err.to_string(),
193 ));
194 }
195
196 let decompress_len = self.decompress_buf.len();
197 self.decoder.decode(&mut DecodeBuf::new(
198 &mut self.decompress_buf,
199 decompress_len,
200 ))
201 } else {
202 self.decoder.decode(&mut DecodeBuf::new(&mut self.buf, len))
203 };
204
205 return match decoding_result {
206 Ok(Some(r)) => {
207 self.state = State::ReadHeader;
208 Ok(Some(r))
209 }
210 Ok(None) => Ok(None),
211 Err(err) => Err(err),
212 };
213 }
214
215 Ok(None)
216 }
217
218 pub fn decode_chunk(&mut self) -> Result<Option<T>, crate::status::Status> {
219 if self.decode_as_grpc {
220 self.decode_grpc()
221 } else {
222 self.decode_http()
223 }
224 }
225}
226
227impl<T> Stream for Decoding<T> {
228 type Item = Result<T, crate::status::Status>;
229
230 fn poll_next(
231 mut self: std::pin::Pin<&mut Self>,
232 cx: &mut std::task::Context<'_>,
233 ) -> std::task::Poll<Option<Self::Item>> {
234 loop {
235 if self.state == State::Error {
236 return Poll::Ready(None);
237 }
238
239 if let Some(item) = self.decode_chunk()? {
240 return Poll::Ready(Some(Ok(item)));
241 }
242
243 let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) {
244 Some(Ok(d)) => Some(d),
245 Some(Err(e)) => {
246 let _ = std::mem::replace(&mut self.state, State::Error);
247 let err: crate::Error = e.into();
248 return Poll::Ready(Some(Err(crate::status::Status::new(
249 crate::status::Code::Internal,
250 err.to_string(),
251 ))));
252 }
253 None => None,
254 };
255
256 if let Some(data) = chunk {
257 self.buf.put(data)
258 } else {
259 break;
260 }
261 }
262
263 match ready!(Pin::new(&mut self.body).poll_trailers(cx)) {
264 Ok(trailer) => {
265 self.trailers = trailer.map(Metadata::from_headers);
266 }
267 Err(err) => {
268 error!("poll_trailers, err: {}", err);
269 }
270 }
271
272 Poll::Ready(None)
273 }
274
275 fn size_hint(&self) -> (usize, Option<usize>) {
276 (0, None)
277 }
278}