use std::io::Read;
use std::io::Write;
use bytes::Bytes;
use flate2::read::DeflateDecoder;
use flate2::write::DeflateEncoder;
use flate2::Compression;
use super::PayloadCodec;
use crate::error::{MqttError, Result};
const DEFAULT_MAX_DECOMPRESSED_SIZE: usize = 10 * 1024 * 1024;
pub struct DeflateCodec {
level: Compression,
min_size: usize,
max_decompressed_size: usize,
}
impl Default for DeflateCodec {
fn default() -> Self {
Self::new()
}
}
impl DeflateCodec {
#[must_use]
pub fn new() -> Self {
Self {
level: Compression::default(),
min_size: 128,
max_decompressed_size: DEFAULT_MAX_DECOMPRESSED_SIZE,
}
}
#[must_use]
pub fn with_level(mut self, level: u32) -> Self {
self.level = Compression::new(level);
self
}
#[must_use]
pub fn with_min_size(mut self, size: usize) -> Self {
self.min_size = size;
self
}
#[must_use]
pub fn with_max_decompressed_size(mut self, size: usize) -> Self {
self.max_decompressed_size = size;
self
}
}
impl PayloadCodec for DeflateCodec {
fn name(&self) -> &'static str {
"deflate"
}
fn content_type(&self) -> &'static str {
"application/x-deflate"
}
fn encode(&self, payload: &[u8]) -> Result<Bytes> {
let mut encoder = DeflateEncoder::new(Vec::new(), self.level);
encoder
.write_all(payload)
.map_err(|e| MqttError::ProtocolError(format!("deflate encode failed: {e}")))?;
let compressed = encoder
.finish()
.map_err(|e| MqttError::ProtocolError(format!("deflate finish failed: {e}")))?;
Ok(Bytes::from(compressed))
}
fn decode(&self, payload: &[u8]) -> Result<Bytes> {
let limit = self.max_decompressed_size;
let mut decoder =
DeflateDecoder::new(payload).take(u64::try_from(limit + 1).unwrap_or(u64::MAX));
let mut decompressed = Vec::new();
decoder
.read_to_end(&mut decompressed)
.map_err(|e| MqttError::ProtocolError(format!("deflate decode failed: {e}")))?;
if decompressed.len() > limit {
return Err(MqttError::ProtocolError(format!(
"deflate decompressed size {} exceeds limit {limit}",
decompressed.len()
)));
}
Ok(Bytes::from(decompressed))
}
fn min_size_threshold(&self) -> usize {
self.min_size
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deflate_roundtrip() {
let codec = DeflateCodec::new();
let original = b"Hello, World! This is a test payload for compression.";
let encoded = codec.encode(original).unwrap();
let decoded = codec.decode(&encoded).unwrap();
assert_eq!(&decoded[..], original);
}
#[test]
fn test_deflate_compression_ratio() {
let codec = DeflateCodec::new();
let original: Vec<u8> = std::iter::repeat_n(b'A', 1000).collect();
let encoded = codec.encode(&original).unwrap();
assert!(encoded.len() < original.len());
}
#[test]
fn test_deflate_with_level() {
let codec = DeflateCodec::new().with_level(9);
let original = b"Hello, World! This is a test payload for compression.";
let encoded = codec.encode(original).unwrap();
let decoded = codec.decode(&encoded).unwrap();
assert_eq!(&decoded[..], original);
}
#[test]
fn test_deflate_content_type() {
let codec = DeflateCodec::new();
assert_eq!(codec.content_type(), "application/x-deflate");
}
#[test]
fn test_deflate_min_size_threshold() {
let codec = DeflateCodec::new().with_min_size(256);
assert_eq!(codec.min_size_threshold(), 256);
}
#[test]
fn test_deflate_should_encode() {
let codec = DeflateCodec::new().with_min_size(100);
let small_payload = vec![0u8; 50];
let large_payload = vec![0u8; 150];
assert!(!codec.should_encode(&small_payload));
assert!(codec.should_encode(&large_payload));
}
#[test]
fn test_deflate_invalid_data() {
let codec = DeflateCodec::new();
let invalid = b"\xff\xfe\xfd\xfc";
let result = codec.decode(invalid);
assert!(result.is_err());
}
}