1use crate::block::{compress_block, decompress_block_payload};
4use crate::config::{EngineConfiguration, ProgressEvent, ProgressPhase};
5use crate::format::{BlockHeader, BlockIndexEntry, FileFlags, FileFooter, FileHeader, IndexHeader};
6use crate::index::load_index;
7use crush_core::error::{CrushError, Result};
8use rayon::prelude::*;
9use std::io::{Cursor, Read, Seek, SeekFrom, Write};
10use std::path::Path;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13
14fn with_pool<T: Send>(config: &EngineConfiguration, f: impl FnOnce() -> T + Send) -> T {
21 match &config.thread_pool {
22 Some(pool) => pool.install(f),
23 None => f(),
24 }
25}
26
27pub fn compress(input: &[u8], config: &EngineConfiguration) -> Result<Vec<u8>> {
38 let cancelled = Arc::new(AtomicBool::new(false));
39
40 let block_size = config.block_size as usize;
41 let blocks: Vec<&[u8]> = input.chunks(block_size).collect();
42 let total_blocks = blocks.len() as u64;
43
44 let results: Vec<Result<crate::block::CompressedBlock>> = with_pool(config, || {
46 blocks
47 .par_iter()
48 .enumerate()
49 .map(|(i, chunk)| {
50 if cancelled.load(Ordering::Acquire) {
51 return Err(CrushError::Cancelled);
52 }
53 compress_block(chunk, i, config)
54 })
55 .collect()
56 });
57
58 #[allow(clippy::cast_possible_truncation)]
61 let mut compressed_blocks = Vec::with_capacity(total_blocks as usize);
62 for r in results {
63 compressed_blocks.push(r?);
64 }
65
66 let mut out = Vec::new();
68
69 let mut flags = FileFlags::default();
70 if config.checksums {
71 flags = flags.with_checksums();
72 }
73
74 let header = FileHeader::new(
75 config.block_size,
76 config.compression_level,
77 flags,
78 input.len() as u64,
79 total_blocks,
80 );
81 out.extend_from_slice(&header.to_bytes());
82
83 let mut index_entries = Vec::with_capacity(compressed_blocks.len());
84 let mut bytes_processed: u64 = 0;
85
86 for (i, block) in compressed_blocks.iter().enumerate() {
87 let block_offset = out.len() as u64;
88 out.extend_from_slice(&block.header.to_bytes());
89 out.extend_from_slice(&block.payload);
90
91 index_entries.push(BlockIndexEntry {
92 block_offset,
93 compressed_size: block.header.compressed_size,
94 uncompressed_size: block.header.uncompressed_size,
95 checksum: block.header.checksum,
96 });
97
98 bytes_processed += u64::from(block.header.uncompressed_size);
99
100 if let Some(cb_arc) = &config.progress {
102 let event = ProgressEvent {
103 bytes_processed,
104 blocks_completed: i as u64 + 1,
105 total_blocks: Some(total_blocks),
106 phase: ProgressPhase::Compressing,
107 };
108 let mut cb = cb_arc.lock().map_err(|_| {
109 CrushError::InvalidConfig("progress callback mutex poisoned".to_owned())
110 })?;
111 if !cb(event) {
112 return Err(CrushError::Cancelled);
113 }
114 }
115 }
116
117 let index_offset = out.len() as u64;
119 let entry_count = u32::try_from(index_entries.len())
120 .map_err(|_| CrushError::InvalidConfig("too many blocks for index".to_owned()))?;
121 let ih = IndexHeader {
122 entry_count,
123 index_flags: 0,
124 };
125 out.extend_from_slice(&ih.to_bytes());
126 for e in &index_entries {
127 out.extend_from_slice(&e.to_bytes());
128 }
129
130 let index_size = u32::try_from(IndexHeader::SIZE + index_entries.len() * BlockIndexEntry::SIZE)
132 .map_err(|_| CrushError::InvalidConfig("index too large for footer".to_owned()))?;
133 let footer = FileFooter::new(index_offset, index_size);
134 out.extend_from_slice(&footer.to_bytes());
135
136 Ok(out)
137}
138
139pub fn compress_file(path: &Path, config: &EngineConfiguration) -> Result<Vec<u8>> {
155 let file = std::fs::File::open(path)?;
156 let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
158 compress(&mmap, config)
159}
160
161pub fn compress_to_writer<W: Write>(
171 input: &[u8],
172 mut writer: W,
173 config: &EngineConfiguration,
174) -> Result<u64> {
175 let out = compress(input, config)?;
176 let len = out.len() as u64;
177 writer.write_all(&out)?;
178 Ok(len)
179}
180
181pub fn compress_stream<R: Read, W: Write>(
194 mut reader: R,
195 mut writer: W,
196 config: &EngineConfiguration,
197) -> Result<u64> {
198 let mut input = Vec::new();
199 reader.read_to_end(&mut input)?;
200 let out = compress(&input, config)?;
201 let len = out.len() as u64;
202 writer.write_all(&out)?;
203 Ok(len)
204}
205
206pub fn decompress(input: &[u8], config: &EngineConfiguration) -> Result<Vec<u8>> {
224 let mut cursor = Cursor::new(input);
225 decompress_from_reader(&mut cursor, config)
226}
227
228pub fn decompress_from_reader<R: Read + Seek>(
248 reader: &mut R,
249 config: &EngineConfiguration,
250) -> Result<Vec<u8>> {
251 let index = load_index(reader)?;
252
253 #[allow(
256 clippy::cast_precision_loss,
257 clippy::cast_possible_truncation,
258 clippy::cast_sign_loss
259 )]
260 let limit = {
261 let file_size = reader.seek(SeekFrom::End(0))?;
262 (file_size as f64 * config.max_decompression_ratio) as u64
263 };
264 let total_uncompressed = index.total_uncompressed_size();
265 if total_uncompressed > limit {
266 return Err(CrushError::ExpansionLimitExceeded { block_index: 0 });
267 }
268
269 let total_blocks = index.len();
270 let checksums_enabled = index.checksums_enabled;
271
272 let raw_blocks: Vec<(BlockHeader, Vec<u8>)> = index
274 .entries
275 .iter()
276 .enumerate()
277 .map(|(i, entry)| -> Result<(BlockHeader, Vec<u8>)> {
278 reader.seek(SeekFrom::Start(entry.block_offset))?;
279 let mut hdr_buf = [0u8; BlockHeader::SIZE];
280 reader.read_exact(&mut hdr_buf).map_err(|e| {
281 CrushError::InvalidFormat(format!("block {i} header read error: {e}"))
282 })?;
283 let header = BlockHeader::from_bytes(&hdr_buf);
284 let mut payload = vec![0u8; header.compressed_size as usize];
285 reader.read_exact(&mut payload).map_err(|e| {
286 CrushError::InvalidFormat(format!("block {i} payload read error: {e}"))
287 })?;
288 Ok((header, payload))
289 })
290 .collect::<Result<Vec<_>>>()?;
291
292 let results: Vec<Result<(usize, Vec<u8>)>> = with_pool(config, || {
294 raw_blocks
295 .par_iter()
296 .enumerate()
297 .map(|(i, (header, payload))| {
298 let decompressed =
299 decompress_block_payload(header, payload, i as u64, checksums_enabled)?;
300 Ok((i, decompressed))
301 })
302 .collect()
303 });
304
305 let total_blocks_usize = usize::try_from(total_blocks)
307 .map_err(|_| CrushError::InvalidConfig("block count overflows usize".to_owned()))?;
308 let mut ordered: Vec<Option<Vec<u8>>> = (0..total_blocks_usize).map(|_| None).collect();
309 let mut bytes_processed: u64 = 0;
310
311 for r in results {
312 let (i, data) = r?;
313 bytes_processed += data.len() as u64;
314 ordered[i] = Some(data);
315 }
316
317 for (i, chunk) in ordered.iter().enumerate() {
319 if let (Some(cb_arc), Some(_chunk)) = (&config.progress, chunk) {
320 let event = ProgressEvent {
321 bytes_processed,
322 blocks_completed: i as u64 + 1,
323 total_blocks: Some(total_blocks),
324 phase: ProgressPhase::Decompressing,
325 };
326 let mut cb = cb_arc
327 .lock()
328 .map_err(|_| CrushError::InvalidConfig("progress mutex poisoned".to_owned()))?;
329 if !cb(event) {
330 return Err(CrushError::Cancelled);
331 }
332 }
333 }
334
335 let output: Vec<u8> = ordered.into_iter().flatten().flatten().collect();
336
337 Ok(output)
338}
339
340#[cfg(test)]
345#[allow(
346 clippy::expect_used,
347 clippy::unwrap_used,
348 clippy::panic,
349 clippy::cast_possible_truncation,
350 clippy::missing_panics_doc
351)]
352mod tests {
353 use super::*;
354 use crate::format::FORMAT_VERSION;
355 use std::io::Write as IoWrite;
356 use tempfile::NamedTempFile;
357
358 fn default_config() -> EngineConfiguration {
359 EngineConfiguration::builder()
360 .block_size(65_536) .build()
362 .expect("config")
363 }
364
365 #[test]
366 fn test_compress_roundtrip_small() {
367 let data: Vec<u8> = b"hello world"
368 .iter()
369 .cycle()
370 .take(200_000)
371 .copied()
372 .collect();
373 let config = default_config();
374 let compressed = compress(&data, &config).expect("compress");
375 let recovered = decompress(&compressed, &config).expect("decompress");
376 assert_eq!(data, recovered);
377 }
378
379 #[test]
380 fn test_compress_incompressible_stored() {
381 let data: Vec<u8> = b"hello world!"
386 .iter()
387 .cycle()
388 .take(200_000)
389 .copied()
390 .collect();
391 let config = EngineConfiguration::builder()
392 .block_size(65_536)
393 .max_expansion_ratio(0.001) .build()
395 .expect("config");
396 let compressed = compress(&data, &config).expect("compress");
397 let recovered = decompress(&compressed, &config).expect("decompress");
399 assert_eq!(data, recovered);
400 let mut cursor = Cursor::new(&compressed);
402 let index = load_index(&mut cursor).expect("load_index");
403 cursor
405 .seek(SeekFrom::Start(index.entries[0].block_offset))
406 .expect("seek");
407 let mut hdr = [0u8; BlockHeader::SIZE];
408 cursor.read_exact(&mut hdr).expect("read hdr");
409 let header = BlockHeader::from_bytes(&hdr);
410 assert!(
411 header.flags.stored(),
412 "expected stored flag on incompressible data"
413 );
414 }
415
416 #[test]
417 fn test_compress_output_valid_crsh_format() {
418 let data: Vec<u8> = b"test".iter().cycle().take(100_000).copied().collect();
419 let config = default_config();
420 let compressed = compress(&data, &config).expect("compress");
421 let hdr_bytes: [u8; FileHeader::SIZE] = compressed[..FileHeader::SIZE]
423 .try_into()
424 .expect("hdr bytes");
425 let hdr = FileHeader::from_bytes(&hdr_bytes).expect("parse header");
426 assert_eq!(hdr.magic, crate::format::CRSH_MAGIC);
427 assert_eq!(hdr.format_version, FORMAT_VERSION);
428 let footer_bytes: [u8; FileFooter::SIZE] = compressed
430 [compressed.len() - FileFooter::SIZE..]
431 .try_into()
432 .expect("footer bytes");
433 let footer = FileFooter::from_bytes(&footer_bytes).expect("parse footer");
434 assert_eq!(footer.magic, crate::format::CRSH_MAGIC);
435 let mut cursor = Cursor::new(&compressed);
437 let index = load_index(&mut cursor).expect("load_index");
438 assert_eq!(index.len(), hdr.block_count);
439 }
440
441 #[test]
442 fn test_progress_callback_invoked_per_block() {
443 use std::sync::{Arc, Mutex};
444 let data: Vec<u8> = b"abc".iter().cycle().take(300_000).copied().collect();
445 let count = Arc::new(Mutex::new(0u64));
446 let count_clone = count.clone();
447 let cb: crate::config::ProgressCallback = Box::new(move |_event| {
448 let mut c = count_clone.lock().expect("lock");
449 *c += 1;
450 true
451 });
452 let config = EngineConfiguration::builder()
453 .block_size(65_536)
454 .progress(Arc::new(Mutex::new(cb)))
455 .build()
456 .expect("config");
457 compress(&data, &config).expect("compress");
458 let final_count = *count.lock().expect("lock");
459 assert!(final_count >= 1, "progress callback was not invoked");
461 }
462
463 #[test]
464 fn test_cancel_halts_at_block_boundary() {
465 use std::sync::{Arc, Mutex};
466 let data: Vec<u8> = b"xyz".iter().cycle().take(1_000_000).copied().collect();
467 let cb: crate::config::ProgressCallback = Box::new(|_event| false); let config = EngineConfiguration::builder()
469 .block_size(65_536)
470 .progress(Arc::new(Mutex::new(cb)))
471 .build()
472 .expect("config");
473 let result = compress(&data, &config);
474 assert!(result.is_err());
475 assert!(result.unwrap_err().is_cancelled());
476 }
477
478 #[test]
479 fn test_compress_file_roundtrip() {
480 let data: Vec<u8> = b"file data".iter().cycle().take(200_000).copied().collect();
481 let mut tmp = NamedTempFile::new().expect("temp file");
482 tmp.write_all(&data).expect("write");
483 let config = default_config();
484 let compressed = compress_file(tmp.path(), &config).expect("compress_file");
485 let recovered = decompress(&compressed, &config).expect("decompress");
486 assert_eq!(data, recovered);
487 }
488
489 #[test]
490 fn test_decompress_roundtrip() {
491 let data: Vec<u8> = b"decompress me"
492 .iter()
493 .cycle()
494 .take(500_000)
495 .copied()
496 .collect();
497 let config = default_config();
498 let compressed = compress(&data, &config).expect("compress");
499 let recovered = decompress(&compressed, &config).expect("decompress");
500 assert_eq!(data, recovered);
501 }
502
503 #[test]
504 fn test_decompress_corrupt_block_detected() {
505 let data: Vec<u8> = b"corrupt test"
506 .iter()
507 .cycle()
508 .take(200_000)
509 .copied()
510 .collect();
511 let config = default_config();
512 let mut compressed = compress(&data, &config).expect("compress");
513
514 let mut cursor = Cursor::new(&compressed);
516 let index = load_index(&mut cursor).expect("load_index");
517 let block0_offset = index.entries[0].block_offset as usize;
518
519 let payload_start = block0_offset + BlockHeader::SIZE;
521 if payload_start < compressed.len() {
522 compressed[payload_start] ^= 0xFF;
523 }
524
525 let result = decompress(&compressed, &config);
526 assert!(result.is_err());
527 let err = result.unwrap_err();
528 assert!(
529 matches!(err, CrushError::ChecksumMismatch { block_index: 0, .. })
530 || matches!(err, CrushError::InvalidFormat(_)),
531 "expected checksum or format error, got {err:?}"
532 );
533 }
534
535 #[test]
536 fn test_version_mismatch_rejected() {
537 let data: Vec<u8> = b"version test"
538 .iter()
539 .cycle()
540 .take(100_000)
541 .copied()
542 .collect();
543 let config = default_config();
544 let mut compressed = compress(&data, &config).expect("compress");
545
546 let footer_start = compressed.len() - FileFooter::SIZE;
548 compressed[footer_start + 16..footer_start + 20].copy_from_slice(&9999u32.to_le_bytes());
550 let result = decompress(&compressed, &config);
554 assert!(result.is_err());
555 }
556
557 #[test]
558 fn test_expansion_limit_exceeded() {
559 let data: Vec<u8> = b"test data".iter().cycle().take(100_000).copied().collect();
560 let compress_config = default_config();
561 let compressed = compress(&data, &compress_config).expect("compress");
562
563 let decompress_config = EngineConfiguration::builder()
565 .block_size(65_536)
566 .max_decompression_ratio(0.000_001)
567 .build()
568 .expect("config");
569 let result = decompress(&compressed, &decompress_config);
570 assert!(result.is_err());
571 assert!(matches!(
572 result.unwrap_err(),
573 CrushError::ExpansionLimitExceeded { .. }
574 ));
575 }
576
577 #[test]
578 fn test_truncated_footer_rejected() {
579 let data: Vec<u8> = b"truncated".iter().cycle().take(100_000).copied().collect();
580 let config = default_config();
581 let mut compressed = compress(&data, &config).expect("compress");
582 compressed.truncate(compressed.len() - FileFooter::SIZE);
584 let result = decompress(&compressed, &config);
585 assert!(result.is_err());
586 }
587
588 proptest::proptest! {
593 #![proptest_config(proptest::prelude::ProptestConfig::with_cases(50))]
594
595 #[test]
596 fn proptest_compress_decompress_roundtrip(
597 data in proptest::collection::vec(proptest::prelude::any::<u8>(), 0..200_000),
598 block_kb in proptest::prelude::prop_oneof![
599 proptest::prelude::Just(64usize),
600 proptest::prelude::Just(256),
601 proptest::prelude::Just(1024)
602 ],
603 level in 0u8..=9,
604 ) {
605 let block_size = u32::try_from(block_kb * 1024).unwrap();
606 let config = EngineConfiguration::builder()
607 .block_size(block_size)
608 .compression_level(level)
609 .build()
610 .unwrap();
611 let compressed = compress(&data, &config).unwrap();
612 let recovered = decompress(&compressed, &config).unwrap();
613 proptest::prop_assert_eq!(data, recovered);
614 }
615 }
616}