apache_dubbo/triple/
decode.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18use std::{pin::Pin, task::Poll};
19
20use bytes::{Buf, BufMut, Bytes, BytesMut};
21use futures_util::Stream;
22use futures_util::{future, ready};
23use http_body::Body;
24
25use super::compression::{decompress, CompressionEncoding};
26use crate::invocation::Metadata;
27use crate::triple::codec::{DecodeBuf, Decoder};
28
29type BoxBody = http_body::combinators::UnsyncBoxBody<Bytes, crate::status::Status>;
30
31pub struct Decoding<T> {
32    state: State,
33    body: BoxBody,
34    decoder: Box<dyn Decoder<Item = T, Error = crate::status::Status> + Send + 'static>,
35    buf: BytesMut,
36    trailers: Option<Metadata>,
37    compress: Option<CompressionEncoding>,
38    decompress_buf: BytesMut,
39}
40
41#[derive(PartialEq)]
42enum State {
43    ReadHeader,
44    ReadBody { len: usize, is_compressed: bool },
45    Error,
46}
47
48impl<T> Decoding<T> {
49    pub fn new<B, D>(body: B, decoder: D, compress: Option<CompressionEncoding>) -> Self
50    where
51        B: Body + Send + 'static,
52        B::Error: Into<crate::Error>,
53        D: Decoder<Item = T, Error = crate::status::Status> + Send + 'static,
54    {
55        Self {
56            state: State::ReadHeader,
57            body: body
58                .map_data(|mut buf| buf.copy_to_bytes(buf.remaining()))
59                .map_err(|_err| {
60                    crate::status::Status::new(
61                        crate::status::Code::Internal,
62                        "internal decode err".to_string(),
63                    )
64                })
65                .boxed_unsync(),
66            decoder: Box::new(decoder),
67            buf: BytesMut::with_capacity(super::consts::BUFFER_SIZE),
68            trailers: None,
69            compress,
70            decompress_buf: BytesMut::new(),
71        }
72    }
73
74    pub async fn message(&mut self) -> Result<Option<T>, crate::status::Status> {
75        match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await {
76            Some(Ok(res)) => Ok(Some(res)),
77            Some(Err(err)) => Err(err),
78            None => Ok(None),
79        }
80    }
81
82    pub async fn trailer(&mut self) -> Result<Option<Metadata>, crate::status::Status> {
83        if let Some(t) = self.trailers.take() {
84            return Ok(Some(t));
85        }
86        // while self.message().await?.is_some() {}
87
88        let trailer = future::poll_fn(|cx| Pin::new(&mut self.body).poll_trailers(cx)).await;
89        trailer.map(|data| data.map(Metadata::from_headers))
90    }
91
92    pub fn decode_chunk(&mut self) -> Result<Option<T>, crate::status::Status> {
93        if self.state == State::ReadHeader {
94            // buffer is full
95            if self.buf.remaining() < super::consts::HEADER_SIZE {
96                return Ok(None);
97            }
98
99            let is_compressed = match self.buf.get_u8() {
100                0 => false,
101                1 => {
102                    if self.compress.is_some() {
103                        true
104                    } else {
105                        return Err(crate::status::Status::new(
106                            crate::status::Code::Internal,
107                            "set compression flag, but no grpc-encoding specified".to_string(),
108                        ));
109                    }
110                }
111                v => {
112                    return Err(crate::status::Status::new(
113                        crate::status::Code::Internal,
114                        format!(
115                            "receive message with compression flag{}, flag should be 0 or 1",
116                            v
117                        ),
118                    ))
119                }
120            };
121            let len = self.buf.get_u32() as usize;
122            self.buf.reserve(len as usize);
123
124            self.state = State::ReadBody { len, is_compressed }
125        }
126
127        if let State::ReadBody { len, is_compressed } = self.state {
128            if self.buf.remaining() < len || self.buf.len() < len {
129                return Ok(None);
130            }
131
132            let decoding_result = if is_compressed {
133                self.decompress_buf.clear();
134                if let Err(err) = decompress(
135                    self.compress.unwrap(),
136                    &mut self.buf,
137                    &mut self.decompress_buf,
138                    len,
139                ) {
140                    return Err(crate::status::Status::new(
141                        crate::status::Code::Internal,
142                        err.to_string(),
143                    ));
144                }
145
146                let decompress_len = self.decompress_buf.len();
147                self.decoder.decode(&mut DecodeBuf::new(
148                    &mut self.decompress_buf,
149                    decompress_len,
150                ))
151            } else {
152                self.decoder.decode(&mut DecodeBuf::new(&mut self.buf, len))
153            };
154
155            return match decoding_result {
156                Ok(Some(r)) => {
157                    self.state = State::ReadHeader;
158                    Ok(Some(r))
159                }
160                Ok(None) => Ok(None),
161                Err(err) => Err(err),
162            };
163        }
164
165        Ok(None)
166    }
167}
168
169impl<T> Stream for Decoding<T> {
170    type Item = Result<T, crate::status::Status>;
171
172    fn poll_next(
173        mut self: std::pin::Pin<&mut Self>,
174        cx: &mut std::task::Context<'_>,
175    ) -> std::task::Poll<Option<Self::Item>> {
176        loop {
177            if self.state == State::Error {
178                return Poll::Ready(None);
179            }
180
181            if let Some(item) = self.decode_chunk()? {
182                return Poll::Ready(Some(Ok(item)));
183            }
184
185            let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) {
186                Some(Ok(d)) => Some(d),
187                Some(Err(e)) => {
188                    let _ = std::mem::replace(&mut self.state, State::Error);
189                    let err: crate::Error = e.into();
190                    return Poll::Ready(Some(Err(crate::status::Status::new(
191                        crate::status::Code::Internal,
192                        err.to_string(),
193                    ))));
194                }
195                None => None,
196            };
197
198            if let Some(data) = chunk {
199                self.buf.put(data)
200            } else {
201                break;
202            }
203        }
204
205        match ready!(Pin::new(&mut self.body).poll_trailers(cx)) {
206            Ok(trailer) => {
207                self.trailers = trailer.map(Metadata::from_headers);
208            }
209            Err(err) => {
210                tracing::error!("poll_trailers, err: {}", err);
211            }
212        }
213
214        Poll::Ready(None)
215    }
216
217    fn size_hint(&self) -> (usize, Option<usize>) {
218        (0, None)
219    }
220}