dubbo/triple/
decode.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18use std::{pin::Pin, task::Poll};
19
20use crate::logger::tracing::error;
21use bytes::{Buf, BufMut, Bytes, BytesMut};
22use futures_util::{future, ready, Stream};
23use http_body::Body;
24
25use super::compression::{decompress, CompressionEncoding};
26use crate::{
27    invocation::Metadata,
28    triple::codec::{DecodeBuf, Decoder},
29};
30
31type BoxBody = http_body::combinators::UnsyncBoxBody<Bytes, crate::status::Status>;
32
33pub struct Decoding<T> {
34    state: State,
35    body: BoxBody,
36    decoder: Box<dyn Decoder<Item = T, Error = crate::status::Status> + Send + 'static>,
37    buf: BytesMut,
38    trailers: Option<Metadata>,
39    compress: Option<CompressionEncoding>,
40    decompress_buf: BytesMut,
41    decode_as_grpc: bool,
42}
43
44#[derive(PartialEq)]
45enum State {
46    ReadHeader,
47    ReadHttpBody,
48    ReadBody { len: usize, is_compressed: bool },
49    Error,
50}
51
52impl<T> Decoding<T> {
53    pub fn new<B>(
54        body: B,
55        decoder: Box<dyn Decoder<Item = T, Error = crate::status::Status> + Send + 'static>,
56        compress: Option<CompressionEncoding>,
57        decode_as_grpc: bool,
58    ) -> Self
59    where
60        B: Body + Send + 'static,
61        B::Error: Into<crate::Error>,
62    {
63        //Determine whether to use the gRPC mode to handle request data
64        Self {
65            state: State::ReadHeader,
66            body: body
67                .map_data(|mut buf| buf.copy_to_bytes(buf.remaining()))
68                .map_err(|_err| {
69                    crate::status::Status::new(
70                        crate::status::Code::Internal,
71                        "internal decode err".to_string(),
72                    )
73                })
74                .boxed_unsync(),
75            decoder,
76            buf: BytesMut::with_capacity(super::consts::BUFFER_SIZE),
77            trailers: None,
78            compress,
79            decompress_buf: BytesMut::new(),
80            decode_as_grpc,
81        }
82    }
83
84    pub async fn message(&mut self) -> Result<Option<T>, crate::status::Status> {
85        match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await {
86            Some(Ok(res)) => Ok(Some(res)),
87            Some(Err(err)) => Err(err),
88            None => Ok(None),
89        }
90    }
91
92    pub async fn trailer(&mut self) -> Result<Option<Metadata>, crate::status::Status> {
93        if let Some(t) = self.trailers.take() {
94            return Ok(Some(t));
95        }
96        // while self.message().await?.is_some() {}
97
98        let trailer = future::poll_fn(|cx| Pin::new(&mut self.body).poll_trailers(cx)).await;
99        trailer.map(|data| data.map(Metadata::from_headers))
100    }
101
102    pub fn decode_http(&mut self) -> Result<Option<T>, crate::status::Status> {
103        if self.state == State::ReadHeader {
104            self.state = State::ReadHttpBody;
105            return Ok(None);
106        }
107        if let State::ReadHttpBody = self.state {
108            if self.buf.is_empty() {
109                return Ok(None);
110            }
111            match self.compress {
112                None => self.decompress_buf = self.buf.clone(),
113                Some(compress) => {
114                    let len = self.buf.len();
115                    if let Err(err) =
116                        decompress(compress, &mut self.buf, &mut self.decompress_buf, len)
117                    {
118                        return Err(crate::status::Status::new(
119                            crate::status::Code::Internal,
120                            err.to_string(),
121                        ));
122                    }
123                }
124            }
125            let len = self.decompress_buf.len();
126            let decoding_result = self
127                .decoder
128                .decode(&mut DecodeBuf::new(&mut self.decompress_buf, len));
129
130            return match decoding_result {
131                Ok(Some(r)) => {
132                    self.state = State::ReadHeader;
133                    Ok(Some(r))
134                }
135                Ok(None) => Ok(None),
136                Err(err) => Err(err),
137            };
138        }
139        Ok(None)
140    }
141
142    pub fn decode_grpc(&mut self) -> Result<Option<T>, crate::status::Status> {
143        if self.state == State::ReadHeader {
144            // buffer is full
145            if self.buf.remaining() < super::consts::HEADER_SIZE {
146                return Ok(None);
147            }
148
149            let is_compressed = match self.buf.get_u8() {
150                0 => false,
151                1 => {
152                    if self.compress.is_some() {
153                        true
154                    } else {
155                        return Err(crate::status::Status::new(
156                            crate::status::Code::Internal,
157                            "set compression flag, but no grpc-encoding specified".to_string(),
158                        ));
159                    }
160                }
161                v => {
162                    return Err(crate::status::Status::new(
163                        crate::status::Code::Internal,
164                        format!(
165                            "receive message with compression flag{}, flag should be 0 or 1",
166                            v
167                        ),
168                    ))
169                }
170            };
171            let len = self.buf.get_u32() as usize;
172            self.buf.reserve(len as usize);
173
174            self.state = State::ReadBody { len, is_compressed }
175        }
176
177        if let State::ReadBody { len, is_compressed } = self.state {
178            if self.buf.remaining() < len || self.buf.len() < len {
179                return Ok(None);
180            }
181
182            let decoding_result = if is_compressed {
183                self.decompress_buf.clear();
184                if let Err(err) = decompress(
185                    self.compress.unwrap(),
186                    &mut self.buf,
187                    &mut self.decompress_buf,
188                    len,
189                ) {
190                    return Err(crate::status::Status::new(
191                        crate::status::Code::Internal,
192                        err.to_string(),
193                    ));
194                }
195
196                let decompress_len = self.decompress_buf.len();
197                self.decoder.decode(&mut DecodeBuf::new(
198                    &mut self.decompress_buf,
199                    decompress_len,
200                ))
201            } else {
202                self.decoder.decode(&mut DecodeBuf::new(&mut self.buf, len))
203            };
204
205            return match decoding_result {
206                Ok(Some(r)) => {
207                    self.state = State::ReadHeader;
208                    Ok(Some(r))
209                }
210                Ok(None) => Ok(None),
211                Err(err) => Err(err),
212            };
213        }
214
215        Ok(None)
216    }
217
218    pub fn decode_chunk(&mut self) -> Result<Option<T>, crate::status::Status> {
219        if self.decode_as_grpc {
220            self.decode_grpc()
221        } else {
222            self.decode_http()
223        }
224    }
225}
226
227impl<T> Stream for Decoding<T> {
228    type Item = Result<T, crate::status::Status>;
229
230    fn poll_next(
231        mut self: std::pin::Pin<&mut Self>,
232        cx: &mut std::task::Context<'_>,
233    ) -> std::task::Poll<Option<Self::Item>> {
234        loop {
235            if self.state == State::Error {
236                return Poll::Ready(None);
237            }
238
239            if let Some(item) = self.decode_chunk()? {
240                return Poll::Ready(Some(Ok(item)));
241            }
242
243            let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) {
244                Some(Ok(d)) => Some(d),
245                Some(Err(e)) => {
246                    let _ = std::mem::replace(&mut self.state, State::Error);
247                    let err: crate::Error = e.into();
248                    return Poll::Ready(Some(Err(crate::status::Status::new(
249                        crate::status::Code::Internal,
250                        err.to_string(),
251                    ))));
252                }
253                None => None,
254            };
255
256            if let Some(data) = chunk {
257                self.buf.put(data)
258            } else {
259                break;
260            }
261        }
262
263        match ready!(Pin::new(&mut self.body).poll_trailers(cx)) {
264            Ok(trailer) => {
265                self.trailers = trailer.map(Metadata::from_headers);
266            }
267            Err(err) => {
268                error!("poll_trailers, err: {}", err);
269            }
270        }
271
272        Poll::Ready(None)
273    }
274
275    fn size_hint(&self) -> (usize, Option<usize>) {
276        (0, None)
277    }
278}