Skip to main content

mnemonist_quant/
store.rs

1//! Compressed embedding store using TurboQuant quantization.
2//!
3//! Binary format (LMCQ):
4//! ```text
5//! Header: [magic: 4B "LMCQ"] [version: u8] [dimension: u32] [count: u32]
6//!         [bits: u8] [quant_type: u8 (0=mse, 1=prod)] [rotation_seed: u64]
7//!         [qjl_seed: u64 (only if prod)]
8//! Entry:  [file_len: u16] [file: bytes] [hash: u64]
9//!         [norm: f32] [packed_indices: ceil(dim * bits_mse / 8) bytes]
10//!         (if prod: [residual_norm: f32] [qjl_bits: ceil(dim / 8) bytes])
11//! ```
12
13use std::fs;
14use std::io::{Read as _, Write as _};
15use std::path::Path;
16
17use crate::QuantError;
18use crate::pack;
19
20const MAGIC: &[u8; 4] = b"LMCQ";
21const FORMAT_VERSION: u8 = 1;
22
23/// Type of quantization used.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
25pub enum QuantType {
26    Mse = 0,
27    Prod = 1,
28}
29
30/// A single compressed embedding entry.
31#[derive(Debug, Clone)]
32pub struct CompressedEntry {
33    pub file: String,
34    pub hash: u64,
35    pub norm: f32,
36    pub packed_indices: Vec<u8>,
37    /// Only present for QuantType::Prod.
38    pub residual_norm: Option<f32>,
39    /// Only present for QuantType::Prod.
40    pub qjl_bits: Option<Vec<u8>>,
41}
42
43/// Compressed embedding store using TurboQuant.
44#[derive(Debug, Clone)]
45pub struct CompressedEmbeddingStore {
46    pub dimension: usize,
47    pub bits: u8,
48    pub quant_type: QuantType,
49    pub rotation_seed: u64,
50    pub qjl_seed: Option<u64>,
51    pub entries: Vec<CompressedEntry>,
52}
53
54impl CompressedEmbeddingStore {
55    /// Create a new empty store.
56    pub fn new(
57        dimension: usize,
58        bits: u8,
59        quant_type: QuantType,
60        rotation_seed: u64,
61        qjl_seed: Option<u64>,
62    ) -> Self {
63        Self {
64            dimension,
65            bits,
66            quant_type,
67            rotation_seed,
68            qjl_seed,
69            entries: Vec::new(),
70        }
71    }
72
73    /// The effective MSE bit-width (bits for MSE, bits-1 for Prod).
74    fn mse_bits(&self) -> u8 {
75        match self.quant_type {
76            QuantType::Mse => self.bits,
77            QuantType::Prod => self.bits - 1,
78        }
79    }
80
81    /// Save to a binary file.
82    pub fn save(&self, path: &Path) -> Result<(), QuantError> {
83        let mut buf = Vec::new();
84
85        // Header
86        buf.write_all(MAGIC)?;
87        buf.write_all(&[FORMAT_VERSION])?;
88        buf.write_all(&(self.dimension as u32).to_le_bytes())?;
89        buf.write_all(&(self.entries.len() as u32).to_le_bytes())?;
90        buf.write_all(&[self.bits])?;
91        buf.write_all(&[self.quant_type as u8])?;
92        buf.write_all(&self.rotation_seed.to_le_bytes())?;
93        if self.quant_type == QuantType::Prod {
94            buf.write_all(&self.qjl_seed.unwrap_or(0).to_le_bytes())?;
95        }
96
97        let mse_bits = self.mse_bits();
98        let indices_size = pack::packed_byte_size(self.dimension, mse_bits);
99        let qjl_size = self.dimension.div_ceil(8);
100
101        // Entries
102        for entry in &self.entries {
103            let file_bytes = entry.file.as_bytes();
104            buf.write_all(&(file_bytes.len() as u16).to_le_bytes())?;
105            buf.write_all(file_bytes)?;
106            buf.write_all(&entry.hash.to_le_bytes())?;
107            buf.write_all(&entry.norm.to_le_bytes())?;
108
109            assert_eq!(entry.packed_indices.len(), indices_size);
110            buf.write_all(&entry.packed_indices)?;
111
112            if self.quant_type == QuantType::Prod {
113                let rn = entry.residual_norm.unwrap_or(0.0);
114                buf.write_all(&rn.to_le_bytes())?;
115
116                let default_qjl = vec![0u8; qjl_size];
117                let qjl = entry.qjl_bits.as_deref().unwrap_or(&default_qjl);
118                assert_eq!(qjl.len(), qjl_size);
119                buf.write_all(qjl)?;
120            }
121        }
122
123        if let Some(parent) = path.parent() {
124            fs::create_dir_all(parent)?;
125        }
126        fs::write(path, buf)?;
127        Ok(())
128    }
129
130    /// Load from a binary file.
131    pub fn load(path: &Path) -> Result<Self, QuantError> {
132        let data = fs::read(path)?;
133        let mut cursor = &data[..];
134
135        // Magic
136        let mut magic = [0u8; 4];
137        cursor.read_exact(&mut magic)?;
138        if &magic != MAGIC {
139            return Err(QuantError::Format("invalid magic bytes".into()));
140        }
141
142        // Version
143        let mut ver = [0u8; 1];
144        cursor.read_exact(&mut ver)?;
145        if ver[0] != FORMAT_VERSION {
146            return Err(QuantError::Format(format!(
147                "unsupported version: {}",
148                ver[0]
149            )));
150        }
151
152        // Dimension
153        let mut dim_bytes = [0u8; 4];
154        cursor.read_exact(&mut dim_bytes)?;
155        let dimension = u32::from_le_bytes(dim_bytes) as usize;
156
157        // Count
158        let mut count_bytes = [0u8; 4];
159        cursor.read_exact(&mut count_bytes)?;
160        let count = u32::from_le_bytes(count_bytes) as usize;
161
162        // Bits
163        let mut bits_byte = [0u8; 1];
164        cursor.read_exact(&mut bits_byte)?;
165        let bits = bits_byte[0];
166
167        // Quant type
168        let mut qt_byte = [0u8; 1];
169        cursor.read_exact(&mut qt_byte)?;
170        let quant_type = match qt_byte[0] {
171            0 => QuantType::Mse,
172            1 => QuantType::Prod,
173            v => return Err(QuantError::Format(format!("unknown quant type: {v}"))),
174        };
175
176        // Rotation seed
177        let mut seed_bytes = [0u8; 8];
178        cursor.read_exact(&mut seed_bytes)?;
179        let rotation_seed = u64::from_le_bytes(seed_bytes);
180
181        // QJL seed (only for prod)
182        let qjl_seed = if quant_type == QuantType::Prod {
183            let mut qjl_seed_bytes = [0u8; 8];
184            cursor.read_exact(&mut qjl_seed_bytes)?;
185            Some(u64::from_le_bytes(qjl_seed_bytes))
186        } else {
187            None
188        };
189
190        let mse_bits = match quant_type {
191            QuantType::Mse => bits,
192            QuantType::Prod => bits - 1,
193        };
194        let indices_size = pack::packed_byte_size(dimension, mse_bits);
195        let qjl_size = dimension.div_ceil(8);
196
197        // Entries
198        let mut entries = Vec::with_capacity(count);
199        for _ in 0..count {
200            // File name
201            let mut file_len_bytes = [0u8; 2];
202            cursor.read_exact(&mut file_len_bytes)?;
203            let file_len = u16::from_le_bytes(file_len_bytes) as usize;
204            let mut file_bytes = vec![0u8; file_len];
205            cursor.read_exact(&mut file_bytes)?;
206            let file = String::from_utf8(file_bytes)
207                .map_err(|e| QuantError::Format(format!("invalid UTF-8: {e}")))?;
208
209            // Hash
210            let mut hash_bytes = [0u8; 8];
211            cursor.read_exact(&mut hash_bytes)?;
212            let hash = u64::from_le_bytes(hash_bytes);
213
214            // Norm
215            let mut norm_bytes = [0u8; 4];
216            cursor.read_exact(&mut norm_bytes)?;
217            let norm = f32::from_le_bytes(norm_bytes);
218
219            // Packed indices
220            let mut packed_indices = vec![0u8; indices_size];
221            cursor.read_exact(&mut packed_indices)?;
222
223            // Prod-specific fields
224            let (residual_norm, qjl_bits) = if quant_type == QuantType::Prod {
225                let mut rn_bytes = [0u8; 4];
226                cursor.read_exact(&mut rn_bytes)?;
227                let rn = f32::from_le_bytes(rn_bytes);
228
229                let mut qjl = vec![0u8; qjl_size];
230                cursor.read_exact(&mut qjl)?;
231
232                (Some(rn), Some(qjl))
233            } else {
234                (None, None)
235            };
236
237            entries.push(CompressedEntry {
238                file,
239                hash,
240                norm,
241                packed_indices,
242                residual_norm,
243                qjl_bits,
244            });
245        }
246
247        Ok(Self {
248            dimension,
249            bits,
250            quant_type,
251            rotation_seed,
252            qjl_seed,
253            entries,
254        })
255    }
256
257    /// Find an entry by filename.
258    pub fn get(&self, file: &str) -> Option<&CompressedEntry> {
259        self.entries.iter().find(|e| e.file == file)
260    }
261
262    /// Insert or update an entry.
263    pub fn upsert(&mut self, entry: CompressedEntry) {
264        if let Some(existing) = self.entries.iter_mut().find(|e| e.file == entry.file) {
265            *existing = entry;
266        } else {
267            self.entries.push(entry);
268        }
269    }
270
271    /// Remove an entry by filename.
272    pub fn remove(&mut self, file: &str) -> bool {
273        let len = self.entries.len();
274        self.entries.retain(|e| e.file != file);
275        self.entries.len() < len
276    }
277
278    /// Storage size in bytes for the entry data (excluding header).
279    pub fn data_size(&self) -> usize {
280        let mse_bits = self.mse_bits();
281        let indices_size = pack::packed_byte_size(self.dimension, mse_bits);
282        let per_entry = 2 + 8 + 4 + indices_size // file_len + hash + norm + indices
283            + if self.quant_type == QuantType::Prod {
284                4 + self.dimension.div_ceil(8) // residual_norm + qjl_bits
285            } else {
286                0
287            };
288        self.entries.len() * per_entry
289    }
290
291    /// Equivalent uncompressed size (f32 embeddings) for comparison.
292    pub fn uncompressed_size(&self) -> usize {
293        self.entries.len() * self.dimension * 4
294    }
295
296    /// Compression ratio (uncompressed / compressed).
297    pub fn compression_ratio(&self) -> f32 {
298        let compressed = self.data_size();
299        if compressed == 0 {
300            return 0.0;
301        }
302        self.uncompressed_size() as f32 / compressed as f32
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    fn make_test_store(quant_type: QuantType) -> CompressedEmbeddingStore {
311        let dim = 64;
312        let bits: u8 = 2;
313        let mse_bits = match quant_type {
314            QuantType::Mse => bits,
315            QuantType::Prod => bits - 1,
316        };
317        let indices_size = pack::packed_byte_size(dim, mse_bits);
318        let qjl_size = (dim + 7) / 8;
319
320        let mut store = CompressedEmbeddingStore::new(
321            dim,
322            bits,
323            quant_type,
324            42,
325            if quant_type == QuantType::Prod {
326                Some(99)
327            } else {
328                None
329            },
330        );
331
332        store.upsert(CompressedEntry {
333            file: "feedback_test.md".into(),
334            hash: 12345,
335            norm: 1.5,
336            packed_indices: vec![0xAB; indices_size],
337            residual_norm: if quant_type == QuantType::Prod {
338                Some(0.3)
339            } else {
340                None
341            },
342            qjl_bits: if quant_type == QuantType::Prod {
343                Some(vec![0xCD; qjl_size])
344            } else {
345                None
346            },
347        });
348
349        store.upsert(CompressedEntry {
350            file: "user_prefs.md".into(),
351            hash: 67890,
352            norm: 2.0,
353            packed_indices: vec![0x12; indices_size],
354            residual_norm: if quant_type == QuantType::Prod {
355                Some(0.1)
356            } else {
357                None
358            },
359            qjl_bits: if quant_type == QuantType::Prod {
360                Some(vec![0x34; qjl_size])
361            } else {
362                None
363            },
364        });
365
366        store
367    }
368
369    #[test]
370    fn mse_store_roundtrip() {
371        let store = make_test_store(QuantType::Mse);
372        let tmp = tempfile::tempdir().unwrap();
373        let path = tmp.path().join("test.lmcq");
374
375        store.save(&path).unwrap();
376        let loaded = CompressedEmbeddingStore::load(&path).unwrap();
377
378        assert_eq!(loaded.dimension, store.dimension);
379        assert_eq!(loaded.bits, store.bits);
380        assert_eq!(loaded.quant_type, QuantType::Mse);
381        assert_eq!(loaded.rotation_seed, 42);
382        assert_eq!(loaded.entries.len(), 2);
383        assert_eq!(loaded.entries[0].file, "feedback_test.md");
384        assert_eq!(loaded.entries[0].norm, 1.5);
385        assert_eq!(
386            loaded.entries[0].packed_indices,
387            store.entries[0].packed_indices
388        );
389    }
390
391    #[test]
392    fn prod_store_roundtrip() {
393        let store = make_test_store(QuantType::Prod);
394        let tmp = tempfile::tempdir().unwrap();
395        let path = tmp.path().join("test.lmcq");
396
397        store.save(&path).unwrap();
398        let loaded = CompressedEmbeddingStore::load(&path).unwrap();
399
400        assert_eq!(loaded.quant_type, QuantType::Prod);
401        assert_eq!(loaded.qjl_seed, Some(99));
402        assert_eq!(loaded.entries[0].residual_norm, Some(0.3));
403        assert!(loaded.entries[0].qjl_bits.is_some());
404    }
405
406    #[test]
407    fn compression_ratio_positive() {
408        let store = make_test_store(QuantType::Mse);
409        let ratio = store.compression_ratio();
410        assert!(ratio > 1.0, "compression ratio should be > 1, got {ratio}");
411    }
412
413    #[test]
414    fn upsert_replaces() {
415        let mut store = make_test_store(QuantType::Mse);
416        let old_hash = store.entries[0].hash;
417
418        store.upsert(CompressedEntry {
419            file: "feedback_test.md".into(),
420            hash: 99999,
421            norm: 3.0,
422            packed_indices: store.entries[0].packed_indices.clone(),
423            residual_norm: None,
424            qjl_bits: None,
425        });
426
427        assert_eq!(store.entries.len(), 2);
428        assert_ne!(store.get("feedback_test.md").unwrap().hash, old_hash);
429    }
430
431    #[test]
432    fn remove_entry() {
433        let mut store = make_test_store(QuantType::Mse);
434        assert!(store.remove("feedback_test.md"));
435        assert_eq!(store.entries.len(), 1);
436        assert!(!store.remove("nonexistent.md"));
437    }
438}