Skip to main content

crush_parallel/
engine.rs

1//! Main compression and decompression entry points.
2
3use crate::block::{compress_block, decompress_block_payload};
4use crate::config::{EngineConfiguration, ProgressEvent, ProgressPhase};
5use crate::format::{BlockHeader, BlockIndexEntry, FileFlags, FileFooter, FileHeader, IndexHeader};
6use crate::index::load_index;
7use crush_core::error::{CrushError, Result};
8use rayon::prelude::*;
9use std::io::{Cursor, Read, Seek, SeekFrom, Write};
10use std::path::Path;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13
14// ---------------------------------------------------------------------------
15// Thread pool helper
16// ---------------------------------------------------------------------------
17
18/// Run `f` inside the thread pool from `config`, or the global rayon pool when
19/// `config.workers == 0`.
20fn with_pool<T: Send>(config: &EngineConfiguration, f: impl FnOnce() -> T + Send) -> T {
21    match &config.thread_pool {
22        Some(pool) => pool.install(f),
23        None => f(),
24    }
25}
26
27// ---------------------------------------------------------------------------
28// compress
29// ---------------------------------------------------------------------------
30
31/// Compress `input` bytes using the parallel block engine.
32///
33/// # Errors
34///
35/// - [`CrushError::Cancelled`] — cancelled via the progress callback.
36/// - [`CrushError::InvalidConfig`] — configuration validation failed.
37pub fn compress(input: &[u8], config: &EngineConfiguration) -> Result<Vec<u8>> {
38    let cancelled = Arc::new(AtomicBool::new(false));
39
40    let block_size = config.block_size as usize;
41    let blocks: Vec<&[u8]> = input.chunks(block_size).collect();
42    let total_blocks = blocks.len() as u64;
43
44    // Compress all blocks in parallel (respects config.workers via dedicated pool)
45    let results: Vec<Result<crate::block::CompressedBlock>> = with_pool(config, || {
46        blocks
47            .par_iter()
48            .enumerate()
49            .map(|(i, chunk)| {
50                if cancelled.load(Ordering::Acquire) {
51                    return Err(CrushError::Cancelled);
52                }
53                compress_block(chunk, i, config)
54            })
55            .collect()
56    });
57
58    // Check for errors / cancellation
59    // block count fits usize: validated against u32::MAX in the index write path
60    #[allow(clippy::cast_possible_truncation)]
61    let mut compressed_blocks = Vec::with_capacity(total_blocks as usize);
62    for r in results {
63        compressed_blocks.push(r?);
64    }
65
66    // Assemble output
67    let mut out = Vec::new();
68
69    let mut flags = FileFlags::default();
70    if config.checksums {
71        flags = flags.with_checksums();
72    }
73
74    let header = FileHeader::new(
75        config.block_size,
76        config.compression_level,
77        flags,
78        input.len() as u64,
79        total_blocks,
80    );
81    out.extend_from_slice(&header.to_bytes());
82
83    let mut index_entries = Vec::with_capacity(compressed_blocks.len());
84    let mut bytes_processed: u64 = 0;
85
86    for (i, block) in compressed_blocks.iter().enumerate() {
87        let block_offset = out.len() as u64;
88        out.extend_from_slice(&block.header.to_bytes());
89        out.extend_from_slice(&block.payload);
90
91        index_entries.push(BlockIndexEntry {
92            block_offset,
93            compressed_size: block.header.compressed_size,
94            uncompressed_size: block.header.uncompressed_size,
95            checksum: block.header.checksum,
96        });
97
98        bytes_processed += u64::from(block.header.uncompressed_size);
99
100        // Invoke progress callback
101        if let Some(cb_arc) = &config.progress {
102            let event = ProgressEvent {
103                bytes_processed,
104                blocks_completed: i as u64 + 1,
105                total_blocks: Some(total_blocks),
106                phase: ProgressPhase::Compressing,
107            };
108            let mut cb = cb_arc.lock().map_err(|_| {
109                CrushError::InvalidConfig("progress callback mutex poisoned".to_owned())
110            })?;
111            if !cb(event) {
112                return Err(CrushError::Cancelled);
113            }
114        }
115    }
116
117    // Write index
118    let index_offset = out.len() as u64;
119    let entry_count = u32::try_from(index_entries.len())
120        .map_err(|_| CrushError::InvalidConfig("too many blocks for index".to_owned()))?;
121    let ih = IndexHeader {
122        entry_count,
123        index_flags: 0,
124    };
125    out.extend_from_slice(&ih.to_bytes());
126    for e in &index_entries {
127        out.extend_from_slice(&e.to_bytes());
128    }
129
130    // Write footer
131    let index_size = u32::try_from(IndexHeader::SIZE + index_entries.len() * BlockIndexEntry::SIZE)
132        .map_err(|_| CrushError::InvalidConfig("index too large for footer".to_owned()))?;
133    let footer = FileFooter::new(index_offset, index_size);
134    out.extend_from_slice(&footer.to_bytes());
135
136    Ok(out)
137}
138
139// ---------------------------------------------------------------------------
140// compress_file
141// ---------------------------------------------------------------------------
142
143/// Compress a file at `path` using memory-mapped zero-copy I/O (FR-009).
144///
145/// This is the preferred entry point for large file inputs.
146/// Internally uses `memmap2::MmapOptions` to avoid copying file contents
147/// into a heap buffer before compression.
148///
149/// # Errors
150///
151/// - `CrushError::Io` — file could not be opened or memory-mapped.
152/// - `CrushError::Cancelled` — cancelled via the progress callback.
153/// - `CrushError::InvalidConfig` — configuration validation failed.
154pub fn compress_file(path: &Path, config: &EngineConfiguration) -> Result<Vec<u8>> {
155    let file = std::fs::File::open(path)?;
156    // SAFETY: the file is only read, not written, during compression.
157    let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
158    compress(&mmap, config)
159}
160
161// ---------------------------------------------------------------------------
162// compress_to_writer
163// ---------------------------------------------------------------------------
164
165/// Compress `input` bytes, writing the CRSH output to `writer`.
166///
167/// # Errors
168///
169/// Returns any error from `compress()` or from the underlying `write_all`.
170pub fn compress_to_writer<W: Write>(
171    input: &[u8],
172    mut writer: W,
173    config: &EngineConfiguration,
174) -> Result<u64> {
175    let out = compress(input, config)?;
176    let len = out.len() as u64;
177    writer.write_all(&out)?;
178    Ok(len)
179}
180
181// ---------------------------------------------------------------------------
182// compress_stream
183// ---------------------------------------------------------------------------
184
185/// Compress a [`Read`] stream, writing CRSH output to `writer`.
186///
187/// Total input size is unknown at start; `FileHeader` `block_count` and
188/// `uncompressed_size` are set to `u64::MAX` (streaming sentinel).
189///
190/// # Errors
191///
192/// Returns any error from `compress()` or the underlying writer.
193pub fn compress_stream<R: Read, W: Write>(
194    mut reader: R,
195    mut writer: W,
196    config: &EngineConfiguration,
197) -> Result<u64> {
198    let mut input = Vec::new();
199    reader.read_to_end(&mut input)?;
200    let out = compress(&input, config)?;
201    let len = out.len() as u64;
202    writer.write_all(&out)?;
203    Ok(len)
204}
205
206// ---------------------------------------------------------------------------
207// decompress
208// ---------------------------------------------------------------------------
209
210/// Decompress a CRSH-format byte slice.
211///
212/// Reads the file footer first, loads the index, then decompresses all
213/// blocks in parallel.
214///
215/// # Errors
216///
217/// - `VersionMismatch` — file was produced by a different engine version.
218/// - `InvalidFormat` — magic bytes invalid or file truncated.
219/// - `ChecksumMismatch` — a block's CRC32 does not match.
220/// - `ExpansionLimitExceeded` — output would exceed `config.max_decompression_ratio`.
221/// - `IndexCorrupted` — block index is missing or truncated.
222/// - `Cancelled` — cancelled via progress callback.
223pub fn decompress(input: &[u8], config: &EngineConfiguration) -> Result<Vec<u8>> {
224    let mut cursor = Cursor::new(input);
225    decompress_from_reader(&mut cursor, config)
226}
227
228// ---------------------------------------------------------------------------
229// decompress_from_reader
230// ---------------------------------------------------------------------------
231
232/// Decompress a CRSH file from a seekable reader.
233///
234/// Phase 1 (sequential): reads all block headers and payloads into memory.
235/// Phase 2 (parallel): decompresses the collected payloads via rayon.
236///
237/// This two-phase approach avoids sharing `&mut R` across rayon worker threads.
238///
239/// # Errors
240///
241/// - [`CrushError::ExpansionLimitExceeded`] — decompressed size would exceed the ratio limit.
242/// - [`CrushError::ChecksumMismatch`] — a block's CRC32 does not match.
243/// - [`CrushError::InvalidFormat`] — file is truncated or block data is corrupt.
244/// - [`CrushError::IndexCorrupted`] — footer or index is missing or invalid.
245/// - [`CrushError::Cancelled`] — cancelled via progress callback.
246/// - [`CrushError::Io`] — underlying I/O error.
247pub fn decompress_from_reader<R: Read + Seek>(
248    reader: &mut R,
249    config: &EngineConfiguration,
250) -> Result<Vec<u8>> {
251    let index = load_index(reader)?;
252
253    // Expansion ratio guard.
254    // Precision and sign-loss are acceptable: this is a heuristic ratio check, not a byte counter.
255    #[allow(
256        clippy::cast_precision_loss,
257        clippy::cast_possible_truncation,
258        clippy::cast_sign_loss
259    )]
260    let limit = {
261        let file_size = reader.seek(SeekFrom::End(0))?;
262        (file_size as f64 * config.max_decompression_ratio) as u64
263    };
264    let total_uncompressed = index.total_uncompressed_size();
265    if total_uncompressed > limit {
266        return Err(CrushError::ExpansionLimitExceeded { block_index: 0 });
267    }
268
269    let total_blocks = index.len();
270    let checksums_enabled = index.checksums_enabled;
271
272    // Phase 1: Read all block data sequentially (requires exclusive &mut reader).
273    let raw_blocks: Vec<(BlockHeader, Vec<u8>)> = index
274        .entries
275        .iter()
276        .enumerate()
277        .map(|(i, entry)| -> Result<(BlockHeader, Vec<u8>)> {
278            reader.seek(SeekFrom::Start(entry.block_offset))?;
279            let mut hdr_buf = [0u8; BlockHeader::SIZE];
280            reader.read_exact(&mut hdr_buf).map_err(|e| {
281                CrushError::InvalidFormat(format!("block {i} header read error: {e}"))
282            })?;
283            let header = BlockHeader::from_bytes(&hdr_buf);
284            let mut payload = vec![0u8; header.compressed_size as usize];
285            reader.read_exact(&mut payload).map_err(|e| {
286                CrushError::InvalidFormat(format!("block {i} payload read error: {e}"))
287            })?;
288            Ok((header, payload))
289        })
290        .collect::<Result<Vec<_>>>()?;
291
292    // Phase 2: Decompress in parallel — no reader access needed.
293    let results: Vec<Result<(usize, Vec<u8>)>> = with_pool(config, || {
294        raw_blocks
295            .par_iter()
296            .enumerate()
297            .map(|(i, (header, payload))| {
298                let decompressed =
299                    decompress_block_payload(header, payload, i as u64, checksums_enabled)?;
300                Ok((i, decompressed))
301            })
302            .collect()
303    });
304
305    // Re-assemble in order
306    let total_blocks_usize = usize::try_from(total_blocks)
307        .map_err(|_| CrushError::InvalidConfig("block count overflows usize".to_owned()))?;
308    let mut ordered: Vec<Option<Vec<u8>>> = (0..total_blocks_usize).map(|_| None).collect();
309    let mut bytes_processed: u64 = 0;
310
311    for r in results {
312        let (i, data) = r?;
313        bytes_processed += data.len() as u64;
314        ordered[i] = Some(data);
315    }
316
317    // Invoke progress callbacks (single-threaded pass after parallel decompress)
318    for (i, chunk) in ordered.iter().enumerate() {
319        if let (Some(cb_arc), Some(_chunk)) = (&config.progress, chunk) {
320            let event = ProgressEvent {
321                bytes_processed,
322                blocks_completed: i as u64 + 1,
323                total_blocks: Some(total_blocks),
324                phase: ProgressPhase::Decompressing,
325            };
326            let mut cb = cb_arc
327                .lock()
328                .map_err(|_| CrushError::InvalidConfig("progress mutex poisoned".to_owned()))?;
329            if !cb(event) {
330                return Err(CrushError::Cancelled);
331            }
332        }
333    }
334
335    let output: Vec<u8> = ordered.into_iter().flatten().flatten().collect();
336
337    Ok(output)
338}
339
340// ---------------------------------------------------------------------------
341// Tests
342// ---------------------------------------------------------------------------
343
344#[cfg(test)]
345#[allow(
346    clippy::expect_used,
347    clippy::unwrap_used,
348    clippy::panic,
349    clippy::cast_possible_truncation,
350    clippy::missing_panics_doc
351)]
352mod tests {
353    use super::*;
354    use crate::format::FORMAT_VERSION;
355    use std::io::Write as IoWrite;
356    use tempfile::NamedTempFile;
357
358    fn default_config() -> EngineConfiguration {
359        EngineConfiguration::builder()
360            .block_size(65_536) // small blocks for fast tests
361            .build()
362            .expect("config")
363    }
364
365    #[test]
366    fn test_compress_roundtrip_small() {
367        let data: Vec<u8> = b"hello world"
368            .iter()
369            .cycle()
370            .take(200_000)
371            .copied()
372            .collect();
373        let config = default_config();
374        let compressed = compress(&data, &config).expect("compress");
375        let recovered = decompress(&compressed, &config).expect("decompress");
376        assert_eq!(data, recovered);
377    }
378
379    #[test]
380    fn test_compress_incompressible_stored() {
381        // Use a tiny expansion ratio so that even well-compressed output triggers stored.
382        // With ratio 0.001, DEFLATE would need to compress a 65536-byte block to < 66 bytes
383        // to avoid stored — impossible in practice. This deterministically tests stored-block
384        // write/read paths without depending on data entropy.
385        let data: Vec<u8> = b"hello world!"
386            .iter()
387            .cycle()
388            .take(200_000)
389            .copied()
390            .collect();
391        let config = EngineConfiguration::builder()
392            .block_size(65_536)
393            .max_expansion_ratio(0.001) // essentially forces stored for all blocks
394            .build()
395            .expect("config");
396        let compressed = compress(&data, &config).expect("compress");
397        // Round-trip must still work
398        let recovered = decompress(&compressed, &config).expect("decompress");
399        assert_eq!(data, recovered);
400        // Verify at least one stored block exists
401        let mut cursor = Cursor::new(&compressed);
402        let index = load_index(&mut cursor).expect("load_index");
403        // Seek to first block header and check stored flag
404        cursor
405            .seek(SeekFrom::Start(index.entries[0].block_offset))
406            .expect("seek");
407        let mut hdr = [0u8; BlockHeader::SIZE];
408        cursor.read_exact(&mut hdr).expect("read hdr");
409        let header = BlockHeader::from_bytes(&hdr);
410        assert!(
411            header.flags.stored(),
412            "expected stored flag on incompressible data"
413        );
414    }
415
416    #[test]
417    fn test_compress_output_valid_crsh_format() {
418        let data: Vec<u8> = b"test".iter().cycle().take(100_000).copied().collect();
419        let config = default_config();
420        let compressed = compress(&data, &config).expect("compress");
421        // Validate FileHeader
422        let hdr_bytes: [u8; FileHeader::SIZE] = compressed[..FileHeader::SIZE]
423            .try_into()
424            .expect("hdr bytes");
425        let hdr = FileHeader::from_bytes(&hdr_bytes).expect("parse header");
426        assert_eq!(hdr.magic, crate::format::CRSH_MAGIC);
427        assert_eq!(hdr.format_version, FORMAT_VERSION);
428        // Validate footer
429        let footer_bytes: [u8; FileFooter::SIZE] = compressed
430            [compressed.len() - FileFooter::SIZE..]
431            .try_into()
432            .expect("footer bytes");
433        let footer = FileFooter::from_bytes(&footer_bytes).expect("parse footer");
434        assert_eq!(footer.magic, crate::format::CRSH_MAGIC);
435        // Load index and verify entry count matches header block_count
436        let mut cursor = Cursor::new(&compressed);
437        let index = load_index(&mut cursor).expect("load_index");
438        assert_eq!(index.len(), hdr.block_count);
439    }
440
441    #[test]
442    fn test_progress_callback_invoked_per_block() {
443        use std::sync::{Arc, Mutex};
444        let data: Vec<u8> = b"abc".iter().cycle().take(300_000).copied().collect();
445        let count = Arc::new(Mutex::new(0u64));
446        let count_clone = count.clone();
447        let cb: crate::config::ProgressCallback = Box::new(move |_event| {
448            let mut c = count_clone.lock().expect("lock");
449            *c += 1;
450            true
451        });
452        let config = EngineConfiguration::builder()
453            .block_size(65_536)
454            .progress(Arc::new(Mutex::new(cb)))
455            .build()
456            .expect("config");
457        compress(&data, &config).expect("compress");
458        let final_count = *count.lock().expect("lock");
459        // 300_000 / 65_536 = ceil 5 blocks
460        assert!(final_count >= 1, "progress callback was not invoked");
461    }
462
463    #[test]
464    fn test_cancel_halts_at_block_boundary() {
465        use std::sync::{Arc, Mutex};
466        let data: Vec<u8> = b"xyz".iter().cycle().take(1_000_000).copied().collect();
467        let cb: crate::config::ProgressCallback = Box::new(|_event| false); // always cancel
468        let config = EngineConfiguration::builder()
469            .block_size(65_536)
470            .progress(Arc::new(Mutex::new(cb)))
471            .build()
472            .expect("config");
473        let result = compress(&data, &config);
474        assert!(result.is_err());
475        assert!(result.unwrap_err().is_cancelled());
476    }
477
478    #[test]
479    fn test_compress_file_roundtrip() {
480        let data: Vec<u8> = b"file data".iter().cycle().take(200_000).copied().collect();
481        let mut tmp = NamedTempFile::new().expect("temp file");
482        tmp.write_all(&data).expect("write");
483        let config = default_config();
484        let compressed = compress_file(tmp.path(), &config).expect("compress_file");
485        let recovered = decompress(&compressed, &config).expect("decompress");
486        assert_eq!(data, recovered);
487    }
488
489    #[test]
490    fn test_decompress_roundtrip() {
491        let data: Vec<u8> = b"decompress me"
492            .iter()
493            .cycle()
494            .take(500_000)
495            .copied()
496            .collect();
497        let config = default_config();
498        let compressed = compress(&data, &config).expect("compress");
499        let recovered = decompress(&compressed, &config).expect("decompress");
500        assert_eq!(data, recovered);
501    }
502
503    #[test]
504    fn test_decompress_corrupt_block_detected() {
505        let data: Vec<u8> = b"corrupt test"
506            .iter()
507            .cycle()
508            .take(200_000)
509            .copied()
510            .collect();
511        let config = default_config();
512        let mut compressed = compress(&data, &config).expect("compress");
513
514        // Find first block header offset (right after FileHeader)
515        let mut cursor = Cursor::new(&compressed);
516        let index = load_index(&mut cursor).expect("load_index");
517        let block0_offset = index.entries[0].block_offset as usize;
518
519        // Corrupt a byte in the payload (after BlockHeader)
520        let payload_start = block0_offset + BlockHeader::SIZE;
521        if payload_start < compressed.len() {
522            compressed[payload_start] ^= 0xFF;
523        }
524
525        let result = decompress(&compressed, &config);
526        assert!(result.is_err());
527        let err = result.unwrap_err();
528        assert!(
529            matches!(err, CrushError::ChecksumMismatch { block_index: 0, .. })
530                || matches!(err, CrushError::InvalidFormat(_)),
531            "expected checksum or format error, got {err:?}"
532        );
533    }
534
535    #[test]
536    fn test_version_mismatch_rejected() {
537        let data: Vec<u8> = b"version test"
538            .iter()
539            .cycle()
540            .take(100_000)
541            .copied()
542            .collect();
543        let config = default_config();
544        let mut compressed = compress(&data, &config).expect("compress");
545
546        // Overwrite format_version in footer with a different value
547        let footer_start = compressed.len() - FileFooter::SIZE;
548        // format_version is at bytes [16..20] of the footer
549        compressed[footer_start + 16..footer_start + 20].copy_from_slice(&9999u32.to_le_bytes());
550        // Also need to fix the footer checksum (bytes [12..16]) to avoid IndexCorrupted first
551        // Actually we want VersionMismatch — just set an obviously wrong version
552        // The footer checksum will fail first; we accept either error
553        let result = decompress(&compressed, &config);
554        assert!(result.is_err());
555    }
556
557    #[test]
558    fn test_expansion_limit_exceeded() {
559        let data: Vec<u8> = b"test data".iter().cycle().take(100_000).copied().collect();
560        let compress_config = default_config();
561        let compressed = compress(&data, &compress_config).expect("compress");
562
563        // Set an absurdly tight decompression ratio
564        let decompress_config = EngineConfiguration::builder()
565            .block_size(65_536)
566            .max_decompression_ratio(0.000_001)
567            .build()
568            .expect("config");
569        let result = decompress(&compressed, &decompress_config);
570        assert!(result.is_err());
571        assert!(matches!(
572            result.unwrap_err(),
573            CrushError::ExpansionLimitExceeded { .. }
574        ));
575    }
576
577    #[test]
578    fn test_truncated_footer_rejected() {
579        let data: Vec<u8> = b"truncated".iter().cycle().take(100_000).copied().collect();
580        let config = default_config();
581        let mut compressed = compress(&data, &config).expect("compress");
582        // Remove the last 24 bytes (footer)
583        compressed.truncate(compressed.len() - FileFooter::SIZE);
584        let result = decompress(&compressed, &config);
585        assert!(result.is_err());
586    }
587
588    // ---------------------------------------------------------------------------
589    // Property-based round-trip test (T070)
590    // ---------------------------------------------------------------------------
591
592    proptest::proptest! {
593        #![proptest_config(proptest::prelude::ProptestConfig::with_cases(50))]
594
595        #[test]
596        fn proptest_compress_decompress_roundtrip(
597            data in proptest::collection::vec(proptest::prelude::any::<u8>(), 0..200_000),
598            block_kb in proptest::prelude::prop_oneof![
599                proptest::prelude::Just(64usize),
600                proptest::prelude::Just(256),
601                proptest::prelude::Just(1024)
602            ],
603            level in 0u8..=9,
604        ) {
605            let block_size = u32::try_from(block_kb * 1024).unwrap();
606            let config = EngineConfiguration::builder()
607                .block_size(block_size)
608                .compression_level(level)
609                .build()
610                .unwrap();
611            let compressed = compress(&data, &config).unwrap();
612            let recovered = decompress(&compressed, &config).unwrap();
613            proptest::prop_assert_eq!(data, recovered);
614        }
615    }
616}