Skip to main content

oxirs_chat/memory_optimization/
compression.rs

1//! Compression for cached embeddings and model data
2
3use anyhow::{anyhow, Result};
4use serde::{Deserialize, Serialize};
5use std::io::{Read, Write};
6
7/// Compression algorithm
8#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
9pub enum CompressionAlgorithm {
10    /// Zstandard (fast, good compression)
11    #[default]
12    Zstd,
13    /// LZ4 (very fast, moderate compression)
14    Lz4,
15    /// Gzip (slower, better compression)
16    Gzip,
17    /// No compression
18    None,
19}
20
21/// Compressor for data
22pub struct Compressor {
23    algorithm: CompressionAlgorithm,
24}
25
26impl Compressor {
27    pub fn new(algorithm: CompressionAlgorithm) -> Self {
28        Self { algorithm }
29    }
30
31    /// Compress data
32    pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
33        match self.algorithm {
34            CompressionAlgorithm::Zstd => self.compress_zstd(data),
35            CompressionAlgorithm::Lz4 => self.compress_lz4(data),
36            CompressionAlgorithm::Gzip => self.compress_gzip(data),
37            CompressionAlgorithm::None => Ok(data.to_vec()),
38        }
39    }
40
41    /// Decompress data
42    pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>> {
43        match self.algorithm {
44            CompressionAlgorithm::Zstd => self.decompress_zstd(data),
45            CompressionAlgorithm::Lz4 => self.decompress_lz4(data),
46            CompressionAlgorithm::Gzip => self.decompress_gzip(data),
47            CompressionAlgorithm::None => Ok(data.to_vec()),
48        }
49    }
50
51    fn compress_zstd(&self, data: &[u8]) -> Result<Vec<u8>> {
52        oxiarc_zstd::encode_all(data, 3).map_err(|e| anyhow!("Zstd compression failed: {}", e))
53    }
54
55    fn decompress_zstd(&self, data: &[u8]) -> Result<Vec<u8>> {
56        oxiarc_zstd::decode_all(data).map_err(|e| anyhow!("Zstd decompression failed: {}", e))
57    }
58
59    fn compress_lz4(&self, data: &[u8]) -> Result<Vec<u8>> {
60        oxiarc_lz4::compress(data).map_err(|e| anyhow!("LZ4 compression failed: {}", e))
61    }
62
63    fn decompress_lz4(&self, data: &[u8]) -> Result<Vec<u8>> {
64        oxiarc_lz4::decompress(data, 100 * 1024 * 1024)
65            .map_err(|e| anyhow!("LZ4 decompression failed: {}", e))
66    }
67
68    fn compress_gzip(&self, data: &[u8]) -> Result<Vec<u8>> {
69        use flate2::write::GzEncoder;
70        use flate2::Compression;
71
72        let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
73        encoder
74            .write_all(data)
75            .map_err(|e| anyhow!("Gzip compression failed: {}", e))?;
76        encoder
77            .finish()
78            .map_err(|e| anyhow!("Gzip compression failed: {}", e))
79    }
80
81    fn decompress_gzip(&self, data: &[u8]) -> Result<Vec<u8>> {
82        use flate2::read::GzDecoder;
83
84        let mut decoder = GzDecoder::new(data);
85        let mut decompressed = Vec::new();
86        decoder
87            .read_to_end(&mut decompressed)
88            .map_err(|e| anyhow!("Gzip decompression failed: {}", e))?;
89        Ok(decompressed)
90    }
91
92    /// Calculate compression ratio
93    pub fn compression_ratio(&self, original: &[u8], compressed: &[u8]) -> f64 {
94        if compressed.is_empty() {
95            return 0.0;
96        }
97        original.len() as f64 / compressed.len() as f64
98    }
99}
100
101/// Compressed embedding storage
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct CompressedEmbedding {
104    pub compressed_data: Vec<u8>,
105    pub original_size: usize,
106    pub algorithm: CompressionAlgorithm,
107}
108
109impl CompressedEmbedding {
110    /// Compress embeddings (f32 array)
111    pub fn from_embeddings(embeddings: &[f32], algorithm: CompressionAlgorithm) -> Result<Self> {
112        let compressor = Compressor::new(algorithm);
113
114        // Convert f32 to bytes
115        let bytes: Vec<u8> = embeddings.iter().flat_map(|f| f.to_le_bytes()).collect();
116
117        let compressed_data = compressor.compress(&bytes)?;
118
119        Ok(Self {
120            compressed_data,
121            original_size: bytes.len(),
122            algorithm,
123        })
124    }
125
126    /// Decompress to embeddings
127    pub fn to_embeddings(&self) -> Result<Vec<f32>> {
128        let compressor = Compressor::new(self.algorithm);
129        let bytes = compressor.decompress(&self.compressed_data)?;
130
131        // Convert bytes back to f32
132        let embeddings: Vec<f32> = bytes
133            .chunks_exact(4)
134            .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
135            .collect();
136
137        Ok(embeddings)
138    }
139
140    /// Get compression ratio
141    pub fn compression_ratio(&self) -> f64 {
142        self.original_size as f64 / self.compressed_data.len() as f64
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn test_zstd_compression() {
152        let compressor = Compressor::new(CompressionAlgorithm::Zstd);
153        let data = vec![42u8; 1000];
154
155        let compressed = compressor.compress(&data).expect("should succeed");
156        assert!(compressed.len() < data.len());
157
158        let decompressed = compressor.decompress(&compressed).expect("should succeed");
159        assert_eq!(decompressed, data);
160    }
161
162    #[test]
163    fn test_lz4_compression() {
164        let compressor = Compressor::new(CompressionAlgorithm::Lz4);
165        let data = vec![42u8; 1000];
166
167        let compressed = compressor.compress(&data).expect("should succeed");
168        assert!(compressed.len() < data.len());
169
170        let decompressed = compressor.decompress(&compressed).expect("should succeed");
171        assert_eq!(decompressed, data);
172    }
173
174    #[test]
175    fn test_gzip_compression() {
176        let compressor = Compressor::new(CompressionAlgorithm::Gzip);
177        let data = vec![42u8; 1000];
178
179        let compressed = compressor.compress(&data).expect("should succeed");
180        assert!(compressed.len() < data.len());
181
182        let decompressed = compressor.decompress(&compressed).expect("should succeed");
183        assert_eq!(decompressed, data);
184    }
185
186    #[test]
187    fn test_compressed_embedding() {
188        let embeddings = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
189        let compressed =
190            CompressedEmbedding::from_embeddings(&embeddings, CompressionAlgorithm::Zstd)
191                .expect("should succeed");
192
193        let decompressed = compressed.to_embeddings().expect("should succeed");
194        assert_eq!(decompressed.len(), embeddings.len());
195
196        for (a, b) in embeddings.iter().zip(decompressed.iter()) {
197            assert!((a - b).abs() < 0.001);
198        }
199    }
200
201    #[test]
202    fn test_compression_ratio() {
203        let compressor = Compressor::new(CompressionAlgorithm::Zstd);
204        let data = vec![42u8; 1000];
205
206        let compressed = compressor.compress(&data).expect("should succeed");
207        let ratio = compressor.compression_ratio(&data, &compressed);
208
209        assert!(ratio > 1.0); // Should have some compression
210    }
211
212    #[test]
213    fn test_no_compression() {
214        let compressor = Compressor::new(CompressionAlgorithm::None);
215        let data = vec![1, 2, 3, 4, 5];
216
217        let compressed = compressor.compress(&data).expect("should succeed");
218        assert_eq!(compressed, data);
219
220        let decompressed = compressor.decompress(&compressed).expect("should succeed");
221        assert_eq!(decompressed, data);
222    }
223}