1#[cfg(any(feature = "zstd", feature = "snappy", feature = "lz4"))]
13use crous_core::error::CrousError;
14use crous_core::error::Result;
15use crous_core::wire::CompressionType;
16
17pub trait Compressor: Send + Sync {
31 fn compression_type(&self) -> CompressionType;
33
34 fn compress(&self, input: &[u8]) -> Result<Vec<u8>>;
36
37 fn decompress(&self, input: &[u8], max_output: usize) -> Result<Vec<u8>>;
40
41 fn name(&self) -> &'static str;
43}
44
45pub struct NoCompression;
47
48impl Compressor for NoCompression {
49 fn compression_type(&self) -> CompressionType {
50 CompressionType::None
51 }
52
53 fn compress(&self, input: &[u8]) -> Result<Vec<u8>> {
54 Ok(input.to_vec())
55 }
56
57 fn decompress(&self, input: &[u8], _max_output: usize) -> Result<Vec<u8>> {
58 Ok(input.to_vec())
59 }
60
61 fn name(&self) -> &'static str {
62 "none"
63 }
64}
65
66#[cfg(feature = "zstd")]
68pub struct ZstdCompressor {
69 pub level: i32,
71}
72
73#[cfg(feature = "zstd")]
74impl Default for ZstdCompressor {
75 fn default() -> Self {
76 Self { level: 3 }
77 }
78}
79
80#[cfg(feature = "zstd")]
81impl Compressor for ZstdCompressor {
82 fn compression_type(&self) -> CompressionType {
83 CompressionType::Zstd
84 }
85
86 fn compress(&self, input: &[u8]) -> Result<Vec<u8>> {
87 zstd::bulk::compress(input, self.level)
88 .map_err(|e| CrousError::DecompressionError(format!("zstd compress: {e}")))
89 }
90
91 fn decompress(&self, input: &[u8], max_output: usize) -> Result<Vec<u8>> {
92 zstd::bulk::decompress(input, max_output)
93 .map_err(|e| CrousError::DecompressionError(format!("zstd decompress: {e}")))
94 }
95
96 fn name(&self) -> &'static str {
97 "zstd"
98 }
99}
100
101#[cfg(feature = "snappy")]
103pub struct SnappyCompressor;
104
105#[cfg(feature = "snappy")]
106impl Compressor for SnappyCompressor {
107 fn compression_type(&self) -> CompressionType {
108 CompressionType::Snappy
109 }
110
111 fn compress(&self, input: &[u8]) -> Result<Vec<u8>> {
112 let mut encoder = snap::raw::Encoder::new();
113 encoder
114 .compress_vec(input)
115 .map_err(|e| CrousError::DecompressionError(format!("snappy compress: {e}")))
116 }
117
118 fn decompress(&self, input: &[u8], max_output: usize) -> Result<Vec<u8>> {
119 let decompressed_len = snap::raw::decompress_len(input)
120 .map_err(|e| CrousError::DecompressionError(format!("snappy len: {e}")))?;
121 if decompressed_len > max_output {
122 return Err(CrousError::MemoryLimitExceeded(
123 decompressed_len,
124 max_output,
125 ));
126 }
127 let mut decoder = snap::raw::Decoder::new();
128 decoder
129 .decompress_vec(input)
130 .map_err(|e| CrousError::DecompressionError(format!("snappy decompress: {e}")))
131 }
132
133 fn name(&self) -> &'static str {
134 "snappy"
135 }
136}
137
138#[cfg(feature = "lz4")]
142pub struct Lz4Compressor;
143
144#[cfg(feature = "lz4")]
145impl Compressor for Lz4Compressor {
146 fn compression_type(&self) -> CompressionType {
147 CompressionType::Lz4
148 }
149
150 fn compress(&self, input: &[u8]) -> Result<Vec<u8>> {
151 Ok(lz4_flex::compress_prepend_size(input))
152 }
153
154 fn decompress(&self, input: &[u8], max_output: usize) -> Result<Vec<u8>> {
155 if input.len() < 4 {
157 return Err(CrousError::DecompressionError(
158 "lz4: input too short for size prefix".into(),
159 ));
160 }
161 let expected_size = u32::from_le_bytes([input[0], input[1], input[2], input[3]]) as usize;
162 if expected_size > max_output {
163 return Err(CrousError::MemoryLimitExceeded(expected_size, max_output));
164 }
165 lz4_flex::decompress_size_prepended(input)
166 .map_err(|e| CrousError::DecompressionError(format!("lz4 decompress: {e}")))
167 }
168
169 fn name(&self) -> &'static str {
170 "lz4"
171 }
172}
173
174pub struct AdaptiveSelector {
180 pub ratio_threshold: f64,
183 pub sample_size: usize,
185}
186
187impl Default for AdaptiveSelector {
188 fn default() -> Self {
189 Self {
190 ratio_threshold: 0.90,
191 sample_size: 64 * 1024, }
193 }
194}
195
196impl AdaptiveSelector {
197 pub fn select(&self, data: &[u8], registry: &CompressorRegistry) -> CompressionType {
200 let sample = if data.len() > self.sample_size {
201 &data[..self.sample_size]
202 } else {
203 data
204 };
205
206 let mut best_type = CompressionType::None;
207 let mut best_ratio = 1.0f64;
208
209 for comp in ®istry.compressors {
210 if comp.compression_type() == CompressionType::None {
211 continue;
212 }
213 if let Ok(compressed) = comp.compress(sample) {
214 let ratio = compressed.len() as f64 / sample.len() as f64;
215 if ratio < best_ratio && ratio < self.ratio_threshold {
216 best_ratio = ratio;
217 best_type = comp.compression_type();
218 }
219 }
220 }
221 best_type
222 }
223}
224
225pub struct CompressorRegistry {
227 compressors: Vec<Box<dyn Compressor>>,
228}
229
230impl CompressorRegistry {
231 pub fn new() -> Self {
233 Self {
234 compressors: vec![Box::new(NoCompression)],
235 }
236 }
237
238 pub fn with_defaults() -> Self {
240 #[allow(unused_mut)]
241 let mut reg = Self::new();
242 #[cfg(feature = "zstd")]
243 reg.register(Box::new(ZstdCompressor::default()));
244 #[cfg(feature = "lz4")]
245 reg.register(Box::new(Lz4Compressor));
246 #[cfg(feature = "snappy")]
247 reg.register(Box::new(SnappyCompressor));
248 reg
249 }
250
251 pub fn register(&mut self, compressor: Box<dyn Compressor>) {
253 self.compressors.push(compressor);
254 }
255
256 pub fn find(&self, comp_type: CompressionType) -> Option<&dyn Compressor> {
258 self.compressors
259 .iter()
260 .find(|c| c.compression_type() == comp_type)
261 .map(|c| c.as_ref())
262 }
263}
264
265impl Default for CompressorRegistry {
266 fn default() -> Self {
267 Self::new()
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn no_compression_roundtrip() {
277 let comp = NoCompression;
278 let data = b"hello world, this is a test";
279 let compressed = comp.compress(data).unwrap();
280 let decompressed = comp.decompress(&compressed, 1024).unwrap();
281 assert_eq!(&decompressed, data);
282 }
283
284 #[test]
285 fn registry_find() {
286 let reg = CompressorRegistry::new();
287 assert!(reg.find(CompressionType::None).is_some());
288 }
290}