use std::io::Write;
use bytes::Bytes;
use flate2::write::GzEncoder;
use flate2::Compression as CompressionLevel;
use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, VARY};
use http::HeaderValue;
use crate::body::RespBody;
use crate::constants::TEXT_EVENT_STREAM;
use crate::error::Result;
use crate::middleware::{DuplicatePolicy, Middleware, Next, Request};
use crate::response::{into_body_bytes, Response};
use crate::router::BoxFuture;
const GZIP: &str = "gzip";
const DEFAULT_MAXIMUM_SIZE: usize = 8 * 1024 * 1024;
pub struct Compression {
gzip: bool,
minimum_size: usize,
maximum_size: usize,
}
impl Compression {
pub fn new() -> Self {
Self {
gzip: false,
minimum_size: 0,
maximum_size: DEFAULT_MAXIMUM_SIZE,
}
}
pub fn gzip(mut self) -> Self {
self.gzip = true;
self
}
pub fn minimum_size(mut self, bytes: usize) -> Self {
self.minimum_size = bytes;
self
}
pub fn maximum_size(mut self, bytes: usize) -> Self {
self.maximum_size = bytes;
self
}
}
impl Default for Compression {
fn default() -> Self {
Self::new()
}
}
impl Middleware for Compression {
fn handle(&self, request: Request, next: Next) -> BoxFuture<'static, Result<Response>> {
let gzip_enabled = self.gzip;
let minimum_size = self.minimum_size;
let maximum_size = self.maximum_size;
let accepts_gzip = request
.headers()
.get(ACCEPT_ENCODING)
.and_then(|value| value.to_str().ok())
.map(|value| value.to_ascii_lowercase().contains(GZIP))
.unwrap_or(false);
Box::pin(async move {
let mut response = next.run(request).await?;
if gzip_enabled && !is_event_stream(&response) {
append_vary_accept_encoding(response.headers_mut());
}
if !gzip_enabled
|| !accepts_gzip
|| response.headers().contains_key(CONTENT_ENCODING)
|| is_event_stream(&response)
{
return Ok(response);
}
if let Some(length) = content_length(response.headers()) {
if length > maximum_size || length < minimum_size {
return Ok(response);
}
}
let (mut parts, bytes) = into_body_bytes(response).await;
if bytes.len() < minimum_size || bytes.len() > maximum_size {
return Ok(Response::from_parts(parts, RespBody::new(bytes)));
}
match gzip(&bytes) {
Ok(compressed) => {
parts
.headers
.insert(CONTENT_ENCODING, HeaderValue::from_static(GZIP));
if let Ok(length) = HeaderValue::from_str(&compressed.len().to_string()) {
parts.headers.insert(CONTENT_LENGTH, length);
}
Ok(Response::from_parts(
parts,
RespBody::new(Bytes::from(compressed)),
))
}
Err(_) => Ok(Response::from_parts(parts, RespBody::new(bytes))),
}
})
}
fn name(&self) -> &'static str {
"Compression"
}
fn duplicate_policy(&self) -> DuplicatePolicy {
DuplicatePolicy::Reject
}
}
fn append_vary_accept_encoding(headers: &mut http::HeaderMap) {
let already_present = headers
.get_all(VARY)
.iter()
.filter_map(|value| value.to_str().ok())
.any(|value| value.to_ascii_lowercase().contains("accept-encoding"));
if !already_present {
headers.append(VARY, HeaderValue::from_static("Accept-Encoding"));
}
}
fn content_length(headers: &http::HeaderMap) -> Option<usize> {
headers
.get(CONTENT_LENGTH)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.trim().parse::<usize>().ok())
}
fn is_event_stream(response: &Response) -> bool {
response
.headers()
.get(CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map(|value| value.starts_with(TEXT_EVENT_STREAM))
.unwrap_or(false)
}
fn gzip(data: &[u8]) -> std::io::Result<Vec<u8>> {
let mut encoder = GzEncoder::new(
Vec::with_capacity(data.len() / 2 + 16),
CompressionLevel::default(),
);
encoder.write_all(data)?;
encoder.finish()
}
#[cfg(test)]
mod tests {
use super::*;
fn response_with_content_type(value: &'static str) -> Response {
let mut response = http::Response::new(RespBody::new(Bytes::new()));
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static(value));
response
}
#[test]
fn event_stream_is_detected_and_bypasses_compression() {
assert!(is_event_stream(&response_with_content_type(
TEXT_EVENT_STREAM
)));
assert!(!is_event_stream(&response_with_content_type(
"application/json"
)));
assert!(!is_event_stream(&http::Response::new(RespBody::new(
Bytes::new()
))));
}
#[test]
fn content_length_parses_only_valid_values() {
let mut headers = http::HeaderMap::new();
assert_eq!(content_length(&headers), None);
headers.insert(CONTENT_LENGTH, HeaderValue::from_static("1024"));
assert_eq!(content_length(&headers), Some(1024));
headers.insert(CONTENT_LENGTH, HeaderValue::from_static("not-a-number"));
assert_eq!(content_length(&headers), None);
}
#[test]
fn gzip_round_trips_through_flate2() {
let original = b"hello world, this is a test that compresses well. ".repeat(20);
let compressed = gzip(&original).expect("gzip must succeed");
assert!(compressed.len() < original.len());
}
}