ragc_core/
zstd_pool.rs

1// ZSTD Context Pooling - Matching C++ AGC Design
2//
3// C++ AGC creates ONE ZSTD_CCtx per thread and reuses it for all compressions.
4// This module implements the exact same approach using zstd-safe bindings.
5//
6// CRITICAL FIX: Previous implementation used zstd::stream::copy_encode which:
7// 1. Created a new encoder context for each call (massive overhead!)
8// 2. Did internal buffering we don't need
9// 3. Cloned the output buffer (defeating pooling!)
10//
11// New implementation matches C++ AGC line-by-line:
12// - Thread-local ZSTD_CCtx reused for all compressions (like C++ line 176)
13// - Pre-allocated output buffer (like C++ new uint8_t[a_size])
14// - Direct compression into buffer (like ZSTD_compressCCtx)
15
16use anyhow::Result;
17use ragc_common::types::{Contig, PackedBlock};
18use std::cell::RefCell;
19
20thread_local! {
21    /// Thread-local ZSTD compression context - EXACTLY like C++ AGC
22    /// C++ AGC: ZSTD_CCtx* zstd_ctx (reused per thread)
23    /// Rust: zstd::bulk::Compressor<'static> (same concept)
24    ///
25    /// Note: We store (Compressor, level) to detect level changes.
26    /// In practice, compression level rarely changes within a thread.
27    static ZSTD_ENCODER: RefCell<Option<(zstd::bulk::Compressor<'static>, i32)>> = const { RefCell::new(None) };
28}
29
30/// Compress a segment using thread-local ZSTD context (MATCHING C++ AGC)
31///
32/// C++ AGC equivalent (segment.h:172-189):
33/// ```cpp
34/// size_t a_size = ZSTD_compressBound(data.size());
35/// uint8_t *packed = new uint8_t[a_size+1u];
36/// uint32_t packed_size = ZSTD_compressCCtx(zstd_ctx, packed, a_size, data.data(), data.size(), level);
37/// vector<uint8_t> v_packed(packed, packed + packed_size + 1);
38/// delete[] packed;
39/// ```
40///
41/// Our Rust implementation does the same but with thread-local context pooling.
42pub fn compress_segment_pooled(data: &Contig, level: i32) -> Result<PackedBlock> {
43    ZSTD_ENCODER.with(|encoder_cell| {
44        let mut encoder_opt = encoder_cell.borrow_mut();
45
46        // Get or create encoder for this level
47        let encoder = match encoder_opt.as_mut() {
48            Some((enc, cached_level)) if *cached_level == level => enc,
49            _ => {
50                // Create new encoder for this level
51                let new_encoder = zstd::bulk::Compressor::new(level)
52                    .map_err(|e| anyhow::anyhow!("Failed to create ZSTD encoder: {e}"))?;
53                *encoder_opt = Some((new_encoder, level));
54                &mut encoder_opt.as_mut().unwrap().0
55            }
56        };
57
58        // Compress directly - bulk::Compressor handles buffer internally
59        // and returns owned Vec (no clone needed!)
60        encoder
61            .compress(data.as_slice())
62            .map_err(|e| anyhow::anyhow!("ZSTD compression failed: {e}"))
63    })
64}
65
66/// Decompress a segment using ZSTD
67///
68/// Note: Decompression is less critical for pooling since:
69/// 1. Decompression contexts are smaller than compression contexts
70/// 2. Decompression happens less frequently in the hot path
71/// 3. Current implementation is adequate for now
72///
73/// Future optimization: Add thread-local decoder pool if profiling shows benefit
74pub fn decompress_segment_pooled(compressed: &PackedBlock) -> Result<Contig> {
75    // For now, use the existing decode_all
76    // Decompression contexts are smaller and less critical to pool
77    zstd::decode_all(compressed.as_slice())
78        .map_err(|e| anyhow::anyhow!("Failed to decompress segment with ZSTD: {e}"))
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84
85    #[test]
86    fn test_pooled_compress_decompress_roundtrip() {
87        let original = vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3];
88
89        let compressed = compress_segment_pooled(&original, 11).unwrap();
90        let decompressed = decompress_segment_pooled(&compressed).unwrap();
91
92        assert_eq!(original, decompressed);
93    }
94
95    #[test]
96    fn test_pooled_multiple_compressions() {
97        // Test that context reuse works correctly
98        for _ in 0..10 {
99            let original = vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3];
100            let compressed = compress_segment_pooled(&original, 11).unwrap();
101            let decompressed = decompress_segment_pooled(&compressed).unwrap();
102            assert_eq!(original, decompressed);
103        }
104    }
105
106    #[test]
107    fn test_pooled_different_levels() {
108        let original = vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3];
109
110        for level in [1, 3, 9, 17, 19].iter() {
111            let compressed = compress_segment_pooled(&original, *level).unwrap();
112            let decompressed = decompress_segment_pooled(&compressed).unwrap();
113            assert_eq!(original, decompressed);
114        }
115    }
116
117    #[test]
118    fn test_pooled_large_data() {
119        let mut original = Vec::new();
120        for i in 0..10000 {
121            original.push((i % 4) as u8);
122        }
123
124        let compressed = compress_segment_pooled(&original, 17).unwrap();
125        let decompressed = decompress_segment_pooled(&compressed).unwrap();
126
127        assert_eq!(original, decompressed);
128        assert!(compressed.len() < original.len());
129    }
130}