use crate::{CompressionAlgorithm, Levels};
use futures_lite::io::BufReader;
use trillium_client::{
ClientHandler, Conn, ConnExt,
KnownHeaderName::{AcceptEncoding, ContentEncoding},
Result,
};
const ACCEPT_ENCODING: &str = "zstd, br, gzip";
#[derive(Clone, Copy, Debug, Default)]
pub struct Compression {
default_encoding: Option<CompressionAlgorithm>,
}
impl Compression {
pub fn new() -> Self {
Self::default()
}
pub fn with_default_encoding(mut self, encoding: CompressionAlgorithm) -> Self {
self.default_encoding = Some(encoding);
self
}
}
impl ClientHandler for Compression {
async fn run(&self, conn: &mut Conn) -> Result<()> {
conn.request_headers_mut()
.try_insert(AcceptEncoding, ACCEPT_ENCODING);
let Some(encoding) = conn
.state::<CompressionAlgorithm>()
.copied()
.or(self.default_encoding)
else {
return Ok(());
};
if encoding == CompressionAlgorithm::Identity
|| conn.request_headers().get_str(ContentEncoding).is_some()
{
return Ok(());
}
let Some(body) = conn.take_request_body() else {
return Ok(());
};
let (body, encoded) = encoding.encode(body, Levels::default()).await;
conn.set_request_body(body);
if encoded {
conn.request_headers_mut()
.insert(ContentEncoding, encoding.as_str());
}
Ok(())
}
async fn after_response(&self, conn: &mut Conn) -> Result<()> {
let Some(encoding) = conn
.response_headers()
.get_str(ContentEncoding)
.and_then(|encoding| encoding.parse::<CompressionAlgorithm>().ok())
.filter(|&encoding| encoding != CompressionAlgorithm::Identity)
else {
return Ok(());
};
let Some(body) = conn.take_response_body() else {
return Ok(());
};
let decoded = encoding.decode_streaming(BufReader::new(body));
conn.set_response_body(decoded)
.response_headers_mut()
.remove(ContentEncoding);
Ok(())
}
}