1use anyhow::Result;
10use ragc_common::types::{Contig, PackedBlock};
11use std::cell::RefCell;
12
13thread_local! {
15 static ZSTD_CCTX: RefCell<zstd_safe::CCtx<'static>> = RefCell::new(zstd_safe::CCtx::create());
16}
17
18pub fn compress_segment_pooled(data: &Contig, level: i32) -> Result<PackedBlock> {
30 ZSTD_CCTX.with(|cctx| {
31 let mut cctx = cctx.borrow_mut();
32
33 let max_compressed_size = zstd_safe::compress_bound(data.len());
35 let mut output = vec![0u8; max_compressed_size];
36
37 match cctx.compress(&mut output, data, level) {
39 Ok(compressed_size) => {
40 output.truncate(compressed_size);
41 Ok(output)
42 }
43 Err(code) => {
44 let msg = zstd_safe::get_error_name(code);
45 Err(anyhow::anyhow!("ZSTD compression failed: {}", msg))
46 }
47 }
48 })
49}
50
51pub fn decompress_segment_pooled(compressed: &PackedBlock) -> Result<Contig> {
60 zstd::decode_all(compressed.as_slice())
63 .map_err(|e| anyhow::anyhow!("Failed to decompress segment with ZSTD: {e}"))
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69
70 #[test]
71 fn test_pooled_compress_decompress_roundtrip() {
72 let original = vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3];
73
74 let compressed = compress_segment_pooled(&original, 11).unwrap();
75 let decompressed = decompress_segment_pooled(&compressed).unwrap();
76
77 assert_eq!(original, decompressed);
78 }
79
80 #[test]
81 fn test_pooled_multiple_compressions() {
82 for _ in 0..10 {
84 let original = vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3];
85 let compressed = compress_segment_pooled(&original, 11).unwrap();
86 let decompressed = decompress_segment_pooled(&compressed).unwrap();
87 assert_eq!(original, decompressed);
88 }
89 }
90
91 #[test]
92 fn test_pooled_different_levels() {
93 let original = vec![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3];
94
95 for level in [1, 3, 9, 17, 19].iter() {
96 let compressed = compress_segment_pooled(&original, *level).unwrap();
97 let decompressed = decompress_segment_pooled(&compressed).unwrap();
98 assert_eq!(original, decompressed);
99 }
100 }
101
102 #[test]
103 fn test_pooled_large_data() {
104 let mut original = Vec::new();
105 for i in 0..10000 {
106 original.push((i % 4) as u8);
107 }
108
109 let compressed = compress_segment_pooled(&original, 17).unwrap();
110 let decompressed = decompress_segment_pooled(&compressed).unwrap();
111
112 assert_eq!(original, decompressed);
113 assert!(compressed.len() < original.len());
114 }
115}