use std::{pin::Pin, task::Poll};
use crate::logger::tracing::error;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures_util::{future, ready, Stream};
use http_body::Body;
use super::compression::{decompress, CompressionEncoding};
use crate::{
invocation::Metadata,
triple::codec::{DecodeBuf, Decoder},
};
type BoxBody = http_body::combinators::UnsyncBoxBody<Bytes, crate::status::Status>;
pub struct Decoding<T> {
state: State,
body: BoxBody,
decoder: Box<dyn Decoder<Item = T, Error = crate::status::Status> + Send + 'static>,
buf: BytesMut,
trailers: Option<Metadata>,
compress: Option<CompressionEncoding>,
decompress_buf: BytesMut,
decode_as_grpc: bool,
}
#[derive(PartialEq)]
enum State {
ReadHeader,
ReadHttpBody,
ReadBody { len: usize, is_compressed: bool },
Error,
}
impl<T> Decoding<T> {
pub fn new<B>(
body: B,
decoder: Box<dyn Decoder<Item = T, Error = crate::status::Status> + Send + 'static>,
compress: Option<CompressionEncoding>,
decode_as_grpc: bool,
) -> Self
where
B: Body + Send + 'static,
B::Error: Into<crate::Error>,
{
Self {
state: State::ReadHeader,
body: body
.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()))
.map_err(|_err| {
crate::status::Status::new(
crate::status::Code::Internal,
"internal decode err".to_string(),
)
})
.boxed_unsync(),
decoder,
buf: BytesMut::with_capacity(super::consts::BUFFER_SIZE),
trailers: None,
compress,
decompress_buf: BytesMut::new(),
decode_as_grpc,
}
}
pub async fn message(&mut self) -> Result<Option<T>, crate::status::Status> {
match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await {
Some(Ok(res)) => Ok(Some(res)),
Some(Err(err)) => Err(err),
None => Ok(None),
}
}
pub async fn trailer(&mut self) -> Result<Option<Metadata>, crate::status::Status> {
if let Some(t) = self.trailers.take() {
return Ok(Some(t));
}
let trailer = future::poll_fn(|cx| Pin::new(&mut self.body).poll_trailers(cx)).await;
trailer.map(|data| data.map(Metadata::from_headers))
}
pub fn decode_http(&mut self) -> Result<Option<T>, crate::status::Status> {
if self.state == State::ReadHeader {
self.state = State::ReadHttpBody;
return Ok(None);
}
if let State::ReadHttpBody = self.state {
if self.buf.is_empty() {
return Ok(None);
}
match self.compress {
None => self.decompress_buf = self.buf.clone(),
Some(compress) => {
let len = self.buf.len();
if let Err(err) =
decompress(compress, &mut self.buf, &mut self.decompress_buf, len)
{
return Err(crate::status::Status::new(
crate::status::Code::Internal,
err.to_string(),
));
}
}
}
let len = self.decompress_buf.len();
let decoding_result = self
.decoder
.decode(&mut DecodeBuf::new(&mut self.decompress_buf, len));
return match decoding_result {
Ok(Some(r)) => {
self.state = State::ReadHeader;
Ok(Some(r))
}
Ok(None) => Ok(None),
Err(err) => Err(err),
};
}
Ok(None)
}
pub fn decode_grpc(&mut self) -> Result<Option<T>, crate::status::Status> {
if self.state == State::ReadHeader {
if self.buf.remaining() < super::consts::HEADER_SIZE {
return Ok(None);
}
let is_compressed = match self.buf.get_u8() {
0 => false,
1 => {
if self.compress.is_some() {
true
} else {
return Err(crate::status::Status::new(
crate::status::Code::Internal,
"set compression flag, but no grpc-encoding specified".to_string(),
));
}
}
v => {
return Err(crate::status::Status::new(
crate::status::Code::Internal,
format!(
"receive message with compression flag{}, flag should be 0 or 1",
v
),
))
}
};
let len = self.buf.get_u32() as usize;
self.buf.reserve(len as usize);
self.state = State::ReadBody { len, is_compressed }
}
if let State::ReadBody { len, is_compressed } = self.state {
if self.buf.remaining() < len || self.buf.len() < len {
return Ok(None);
}
let decoding_result = if is_compressed {
self.decompress_buf.clear();
if let Err(err) = decompress(
self.compress.unwrap(),
&mut self.buf,
&mut self.decompress_buf,
len,
) {
return Err(crate::status::Status::new(
crate::status::Code::Internal,
err.to_string(),
));
}
let decompress_len = self.decompress_buf.len();
self.decoder.decode(&mut DecodeBuf::new(
&mut self.decompress_buf,
decompress_len,
))
} else {
self.decoder.decode(&mut DecodeBuf::new(&mut self.buf, len))
};
return match decoding_result {
Ok(Some(r)) => {
self.state = State::ReadHeader;
Ok(Some(r))
}
Ok(None) => Ok(None),
Err(err) => Err(err),
};
}
Ok(None)
}
pub fn decode_chunk(&mut self) -> Result<Option<T>, crate::status::Status> {
if self.decode_as_grpc {
self.decode_grpc()
} else {
self.decode_http()
}
}
}
impl<T> Stream for Decoding<T> {
type Item = Result<T, crate::status::Status>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
loop {
if self.state == State::Error {
return Poll::Ready(None);
}
if let Some(item) = self.decode_chunk()? {
return Poll::Ready(Some(Ok(item)));
}
let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) {
Some(Ok(d)) => Some(d),
Some(Err(e)) => {
let _ = std::mem::replace(&mut self.state, State::Error);
let err: crate::Error = e.into();
return Poll::Ready(Some(Err(crate::status::Status::new(
crate::status::Code::Internal,
err.to_string(),
))));
}
None => None,
};
if let Some(data) = chunk {
self.buf.put(data)
} else {
break;
}
}
match ready!(Pin::new(&mut self.body).poll_trailers(cx)) {
Ok(trailer) => {
self.trailers = trailer.map(Metadata::from_headers);
}
Err(err) => {
error!("poll_trailers, err: {}", err);
}
}
Poll::Ready(None)
}
fn size_hint(&self) -> (usize, Option<usize>) {
(0, None)
}
}