use noflate::deflate::{Decoder, Encoder};
use crate::error::Error;
use crate::websocket_extension::PerMessageDeflateConfig;
const DEFLATE_TRAILER: [u8; 4] = [0x00, 0x00, 0xFF, 0xFF];
const DECOMPRESS_FEED_CHUNK: usize = 8192;
pub struct Compressor {
encoder: Encoder,
reset_after_message: bool,
}
impl Compressor {
pub fn new(config: &PerMessageDeflateConfig, is_client: bool) -> Self {
let reset_after_message = if is_client {
config.client_no_context_takeover
} else {
config.server_no_context_takeover
};
Self {
encoder: Encoder::new(),
reset_after_message,
}
}
pub fn compress(&mut self, data: &[u8]) -> Result<Vec<u8>, Error> {
self.encoder
.feed(data)
.map_err(|e| Error::invalid_data(format!("compression failed: {}", e)))?;
self.encoder
.sync_flush()
.map_err(|e| Error::invalid_data(format!("compression flush failed: {}", e)))?;
let mut out = self.encoder.output().to_vec();
self.encoder.advance(out.len());
if out.ends_with(&DEFLATE_TRAILER) {
out.truncate(out.len() - 4);
}
if self.reset_after_message {
self.encoder.reset_history();
}
Ok(out)
}
}
pub struct Decompressor {
decoder: Decoder,
reset_after_message: bool,
}
impl Decompressor {
pub fn new(config: &PerMessageDeflateConfig, is_client: bool) -> Self {
let reset_after_message = if is_client {
config.server_no_context_takeover
} else {
config.client_no_context_takeover
};
Self {
decoder: Decoder::new(),
reset_after_message,
}
}
pub fn decompress(&mut self, data: &[u8], max_size: usize) -> Result<Vec<u8>, Error> {
let mut decompressed = Vec::new();
let feed_chain = data
.chunks(DECOMPRESS_FEED_CHUNK)
.chain(core::iter::once(DEFLATE_TRAILER.as_slice()));
for chunk in feed_chain {
self.decoder
.feed(chunk)
.map_err(|e| Error::invalid_data(format!("decompression failed: {}", e)))?;
let produced = self.decoder.output();
if decompressed.len().saturating_add(produced.len()) > max_size {
self.reset_if_needed();
return Err(Error::invalid_data(format!(
"decompressed size exceeds maximum limit of {} bytes",
max_size
)));
}
let produced_len = produced.len();
decompressed.extend_from_slice(produced);
self.decoder.advance(produced_len);
}
if self.reset_after_message {
self.decoder = Decoder::new();
}
Ok(decompressed)
}
fn reset_if_needed(&mut self) {
if self.reset_after_message {
self.decoder = Decoder::new();
}
}
}
pub struct PerMessageDeflate {
compressor: Compressor,
decompressor: Decompressor,
config: PerMessageDeflateConfig,
}
impl PerMessageDeflate {
pub fn new_client(config: PerMessageDeflateConfig) -> Self {
Self {
compressor: Compressor::new(&config, true),
decompressor: Decompressor::new(&config, true),
config,
}
}
pub fn new_server(config: PerMessageDeflateConfig) -> Self {
Self {
compressor: Compressor::new(&config, false),
decompressor: Decompressor::new(&config, false),
config,
}
}
pub fn config(&self) -> &PerMessageDeflateConfig {
&self.config
}
pub fn compress(&mut self, data: &[u8]) -> Result<Vec<u8>, Error> {
self.compressor.compress(data)
}
pub fn decompress(&mut self, data: &[u8], max_size: usize) -> Result<Vec<u8>, Error> {
self.decompressor.decompress(data, max_size)
}
pub fn should_compress(&self, data: &[u8], threshold: usize) -> bool {
data.len() >= threshold
}
}