use super::parser::{ParseResult, ParsedData, Parser};
use crate::exception::Error;
use async_trait::async_trait;
use brotli::Decompressor;
use bytes::Bytes;
use flate2::read::{DeflateDecoder, GzDecoder};
use http::HeaderMap;
use std::io::Read;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionEncoding {
Gzip,
Brotli,
Deflate,
}
impl CompressionEncoding {
pub fn from_header(encoding: &str) -> Option<Self> {
match encoding.to_lowercase().as_str() {
"gzip" => Some(Self::Gzip),
"br" => Some(Self::Brotli),
"deflate" => Some(Self::Deflate),
_ => None,
}
}
const DEFAULT_MAX_OUTPUT_SIZE: u64 = 100 * 1024 * 1024;
fn decompress(&self, data: &[u8]) -> ParseResult<Vec<u8>> {
self.decompress_with_limit(data, Self::DEFAULT_MAX_OUTPUT_SIZE)
}
fn decompress_with_limit(&self, data: &[u8], max_output_size: u64) -> ParseResult<Vec<u8>> {
match self {
Self::Gzip => {
let decoder = GzDecoder::new(data);
let mut limited = decoder.take(max_output_size);
let mut decompressed = Vec::new();
limited
.read_to_end(&mut decompressed)
.map_err(|e| Error::ParseError(format!("Gzip decompression error: {}", e)))?;
Ok(decompressed)
}
Self::Brotli => {
let decoder = Decompressor::new(data, 4096);
let mut limited = decoder.take(max_output_size);
let mut decompressed = Vec::new();
limited
.read_to_end(&mut decompressed)
.map_err(|e| Error::ParseError(format!("Brotli decompression error: {}", e)))?;
Ok(decompressed)
}
Self::Deflate => {
let decoder = DeflateDecoder::new(data);
let mut limited = decoder.take(max_output_size);
let mut decompressed = Vec::new();
limited.read_to_end(&mut decompressed).map_err(|e| {
Error::ParseError(format!("Deflate decompression error: {}", e))
})?;
Ok(decompressed)
}
}
}
}
pub struct CompressedParser {
inner: Arc<dyn Parser>,
}
impl CompressedParser {
pub fn new(inner: Arc<dyn Parser>) -> Self {
Self { inner }
}
pub fn decompress_if_needed(
&self,
content_encoding: Option<&str>,
body: Bytes,
) -> ParseResult<Bytes> {
if let Some(encoding_str) = content_encoding
&& let Some(encoding) = CompressionEncoding::from_header(encoding_str)
{
let decompressed = encoding.decompress(&body)?;
return Ok(Bytes::from(decompressed));
}
Ok(body)
}
}
#[async_trait]
impl Parser for CompressedParser {
fn media_types(&self) -> Vec<String> {
self.inner.media_types()
}
async fn parse(
&self,
content_type: Option<&str>,
body: Bytes,
headers: &HeaderMap,
) -> ParseResult<ParsedData> {
let encoding = headers
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.unwrap_or("identity");
let decompressed = match encoding {
"gzip" => CompressionEncoding::Gzip
.decompress(&body)
.map(Bytes::from)?,
"deflate" => CompressionEncoding::Deflate
.decompress(&body)
.map(Bytes::from)?,
"br" => CompressionEncoding::Brotli
.decompress(&body)
.map(Bytes::from)?,
_ => body,
};
self.inner.parse(content_type, decompressed, headers).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parsers::json::JSONParser;
use flate2::Compression;
use flate2::write::{DeflateEncoder, GzEncoder};
use std::io::Write;
#[test]
fn test_compression_encoding_from_header() {
assert_eq!(
CompressionEncoding::from_header("gzip"),
Some(CompressionEncoding::Gzip)
);
assert_eq!(
CompressionEncoding::from_header("GZIP"),
Some(CompressionEncoding::Gzip)
);
assert_eq!(
CompressionEncoding::from_header("br"),
Some(CompressionEncoding::Brotli)
);
assert_eq!(
CompressionEncoding::from_header("deflate"),
Some(CompressionEncoding::Deflate)
);
assert_eq!(CompressionEncoding::from_header("unknown"), None);
}
#[tokio::test]
async fn test_compressed_parser_gzip() {
use http::HeaderMap;
let json_data = r#"{"name":"John","age":30}"#;
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(json_data.as_bytes()).unwrap();
let compressed = encoder.finish().unwrap();
let mut headers = HeaderMap::new();
headers.insert("content-encoding", "gzip".parse().unwrap());
let parser = CompressedParser::new(Arc::new(JSONParser::new()));
let result = parser
.parse(Some("application/json"), Bytes::from(compressed), &headers)
.await
.unwrap();
match result {
ParsedData::Json(value) => {
assert_eq!(value["name"], "John");
assert_eq!(value["age"], 30);
}
_ => panic!("Expected JSON data"),
}
}
#[tokio::test]
async fn test_compressed_parser_deflate() {
use http::HeaderMap;
let json_data = r#"{"name":"Alice","city":"NYC"}"#;
let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
encoder.write_all(json_data.as_bytes()).unwrap();
let compressed = encoder.finish().unwrap();
let mut headers = HeaderMap::new();
headers.insert("content-encoding", "deflate".parse().unwrap());
let parser = CompressedParser::new(Arc::new(JSONParser::new()));
let result = parser
.parse(Some("application/json"), Bytes::from(compressed), &headers)
.await
.unwrap();
match result {
ParsedData::Json(value) => {
assert_eq!(value["name"], "Alice");
assert_eq!(value["city"], "NYC");
}
_ => panic!("Expected JSON data"),
}
}
#[tokio::test]
async fn test_compressed_parser_brotli() {
use http::HeaderMap;
let json_data = r#"{"product":"widget","price":19.99}"#;
let mut compressed = Vec::new();
{
let mut encoder = brotli::CompressorWriter::new(
&mut compressed,
4096,
11, 22, );
encoder.write_all(json_data.as_bytes()).unwrap();
}
let mut headers = HeaderMap::new();
headers.insert("content-encoding", "br".parse().unwrap());
let parser = CompressedParser::new(Arc::new(JSONParser::new()));
let result = parser
.parse(Some("application/json"), Bytes::from(compressed), &headers)
.await
.unwrap();
match result {
ParsedData::Json(value) => {
assert_eq!(value["product"], "widget");
assert_eq!(value["price"], 19.99);
}
_ => panic!("Expected JSON data"),
}
}
#[tokio::test]
async fn test_compressed_parser_uncompressed() {
use http::HeaderMap;
let json_data = r#"{"uncompressed":true}"#;
let headers = HeaderMap::new();
let parser = CompressedParser::new(Arc::new(JSONParser::new()));
let result = parser
.parse(Some("application/json"), Bytes::from(json_data), &headers)
.await
.unwrap();
match result {
ParsedData::Json(value) => {
assert!(value["uncompressed"].as_bool().unwrap());
}
_ => panic!("Expected JSON data"),
}
}
#[test]
fn test_compressed_parser_media_types() {
let parser = CompressedParser::new(Arc::new(JSONParser::new()));
let media_types = parser.media_types();
assert!(media_types.contains(&"application/json".to_string()));
}
#[tokio::test]
async fn test_gzip_decompression() {
let original = b"Hello, World!";
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let decompressed = CompressionEncoding::Gzip.decompress(&compressed).unwrap();
assert_eq!(decompressed, original);
}
#[tokio::test]
async fn test_deflate_decompression() {
let original = b"Test data for deflate";
let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let decompressed = CompressionEncoding::Deflate
.decompress(&compressed)
.unwrap();
assert_eq!(decompressed, original);
}
#[tokio::test]
async fn test_brotli_decompression() {
let original = b"Brotli compressed content";
let mut compressed = Vec::new();
{
let mut encoder = brotli::CompressorWriter::new(&mut compressed, 4096, 11, 22);
encoder.write_all(original).unwrap();
}
let decompressed = CompressionEncoding::Brotli.decompress(&compressed).unwrap();
assert_eq!(decompressed, original);
}
#[tokio::test]
async fn test_invalid_gzip_data() {
let invalid_data = b"not gzip data";
let result = CompressionEncoding::Gzip.decompress(invalid_data);
assert!(result.is_err());
}
#[tokio::test]
async fn test_invalid_brotli_data() {
let invalid_data = b"not brotli data";
let result = CompressionEncoding::Brotli.decompress(invalid_data);
assert!(result.is_err());
}
#[test]
fn test_decompress_if_needed_gzip() {
let original = b"Hello, World!";
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let parser = CompressedParser::new(Arc::new(JSONParser::new()));
let result = parser
.decompress_if_needed(Some("gzip"), Bytes::from(compressed))
.unwrap();
assert_eq!(result.as_ref(), original);
}
#[test]
fn test_decompress_if_needed_deflate() {
let original = b"Test data for deflate";
let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let parser = CompressedParser::new(Arc::new(JSONParser::new()));
let result = parser
.decompress_if_needed(Some("deflate"), Bytes::from(compressed))
.unwrap();
assert_eq!(result.as_ref(), original);
}
#[test]
fn test_decompress_if_needed_brotli() {
let original = b"Brotli compressed content";
let mut compressed = Vec::new();
{
let mut encoder = brotli::CompressorWriter::new(&mut compressed, 4096, 11, 22);
encoder.write_all(original).unwrap();
}
let parser = CompressedParser::new(Arc::new(JSONParser::new()));
let result = parser
.decompress_if_needed(Some("br"), Bytes::from(compressed))
.unwrap();
assert_eq!(result.as_ref(), original);
}
#[test]
fn test_decompress_if_needed_no_encoding() {
let original = b"Uncompressed data";
let parser = CompressedParser::new(Arc::new(JSONParser::new()));
let result = parser
.decompress_if_needed(None, Bytes::from(original.as_slice()))
.unwrap();
assert_eq!(result.as_ref(), original);
}
#[test]
fn test_decompress_if_needed_unknown_encoding() {
let original = b"Unknown encoding";
let parser = CompressedParser::new(Arc::new(JSONParser::new()));
let result = parser
.decompress_if_needed(Some("unknown"), Bytes::from(original.as_slice()))
.unwrap();
assert_eq!(result.as_ref(), original);
}
#[test]
fn test_decompress_if_needed_case_insensitive() {
let original = b"Case test";
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(original).unwrap();
let compressed = encoder.finish().unwrap();
let parser = CompressedParser::new(Arc::new(JSONParser::new()));
let result = parser
.decompress_if_needed(Some("GZIP"), Bytes::from(compressed.clone()))
.unwrap();
assert_eq!(result.as_ref(), original);
let result = parser
.decompress_if_needed(Some("GzIp"), Bytes::from(compressed))
.unwrap();
assert_eq!(result.as_ref(), original);
}
}