1use std::fs;
14use std::io::{Read as _, Write as _};
15use std::path::Path;
16
17use crate::QuantError;
18use crate::pack;
19
20const MAGIC: &[u8; 4] = b"LMCQ";
21const FORMAT_VERSION: u8 = 1;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
25pub enum QuantType {
26 Mse = 0,
27 Prod = 1,
28}
29
30#[derive(Debug, Clone)]
32pub struct CompressedEntry {
33 pub file: String,
34 pub hash: u64,
35 pub norm: f32,
36 pub packed_indices: Vec<u8>,
37 pub residual_norm: Option<f32>,
39 pub qjl_bits: Option<Vec<u8>>,
41}
42
43#[derive(Debug, Clone)]
45pub struct CompressedEmbeddingStore {
46 pub dimension: usize,
47 pub bits: u8,
48 pub quant_type: QuantType,
49 pub rotation_seed: u64,
50 pub qjl_seed: Option<u64>,
51 pub entries: Vec<CompressedEntry>,
52}
53
54impl CompressedEmbeddingStore {
55 pub fn new(
57 dimension: usize,
58 bits: u8,
59 quant_type: QuantType,
60 rotation_seed: u64,
61 qjl_seed: Option<u64>,
62 ) -> Self {
63 Self {
64 dimension,
65 bits,
66 quant_type,
67 rotation_seed,
68 qjl_seed,
69 entries: Vec::new(),
70 }
71 }
72
73 fn mse_bits(&self) -> u8 {
75 match self.quant_type {
76 QuantType::Mse => self.bits,
77 QuantType::Prod => self.bits - 1,
78 }
79 }
80
81 pub fn save(&self, path: &Path) -> Result<(), QuantError> {
83 let mut buf = Vec::new();
84
85 buf.write_all(MAGIC)?;
87 buf.write_all(&[FORMAT_VERSION])?;
88 buf.write_all(&(self.dimension as u32).to_le_bytes())?;
89 buf.write_all(&(self.entries.len() as u32).to_le_bytes())?;
90 buf.write_all(&[self.bits])?;
91 buf.write_all(&[self.quant_type as u8])?;
92 buf.write_all(&self.rotation_seed.to_le_bytes())?;
93 if self.quant_type == QuantType::Prod {
94 buf.write_all(&self.qjl_seed.unwrap_or(0).to_le_bytes())?;
95 }
96
97 let mse_bits = self.mse_bits();
98 let indices_size = pack::packed_byte_size(self.dimension, mse_bits);
99 let qjl_size = self.dimension.div_ceil(8);
100
101 for entry in &self.entries {
103 let file_bytes = entry.file.as_bytes();
104 buf.write_all(&(file_bytes.len() as u16).to_le_bytes())?;
105 buf.write_all(file_bytes)?;
106 buf.write_all(&entry.hash.to_le_bytes())?;
107 buf.write_all(&entry.norm.to_le_bytes())?;
108
109 assert_eq!(entry.packed_indices.len(), indices_size);
110 buf.write_all(&entry.packed_indices)?;
111
112 if self.quant_type == QuantType::Prod {
113 let rn = entry.residual_norm.unwrap_or(0.0);
114 buf.write_all(&rn.to_le_bytes())?;
115
116 let default_qjl = vec![0u8; qjl_size];
117 let qjl = entry.qjl_bits.as_deref().unwrap_or(&default_qjl);
118 assert_eq!(qjl.len(), qjl_size);
119 buf.write_all(qjl)?;
120 }
121 }
122
123 if let Some(parent) = path.parent() {
124 fs::create_dir_all(parent)?;
125 }
126 fs::write(path, buf)?;
127 Ok(())
128 }
129
130 pub fn load(path: &Path) -> Result<Self, QuantError> {
132 let data = fs::read(path)?;
133 let mut cursor = &data[..];
134
135 let mut magic = [0u8; 4];
137 cursor.read_exact(&mut magic)?;
138 if &magic != MAGIC {
139 return Err(QuantError::Format("invalid magic bytes".into()));
140 }
141
142 let mut ver = [0u8; 1];
144 cursor.read_exact(&mut ver)?;
145 if ver[0] != FORMAT_VERSION {
146 return Err(QuantError::Format(format!(
147 "unsupported version: {}",
148 ver[0]
149 )));
150 }
151
152 let mut dim_bytes = [0u8; 4];
154 cursor.read_exact(&mut dim_bytes)?;
155 let dimension = u32::from_le_bytes(dim_bytes) as usize;
156
157 let mut count_bytes = [0u8; 4];
159 cursor.read_exact(&mut count_bytes)?;
160 let count = u32::from_le_bytes(count_bytes) as usize;
161
162 let mut bits_byte = [0u8; 1];
164 cursor.read_exact(&mut bits_byte)?;
165 let bits = bits_byte[0];
166
167 let mut qt_byte = [0u8; 1];
169 cursor.read_exact(&mut qt_byte)?;
170 let quant_type = match qt_byte[0] {
171 0 => QuantType::Mse,
172 1 => QuantType::Prod,
173 v => return Err(QuantError::Format(format!("unknown quant type: {v}"))),
174 };
175
176 let mut seed_bytes = [0u8; 8];
178 cursor.read_exact(&mut seed_bytes)?;
179 let rotation_seed = u64::from_le_bytes(seed_bytes);
180
181 let qjl_seed = if quant_type == QuantType::Prod {
183 let mut qjl_seed_bytes = [0u8; 8];
184 cursor.read_exact(&mut qjl_seed_bytes)?;
185 Some(u64::from_le_bytes(qjl_seed_bytes))
186 } else {
187 None
188 };
189
190 let mse_bits = match quant_type {
191 QuantType::Mse => bits,
192 QuantType::Prod => bits - 1,
193 };
194 let indices_size = pack::packed_byte_size(dimension, mse_bits);
195 let qjl_size = dimension.div_ceil(8);
196
197 let mut entries = Vec::with_capacity(count);
199 for _ in 0..count {
200 let mut file_len_bytes = [0u8; 2];
202 cursor.read_exact(&mut file_len_bytes)?;
203 let file_len = u16::from_le_bytes(file_len_bytes) as usize;
204 let mut file_bytes = vec![0u8; file_len];
205 cursor.read_exact(&mut file_bytes)?;
206 let file = String::from_utf8(file_bytes)
207 .map_err(|e| QuantError::Format(format!("invalid UTF-8: {e}")))?;
208
209 let mut hash_bytes = [0u8; 8];
211 cursor.read_exact(&mut hash_bytes)?;
212 let hash = u64::from_le_bytes(hash_bytes);
213
214 let mut norm_bytes = [0u8; 4];
216 cursor.read_exact(&mut norm_bytes)?;
217 let norm = f32::from_le_bytes(norm_bytes);
218
219 let mut packed_indices = vec![0u8; indices_size];
221 cursor.read_exact(&mut packed_indices)?;
222
223 let (residual_norm, qjl_bits) = if quant_type == QuantType::Prod {
225 let mut rn_bytes = [0u8; 4];
226 cursor.read_exact(&mut rn_bytes)?;
227 let rn = f32::from_le_bytes(rn_bytes);
228
229 let mut qjl = vec![0u8; qjl_size];
230 cursor.read_exact(&mut qjl)?;
231
232 (Some(rn), Some(qjl))
233 } else {
234 (None, None)
235 };
236
237 entries.push(CompressedEntry {
238 file,
239 hash,
240 norm,
241 packed_indices,
242 residual_norm,
243 qjl_bits,
244 });
245 }
246
247 Ok(Self {
248 dimension,
249 bits,
250 quant_type,
251 rotation_seed,
252 qjl_seed,
253 entries,
254 })
255 }
256
257 pub fn get(&self, file: &str) -> Option<&CompressedEntry> {
259 self.entries.iter().find(|e| e.file == file)
260 }
261
262 pub fn upsert(&mut self, entry: CompressedEntry) {
264 if let Some(existing) = self.entries.iter_mut().find(|e| e.file == entry.file) {
265 *existing = entry;
266 } else {
267 self.entries.push(entry);
268 }
269 }
270
271 pub fn remove(&mut self, file: &str) -> bool {
273 let len = self.entries.len();
274 self.entries.retain(|e| e.file != file);
275 self.entries.len() < len
276 }
277
278 pub fn data_size(&self) -> usize {
280 let mse_bits = self.mse_bits();
281 let indices_size = pack::packed_byte_size(self.dimension, mse_bits);
282 let per_entry = 2 + 8 + 4 + indices_size + if self.quant_type == QuantType::Prod {
284 4 + self.dimension.div_ceil(8) } else {
286 0
287 };
288 self.entries.len() * per_entry
289 }
290
291 pub fn uncompressed_size(&self) -> usize {
293 self.entries.len() * self.dimension * 4
294 }
295
296 pub fn compression_ratio(&self) -> f32 {
298 let compressed = self.data_size();
299 if compressed == 0 {
300 return 0.0;
301 }
302 self.uncompressed_size() as f32 / compressed as f32
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 fn make_test_store(quant_type: QuantType) -> CompressedEmbeddingStore {
311 let dim = 64;
312 let bits: u8 = 2;
313 let mse_bits = match quant_type {
314 QuantType::Mse => bits,
315 QuantType::Prod => bits - 1,
316 };
317 let indices_size = pack::packed_byte_size(dim, mse_bits);
318 let qjl_size = (dim + 7) / 8;
319
320 let mut store = CompressedEmbeddingStore::new(
321 dim,
322 bits,
323 quant_type,
324 42,
325 if quant_type == QuantType::Prod {
326 Some(99)
327 } else {
328 None
329 },
330 );
331
332 store.upsert(CompressedEntry {
333 file: "feedback_test.md".into(),
334 hash: 12345,
335 norm: 1.5,
336 packed_indices: vec![0xAB; indices_size],
337 residual_norm: if quant_type == QuantType::Prod {
338 Some(0.3)
339 } else {
340 None
341 },
342 qjl_bits: if quant_type == QuantType::Prod {
343 Some(vec![0xCD; qjl_size])
344 } else {
345 None
346 },
347 });
348
349 store.upsert(CompressedEntry {
350 file: "user_prefs.md".into(),
351 hash: 67890,
352 norm: 2.0,
353 packed_indices: vec![0x12; indices_size],
354 residual_norm: if quant_type == QuantType::Prod {
355 Some(0.1)
356 } else {
357 None
358 },
359 qjl_bits: if quant_type == QuantType::Prod {
360 Some(vec![0x34; qjl_size])
361 } else {
362 None
363 },
364 });
365
366 store
367 }
368
369 #[test]
370 fn mse_store_roundtrip() {
371 let store = make_test_store(QuantType::Mse);
372 let tmp = tempfile::tempdir().unwrap();
373 let path = tmp.path().join("test.lmcq");
374
375 store.save(&path).unwrap();
376 let loaded = CompressedEmbeddingStore::load(&path).unwrap();
377
378 assert_eq!(loaded.dimension, store.dimension);
379 assert_eq!(loaded.bits, store.bits);
380 assert_eq!(loaded.quant_type, QuantType::Mse);
381 assert_eq!(loaded.rotation_seed, 42);
382 assert_eq!(loaded.entries.len(), 2);
383 assert_eq!(loaded.entries[0].file, "feedback_test.md");
384 assert_eq!(loaded.entries[0].norm, 1.5);
385 assert_eq!(
386 loaded.entries[0].packed_indices,
387 store.entries[0].packed_indices
388 );
389 }
390
391 #[test]
392 fn prod_store_roundtrip() {
393 let store = make_test_store(QuantType::Prod);
394 let tmp = tempfile::tempdir().unwrap();
395 let path = tmp.path().join("test.lmcq");
396
397 store.save(&path).unwrap();
398 let loaded = CompressedEmbeddingStore::load(&path).unwrap();
399
400 assert_eq!(loaded.quant_type, QuantType::Prod);
401 assert_eq!(loaded.qjl_seed, Some(99));
402 assert_eq!(loaded.entries[0].residual_norm, Some(0.3));
403 assert!(loaded.entries[0].qjl_bits.is_some());
404 }
405
406 #[test]
407 fn compression_ratio_positive() {
408 let store = make_test_store(QuantType::Mse);
409 let ratio = store.compression_ratio();
410 assert!(ratio > 1.0, "compression ratio should be > 1, got {ratio}");
411 }
412
413 #[test]
414 fn upsert_replaces() {
415 let mut store = make_test_store(QuantType::Mse);
416 let old_hash = store.entries[0].hash;
417
418 store.upsert(CompressedEntry {
419 file: "feedback_test.md".into(),
420 hash: 99999,
421 norm: 3.0,
422 packed_indices: store.entries[0].packed_indices.clone(),
423 residual_norm: None,
424 qjl_bits: None,
425 });
426
427 assert_eq!(store.entries.len(), 2);
428 assert_ne!(store.get("feedback_test.md").unwrap().hash, old_hash);
429 }
430
431 #[test]
432 fn remove_entry() {
433 let mut store = make_test_store(QuantType::Mse);
434 assert!(store.remove("feedback_test.md"));
435 assert_eq!(store.entries.len(), 1);
436 assert!(!store.remove("nonexistent.md"));
437 }
438}