use async_compression::tokio::bufread::{BrotliEncoder, GzipEncoder, ZlibEncoder, ZstdEncoder};
pub use async_compression::Level;
use tokio_util::io::StreamReader;
use crate::http::header::{HeaderMap, ACCEPT_ENCODING, CONTENT_ENCODING};
use crate::http::{HeaderValue, StatusCode};
use crate::{async_trait, status, Context, Middleware, Next, Result};
#[derive(Debug, Copy, Clone)]
pub struct Compress(pub Level);
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum Encoding {
Gzip,
Deflate,
Brotli,
Zstd,
Identity,
}
impl Encoding {
fn parse(s: &str) -> Result<Option<Encoding>> {
match s {
"gzip" => Ok(Some(Encoding::Gzip)),
"deflate" => Ok(Some(Encoding::Deflate)),
"br" => Ok(Some(Encoding::Brotli)),
"zstd" => Ok(Some(Encoding::Zstd)),
"identity" => Ok(Some(Encoding::Identity)),
"*" => Ok(None),
_ => Err(status!(
StatusCode::BAD_REQUEST,
format!("unknown encoding: {}", s),
true
)),
}
}
fn to_header_value(self) -> HeaderValue {
match self {
Encoding::Gzip => HeaderValue::from_str("gzip").unwrap(),
Encoding::Deflate => HeaderValue::from_str("deflate").unwrap(),
Encoding::Brotli => HeaderValue::from_str("br").unwrap(),
Encoding::Zstd => HeaderValue::from_str("zstd").unwrap(),
Encoding::Identity => HeaderValue::from_str("identity").unwrap(),
}
}
}
fn select_encoding(headers: &HeaderMap) -> Result<Option<Encoding>> {
let mut preferred_encoding = None;
let mut max_qval = 0.0;
for (encoding, qval) in accept_encodings(headers)? {
if qval > max_qval {
preferred_encoding = encoding;
max_qval = qval;
}
}
Ok(preferred_encoding)
}
fn accept_encodings(headers: &HeaderMap) -> Result<Vec<(Option<Encoding>, f32)>> {
headers
.get_all(ACCEPT_ENCODING)
.iter()
.map(|hval| {
hval.to_str()
.map_err(|err| status!(StatusCode::BAD_REQUEST, err, true))
})
.collect::<Result<Vec<&str>>>()?
.iter()
.flat_map(|s| s.split(',').map(str::trim))
.filter_map(|v| {
let pair: Vec<&str> = v.splitn(2, ";q=").collect();
if pair.is_empty() {
return None;
}
let encoding = match Encoding::parse(pair[0]) {
Ok(encoding) => encoding,
Err(_) => return None, };
let qval = if pair.len() == 1 {
1.0
} else {
match pair[1].parse::<f32>() {
Ok(f) => f,
Err(err) => return Some(Err(status!(StatusCode::BAD_REQUEST, err, true))),
}
};
Some(Ok((encoding, qval)))
})
.collect::<Result<Vec<(Option<Encoding>, f32)>>>()
}
impl Default for Compress {
fn default() -> Self {
Self(Level::Default)
}
}
#[async_trait(?Send)]
impl<'a, S> Middleware<'a, S> for Compress {
#[allow(clippy::trivially_copy_pass_by_ref)]
#[inline]
async fn handle(&'a self, ctx: &'a mut Context<S>, next: Next<'a>) -> Result {
next.await?;
let level = self.0;
let best_encoding = select_encoding(&ctx.req.headers)?;
let body = std::mem::take(&mut ctx.resp.body);
let content_encoding = match best_encoding {
None | Some(Encoding::Gzip) => {
ctx.resp
.write_reader(GzipEncoder::with_quality(StreamReader::new(body), level));
Encoding::Gzip.to_header_value()
}
Some(Encoding::Deflate) => {
ctx.resp
.write_reader(ZlibEncoder::with_quality(StreamReader::new(body), level));
Encoding::Deflate.to_header_value()
}
Some(Encoding::Brotli) => {
ctx.resp
.write_reader(BrotliEncoder::with_quality(StreamReader::new(body), level));
Encoding::Brotli.to_header_value()
}
Some(Encoding::Zstd) => {
ctx.resp
.write_reader(ZstdEncoder::with_quality(StreamReader::new(body), level));
Encoding::Zstd.to_header_value()
}
Some(Encoding::Identity) => {
ctx.resp.body = body;
Encoding::Identity.to_header_value()
}
};
ctx.resp.headers.append(CONTENT_ENCODING, content_encoding);
Ok(())
}
}
#[cfg(all(test, feature = "tcp", feature = "file"))]
mod tests {
use std::io;
use std::pin::Pin;
use std::task::{self, Poll};
use bytes::Bytes;
use futures::Stream;
use tokio::task::spawn;
use crate::body::DispositionType::*;
use crate::compress::{Compress, Level};
use crate::http::header::ACCEPT_ENCODING;
use crate::http::StatusCode;
use crate::preload::*;
use crate::{async_trait, App, Context, Middleware, Next};
struct Consumer<S> {
counter: usize,
stream: S,
assert_counter: usize,
}
impl<S> Stream for Consumer<S>
where
S: 'static + Send + Send + Unpin + Stream<Item = io::Result<Bytes>>,
{
type Item = io::Result<Bytes>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
self.counter += bytes.len();
Poll::Ready(Some(Ok(bytes)))
}
Poll::Ready(None) => {
assert_eq!(self.assert_counter, self.counter);
Poll::Ready(None)
}
poll => poll,
}
}
}
struct Assert(usize);
#[async_trait(?Send)]
impl<'a, S> Middleware<'a, S> for Assert {
async fn handle(&'a self, ctx: &'a mut Context<S>, next: Next<'a>) -> crate::Result {
next.await?;
let body = std::mem::take(&mut ctx.resp.body);
ctx.resp.write_stream(Consumer {
counter: 0,
stream: body,
assert_counter: self.0,
});
Ok(())
}
}
async fn end(ctx: &mut Context) -> crate::Result {
ctx.write_file("../assets/welcome.html", Inline).await
}
#[tokio::test]
async fn compress() -> Result<(), Box<dyn std::error::Error>> {
let app = App::new()
.gate(Assert(202)) .gate(Compress(Level::Fastest))
.gate(Assert(236)) .end(end);
let (addr, server) = app.run()?;
spawn(server);
let client = reqwest::Client::builder().gzip(true).build()?;
let resp = client
.get(&format!("http://{}", addr))
.header(ACCEPT_ENCODING, "gzip")
.send()
.await?;
assert_eq!(StatusCode::OK, resp.status());
assert_eq!(236, resp.text().await?.len());
Ok(())
}
}