1use std::io::Read;
7
8use kimberlite_types::{BoundedSize, CompressionKind};
9
10use crate::StorageError;
11
12const MAX_DECOMPRESSED_SIZE: usize = 1024 * 1024 * 1024; pub trait Codec: Send + Sync {
22 fn kind(&self) -> CompressionKind;
24
25 fn compress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError>;
27
28 fn decompress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError>;
30}
31
32#[derive(Debug, Clone, Copy)]
34pub struct NoneCodec;
35
36impl Codec for NoneCodec {
37 fn kind(&self) -> CompressionKind {
38 CompressionKind::None
39 }
40
41 fn compress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError> {
42 Ok(input.to_vec())
43 }
44
45 fn decompress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError> {
46 Ok(input.to_vec())
47 }
48}
49
50#[derive(Debug, Clone, Copy)]
52pub struct Lz4Codec;
53
54impl Codec for Lz4Codec {
55 fn kind(&self) -> CompressionKind {
56 CompressionKind::Lz4
57 }
58
59 fn compress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError> {
60 Ok(lz4_flex::compress_prepend_size(input))
61 }
62
63 fn decompress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError> {
64 if input.len() < 4 {
77 return Err(StorageError::DecompressionFailed {
78 codec: "lz4",
79 reason: format!(
80 "input too short: need 4-byte size prefix, got {} bytes",
81 input.len()
82 ),
83 });
84 }
85 let claimed_size_raw = u32::from_le_bytes(input[0..4].try_into().expect("4 bytes"));
90 let _claimed_size: BoundedSize<MAX_DECOMPRESSED_SIZE> =
91 BoundedSize::try_from(claimed_size_raw).map_err(|e| {
92 StorageError::DecompressionFailed {
93 codec: "lz4",
94 reason: format!(
95 "claimed size {} exceeds MAX_DECOMPRESSED_SIZE ({})",
96 e.value, e.max
97 ),
98 }
99 })?;
100 lz4_flex::decompress_size_prepended(input).map_err(|e| StorageError::DecompressionFailed {
101 codec: "lz4",
102 reason: e.to_string(),
103 })
104 }
105}
106
107#[derive(Debug, Clone, Copy)]
109pub struct ZstdCodec {
110 pub level: i32,
112}
113
114impl ZstdCodec {
115 pub fn new(level: i32) -> Self {
117 Self { level }
118 }
119}
120
121impl Default for ZstdCodec {
122 fn default() -> Self {
123 Self { level: 3 }
124 }
125}
126
127impl Codec for ZstdCodec {
128 fn kind(&self) -> CompressionKind {
129 CompressionKind::Zstd
130 }
131
132 fn compress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError> {
133 zstd::encode_all(input, self.level).map_err(|e| StorageError::CompressionFailed {
134 codec: "zstd",
135 reason: e.to_string(),
136 })
137 }
138
139 fn decompress(&self, input: &[u8]) -> Result<Vec<u8>, StorageError> {
140 let decoder = zstd::Decoder::new(input).map_err(|e| StorageError::DecompressionFailed {
142 codec: "zstd",
143 reason: format!("failed to create decoder: {e}"),
144 })?;
145
146 let mut output = Vec::new();
147 let mut limited_reader = decoder.take(MAX_DECOMPRESSED_SIZE as u64);
148
149 let bytes_read = std::io::copy(&mut limited_reader, &mut output).map_err(|e| {
150 StorageError::DecompressionFailed {
151 codec: "zstd",
152 reason: format!("decompression failed: {e}"),
153 }
154 })?;
155
156 if bytes_read == MAX_DECOMPRESSED_SIZE as u64 {
158 let mut probe = [0u8; 1];
159 let mut decoder_inner = limited_reader.into_inner();
160 if decoder_inner
161 .read(&mut probe)
162 .map_err(|e| StorageError::DecompressionFailed {
163 codec: "zstd",
164 reason: format!("probe read failed: {e}"),
165 })?
166 > 0
167 {
168 return Err(StorageError::DecompressionFailed {
169 codec: "zstd",
170 reason: format!(
171 "decompressed size exceeds MAX_DECOMPRESSED_SIZE ({MAX_DECOMPRESSED_SIZE} bytes)"
172 ),
173 });
174 }
175 }
176
177 Ok(output)
178 }
179}
180
181#[derive(Debug)]
183pub struct CodecRegistry {
184 lz4: Lz4Codec,
185 zstd: ZstdCodec,
186 none: NoneCodec,
187}
188
189impl CodecRegistry {
190 pub fn new() -> Self {
192 Self {
193 lz4: Lz4Codec,
194 zstd: ZstdCodec::default(),
195 none: NoneCodec,
196 }
197 }
198
199 pub fn with_zstd_level(level: i32) -> Self {
201 Self {
202 lz4: Lz4Codec,
203 zstd: ZstdCodec::new(level),
204 none: NoneCodec,
205 }
206 }
207
208 pub fn get(&self, kind: CompressionKind) -> &dyn Codec {
210 match kind {
211 CompressionKind::None => &self.none,
212 CompressionKind::Lz4 => &self.lz4,
213 CompressionKind::Zstd => &self.zstd,
214 }
215 }
216
217 pub fn compress(&self, kind: CompressionKind, data: &[u8]) -> Result<Vec<u8>, StorageError> {
219 self.get(kind).compress(data)
220 }
221
222 pub fn decompress(&self, kind: CompressionKind, data: &[u8]) -> Result<Vec<u8>, StorageError> {
224 self.get(kind).decompress(data)
225 }
226}
227
228impl Default for CodecRegistry {
229 fn default() -> Self {
230 Self::new()
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn none_codec_roundtrip() {
240 let codec = NoneCodec;
241 let data = b"hello world";
242 let compressed = codec.compress(data).unwrap();
243 let decompressed = codec.decompress(&compressed).unwrap();
244 assert_eq!(data.as_slice(), &decompressed);
245 }
246
247 #[test]
248 fn lz4_codec_roundtrip() {
249 let codec = Lz4Codec;
250 let data = b"hello world hello world hello world";
251 let compressed = codec.compress(data).unwrap();
252 let decompressed = codec.decompress(&compressed).unwrap();
253 assert_eq!(data.as_slice(), &decompressed);
254 }
255
256 #[test]
262 fn lz4_rejects_oversized_size_prefix() {
263 let codec = Lz4Codec;
264 let bomb = [0xFF, 0xFF, 0xFF, 0xFF, 0x00];
267 let err = codec
268 .decompress(&bomb)
269 .expect_err("oversized size prefix must be rejected");
270 match err {
271 StorageError::DecompressionFailed { codec: c, reason } => {
272 assert_eq!(c, "lz4");
273 assert!(
274 reason.contains("exceeds MAX_DECOMPRESSED_SIZE"),
275 "expected size-prefix guard error, got: {reason}"
276 );
277 }
278 other => panic!("expected DecompressionFailed, got {other:?}"),
279 }
280 }
281
282 #[test]
285 fn lz4_rejects_short_input() {
286 let codec = Lz4Codec;
287 assert!(codec.decompress(&[]).is_err());
288 assert!(codec.decompress(&[0x00, 0x00, 0x00]).is_err());
289 }
290
291 #[test]
292 fn zstd_codec_roundtrip() {
293 let codec = ZstdCodec::default();
294 let data = b"hello world hello world hello world";
295 let compressed = codec.compress(data).unwrap();
296 let decompressed = codec.decompress(&compressed).unwrap();
297 assert_eq!(data.as_slice(), &decompressed);
298 }
299
300 #[test]
301 fn lz4_compresses_repetitive_data() {
302 let codec = Lz4Codec;
303 let data: Vec<u8> = vec![42; 10_000];
304 let compressed = codec.compress(&data).unwrap();
305 assert!(compressed.len() < data.len());
306 }
307
308 #[test]
309 fn zstd_compresses_repetitive_data() {
310 let codec = ZstdCodec::default();
311 let data: Vec<u8> = vec![42; 10_000];
312 let compressed = codec.compress(&data).unwrap();
313 assert!(compressed.len() < data.len());
314 }
315
316 #[test]
317 fn codec_registry_lookup() {
318 let registry = CodecRegistry::new();
319 assert_eq!(
320 registry.get(CompressionKind::None).kind(),
321 CompressionKind::None
322 );
323 assert_eq!(
324 registry.get(CompressionKind::Lz4).kind(),
325 CompressionKind::Lz4
326 );
327 assert_eq!(
328 registry.get(CompressionKind::Zstd).kind(),
329 CompressionKind::Zstd
330 );
331 }
332
333 #[test]
334 fn codec_registry_roundtrip() {
335 let registry = CodecRegistry::new();
336 let data = b"test data for codec registry roundtrip";
337
338 for kind in [
339 CompressionKind::None,
340 CompressionKind::Lz4,
341 CompressionKind::Zstd,
342 ] {
343 let compressed = registry.compress(kind, data).unwrap();
344 let decompressed = registry.decompress(kind, &compressed).unwrap();
345 assert_eq!(
346 data.as_slice(),
347 &decompressed,
348 "roundtrip failed for {kind}"
349 );
350 }
351 }
352
353 #[test]
354 fn empty_data_roundtrip() {
355 let registry = CodecRegistry::new();
356 let data = b"";
357
358 for kind in [
359 CompressionKind::None,
360 CompressionKind::Lz4,
361 CompressionKind::Zstd,
362 ] {
363 let compressed = registry.compress(kind, data).unwrap();
364 let decompressed = registry.decompress(kind, &compressed).unwrap();
365 assert_eq!(
366 data.as_slice(),
367 &decompressed,
368 "empty roundtrip failed for {kind}"
369 );
370 }
371 }
372
373 #[test]
374 fn zstd_rejects_decompression_bomb() {
375 let bomb_size = MAX_DECOMPRESSED_SIZE + 1024 * 1024; let payload: Vec<u8> = vec![0u8; bomb_size];
379
380 let codec = ZstdCodec::default();
381 let compressed = codec.compress(&payload).unwrap();
382
383 assert!(
385 compressed.len() < bomb_size / 100,
386 "compressed size {} should be <1% of original {}",
387 compressed.len(),
388 bomb_size
389 );
390
391 let result = codec.decompress(&compressed);
393 assert!(result.is_err(), "decompression bomb should be rejected");
394
395 let err = result.unwrap_err();
396 match err {
397 StorageError::DecompressionFailed { codec: c, reason } => {
398 assert_eq!(c, "zstd");
399 assert!(
400 reason.contains("exceeds MAX_DECOMPRESSED_SIZE"),
401 "error should mention size limit: {reason}"
402 );
403 }
404 _ => panic!("wrong error type: {err:?}"),
405 }
406 }
407
408 #[test]
409 fn zstd_allows_large_but_under_limit_data() {
410 let size = MAX_DECOMPRESSED_SIZE / 2;
412 let payload: Vec<u8> = vec![42u8; size];
413
414 let codec = ZstdCodec::default();
415 let compressed = codec.compress(&payload).unwrap();
416 let decompressed = codec.decompress(&compressed).unwrap();
417
418 assert_eq!(decompressed.len(), size);
419 assert_eq!(decompressed, payload);
420 }
421
422 #[cfg(test)]
423 mod proptest_codec {
424 use super::*;
425 use proptest::prelude::*;
426
427 proptest! {
434 #![proptest_config(ProptestConfig::with_cases(8))]
435
436 #[test]
438 fn zstd_roundtrip_under_limit(data in prop::collection::vec(any::<u8>(), 0..1024*1024)) {
439 let codec = ZstdCodec::default();
440 let compressed = codec.compress(&data).unwrap();
441 let decompressed = codec.decompress(&compressed).unwrap();
442 assert_eq!(data, decompressed);
443 }
444
445 #[test]
447 fn zstd_rejects_oversized_payloads(
448 byte in any::<u8>(),
450 multiplier in 1u32..10
452 ) {
453 let size = MAX_DECOMPRESSED_SIZE + (multiplier as usize * 10 * 1024 * 1024);
454 let payload = vec![byte; size];
455
456 let codec = ZstdCodec::default();
457 let compressed = codec.compress(&payload).unwrap();
458 let result = codec.decompress(&compressed);
459
460 assert!(result.is_err(), "oversized payload should be rejected");
461 }
462 }
463 }
464}