volo_grpc/codec/
decode.rs

1use std::{
2    fmt,
3    marker::PhantomData,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use bytes::{Buf, BufMut, BytesMut};
9use futures::{Stream, future};
10use futures_util::ready;
11use http::StatusCode;
12use http_body::Body;
13use pilota::pb::Message;
14use tracing::{debug, trace};
15
16use super::{BUFFER_SIZE, DefaultDecoder, PREFIX_LEN};
17use crate::{
18    Status,
19    body::BoxBody,
20    codec::{
21        Decoder,
22        compression::{CompressionEncoding, decompress},
23    },
24    metadata::MetadataMap,
25    status::Code,
26};
27
28/// Streaming Received Request and Received Response.
29///
30/// Provides an interface for receiving messages and trailers.
31pub struct RecvStream<T> {
32    body: BoxBody,
33    decoder: DefaultDecoder<T>,
34    trailers: Option<MetadataMap>,
35    buf: BytesMut,
36    state: State,
37    kind: Kind,
38    compression_encoding: Option<CompressionEncoding>,
39    decompress_buf: BytesMut,
40}
41
42impl<T> Unpin for RecvStream<T> {}
43
44#[derive(Debug, Clone)]
45enum State {
46    Header,
47    Body(Option<CompressionEncoding>, usize),
48    Error,
49}
50
51#[derive(Debug, PartialEq, Eq)]
52pub enum Kind {
53    Request,
54    Response(StatusCode),
55}
56
57impl<T> RecvStream<T> {
58    pub fn new(
59        body: BoxBody,
60        kind: Kind,
61        compression_encoding: Option<CompressionEncoding>,
62    ) -> Self {
63        RecvStream {
64            body,
65            decoder: DefaultDecoder(PhantomData),
66            trailers: None,
67            buf: BytesMut::with_capacity(BUFFER_SIZE),
68            state: State::Header,
69            kind,
70            compression_encoding,
71            decompress_buf: BytesMut::new(),
72        }
73    }
74}
75
76impl<T: Message + Default> RecvStream<T> {
77    /// Get the next message from the stream.
78    async fn message(&mut self) -> Result<Option<T>, Status> {
79        match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await {
80            Some(Ok(m)) => Ok(Some(m)),
81            Some(Err(e)) => Err(e),
82            None => Ok(None),
83        }
84    }
85
86    /// Get the trailers from the stream.
87    pub async fn trailers(&mut self) -> Result<Option<MetadataMap>, Status> {
88        if let Some(trailers) = self.trailers.take() {
89            return Ok(Some(trailers));
90        }
91
92        // Ensure read body to the end in case of memory leak.
93        // Related issue: https://github.com/hyperium/h2/issues/631.
94        while self.message().await?.is_some() {}
95
96        if let Some(trailers) = self.trailers.take() {
97            return Ok(Some(trailers));
98        }
99
100        let maybe_trailer = future::poll_fn(|cx| Pin::new(&mut self.body).poll_frame(cx)).await;
101
102        match maybe_trailer {
103            Some(Ok(frame)) => match frame.into_trailers() {
104                Ok(headers) => Ok(Some(MetadataMap::from_headers(headers))),
105                Err(_frame) => {
106                    // **unreachable** because the `frame` cannot be `Frame::Data` here
107                    debug!("[VOLO] unexpected data from stream");
108                    Err(Status::new(
109                        Code::Internal,
110                        "Unexpected data from stream.".to_string(),
111                    ))
112                }
113            },
114            Some(Err(err)) => Err(Status::from_error(Box::new(err))),
115            None => Ok(None),
116        }
117    }
118
119    #[allow(clippy::result_large_err)]
120    fn decode_chunk(&mut self) -> Result<Option<T>, Status> {
121        if let State::Header = self.state {
122            // data is not enough to decode header, return and keep reading
123            if self.buf.remaining() < PREFIX_LEN {
124                return Ok(None);
125            }
126            trace!("[VOLO-GRPC] streaming received buf: {:?}", self.buf);
127
128            let compression_encoding = match self.buf.get_u8() {
129                0 => None,
130                1 => {
131                    if self.compression_encoding.is_some() {
132                        self.compression_encoding
133                    } else {
134                        return Err(Status::new(
135                            Code::Internal,
136                            "protocol error: received message with compressed-flag but no \
137                             grpc-encoding was specified"
138                                .to_string(),
139                        ));
140                    }
141                }
142                flag => {
143                    let message = format!(
144                        "protocol error: received message with invalid compression flag: {flag} \
145                         (valid flags are 0 and 1), while sending request"
146                    );
147                    // https://grpc.github.io/grpc/core/md_doc_compression.html
148                    return Err(Status::new(Code::Internal, message));
149                }
150            };
151            let len = self.buf.get_u32() as usize;
152            self.buf.reserve(len);
153
154            self.state = State::Body(compression_encoding, len);
155        }
156
157        if let State::Body(compression_encoding, len) = &self.state {
158            // data is not enough to decode body, return and keep reading
159            if self.buf.remaining() < *len || self.buf.len() < *len {
160                return Ok(None);
161            }
162            trace!("[VOLO-GRPC] streaming reading body: {:?}", self.buf);
163            let mut buf = self.buf.split_to(*len);
164            let decode_result = if let Some(encoding) = compression_encoding {
165                self.decompress_buf.clear();
166                if let Err(err) = decompress(*encoding, &mut buf, &mut self.decompress_buf) {
167                    let message = if let Kind::Response(status) = self.kind {
168                        format!(
169                            "Error decompressing: {err}, while receiving response with status: \
170                             {status}"
171                        )
172                    } else {
173                        format!("Error decompressing: {err}, while sending request")
174                    };
175                    return Err(Status::new(Code::Internal, message));
176                }
177                DefaultDecoder::<T>::decode(&mut self.decoder, self.decompress_buf.split().freeze())
178            } else {
179                DefaultDecoder::<T>::decode(&mut self.decoder, buf.freeze())
180            };
181
182            return match decode_result {
183                Ok(Some(msg)) => {
184                    self.state = State::Header;
185                    Ok(Some(msg))
186                }
187                Ok(None) => Ok(None),
188                Err(e) => Err(e),
189            };
190        }
191
192        Ok(None)
193    }
194}
195
196impl<T: Message + Default> Stream for RecvStream<T> {
197    type Item = Result<T, Status>;
198
199    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
200        let trailer_frame = loop {
201            if let State::Error = &self.state {
202                return Poll::Ready(None);
203            }
204            if let Some(item) = self.decode_chunk()? {
205                return Poll::Ready(Some(Ok(item)));
206            }
207
208            match ready!(Pin::new(&mut self.body).poll_frame(cx)) {
209                Some(Ok(frame)) => match frame.into_data() {
210                    Ok(data) => self.buf.put(data),
211                    Err(trailer) => {
212                        break Some(trailer);
213                    }
214                },
215                Some(Err(e)) => {
216                    let err: crate::BoxError = e.into();
217                    let status = Status::from_error(err);
218                    if self.kind == Kind::Request && status.code() == Code::Cancelled {
219                        return Poll::Ready(None);
220                    }
221                    debug!("[VOLO] decoder inner stream error: {:?}", status);
222                    let _ = std::mem::replace(&mut self.state, State::Error);
223                    return Poll::Ready(Some(Err(status)));
224                }
225                None => {
226                    if self.buf.has_remaining() {
227                        debug!("[VOLO] unexpected EOF decoding stream");
228                        return Poll::Ready(Some(Err(Status::new(
229                            Code::Internal,
230                            "Unexpected EOF decoding stream.".to_string(),
231                        ))));
232                    } else {
233                        break None;
234                    }
235                }
236            }
237        };
238
239        if let Kind::Response(status) = self.kind {
240            let trailer = match trailer_frame.map(|frame| frame.into_trailers()) {
241                Some(Ok(trailer)) => Some(trailer),
242                Some(Err(_frame)) => {
243                    // **unreachable** because the `frame` cannot be `Frame::Data` here
244                    debug!("[VOLO] unexpected data from stream");
245                    return Poll::Ready(Some(Err(Status::new(
246                        Code::Internal,
247                        "Unexpected data from stream.".to_string(),
248                    ))));
249                }
250                None => None,
251            };
252
253            if let Err(e) = Status::infer_grpc_status(trailer.as_ref(), status) {
254                return if let Some(e) = e {
255                    Some(Err(e)).into()
256                } else {
257                    Poll::Ready(None)
258                };
259            } else {
260                self.trailers = trailer.map(MetadataMap::from_headers);
261            }
262        }
263
264        Poll::Ready(None)
265    }
266}
267
268impl<T> fmt::Debug for RecvStream<T> {
269    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270        f.debug_struct("RecvStream").finish()
271    }
272}