#![cfg(feature = "zstd")]
use std::io::Write;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use bytes::Bytes;
use futures_util::Stream;
use futures_util::TryStreamExt;
use http_body::Body;
use http_body::Frame;
use http_body_util::BodyExt;
use pin_project_lite::pin_project;
use zstd::stream::Encoder;
use crate::body::TakoBody;
use crate::types::BoxError;
pub fn stream_zstd<B>(body: B, level: i32) -> TakoBody
where
B: Body<Data = Bytes, Error = BoxError> + Send + 'static,
{
let upstream = body.into_data_stream();
let zstd_stream = ZstdStream::new(upstream, level).map_ok(Frame::data);
TakoBody::from_try_stream(zstd_stream)
}
pin_project! {
pub struct ZstdStream<S> {
#[pin] inner: S,
encoder: Option<Encoder<'static, Vec<u8>>>,
buffer: Vec<u8>,
pos: usize,
done: bool,
}
}
impl<S> ZstdStream<S> {
fn new(stream: S, level: i32) -> Self {
Self {
inner: stream,
encoder: Some(Encoder::new(Vec::new(), level).expect("zstd encoder")),
buffer: Vec::new(),
pos: 0,
done: false,
}
}
}
impl<S> Stream for ZstdStream<S>
where
S: Stream<Item = Result<Bytes, BoxError>>,
{
type Item = Result<Bytes, BoxError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
if *this.pos < this.buffer.len() {
let chunk = &this.buffer[*this.pos..];
*this.pos = this.buffer.len();
return Poll::Ready(Some(Ok(Bytes::copy_from_slice(chunk))));
}
if *this.done && this.encoder.is_none() {
return Poll::Ready(None);
}
match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(data))) => {
if let Some(enc) = this.encoder.as_mut() {
if let Err(e) = enc.write_all(&data).and_then(|_| enc.flush()) {
return Poll::Ready(Some(Err(e.into())));
}
let out = enc.get_ref();
if !out.is_empty() {
this.buffer.clear();
this.buffer.extend_from_slice(out);
*this.pos = 0;
}
}
continue; }
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
*this.done = true;
if let Some(enc) = this.encoder.take() {
match enc.finish() {
Ok(mut vec) => {
this.buffer.clear();
this.buffer.append(&mut vec);
*this.pos = 0;
continue; }
Err(e) => {
return Poll::Ready(Some(Err(e.into())));
}
}
} else {
return Poll::Ready(None);
}
}
Poll::Pending => {
return Poll::Pending;
}
}
}
}
}