ragc_core/
zstd_pool.rs

1// ZSTD Compression Helpers
2//
3// Uses thread-local ZSTD contexts to avoid allocation overhead per-segment.
4// This matches how C++ AGC uses ZSTD: one ZSTD_CCtx per thread, reused for all segments.
5//
6// Note: ZSTD_resetCCtx_internal (the memset) is still called on every compression,
7// but we avoid the malloc/free overhead of creating/destroying contexts.
8
9use anyhow::Result;
10use ragc_common::types::{Contig, PackedBlock};
11use std::cell::RefCell;
12
13// Thread-local ZSTD compression context - matches C++ AGC's per-thread ZSTD_CCtx
14thread_local! {
15    static ZSTD_CCTX: RefCell<zstd_safe::CCtx<'static>> = RefCell::new(zstd_safe::CCtx::create());
16}
17
18/// Compress a segment using ZSTD at the specified level
19///
20/// Uses thread-local context reuse like C++ AGC's `ZSTD_compressCCtx()`.
21/// This avoids per-segment context allocation/deallocation overhead.
22///
23/// C++ AGC pattern:
24///   ZSTD_CCtx* zstd_cctx = ZSTD_createCCtx();  // once per thread
25///   ZSTD_compressCCtx(zstd_cctx, ...);         // reused for all segments
26///   ZSTD_freeCCtx(zstd_cctx);                  // at thread exit
27///
28/// This implementation follows the same pattern using thread_local.
29pub fn compress_segment_pooled(data: &Contig, level: i32) -> Result<PackedBlock> {
30    ZSTD_CCTX.with(|cctx| {
31        let mut cctx = cctx.borrow_mut();
32
33        // Pre-allocate output buffer (ZSTD_compressBound equivalent)
34        let max_compressed_size = zstd_safe::compress_bound(data.len());
35        let mut output = vec![0u8; max_compressed_size];
36
37        // Use ZSTD_compressCCtx - same as C++ AGC
38        match cctx.compress(&mut output, data, level) {
39            Ok(compressed_size) => {
40                output.truncate(compressed_size);
41                Ok(output)
42            }
43            Err(code) => {
44                let msg = zstd_safe::get_error_name(code);
45                Err(anyhow::anyhow!("ZSTD compression failed: {}", msg))
46            }
47        }
48    })
49}
50
51/// Decompress a segment using ZSTD
52///
53/// Note: Decompression is less critical for pooling since:
54/// 1. Decompression contexts are smaller than compression contexts
55/// 2. Decompression happens less frequently in the hot path
56/// 3. Current implementation is adequate for now
57///
58/// Future optimization: Add thread-local decoder pool if profiling shows benefit
59pub fn decompress_segment_pooled(compressed: &PackedBlock) -> Result<Contig> {
60    // For now, use the existing decode_all
61    // Decompression contexts are smaller and less critical to pool
62    zstd::decode_all(compressed.as_slice())
63        .map_err(|e| anyhow::anyhow!("Failed to decompress segment with ZSTD: {e}"))
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    #[test]
71    fn test_pooled_compress_decompress_roundtrip() {
72        let original = vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3];
73
74        let compressed = compress_segment_pooled(&original, 11).unwrap();
75        let decompressed = decompress_segment_pooled(&compressed).unwrap();
76
77        assert_eq!(original, decompressed);
78    }
79
80    #[test]
81    fn test_pooled_multiple_compressions() {
82        // Test that context reuse works correctly
83        for _ in 0..10 {
84            let original = vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3];
85            let compressed = compress_segment_pooled(&original, 11).unwrap();
86            let decompressed = decompress_segment_pooled(&compressed).unwrap();
87            assert_eq!(original, decompressed);
88        }
89    }
90
91    #[test]
92    fn test_pooled_different_levels() {
93        let original = vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3];
94
95        for level in [1, 3, 9, 17, 19].iter() {
96            let compressed = compress_segment_pooled(&original, *level).unwrap();
97            let decompressed = decompress_segment_pooled(&compressed).unwrap();
98            assert_eq!(original, decompressed);
99        }
100    }
101
102    #[test]
103    fn test_pooled_large_data() {
104        let mut original = Vec::new();
105        for i in 0..10000 {
106            original.push((i % 4) as u8);
107        }
108
109        let compressed = compress_segment_pooled(&original, 17).unwrap();
110        let decompressed = decompress_segment_pooled(&compressed).unwrap();
111
112        assert_eq!(original, decompressed);
113        assert!(compressed.len() < original.len());
114    }
115}