Skip to main content

trueno_rag/
compressed.rs

1//! Compressed Index Serialization (GH-2)
2//!
3//! Provides LZ4/ZSTD compression for BM25 and vector index storage.
4//! Reduces storage footprint by 5-10x for typical RAG indices.
5//!
6//! Compression algorithm is shared via `batuta_common::compression`.
7
8use crate::{BM25Index, Result};
9use serde::{de::DeserializeOwned, Serialize};
10
11// Note: VectorStore compression can be added in the future
12// by implementing Serialize/Deserialize for VectorStore
13
14pub use batuta_common::compression::Compression;
15
16/// Serialize an index to compressed bytes
17///
18/// # Errors
19/// Returns error if serialization or compression fails
20pub fn serialize_compressed<T: Serialize>(index: &T, compression: Compression) -> Result<Vec<u8>> {
21    let bytes = bincode::serialize(index).map_err(|e| {
22        crate::Error::SerializationError(format!("Bincode serialization failed: {e}"))
23    })?;
24    Ok(compression.compress(&bytes)?)
25}
26
27/// Deserialize an index from compressed bytes
28///
29/// # Errors
30/// Returns error if decompression or deserialization fails
31pub fn deserialize_compressed<T: DeserializeOwned>(
32    data: &[u8],
33    compression: Compression,
34) -> Result<T> {
35    let decompressed = compression.decompress(data)?;
36    bincode::deserialize(&decompressed).map_err(|e| {
37        crate::Error::SerializationError(format!("Bincode deserialization failed: {e}"))
38    })
39}
40
41impl BM25Index {
42    /// Serialize to compressed bytes using specified compression
43    ///
44    /// # Errors
45    /// Returns error if serialization or compression fails
46    pub fn to_compressed_bytes(&self, compression: Compression) -> Result<Vec<u8>> {
47        serialize_compressed(self, compression)
48    }
49
50    /// Deserialize from compressed bytes
51    ///
52    /// # Errors
53    /// Returns error if decompression or deserialization fails
54    pub fn from_compressed_bytes(data: &[u8], compression: Compression) -> Result<Self> {
55        deserialize_compressed(data, compression)
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62    use crate::{index::SparseIndex, Chunk, DocumentId};
63
64    fn create_test_chunk(content: &str) -> Chunk {
65        Chunk::new(DocumentId::new(), content.to_string(), 0, content.len())
66    }
67
68    // ============================================================
69    // RED PHASE: These tests define the expected behavior
70    // ============================================================
71
72    #[test]
73    fn test_compression_as_str() {
74        assert_eq!(Compression::Lz4.as_str(), "lz4");
75        assert_eq!(Compression::Zstd.as_str(), "zstd");
76    }
77
78    #[test]
79    fn test_compression_default() {
80        assert_eq!(Compression::default(), Compression::Lz4);
81    }
82
83    #[test]
84    fn test_lz4_compress_decompress() {
85        let data = b"hello world hello world hello world".to_vec();
86        let compressed = Compression::Lz4.compress(&data).unwrap();
87        let decompressed = Compression::Lz4.decompress(&compressed).unwrap();
88        assert_eq!(decompressed, data);
89    }
90
91    #[test]
92    fn test_zstd_compress_decompress() {
93        let data = b"hello world hello world hello world".to_vec();
94        let compressed = Compression::Zstd.compress(&data).unwrap();
95        let decompressed = Compression::Zstd.decompress(&compressed).unwrap();
96        assert_eq!(decompressed, data);
97    }
98
99    #[test]
100    fn test_empty_data_compression() {
101        let empty: Vec<u8> = vec![];
102
103        let lz4_compressed = Compression::Lz4.compress(&empty).unwrap();
104        assert!(lz4_compressed.is_empty());
105        let lz4_decompressed = Compression::Lz4.decompress(&lz4_compressed).unwrap();
106        assert!(lz4_decompressed.is_empty());
107
108        let zstd_compressed = Compression::Zstd.compress(&empty).unwrap();
109        assert!(zstd_compressed.is_empty());
110        let zstd_decompressed = Compression::Zstd.decompress(&zstd_compressed).unwrap();
111        assert!(zstd_decompressed.is_empty());
112    }
113
114    #[test]
115    fn test_lz4_compresses_repeated_data() {
116        let data = vec![0u8; 10000];
117        let compressed = Compression::Lz4.compress(&data).unwrap();
118        // LZ4 should achieve >10x compression on zeros
119        assert!(compressed.len() < data.len() / 10);
120    }
121
122    #[test]
123    fn test_zstd_compresses_repeated_data() {
124        let data = vec![0u8; 10000];
125        let compressed = Compression::Zstd.compress(&data).unwrap();
126        // ZSTD should achieve >10x compression on zeros
127        assert!(compressed.len() < data.len() / 10);
128    }
129
130    // ============================================================
131    // BM25Index Compression Tests
132    // ============================================================
133
134    #[test]
135    fn test_bm25_lz4_roundtrip() {
136        let mut index = BM25Index::new();
137        index.add(&create_test_chunk("machine learning is great"));
138        index.add(&create_test_chunk("deep learning neural networks"));
139        index.add(&create_test_chunk("natural language processing"));
140
141        let compressed = index.to_compressed_bytes(Compression::Lz4).unwrap();
142        let restored = BM25Index::from_compressed_bytes(&compressed, Compression::Lz4).unwrap();
143
144        // Verify restored index works correctly
145        assert_eq!(index.len(), restored.len());
146        let original_results = index.search("machine learning", 10);
147        let restored_results = restored.search("machine learning", 10);
148        assert_eq!(original_results.len(), restored_results.len());
149    }
150
151    #[test]
152    fn test_bm25_zstd_roundtrip() {
153        let mut index = BM25Index::new();
154        index.add(&create_test_chunk("rust programming language"));
155        index.add(&create_test_chunk("systems programming with rust"));
156
157        let compressed = index.to_compressed_bytes(Compression::Zstd).unwrap();
158        let restored = BM25Index::from_compressed_bytes(&compressed, Compression::Zstd).unwrap();
159
160        assert_eq!(index.len(), restored.len());
161    }
162
163    #[test]
164    fn test_bm25_compression_reduces_size() {
165        let mut index = BM25Index::new();
166        // Add many documents to make index larger
167        for i in 0..100 {
168            index.add(&create_test_chunk(&format!(
169                "document number {i} about machine learning and artificial intelligence"
170            )));
171        }
172
173        let uncompressed = bincode::serialize(&index).unwrap();
174        let lz4_compressed = index.to_compressed_bytes(Compression::Lz4).unwrap();
175        let zstd_compressed = index.to_compressed_bytes(Compression::Zstd).unwrap();
176
177        // Both should achieve some compression
178        assert!(lz4_compressed.len() < uncompressed.len());
179        assert!(zstd_compressed.len() < uncompressed.len());
180
181        // ZSTD typically achieves better compression than LZ4
182        assert!(zstd_compressed.len() <= lz4_compressed.len());
183    }
184
185    #[test]
186    fn test_bm25_empty_index_compression() {
187        let index = BM25Index::new();
188
189        let compressed = index.to_compressed_bytes(Compression::Lz4).unwrap();
190        let restored = BM25Index::from_compressed_bytes(&compressed, Compression::Lz4).unwrap();
191
192        assert!(restored.is_empty());
193    }
194
195    #[test]
196    fn test_bm25_preserved_search_behavior() {
197        let mut index = BM25Index::new();
198        index.add(&create_test_chunk("python programming language scripting"));
199        index.add(&create_test_chunk("javascript web development frontend"));
200        index.add(&create_test_chunk("rust systems programming performance"));
201
202        // Serialize and restore
203        let compressed = index.to_compressed_bytes(Compression::Lz4).unwrap();
204        let restored = BM25Index::from_compressed_bytes(&compressed, Compression::Lz4).unwrap();
205
206        // Search should return same results
207        let query = "programming language";
208        let original_results = index.search(query, 3);
209        let restored_results = restored.search(query, 3);
210
211        assert_eq!(original_results.len(), restored_results.len());
212        // Scores should match
213        for ((orig_id, orig_score), (rest_id, rest_score)) in
214            original_results.iter().zip(restored_results.iter())
215        {
216            assert_eq!(orig_id, rest_id);
217            assert!((orig_score - rest_score).abs() < 1e-5);
218        }
219    }
220}