oxirs_chat/memory_optimization/
compression.rs1use anyhow::{anyhow, Result};
4use serde::{Deserialize, Serialize};
5use std::io::{Read, Write};
6
7#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
9pub enum CompressionAlgorithm {
10 #[default]
12 Zstd,
13 Lz4,
15 Gzip,
17 None,
19}
20
21pub struct Compressor {
23 algorithm: CompressionAlgorithm,
24}
25
26impl Compressor {
27 pub fn new(algorithm: CompressionAlgorithm) -> Self {
28 Self { algorithm }
29 }
30
31 pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
33 match self.algorithm {
34 CompressionAlgorithm::Zstd => self.compress_zstd(data),
35 CompressionAlgorithm::Lz4 => self.compress_lz4(data),
36 CompressionAlgorithm::Gzip => self.compress_gzip(data),
37 CompressionAlgorithm::None => Ok(data.to_vec()),
38 }
39 }
40
41 pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>> {
43 match self.algorithm {
44 CompressionAlgorithm::Zstd => self.decompress_zstd(data),
45 CompressionAlgorithm::Lz4 => self.decompress_lz4(data),
46 CompressionAlgorithm::Gzip => self.decompress_gzip(data),
47 CompressionAlgorithm::None => Ok(data.to_vec()),
48 }
49 }
50
51 fn compress_zstd(&self, data: &[u8]) -> Result<Vec<u8>> {
52 oxiarc_zstd::encode_all(data, 3).map_err(|e| anyhow!("Zstd compression failed: {}", e))
53 }
54
55 fn decompress_zstd(&self, data: &[u8]) -> Result<Vec<u8>> {
56 oxiarc_zstd::decode_all(data).map_err(|e| anyhow!("Zstd decompression failed: {}", e))
57 }
58
59 fn compress_lz4(&self, data: &[u8]) -> Result<Vec<u8>> {
60 oxiarc_lz4::compress(data).map_err(|e| anyhow!("LZ4 compression failed: {}", e))
61 }
62
63 fn decompress_lz4(&self, data: &[u8]) -> Result<Vec<u8>> {
64 oxiarc_lz4::decompress(data, 100 * 1024 * 1024)
65 .map_err(|e| anyhow!("LZ4 decompression failed: {}", e))
66 }
67
68 fn compress_gzip(&self, data: &[u8]) -> Result<Vec<u8>> {
69 use flate2::write::GzEncoder;
70 use flate2::Compression;
71
72 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
73 encoder
74 .write_all(data)
75 .map_err(|e| anyhow!("Gzip compression failed: {}", e))?;
76 encoder
77 .finish()
78 .map_err(|e| anyhow!("Gzip compression failed: {}", e))
79 }
80
81 fn decompress_gzip(&self, data: &[u8]) -> Result<Vec<u8>> {
82 use flate2::read::GzDecoder;
83
84 let mut decoder = GzDecoder::new(data);
85 let mut decompressed = Vec::new();
86 decoder
87 .read_to_end(&mut decompressed)
88 .map_err(|e| anyhow!("Gzip decompression failed: {}", e))?;
89 Ok(decompressed)
90 }
91
92 pub fn compression_ratio(&self, original: &[u8], compressed: &[u8]) -> f64 {
94 if compressed.is_empty() {
95 return 0.0;
96 }
97 original.len() as f64 / compressed.len() as f64
98 }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct CompressedEmbedding {
104 pub compressed_data: Vec<u8>,
105 pub original_size: usize,
106 pub algorithm: CompressionAlgorithm,
107}
108
109impl CompressedEmbedding {
110 pub fn from_embeddings(embeddings: &[f32], algorithm: CompressionAlgorithm) -> Result<Self> {
112 let compressor = Compressor::new(algorithm);
113
114 let bytes: Vec<u8> = embeddings.iter().flat_map(|f| f.to_le_bytes()).collect();
116
117 let compressed_data = compressor.compress(&bytes)?;
118
119 Ok(Self {
120 compressed_data,
121 original_size: bytes.len(),
122 algorithm,
123 })
124 }
125
126 pub fn to_embeddings(&self) -> Result<Vec<f32>> {
128 let compressor = Compressor::new(self.algorithm);
129 let bytes = compressor.decompress(&self.compressed_data)?;
130
131 let embeddings: Vec<f32> = bytes
133 .chunks_exact(4)
134 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
135 .collect();
136
137 Ok(embeddings)
138 }
139
140 pub fn compression_ratio(&self) -> f64 {
142 self.original_size as f64 / self.compressed_data.len() as f64
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn test_zstd_compression() {
152 let compressor = Compressor::new(CompressionAlgorithm::Zstd);
153 let data = vec![42u8; 1000];
154
155 let compressed = compressor.compress(&data).expect("should succeed");
156 assert!(compressed.len() < data.len());
157
158 let decompressed = compressor.decompress(&compressed).expect("should succeed");
159 assert_eq!(decompressed, data);
160 }
161
162 #[test]
163 fn test_lz4_compression() {
164 let compressor = Compressor::new(CompressionAlgorithm::Lz4);
165 let data = vec![42u8; 1000];
166
167 let compressed = compressor.compress(&data).expect("should succeed");
168 assert!(compressed.len() < data.len());
169
170 let decompressed = compressor.decompress(&compressed).expect("should succeed");
171 assert_eq!(decompressed, data);
172 }
173
174 #[test]
175 fn test_gzip_compression() {
176 let compressor = Compressor::new(CompressionAlgorithm::Gzip);
177 let data = vec![42u8; 1000];
178
179 let compressed = compressor.compress(&data).expect("should succeed");
180 assert!(compressed.len() < data.len());
181
182 let decompressed = compressor.decompress(&compressed).expect("should succeed");
183 assert_eq!(decompressed, data);
184 }
185
186 #[test]
187 fn test_compressed_embedding() {
188 let embeddings = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
189 let compressed =
190 CompressedEmbedding::from_embeddings(&embeddings, CompressionAlgorithm::Zstd)
191 .expect("should succeed");
192
193 let decompressed = compressed.to_embeddings().expect("should succeed");
194 assert_eq!(decompressed.len(), embeddings.len());
195
196 for (a, b) in embeddings.iter().zip(decompressed.iter()) {
197 assert!((a - b).abs() < 0.001);
198 }
199 }
200
201 #[test]
202 fn test_compression_ratio() {
203 let compressor = Compressor::new(CompressionAlgorithm::Zstd);
204 let data = vec![42u8; 1000];
205
206 let compressed = compressor.compress(&data).expect("should succeed");
207 let ratio = compressor.compression_ratio(&data, &compressed);
208
209 assert!(ratio > 1.0); }
211
212 #[test]
213 fn test_no_compression() {
214 let compressor = Compressor::new(CompressionAlgorithm::None);
215 let data = vec![1, 2, 3, 4, 5];
216
217 let compressed = compressor.compress(&data).expect("should succeed");
218 assert_eq!(compressed, data);
219
220 let decompressed = compressor.decompress(&compressed).expect("should succeed");
221 assert_eq!(decompressed, data);
222 }
223}