Skip to main content

oxirs_embed/embed_compression/
mod.rs

1//! Embedding Compression: Quantization and Product Quantization for efficient
2//! storage and approximate nearest-neighbor search.
3//!
4//! Provides:
5//! - `QuantizedEmbedding` — 8-bit scalar quantization via min-max
6//! - `EmbeddingQuantizer` — batch quantization for 4-bit and 8-bit
7//! - `ProductQuantizer` — product quantization with per-subspace k-means
8
9// ─────────────────────────────────────────────
10// QuantizedEmbedding (scalar 8-bit)
11// ─────────────────────────────────────────────
12
13/// A single embedding quantized to 8-bit precision via min-max scaling.
14#[derive(Debug, Clone)]
15pub struct QuantizedEmbedding {
16    /// Original embedding dimensionality.
17    pub original_dim: usize,
18    /// Quantized values as u8 in [0, 255].
19    pub quantized_data: Vec<u8>,
20    /// Scale factor: (max - min) / 255.
21    pub scale: f32,
22    /// Zero point (minimum value of original embedding).
23    pub zero_point: f32,
24}
25
26impl QuantizedEmbedding {
27    /// Quantize a floating-point embedding to 8-bit using min-max scaling.
28    ///
29    /// Formula: v_q = round((v - min) / (max - min) * 255)
30    pub fn quantize(embedding: &[f32]) -> Self {
31        let dim = embedding.len();
32        if dim == 0 {
33            return Self {
34                original_dim: 0,
35                quantized_data: vec![],
36                scale: 0.0,
37                zero_point: 0.0,
38            };
39        }
40
41        let min_val = embedding.iter().cloned().fold(f32::INFINITY, f32::min);
42        let max_val = embedding.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
43        let range = max_val - min_val;
44
45        let (scale, zero_point) = if range < 1e-10 {
46            // When range=0 all values are identical; use scale=0 so dequantize gives zero_point
47            (0.0_f32, min_val)
48        } else {
49            (range / 255.0, min_val)
50        };
51
52        let quantized_data: Vec<u8> = embedding
53            .iter()
54            .map(|&v| {
55                if range < 1e-10 {
56                    0_u8
57                } else {
58                    ((v - min_val) / range * 255.0).round().clamp(0.0, 255.0) as u8
59                }
60            })
61            .collect();
62
63        Self {
64            original_dim: dim,
65            quantized_data,
66            scale,
67            zero_point,
68        }
69    }
70
71    /// Dequantize back to f32.
72    ///
73    /// Formula: v = v_q * scale + zero_point
74    pub fn dequantize(&self) -> Vec<f32> {
75        self.quantized_data
76            .iter()
77            .map(|&q| q as f32 * self.scale + self.zero_point)
78            .collect()
79    }
80
81    /// Approximate storage size in bytes (header overhead not included).
82    pub fn approx_size_bytes(&self) -> usize {
83        self.quantized_data.len() // 1 byte per element
84            + std::mem::size_of::<f32>() * 2  // scale + zero_point
85            + std::mem::size_of::<usize>() // original_dim
86    }
87}
88
89// ─────────────────────────────────────────────
90// EmbeddingQuantizer
91// ─────────────────────────────────────────────
92
93/// Batch quantizer supporting 4-bit and 8-bit precision.
94#[derive(Debug, Clone)]
95pub struct EmbeddingQuantizer {
96    /// Quantization bit width (4 or 8).
97    pub bits: u8,
98}
99
100impl EmbeddingQuantizer {
101    /// Create a new quantizer with the specified bit width.
102    ///
103    /// `bits` should be 4 or 8; other values are accepted but treated as 8.
104    pub fn new(bits: u8) -> Self {
105        Self { bits }
106    }
107
108    /// Quantize a batch of embeddings.
109    pub fn quantize_batch(&self, embeddings: &[Vec<f32>]) -> Vec<QuantizedEmbedding> {
110        embeddings.iter().map(|e| self.quantize_single(e)).collect()
111    }
112
113    /// Dequantize a batch of quantized embeddings.
114    pub fn dequantize_batch(&self, quantized: &[QuantizedEmbedding]) -> Vec<Vec<f32>> {
115        quantized.iter().map(|q| q.dequantize()).collect()
116    }
117
118    /// Compute compression ratio: original_bytes / quantized_bytes.
119    pub fn compression_ratio(&self, original: &[Vec<f32>]) -> f64 {
120        if original.is_empty() {
121            return 1.0;
122        }
123        let original_bytes: usize = original.iter().map(|v| v.len() * 4).sum(); // f32 = 4 bytes
124        let quantized = self.quantize_batch(original);
125        let quantized_bytes: usize = quantized.iter().map(|q| q.approx_size_bytes()).sum();
126        if quantized_bytes == 0 {
127            return 1.0;
128        }
129        original_bytes as f64 / quantized_bytes as f64
130    }
131
132    // ── Private ───────────────────────────────
133
134    fn quantize_single(&self, embedding: &[f32]) -> QuantizedEmbedding {
135        if self.bits <= 4 {
136            self.quantize_4bit(embedding)
137        } else {
138            QuantizedEmbedding::quantize(embedding)
139        }
140    }
141
142    /// 4-bit quantization: each value is stored in its own byte (4-bit precision, 1 byte/value).
143    /// Uses a scale of (range / 15) so values span [0, 15].
144    fn quantize_4bit(&self, embedding: &[f32]) -> QuantizedEmbedding {
145        let dim = embedding.len();
146        if dim == 0 {
147            return QuantizedEmbedding {
148                original_dim: 0,
149                quantized_data: vec![],
150                scale: 0.0,
151                zero_point: 0.0,
152            };
153        }
154
155        let min_val = embedding.iter().cloned().fold(f32::INFINITY, f32::min);
156        let max_val = embedding.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
157        let range = max_val - min_val;
158
159        // When range=0 all values are identical; use scale=0 so dequantize gives zero_point.
160        let (scale, zero_point) = if range < 1e-10 {
161            (0.0_f32, min_val)
162        } else {
163            (range / 15.0, min_val)
164        };
165
166        // Store one quantized nibble per byte for compatibility with the shared dequantize().
167        let quantized_data: Vec<u8> = embedding
168            .iter()
169            .map(|&v| {
170                if range < 1e-10 {
171                    0_u8
172                } else {
173                    ((v - min_val) / range * 15.0).round().clamp(0.0, 15.0) as u8
174                }
175            })
176            .collect();
177
178        QuantizedEmbedding {
179            original_dim: dim,
180            quantized_data,
181            scale,
182            zero_point,
183        }
184    }
185}
186
187// ─────────────────────────────────────────────
188// ProductQuantizer
189// ─────────────────────────────────────────────
190
191/// Product quantizer: divides the embedding into subspaces and learns a codebook
192/// per subspace for efficient approximate nearest-neighbor search.
193#[derive(Debug, Clone)]
194pub struct ProductQuantizer {
195    /// Number of subspaces.
196    pub subspace_count: usize,
197    /// Number of codes per subspace (e.g. 256 for 8-bit codes).
198    pub codebook_size: usize,
199    /// Codebooks: \[subspace\]\[code\]\[subvec_dim\].
200    pub codebooks: Vec<Vec<Vec<f32>>>,
201    /// Dimension of each sub-vector (embedding_dim / subspace_count).
202    pub subvec_dim: usize,
203}
204
205impl ProductQuantizer {
206    /// Create an untrained product quantizer.
207    pub fn new(subspace_count: usize, codebook_size: usize) -> Self {
208        Self {
209            subspace_count,
210            codebook_size,
211            codebooks: Vec::new(),
212            subvec_dim: 0,
213        }
214    }
215
216    /// Train the product quantizer on a set of embeddings.
217    ///
218    /// Uses a simplified k-means: randomly selects `codebook_size` distinct
219    /// embeddings as initial centroids, then runs a few iterations.
220    pub fn train(&mut self, embeddings: &[Vec<f32>]) {
221        if embeddings.is_empty() || self.subspace_count == 0 {
222            return;
223        }
224        let dim = embeddings[0].len();
225        self.subvec_dim = dim / self.subspace_count;
226        if self.subvec_dim == 0 {
227            self.subvec_dim = 1;
228        }
229
230        self.codebooks = (0..self.subspace_count)
231            .map(|s| {
232                let start = s * self.subvec_dim;
233                let end = ((s + 1) * self.subvec_dim).min(dim);
234
235                // Collect sub-vectors for this subspace
236                let subvecs: Vec<Vec<f32>> =
237                    embeddings.iter().map(|e| e[start..end].to_vec()).collect();
238
239                // Initialize centroids from distinct data points
240                let n_codes = self.codebook_size.min(subvecs.len());
241                let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(n_codes);
242
243                // LCG-based deterministic selection
244                let mut lcg_state: u64 = (s as u64 + 1).wrapping_mul(6_364_136_223_846_793_005);
245                let mut used = std::collections::HashSet::new();
246                while centroids.len() < n_codes {
247                    lcg_state = lcg_state
248                        .wrapping_mul(6_364_136_223_846_793_005)
249                        .wrapping_add(1_442_695_040_888_963_407);
250                    let idx = (lcg_state >> 33) as usize % subvecs.len();
251                    if used.insert(idx) {
252                        centroids.push(subvecs[idx].clone());
253                    }
254                }
255
256                // Run simplified k-means for 5 iterations
257                for _ in 0..5 {
258                    let assignments: Vec<usize> = subvecs
259                        .iter()
260                        .map(|sv| nearest_centroid(sv, &centroids))
261                        .collect();
262
263                    let sub_dim = end - start;
264                    let mut new_centroids = vec![vec![0.0_f32; sub_dim]; n_codes];
265                    let mut counts = vec![0usize; n_codes];
266
267                    for (sv, &c) in subvecs.iter().zip(assignments.iter()) {
268                        for (i, &v) in sv.iter().enumerate() {
269                            if i < new_centroids[c].len() {
270                                new_centroids[c][i] += v;
271                            }
272                        }
273                        counts[c] += 1;
274                    }
275
276                    for (c, count) in counts.iter().enumerate() {
277                        if *count > 0 {
278                            let n = *count as f32;
279                            new_centroids[c].iter_mut().for_each(|x| *x /= n);
280                            centroids[c] = new_centroids[c].clone();
281                        }
282                    }
283                }
284
285                centroids
286            })
287            .collect();
288    }
289
290    /// Encode an embedding as a vector of codebook indices (one per subspace).
291    pub fn encode(&self, embedding: &[f32]) -> Vec<u8> {
292        if self.codebooks.is_empty() || self.subvec_dim == 0 {
293            return vec![0; self.subspace_count];
294        }
295        let dim = embedding.len();
296        (0..self.subspace_count)
297            .map(|s| {
298                let start = s * self.subvec_dim;
299                let end = ((s + 1) * self.subvec_dim).min(dim);
300                let subvec = &embedding[start..end];
301                let code = nearest_centroid(subvec, &self.codebooks[s]);
302                code.min(255) as u8
303            })
304            .collect()
305    }
306
307    /// Decode a code vector back to an approximate embedding.
308    pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
309        if self.codebooks.is_empty() {
310            return vec![];
311        }
312        let mut result = Vec::new();
313        for (s, &code) in codes.iter().enumerate().take(self.subspace_count) {
314            if s >= self.codebooks.len() {
315                break;
316            }
317            let c_idx = (code as usize).min(self.codebooks[s].len().saturating_sub(1));
318            result.extend_from_slice(&self.codebooks[s][c_idx]);
319        }
320        result
321    }
322
323    /// Compute approximate L2 distance between two encoded vectors using codebook lookups.
324    pub fn approx_distance(&self, codes1: &[u8], codes2: &[u8]) -> f32 {
325        if self.codebooks.is_empty() {
326            return 0.0;
327        }
328        let mut total = 0.0_f32;
329        for s in 0..self.subspace_count.min(codes1.len()).min(codes2.len()) {
330            if s >= self.codebooks.len() {
331                break;
332            }
333            let c1 = (codes1[s] as usize).min(self.codebooks[s].len().saturating_sub(1));
334            let c2 = (codes2[s] as usize).min(self.codebooks[s].len().saturating_sub(1));
335            let v1 = &self.codebooks[s][c1];
336            let v2 = &self.codebooks[s][c2];
337            let sq_dist: f32 = v1
338                .iter()
339                .zip(v2.iter())
340                .map(|(a, b)| (a - b) * (a - b))
341                .sum();
342            total += sq_dist;
343        }
344        total
345    }
346
347    /// Check whether the quantizer has been trained.
348    pub fn is_trained(&self) -> bool {
349        !self.codebooks.is_empty()
350    }
351}
352
353// ─────────────────────────────────────────────
354// Helpers
355// ─────────────────────────────────────────────
356
357/// Find the index of the nearest centroid to `query` by squared L2 distance.
358fn nearest_centroid(query: &[f32], centroids: &[Vec<f32>]) -> usize {
359    let mut best_idx = 0;
360    let mut best_dist = f32::INFINITY;
361    for (i, c) in centroids.iter().enumerate() {
362        let d: f32 = query
363            .iter()
364            .zip(c.iter())
365            .map(|(a, b)| (a - b) * (a - b))
366            .sum();
367        if d < best_dist {
368            best_dist = d;
369            best_idx = i;
370        }
371    }
372    best_idx
373}
374
375// ─────────────────────────────────────────────
376// Tests
377// ─────────────────────────────────────────────
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    fn sample_embedding(seed: u32, dim: usize) -> Vec<f32> {
384        let mut v = Vec::with_capacity(dim);
385        let mut s = seed;
386        for _ in 0..dim {
387            s = s.wrapping_mul(1664525).wrapping_add(1013904223);
388            v.push((s as f32 / u32::MAX as f32) * 2.0 - 1.0);
389        }
390        v
391    }
392
393    fn sample_batch(n: usize, dim: usize, base_seed: u32) -> Vec<Vec<f32>> {
394        (0..n)
395            .map(|i| sample_embedding(base_seed + i as u32, dim))
396            .collect()
397    }
398
399    // ── QuantizedEmbedding ────────────────────
400
401    #[test]
402    fn test_quantize_dequantize_roundtrip() {
403        let emb = sample_embedding(1, 16);
404        let q = QuantizedEmbedding::quantize(&emb);
405        let deq = q.dequantize();
406        assert_eq!(deq.len(), emb.len());
407        // Max reconstruction error for 8-bit quantization ≤ range/255
408        let range = emb.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
409            - emb.iter().cloned().fold(f32::INFINITY, f32::min);
410        let max_err = range / 255.0 + 1e-5;
411        for (orig, rec) in emb.iter().zip(deq.iter()) {
412            assert!(
413                (orig - rec).abs() <= max_err + 1e-4,
414                "reconstruction error too large: {} vs {} (max_err={})",
415                orig,
416                rec,
417                max_err
418            );
419        }
420    }
421
422    #[test]
423    fn test_quantize_output_in_range() {
424        let emb = sample_embedding(2, 8);
425        let q = QuantizedEmbedding::quantize(&emb);
426        assert_eq!(q.quantized_data.len(), 8);
427        // all values in [0, 255] — always true for u8, just verify non-empty
428        assert!(!q.quantized_data.is_empty());
429    }
430
431    #[test]
432    fn test_quantize_empty_embedding() {
433        let q = QuantizedEmbedding::quantize(&[]);
434        assert_eq!(q.original_dim, 0);
435        assert!(q.quantized_data.is_empty());
436    }
437
438    #[test]
439    fn test_quantize_constant_embedding() {
440        let val = 3.125_f32;
441        let emb = vec![val; 8];
442        let q = QuantizedEmbedding::quantize(&emb);
443        let deq = q.dequantize();
444        for &v in &deq {
445            assert!(
446                (v - val).abs() < 0.5,
447                "constant embedding should dequantize close to {val}, got {v}"
448            );
449        }
450    }
451
452    #[test]
453    fn test_approx_size_bytes() {
454        let emb = sample_embedding(3, 64);
455        let q = QuantizedEmbedding::quantize(&emb);
456        let sz = q.approx_size_bytes();
457        assert!(sz > 0);
458        // Should be much smaller than original f32 size (64 * 4 = 256 bytes)
459        assert!(sz < 64 * 4, "quantized size should be smaller than f32");
460    }
461
462    // ── EmbeddingQuantizer ────────────────────
463
464    #[test]
465    fn test_quantizer_8bit_creation() {
466        let q = EmbeddingQuantizer::new(8);
467        assert_eq!(q.bits, 8);
468    }
469
470    #[test]
471    fn test_quantizer_4bit_creation() {
472        let q = EmbeddingQuantizer::new(4);
473        assert_eq!(q.bits, 4);
474    }
475
476    #[test]
477    fn test_quantize_batch_count() {
478        let q = EmbeddingQuantizer::new(8);
479        let batch = sample_batch(10, 16, 100);
480        let out = q.quantize_batch(&batch);
481        assert_eq!(out.len(), 10);
482    }
483
484    #[test]
485    fn test_dequantize_batch_count() {
486        let q = EmbeddingQuantizer::new(8);
487        let batch = sample_batch(5, 16, 200);
488        let quantized = q.quantize_batch(&batch);
489        let deq = q.dequantize_batch(&quantized);
490        assert_eq!(deq.len(), 5);
491        assert_eq!(deq[0].len(), 16);
492    }
493
494    #[test]
495    fn test_compression_ratio_8bit() {
496        let q = EmbeddingQuantizer::new(8);
497        let batch = sample_batch(10, 64, 300);
498        let ratio = q.compression_ratio(&batch);
499        // 8-bit should give ratio close to 4x (f32 → u8) minus overhead
500        assert!(
501            ratio > 1.0,
502            "8-bit quantization should compress: ratio={ratio}"
503        );
504    }
505
506    #[test]
507    fn test_compression_ratio_4bit() {
508        let q = EmbeddingQuantizer::new(4);
509        let batch = sample_batch(10, 64, 400);
510        let ratio = q.compression_ratio(&batch);
511        assert!(
512            ratio > 1.0,
513            "4-bit quantization should compress: ratio={ratio}"
514        );
515    }
516
517    #[test]
518    fn test_compression_ratio_empty() {
519        let q = EmbeddingQuantizer::new(8);
520        let ratio = q.compression_ratio(&[]);
521        assert_eq!(ratio, 1.0);
522    }
523
524    #[test]
525    fn test_4bit_quantize_dequantize() {
526        let q = EmbeddingQuantizer::new(4);
527        let batch = sample_batch(3, 16, 500);
528        let quantized = q.quantize_batch(&batch);
529        let deq = q.dequantize_batch(&quantized);
530        // 4-bit reconstruction error ≤ range/15
531        for (orig, rec) in batch.iter().zip(deq.iter()) {
532            let range = orig.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
533                - orig.iter().cloned().fold(f32::INFINITY, f32::min);
534            let max_err = range / 15.0 + 1e-3;
535            for (o, r) in orig.iter().zip(rec.iter()) {
536                assert!(
537                    (o - r).abs() <= max_err + 0.1,
538                    "4-bit error too large: {o} vs {r}"
539                );
540            }
541        }
542    }
543
544    // ── ProductQuantizer ──────────────────────
545
546    #[test]
547    fn test_pq_creation() {
548        let pq = ProductQuantizer::new(4, 16);
549        assert_eq!(pq.subspace_count, 4);
550        assert_eq!(pq.codebook_size, 16);
551        assert!(!pq.is_trained());
552    }
553
554    #[test]
555    fn test_pq_train() {
556        let mut pq = ProductQuantizer::new(4, 8);
557        let batch = sample_batch(50, 16, 1000);
558        pq.train(&batch);
559        assert!(pq.is_trained());
560        assert_eq!(pq.codebooks.len(), 4);
561    }
562
563    #[test]
564    fn test_pq_encode_length() {
565        let mut pq = ProductQuantizer::new(4, 8);
566        let batch = sample_batch(30, 16, 1100);
567        pq.train(&batch);
568        let codes = pq.encode(&batch[0]);
569        assert_eq!(codes.len(), 4);
570    }
571
572    #[test]
573    fn test_pq_decode_length() {
574        let mut pq = ProductQuantizer::new(4, 8);
575        let batch = sample_batch(30, 16, 1200);
576        pq.train(&batch);
577        let codes = pq.encode(&batch[0]);
578        let decoded = pq.decode(&codes);
579        assert!(!decoded.is_empty());
580    }
581
582    #[test]
583    fn test_pq_approx_distance_same_code() {
584        let mut pq = ProductQuantizer::new(4, 8);
585        let batch = sample_batch(30, 16, 1300);
586        pq.train(&batch);
587        let codes = pq.encode(&batch[0]);
588        let dist = pq.approx_distance(&codes, &codes);
589        assert!(
590            dist.abs() < 1e-6,
591            "distance to self should be ~0, got {dist}"
592        );
593    }
594
595    #[test]
596    fn test_pq_approx_distance_different_codes() {
597        let mut pq = ProductQuantizer::new(4, 8);
598        let batch = sample_batch(40, 16, 1400);
599        pq.train(&batch);
600        let c0 = pq.encode(&batch[0]);
601        let c1 = pq.encode(&batch[20]);
602        let dist = pq.approx_distance(&c0, &c1);
603        assert!(dist >= 0.0, "distance should be non-negative");
604        assert!(dist.is_finite(), "distance should be finite");
605    }
606
607    #[test]
608    fn test_pq_encode_before_train_returns_zeros() {
609        let pq = ProductQuantizer::new(4, 8);
610        let emb = sample_embedding(1, 16);
611        let codes = pq.encode(&emb);
612        assert!(codes.iter().all(|&c| c == 0));
613    }
614
615    #[test]
616    fn test_pq_codebook_size_capped_by_data() {
617        let mut pq = ProductQuantizer::new(2, 256); // more codes than data
618        let batch = sample_batch(10, 8, 2000); // only 10 embeddings
619        pq.train(&batch);
620        // Each codebook should have at most 10 entries (capped by data)
621        for cb in &pq.codebooks {
622            assert!(cb.len() <= 256);
623        }
624    }
625
626    #[test]
627    fn test_pq_reconstruction_quality() {
628        let mut pq = ProductQuantizer::new(2, 8);
629        let batch = sample_batch(50, 8, 3000);
630        pq.train(&batch);
631        // Reconstruction of a training vector should be somewhat close
632        let orig = &batch[0];
633        let codes = pq.encode(orig);
634        let decoded = pq.decode(&codes);
635        // Just check that decoded is non-empty and finite
636        assert!(!decoded.is_empty());
637        assert!(decoded.iter().all(|v| v.is_finite()));
638    }
639
640    #[test]
641    fn test_pq_train_empty_no_panic() {
642        let mut pq = ProductQuantizer::new(4, 8);
643        pq.train(&[]); // should not panic
644        assert!(!pq.is_trained());
645    }
646}