use std::{future::Future, io, io::Write, pin::Pin, task::Context, task::Poll};
use flate2::write::{GzDecoder, ZlibDecoder};
use super::Writer;
use crate::http::error::PayloadError;
use crate::http::header::{CONTENT_ENCODING, ContentEncoding, HeaderMap};
use crate::rt::{BlockingResult, spawn_blocking};
use crate::util::{Bytes, Stream};
const INPLACE: usize = 2049;
#[derive(derive_more::Debug)]
pub struct Decoder<S> {
#[debug(skip)]
inner: Option<ContentDecoder>,
stream: S,
eof: bool,
#[debug(skip)]
fut: Option<BlockingResult<Result<(Option<Bytes>, ContentDecoder), io::Error>>>,
}
impl<S> Decoder<S>
where
S: Stream<Item = Result<Bytes, PayloadError>>,
{
#[inline]
pub fn new(stream: S, encoding: ContentEncoding) -> Decoder<S> {
let inner = match encoding {
ContentEncoding::Deflate => Some(ContentDecoder::Deflate(Box::new(
ZlibDecoder::new(Writer::new()),
))),
ContentEncoding::Gzip => Some(ContentDecoder::Gzip(Box::new(GzDecoder::new(
Writer::new(),
)))),
_ => None,
};
Decoder {
inner,
stream,
fut: None,
eof: false,
}
}
#[inline]
pub fn from_headers(stream: S, headers: &HeaderMap) -> Decoder<S> {
let encoding = if let Some(enc) = headers.get(&CONTENT_ENCODING) {
if let Ok(enc) = enc.to_str() {
ContentEncoding::from(enc)
} else {
ContentEncoding::Identity
}
} else {
ContentEncoding::Identity
};
Self::new(stream, encoding)
}
}
impl<S> Stream for Decoder<S>
where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
type Item = Result<Bytes, PayloadError>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
loop {
if let Some(ref mut fut) = self.fut {
let (chunk, decoder) = match Pin::new(fut).poll(cx) {
Poll::Ready(Ok(Ok(item))) => item,
Poll::Ready(Ok(Err(e))) => return Poll::Ready(Some(Err(e.into()))),
Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
Poll::Pending => return Poll::Pending,
};
self.inner = Some(decoder);
self.fut.take();
if let Some(chunk) = chunk {
return Poll::Ready(Some(Ok(chunk)));
}
}
if self.eof {
return Poll::Ready(None);
}
match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
Poll::Ready(Some(Ok(chunk))) => {
if let Some(mut decoder) = self.inner.take() {
if chunk.len() < INPLACE {
let chunk = decoder.feed_data(&chunk)?;
self.inner = Some(decoder);
if let Some(chunk) = chunk {
return Poll::Ready(Some(Ok(chunk)));
}
} else {
self.fut = Some(spawn_blocking(move || {
let chunk = decoder.feed_data(&chunk)?;
Ok((chunk, decoder))
}));
}
continue;
}
return Poll::Ready(Some(Ok(chunk)));
}
Poll::Ready(None) => {
self.eof = true;
return if let Some(mut decoder) = self.inner.take() {
match decoder.feed_eof() {
Ok(Some(res)) => Poll::Ready(Some(Ok(res))),
Ok(None) => Poll::Ready(None),
Err(err) => Poll::Ready(Some(Err(err.into()))),
}
} else {
Poll::Ready(None)
};
}
Poll::Pending => break,
}
}
Poll::Pending
}
}
enum ContentDecoder {
Deflate(Box<ZlibDecoder<Writer>>),
Gzip(Box<GzDecoder<Writer>>),
}
impl ContentDecoder {
fn feed_eof(&mut self) -> io::Result<Option<Bytes>> {
match self {
ContentDecoder::Gzip(decoder) => match decoder.try_finish() {
Ok(()) => {
let b = decoder.get_mut().take();
if b.is_empty() { Ok(None) } else { Ok(Some(b)) }
}
Err(e) => Err(e),
},
ContentDecoder::Deflate(decoder) => match decoder.try_finish() {
Ok(()) => {
let b = decoder.get_mut().take();
if b.is_empty() { Ok(None) } else { Ok(Some(b)) }
}
Err(e) => Err(e),
},
}
}
fn feed_data(&mut self, data: &Bytes) -> io::Result<Option<Bytes>> {
match self {
ContentDecoder::Gzip(decoder) => match decoder.write_all(data) {
Ok(()) => {
decoder.flush()?;
let b = decoder.get_mut().take();
if b.is_empty() { Ok(None) } else { Ok(Some(b)) }
}
Err(e) => Err(e),
},
ContentDecoder::Deflate(decoder) => match decoder.write_all(data) {
Ok(()) => {
decoder.flush()?;
let b = decoder.get_mut().take();
if b.is_empty() { Ok(None) } else { Ok(Some(b)) }
}
Err(e) => Err(e),
},
}
}
}