use crate::service::{Layer, Service};
use bytes::buf::Writer;
use bytes::{BufMut, Bytes, BytesMut};
use flate2::write::GzDecoder;
use http::header::{Entry, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH};
use http::{HeaderValue, Request, Response};
use http_body::{Body, Frame, SizeHint};
use once_cell::sync::Lazy;
use pin_project::pin_project;
use std::error::Error;
use std::io::Write;
use std::mem;
use std::pin::Pin;
use std::task::{Context, Poll};
static GZIP: Lazy<HeaderValue> = Lazy::new(|| HeaderValue::from_static("gzip"));
pub struct GzipLayer;
impl<S> Layer<S> for GzipLayer {
type Service = GzipService<S>;
fn layer(self, inner: S) -> GzipService<S> {
GzipService { inner }
}
}
pub struct GzipService<S> {
inner: S,
}
impl<S, B1, B2> Service<Request<B1>> for GzipService<S>
where
S: Service<Request<B1>, Response = Response<B2>>,
B2: Body<Data = Bytes>,
B2::Error: Into<Box<dyn Error + Sync + Send>>,
{
type Response = Response<DecodedBody<B2>>;
type Error = S::Error;
async fn call(&self, mut req: Request<B1>) -> Result<Self::Response, Self::Error> {
if let Entry::Vacant(e) = req.headers_mut().entry(ACCEPT_ENCODING) {
e.insert(GZIP.clone());
}
let response = self.inner.call(req).await?;
let (mut parts, body) = response.into_parts();
let state = match parts.headers.get(CONTENT_ENCODING) {
Some(v) if v == *GZIP => {
parts.headers.remove(CONTENT_ENCODING);
parts.headers.remove(CONTENT_LENGTH);
State::Decompressing(GzDecoder::new(BytesMut::new().writer()))
}
_ => State::Done,
};
let body = DecodedBody { body, state };
Ok(Response::from_parts(parts, body))
}
}
#[allow(clippy::large_enum_variant)]
enum State {
Decompressing(GzDecoder<Writer<BytesMut>>),
Last(Frame<Bytes>),
Done,
}
#[pin_project]
pub struct DecodedBody<B> {
#[pin]
body: B,
state: State,
}
impl<B> Body for DecodedBody<B>
where
B: Body<Data = Bytes>,
B::Error: Into<Box<dyn Error + Sync + Send>>,
{
type Data = Bytes;
type Error = Box<dyn Error + Sync + Send>;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let mut this = self.project();
loop {
match mem::replace(this.state, State::Done) {
State::Decompressing(mut decoder) => match this.body.as_mut().poll_frame(cx) {
Poll::Ready(Some(Ok(frame))) => match frame.data_ref() {
Some(data) => {
decoder.write_all(data)?;
if this.body.is_end_stream() {
decoder.try_finish()?;
if decoder.get_ref().get_ref().is_empty() {
return Poll::Ready(None);
} else {
let buf = decoder.get_mut().get_mut().split().freeze();
return Poll::Ready(Some(Ok(Frame::data(buf))));
}
} else {
decoder.flush()?;
let buf = decoder.get_mut().get_mut().split().freeze();
*this.state = State::Decompressing(decoder);
if !buf.is_empty() {
return Poll::Ready(Some(Ok(Frame::data(buf))));
}
}
}
None => {
decoder.try_finish()?;
if decoder.get_ref().get_ref().is_empty() {
return Poll::Ready(Some(Ok(frame)));
} else {
*this.state = State::Last(frame);
let buf = decoder.get_mut().get_mut().split().freeze();
return Poll::Ready(Some(Ok(Frame::data(buf))));
}
}
},
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e.into()))),
Poll::Ready(None) => {
decoder.try_finish()?;
if decoder.get_ref().get_ref().is_empty() {
return Poll::Ready(None);
} else {
let buf = decoder.get_mut().get_mut().split().freeze();
return Poll::Ready(Some(Ok(Frame::data(buf))));
}
}
Poll::Pending => {
*this.state = State::Decompressing(decoder);
return Poll::Pending;
}
},
State::Last(frame) => return Poll::Ready(Some(Ok(frame))),
State::Done => return this.body.as_mut().poll_frame(cx).map_err(Into::into),
}
}
}
fn is_end_stream(&self) -> bool {
match self.state {
State::Decompressing(_) | State::Done => self.body.is_end_stream(),
State::Last(_) => false,
}
}
fn size_hint(&self) -> SizeHint {
match &self.state {
State::Decompressing(_) => SizeHint::new(),
State::Last(frame) => match frame.data_ref() {
Some(data) => SizeHint::with_exact(data.len() as u64),
None => SizeHint::with_exact(0),
},
State::Done => self.body.size_hint(),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::service;
use flate2::write::GzEncoder;
use flate2::Compression;
use http_body_util::{BodyExt, Full};
use std::io::Write;
#[tokio::test]
async fn uncompressed() {
let body = "hello world";
let service = GzipLayer.layer(service::service_fn(|req: Request<()>| async move {
assert_eq!(req.headers().get(ACCEPT_ENCODING).unwrap(), "gzip");
let response = Response::builder()
.header(CONTENT_LENGTH, body.len().to_string())
.body(Full::new(Bytes::from(body)))
.unwrap();
Ok::<_, hyper::Error>(response)
}));
let response = service.call(Request::new(())).await.unwrap();
assert_eq!(
response.headers().get(CONTENT_LENGTH).unwrap(),
&*body.len().to_string(),
);
assert_eq!(response.headers().get(CONTENT_ENCODING), None);
let actual = response.into_body().collect().await.unwrap();
assert_eq!(actual.to_bytes(), body.as_bytes());
}
#[tokio::test]
async fn compressed() {
let body = "hello world";
let service = GzipLayer.layer(service::service_fn(|req: Request<()>| async move {
assert_eq!(req.headers().get(ACCEPT_ENCODING).unwrap(), "gzip");
let mut writer = GzEncoder::new(vec![], Compression::default());
writer.write_all(body.as_bytes()).unwrap();
let body = writer.finish().unwrap();
let response = Response::builder()
.header(CONTENT_LENGTH, body.len().to_string())
.header(CONTENT_ENCODING, "gzip")
.body(Full::new(Bytes::from(body)))
.unwrap();
Ok::<_, hyper::Error>(response)
}));
let response = service.call(Request::new(())).await.unwrap();
assert_eq!(response.headers().get(CONTENT_LENGTH), None);
assert_eq!(response.headers().get(CONTENT_ENCODING), None);
let actual = response.into_body().collect().await.unwrap();
assert_eq!(actual.to_bytes(), body.as_bytes());
}
#[tokio::test]
async fn custom_accept_encoding() {
let body = "hello world";
let encoding = "br";
let service = GzipLayer.layer(service::service_fn(|req: Request<()>| async move {
assert_eq!(req.headers().get(ACCEPT_ENCODING).unwrap(), encoding);
let response = Response::builder()
.header(CONTENT_LENGTH, body.len().to_string())
.header(CONTENT_ENCODING, encoding)
.body(Full::new(Bytes::from(body)))
.unwrap();
Ok::<_, hyper::Error>(response)
}));
let request = Request::builder()
.header(ACCEPT_ENCODING, encoding)
.body(())
.unwrap();
let response = service.call(request).await.unwrap();
assert_eq!(
response.headers().get(CONTENT_LENGTH).unwrap(),
&*body.len().to_string(),
);
assert_eq!(response.headers().get(CONTENT_ENCODING).unwrap(), encoding);
let actual = response.into_body().collect().await.unwrap();
assert_eq!(actual.to_bytes(), body.as_bytes());
}
}