use flate2::{Compress, Compression, FlushCompress};
use crate::error::Error;
use crate::websocket_extension::PerMessageDeflateConfig;
const DEFLATE_TRAILER: [u8; 4] = [0x00, 0x00, 0xFF, 0xFF];
pub struct Compressor {
compress: Compress,
level: u32,
reset_after_message: bool,
}
impl Compressor {
pub fn new(config: &PerMessageDeflateConfig, is_client: bool) -> Self {
let level = 6u32; let reset_after_message = if is_client {
config.client_no_context_takeover
} else {
config.server_no_context_takeover
};
Self {
compress: Compress::new(Compression::new(level), false), level,
reset_after_message,
}
}
pub fn set_level(&mut self, level: u32) {
self.level = level.min(9);
}
pub fn compress(&mut self, data: &[u8]) -> Result<Vec<u8>, Error> {
let mut compressed = Vec::new();
let mut output_buf = [0u8; 8192];
let mut input_pos = 0;
while input_pos < data.len() {
let before_in = self.compress.total_in();
let before_out = self.compress.total_out();
self.compress
.compress(&data[input_pos..], &mut output_buf, FlushCompress::None)
.map_err(|e| Error::invalid_data(format!("compression failed: {}", e)))?;
let consumed = (self.compress.total_in() - before_in) as usize;
let produced = (self.compress.total_out() - before_out) as usize;
input_pos += consumed;
if produced > 0 {
compressed.extend_from_slice(&output_buf[..produced]);
}
}
for _ in 0..10 {
let before_out = self.compress.total_out();
self.compress
.compress(&[], &mut output_buf, FlushCompress::Sync)
.map_err(|e| Error::invalid_data(format!("compression flush failed: {}", e)))?;
let produced = (self.compress.total_out() - before_out) as usize;
if produced > 0 {
compressed.extend_from_slice(&output_buf[..produced]);
}
if produced < output_buf.len() {
break;
}
}
if compressed.ends_with(&DEFLATE_TRAILER) {
compressed.truncate(compressed.len() - 4);
}
if self.reset_after_message {
self.compress = Compress::new(Compression::new(self.level), false);
}
Ok(compressed)
}
}
pub struct Decompressor {
decompress: flate2::Decompress,
reset_after_message: bool,
}
const DECOMPRESS_CHUNK_SIZE: usize = 8192;
impl Decompressor {
pub fn new(config: &PerMessageDeflateConfig, is_client: bool) -> Self {
let reset_after_message = if is_client {
config.server_no_context_takeover
} else {
config.client_no_context_takeover
};
Self {
decompress: flate2::Decompress::new(false), reset_after_message,
}
}
pub fn decompress(&mut self, data: &[u8], max_size: usize) -> Result<Vec<u8>, Error> {
let mut input = data.to_vec();
input.extend_from_slice(&DEFLATE_TRAILER);
let mut decompressed = Vec::new();
let mut output_buf = [0u8; DECOMPRESS_CHUNK_SIZE];
let mut input_pos = 0;
let max_iterations = (max_size / DECOMPRESS_CHUNK_SIZE) + 100;
let mut iterations = 0;
loop {
iterations += 1;
if iterations > max_iterations {
if self.reset_after_message {
self.decompress.reset(false);
}
return Err(Error::invalid_data(
"decompression exceeded iteration limit",
));
}
let before_in = self.decompress.total_in();
let before_out = self.decompress.total_out();
let status = self
.decompress
.decompress(
&input[input_pos..],
&mut output_buf,
flate2::FlushDecompress::Sync,
)
.map_err(|e| Error::invalid_data(format!("decompression failed: {}", e)))?;
let consumed = (self.decompress.total_in() - before_in) as usize;
let produced = (self.decompress.total_out() - before_out) as usize;
input_pos += consumed;
if produced > 0 {
if decompressed.len() + produced > max_size {
if self.reset_after_message {
self.decompress.reset(false);
}
return Err(Error::invalid_data(format!(
"decompressed size exceeds maximum limit of {} bytes",
max_size
)));
}
decompressed.extend_from_slice(&output_buf[..produced]);
}
match status {
flate2::Status::StreamEnd => break,
flate2::Status::Ok | flate2::Status::BufError => {
if input_pos >= input.len() && produced == 0 {
break;
}
}
}
}
if self.reset_after_message {
self.decompress.reset(false);
}
Ok(decompressed)
}
pub fn reset(&mut self) {
self.decompress.reset(false);
}
}
pub struct PerMessageDeflate {
compressor: Compressor,
decompressor: Decompressor,
config: PerMessageDeflateConfig,
}
impl PerMessageDeflate {
pub fn new_client(config: PerMessageDeflateConfig) -> Self {
Self {
compressor: Compressor::new(&config, true),
decompressor: Decompressor::new(&config, true),
config,
}
}
pub fn new_server(config: PerMessageDeflateConfig) -> Self {
Self {
compressor: Compressor::new(&config, false),
decompressor: Decompressor::new(&config, false),
config,
}
}
pub fn config(&self) -> &PerMessageDeflateConfig {
&self.config
}
pub fn compress(&mut self, data: &[u8]) -> Result<Vec<u8>, Error> {
self.compressor.compress(data)
}
pub fn decompress(&mut self, data: &[u8], max_size: usize) -> Result<Vec<u8>, Error> {
self.decompressor.decompress(data, max_size)
}
pub fn set_compression_level(&mut self, level: u32) {
self.compressor.set_level(level);
}
pub fn should_compress(&self, data: &[u8], threshold: usize) -> bool {
data.len() >= threshold
}
}