ragc_core/zstd_pool.rs
1// ZSTD Context Pooling - Matching C++ AGC Design
2//
3// C++ AGC creates ONE ZSTD_CCtx per thread and reuses it for all compressions.
4// This module implements the exact same approach using zstd-safe bindings.
5//
6// CRITICAL FIX: Previous implementation used zstd::stream::copy_encode which:
7// 1. Created a new encoder context for each call (massive overhead!)
8// 2. Did internal buffering we don't need
9// 3. Cloned the output buffer (defeating pooling!)
10//
11// New implementation matches C++ AGC line-by-line:
12// - Thread-local ZSTD_CCtx reused for all compressions (like C++ line 176)
13// - Pre-allocated output buffer (like C++ new uint8_t[a_size])
14// - Direct compression into buffer (like ZSTD_compressCCtx)
15
16use anyhow::Result;
17use ragc_common::types::{Contig, PackedBlock};
18use std::cell::RefCell;
19
20thread_local! {
21 /// Thread-local ZSTD compression context - EXACTLY like C++ AGC
22 /// C++ AGC: ZSTD_CCtx* zstd_ctx (reused per thread)
23 /// Rust: zstd::bulk::Compressor<'static> (same concept)
24 ///
25 /// Note: We store (Compressor, level) to detect level changes.
26 /// In practice, compression level rarely changes within a thread.
27 static ZSTD_ENCODER: RefCell<Option<(zstd::bulk::Compressor<'static>, i32)>> = const { RefCell::new(None) };
28}
29
30/// Compress a segment using thread-local ZSTD context (MATCHING C++ AGC)
31///
32/// C++ AGC equivalent (segment.h:172-189):
33/// ```cpp
34/// size_t a_size = ZSTD_compressBound(data.size());
35/// uint8_t *packed = new uint8_t[a_size+1u];
36/// uint32_t packed_size = ZSTD_compressCCtx(zstd_ctx, packed, a_size, data.data(), data.size(), level);
37/// vector<uint8_t> v_packed(packed, packed + packed_size + 1);
38/// delete[] packed;
39/// ```
40///
41/// Our Rust implementation does the same but with thread-local context pooling.
42pub fn compress_segment_pooled(data: &Contig, level: i32) -> Result<PackedBlock> {
43 ZSTD_ENCODER.with(|encoder_cell| {
44 let mut encoder_opt = encoder_cell.borrow_mut();
45
46 // Get or create encoder for this level
47 let encoder = match encoder_opt.as_mut() {
48 Some((enc, cached_level)) if *cached_level == level => enc,
49 _ => {
50 // Create new encoder for this level
51 let new_encoder = zstd::bulk::Compressor::new(level)
52 .map_err(|e| anyhow::anyhow!("Failed to create ZSTD encoder: {e}"))?;
53 *encoder_opt = Some((new_encoder, level));
54 &mut encoder_opt.as_mut().unwrap().0
55 }
56 };
57
58 // Compress directly - bulk::Compressor handles buffer internally
59 // and returns owned Vec (no clone needed!)
60 encoder
61 .compress(data.as_slice())
62 .map_err(|e| anyhow::anyhow!("ZSTD compression failed: {e}"))
63 })
64}
65
66/// Decompress a segment using ZSTD
67///
68/// Note: Decompression is less critical for pooling since:
69/// 1. Decompression contexts are smaller than compression contexts
70/// 2. Decompression happens less frequently in the hot path
71/// 3. Current implementation is adequate for now
72///
73/// Future optimization: Add thread-local decoder pool if profiling shows benefit
74pub fn decompress_segment_pooled(compressed: &PackedBlock) -> Result<Contig> {
75 // For now, use the existing decode_all
76 // Decompression contexts are smaller and less critical to pool
77 zstd::decode_all(compressed.as_slice())
78 .map_err(|e| anyhow::anyhow!("Failed to decompress segment with ZSTD: {e}"))
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 #[test]
86 fn test_pooled_compress_decompress_roundtrip() {
87 let original = vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3];
88
89 let compressed = compress_segment_pooled(&original, 11).unwrap();
90 let decompressed = decompress_segment_pooled(&compressed).unwrap();
91
92 assert_eq!(original, decompressed);
93 }
94
95 #[test]
96 fn test_pooled_multiple_compressions() {
97 // Test that context reuse works correctly
98 for _ in 0..10 {
99 let original = vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3];
100 let compressed = compress_segment_pooled(&original, 11).unwrap();
101 let decompressed = decompress_segment_pooled(&compressed).unwrap();
102 assert_eq!(original, decompressed);
103 }
104 }
105
106 #[test]
107 fn test_pooled_different_levels() {
108 let original = vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3];
109
110 for level in [1, 3, 9, 17, 19].iter() {
111 let compressed = compress_segment_pooled(&original, *level).unwrap();
112 let decompressed = decompress_segment_pooled(&compressed).unwrap();
113 assert_eq!(original, decompressed);
114 }
115 }
116
117 #[test]
118 fn test_pooled_large_data() {
119 let mut original = Vec::new();
120 for i in 0..10000 {
121 original.push((i % 4) as u8);
122 }
123
124 let compressed = compress_segment_pooled(&original, 17).unwrap();
125 let decompressed = decompress_segment_pooled(&compressed).unwrap();
126
127 assert_eq!(original, decompressed);
128 assert!(compressed.len() < original.len());
129 }
130}