Skip to main content

oxirs_vec/
quantized_cache.rs

1//! Quantized embedding cache with scalar quantization, product quantization,
2//! and asymmetric distance computation.
3//!
4//! `QuantizedEmbeddingCache` stores compressed int8 codes for each vector,
5//! supporting two compression schemes:
6//!
7//! * **Scalar Quantization (SQ)** – maps each fp32 scalar to an int8 value
8//!   using per-dimension or global min/max ranges.
9//! * **Product Quantization (PQ)** – splits the vector into sub-spaces and
10//!   quantizes each sub-space to a centroid index, enabling very high
11//!   compression ratios.
12//!
13//! Distance is computed **asymmetrically**: the query is kept in fp32 while
14//! the database codes are decompressed on-the-fly, giving better accuracy
15//! than comparing compressed codes directly.
16//!
17//! Compression ratio and distance accuracy metrics are tracked automatically.
18//!
19//! # Pure Rust Policy
20//!
21//! No unsafe code, no C/Fortran FFI, no CUDA runtime calls.
22
23use anyhow::{anyhow, Result};
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26
27// ── quantization scheme ────────────────────────────────────────────────────
28
29/// Which compression scheme to use in the cache.
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
31pub enum QuantizationScheme {
32    /// Scalar quantization: one int8 per dimension.
33    Scalar,
34    /// Product quantization: one centroid index per sub-space.
35    Product,
36}
37
38// ── scalar quantization helpers ────────────────────────────────────────────
39
40/// Per-dimension parameters for scalar quantization.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ScalarDimParams {
43    /// Minimum fp32 value seen during training.
44    pub min: f32,
45    /// Maximum fp32 value seen during training.
46    pub max: f32,
47    /// Precomputed scale = 255.0 / (max - min).
48    pub scale: f32,
49}
50
51impl ScalarDimParams {
52    fn new(min: f32, max: f32) -> Self {
53        let range = max - min;
54        let scale = if range > 1e-9 { 255.0 / range } else { 1.0 };
55        Self { min, max, scale }
56    }
57
58    /// Quantize a single fp32 scalar to u8.
59    #[inline]
60    pub fn quantize(&self, v: f32) -> u8 {
61        ((v - self.min) * self.scale).clamp(0.0, 255.0) as u8
62    }
63
64    /// Dequantize a u8 back to fp32.
65    #[inline]
66    pub fn dequantize(&self, code: u8) -> f32 {
67        self.min + (code as f32) / self.scale
68    }
69}
70
71// ── product quantization helpers ───────────────────────────────────────────
72
73/// A single PQ codebook: `n_centroids` centroids, each of dimension `sub_dim`.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct PqCodebook {
76    /// Number of centroids.
77    pub n_centroids: usize,
78    /// Dimension of each centroid (= total_dim / n_subspaces).
79    pub sub_dim: usize,
80    /// Flattened centroids: `n_centroids × sub_dim` fp32 values.
81    pub centroids: Vec<f32>,
82}
83
84impl PqCodebook {
85    /// Build a codebook from training sub-vectors using a simple k-means variant.
86    fn train(sub_vectors: &[Vec<f32>], n_centroids: usize, max_iters: usize) -> Self {
87        let sub_dim = if sub_vectors.is_empty() {
88            0
89        } else {
90            sub_vectors[0].len()
91        };
92
93        if sub_vectors.is_empty() || n_centroids == 0 || sub_dim == 0 {
94            return Self {
95                n_centroids,
96                sub_dim,
97                centroids: Vec::new(),
98            };
99        }
100
101        let actual_k = n_centroids.min(sub_vectors.len());
102
103        // Initialise centroids from the first `actual_k` training vectors
104        let mut centroids: Vec<Vec<f32>> = sub_vectors.iter().take(actual_k).cloned().collect();
105
106        for _ in 0..max_iters {
107            // Assignment step
108            let mut assignments: Vec<usize> = Vec::with_capacity(sub_vectors.len());
109            for sv in sub_vectors {
110                let best = centroids
111                    .iter()
112                    .enumerate()
113                    .map(|(i, c)| (i, euclidean_sq_slice(sv, c)))
114                    .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
115                    .map(|(i, _)| i)
116                    .unwrap_or(0);
117                assignments.push(best);
118            }
119
120            // Update step
121            let mut new_centroids = vec![vec![0.0_f32; sub_dim]; actual_k];
122            let mut counts = vec![0usize; actual_k];
123            for (sv, &asgn) in sub_vectors.iter().zip(&assignments) {
124                for (d, &v) in sv.iter().enumerate() {
125                    new_centroids[asgn][d] += v;
126                }
127                counts[asgn] += 1;
128            }
129            for (c, count) in new_centroids.iter_mut().zip(&counts) {
130                if *count > 0 {
131                    for v in c.iter_mut() {
132                        *v /= *count as f32;
133                    }
134                }
135            }
136            centroids = new_centroids;
137        }
138
139        let flat: Vec<f32> = centroids.into_iter().flatten().collect();
140        Self {
141            n_centroids: actual_k,
142            sub_dim,
143            centroids: flat,
144        }
145    }
146
147    /// Find the nearest centroid index for a sub-vector.
148    pub fn encode(&self, sub_vec: &[f32]) -> u8 {
149        let best = (0..self.n_centroids)
150            .map(|i| {
151                let offset = i * self.sub_dim;
152                let centroid = &self.centroids[offset..offset + self.sub_dim];
153                (i, euclidean_sq_slice(sub_vec, centroid))
154            })
155            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
156            .map(|(i, _)| i)
157            .unwrap_or(0);
158        (best & 0xFF) as u8
159    }
160
161    /// Decode a centroid index back to a sub-vector slice.
162    pub fn decode(&self, code: u8) -> &[f32] {
163        let i = (code as usize).min(self.n_centroids.saturating_sub(1));
164        let offset = i * self.sub_dim;
165        &self.centroids[offset..offset + self.sub_dim]
166    }
167}
168
169// ── cache config ───────────────────────────────────────────────────────────
170
171/// Configuration for `QuantizedEmbeddingCache`.
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct QuantizedCacheConfig {
174    /// Which quantization scheme to use.
175    pub scheme: QuantizationScheme,
176    /// For SQ: number of bits (currently fixed at 8).
177    pub sq_bits: u8,
178    /// For PQ: number of sub-spaces.
179    pub pq_n_subspaces: usize,
180    /// For PQ: number of centroids per codebook.
181    pub pq_n_centroids: usize,
182    /// For PQ: number of k-means iterations during training.
183    pub pq_max_iters: usize,
184    /// Whether to normalize vectors before quantization.
185    pub normalize: bool,
186    /// Maximum number of training samples to use.
187    pub max_training_samples: usize,
188}
189
190impl Default for QuantizedCacheConfig {
191    fn default() -> Self {
192        Self {
193            scheme: QuantizationScheme::Scalar,
194            sq_bits: 8,
195            pq_n_subspaces: 8,
196            pq_n_centroids: 256,
197            pq_max_iters: 25,
198            normalize: false,
199            max_training_samples: 10_000,
200        }
201    }
202}
203
204// ── metrics ────────────────────────────────────────────────────────────────
205
206/// Compression ratio and distance accuracy metrics.
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct CacheMetrics {
209    /// Number of vectors stored.
210    pub vector_count: usize,
211    /// Dimensionality of stored vectors.
212    pub dimensions: usize,
213    /// Bytes used for compressed codes.
214    pub compressed_bytes: usize,
215    /// Bytes that fp32 vectors would occupy.
216    pub uncompressed_bytes: usize,
217    /// `uncompressed_bytes / compressed_bytes`.
218    pub compression_ratio: f64,
219    /// Mean absolute error between original and reconstructed vectors
220    /// (measured on the stored vectors, sampled during `train`).
221    pub mean_reconstruction_error: f32,
222    /// Number of queries served.
223    pub queries_served: u64,
224    /// Cumulative distance computations.
225    pub distance_computations: u64,
226}
227
228impl Default for CacheMetrics {
229    fn default() -> Self {
230        Self {
231            vector_count: 0,
232            dimensions: 0,
233            compressed_bytes: 0,
234            uncompressed_bytes: 0,
235            compression_ratio: 0.0,
236            mean_reconstruction_error: 0.0,
237            queries_served: 0,
238            distance_computations: 0,
239        }
240    }
241}
242
243// ── internal storage ───────────────────────────────────────────────────────
244
245/// A stored compressed code, keyed by arbitrary string ID.
246#[derive(Debug, Clone)]
247struct CompressedCode {
248    /// One byte per code slot (SQ: one per dimension; PQ: one per sub-space).
249    codes: Vec<u8>,
250    /// Optional user-defined metadata.
251    metadata: HashMap<String, String>,
252}
253
254// ── main struct ────────────────────────────────────────────────────────────
255
256/// Quantized embedding cache with scalar or product quantization and asymmetric
257/// distance computation for compressed similarity search.
258pub struct QuantizedEmbeddingCache {
259    config: QuantizedCacheConfig,
260    dimensions: usize,
261    // SQ parameters
262    sq_params: Vec<ScalarDimParams>,
263    // PQ codebooks (one per sub-space)
264    pq_codebooks: Vec<PqCodebook>,
265    // Stored compressed codes
266    codes: Vec<CompressedCode>,
267    id_to_idx: HashMap<String, usize>,
268    idx_to_id: Vec<String>,
269    // Metrics
270    metrics: CacheMetrics,
271}
272
273impl QuantizedEmbeddingCache {
274    /// Create a new, untrained cache.
275    pub fn new(config: QuantizedCacheConfig, dimensions: usize) -> Self {
276        Self {
277            config,
278            dimensions,
279            sq_params: Vec::new(),
280            pq_codebooks: Vec::new(),
281            codes: Vec::new(),
282            id_to_idx: HashMap::new(),
283            idx_to_id: Vec::new(),
284            metrics: CacheMetrics {
285                dimensions,
286                ..Default::default()
287            },
288        }
289    }
290
291    // ── training ──────────────────────────────────────────────────────────
292
293    /// Train quantization parameters from `training_vectors`.
294    ///
295    /// Must be called before any calls to `add` or `search`.
296    pub fn train(&mut self, training_vectors: &[Vec<f32>]) -> Result<()> {
297        if training_vectors.is_empty() {
298            return Err(anyhow!("No training vectors provided"));
299        }
300        let dim = training_vectors[0].len();
301        if dim != self.dimensions {
302            return Err(anyhow!(
303                "Training vector dim {} ≠ cache dim {}",
304                dim,
305                self.dimensions
306            ));
307        }
308
309        let limit = training_vectors.len().min(self.config.max_training_samples);
310        let raw_samples = &training_vectors[..limit];
311
312        // When normalization is enabled, normalize training samples so that
313        // the quantizer learns the min/max range of normalized vectors.
314        let normalized_storage: Vec<Vec<f32>>;
315        let samples: &[Vec<f32>] = if self.config.normalize {
316            normalized_storage = raw_samples.iter().map(|v| normalize_vec(v)).collect();
317            &normalized_storage
318        } else {
319            raw_samples
320        };
321
322        match self.config.scheme {
323            QuantizationScheme::Scalar => self.train_scalar(samples)?,
324            QuantizationScheme::Product => self.train_product(samples)?,
325        }
326
327        // Measure reconstruction error on training samples
328        let error = self.measure_reconstruction_error(samples);
329        self.metrics.mean_reconstruction_error = error;
330
331        Ok(())
332    }
333
334    fn train_scalar(&mut self, samples: &[Vec<f32>]) -> Result<()> {
335        let mut dim_mins = vec![f32::INFINITY; self.dimensions];
336        let mut dim_maxs = vec![f32::NEG_INFINITY; self.dimensions];
337
338        for v in samples {
339            for (d, &val) in v.iter().enumerate() {
340                dim_mins[d] = dim_mins[d].min(val);
341                dim_maxs[d] = dim_maxs[d].max(val);
342            }
343        }
344
345        self.sq_params = dim_mins
346            .into_iter()
347            .zip(dim_maxs)
348            .map(|(mn, mx)| ScalarDimParams::new(mn, mx))
349            .collect();
350
351        Ok(())
352    }
353
354    fn train_product(&mut self, samples: &[Vec<f32>]) -> Result<()> {
355        let n_sub = self.config.pq_n_subspaces;
356        if self.dimensions % n_sub != 0 {
357            return Err(anyhow!(
358                "dimensions ({}) must be divisible by pq_n_subspaces ({})",
359                self.dimensions,
360                n_sub
361            ));
362        }
363        let sub_dim = self.dimensions / n_sub;
364
365        self.pq_codebooks = (0..n_sub)
366            .map(|s| {
367                let sub_vecs: Vec<Vec<f32>> = samples
368                    .iter()
369                    .map(|v| v[s * sub_dim..(s + 1) * sub_dim].to_vec())
370                    .collect();
371                PqCodebook::train(
372                    &sub_vecs,
373                    self.config.pq_n_centroids,
374                    self.config.pq_max_iters,
375                )
376            })
377            .collect();
378
379        Ok(())
380    }
381
382    fn measure_reconstruction_error(&self, samples: &[Vec<f32>]) -> f32 {
383        let limit = samples.len().min(200);
384        let mut total = 0.0_f32;
385        for v in &samples[..limit] {
386            let normalized = if self.config.normalize {
387                normalize_vec(v)
388            } else {
389                v.clone()
390            };
391            let codes = self.encode_vector(&normalized);
392            let reconstructed = self.decode_codes(&codes);
393            let err: f32 = normalized
394                .iter()
395                .zip(&reconstructed)
396                .map(|(&a, &b)| (a - b).abs())
397                .sum::<f32>()
398                / self.dimensions as f32;
399            total += err;
400        }
401        total / limit as f32
402    }
403
404    // ── insert / retrieve ──────────────────────────────────────────────────
405
406    /// Compress and store a vector by `id`.
407    pub fn add(&mut self, id: String, vector: Vec<f32>) -> Result<()> {
408        self.add_with_metadata(id, vector, HashMap::new())
409    }
410
411    /// Compress and store a vector with metadata.
412    pub fn add_with_metadata(
413        &mut self,
414        id: String,
415        vector: Vec<f32>,
416        metadata: HashMap<String, String>,
417    ) -> Result<()> {
418        if self.is_untrained() {
419            return Err(anyhow!("Cache not trained; call train() first"));
420        }
421        if vector.len() != self.dimensions {
422            return Err(anyhow!(
423                "Vector dim {} ≠ cache dim {}",
424                vector.len(),
425                self.dimensions
426            ));
427        }
428        if self.id_to_idx.contains_key(&id) {
429            return Err(anyhow!("ID '{}' already in cache", id));
430        }
431
432        let normalized = if self.config.normalize {
433            normalize_vec(&vector)
434        } else {
435            vector
436        };
437        let codes = self.encode_vector(&normalized);
438        let idx = self.codes.len();
439
440        self.codes.push(CompressedCode { codes, metadata });
441        self.id_to_idx.insert(id.clone(), idx);
442        self.idx_to_id.push(id);
443
444        // Update metrics
445        let code_len = self.code_length();
446        self.metrics.vector_count += 1;
447        self.metrics.compressed_bytes += code_len;
448        self.metrics.uncompressed_bytes += self.dimensions * 4;
449        self.metrics.compression_ratio =
450            self.metrics.uncompressed_bytes as f64 / self.metrics.compressed_bytes.max(1) as f64;
451
452        Ok(())
453    }
454
455    /// Retrieve the decompressed (reconstructed) vector for `id`.
456    pub fn get(&self, id: &str) -> Option<Vec<f32>> {
457        let idx = *self.id_to_idx.get(id)?;
458        Some(self.decode_codes(&self.codes[idx].codes))
459    }
460
461    /// Number of vectors in the cache.
462    pub fn len(&self) -> usize {
463        self.codes.len()
464    }
465
466    /// Returns `true` if no vectors are stored.
467    pub fn is_empty(&self) -> bool {
468        self.codes.is_empty()
469    }
470
471    // ── asymmetric search ──────────────────────────────────────────────────
472
473    /// Find the `k` nearest cached vectors to `query` using asymmetric distance.
474    ///
475    /// The query is kept in fp32; each database code is decompressed on-the-fly
476    /// and Euclidean distance is computed.
477    pub fn search(&mut self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
478        if self.is_untrained() {
479            return Err(anyhow!("Cache not trained"));
480        }
481        if query.len() != self.dimensions {
482            return Err(anyhow!(
483                "Query dim {} ≠ cache dim {}",
484                query.len(),
485                self.dimensions
486            ));
487        }
488
489        let normalized_query = if self.config.normalize {
490            normalize_vec(query)
491        } else {
492            query.to_vec()
493        };
494
495        let mut distances: Vec<(usize, f32)> = self
496            .codes
497            .iter()
498            .enumerate()
499            .map(|(i, code)| {
500                let reconstructed = self.decode_codes(&code.codes);
501                let dist = euclidean_sq_slice(&normalized_query, &reconstructed).sqrt();
502                (i, dist)
503            })
504            .collect();
505
506        self.metrics.distance_computations += self.codes.len() as u64;
507        self.metrics.queries_served += 1;
508
509        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
510        distances.truncate(k);
511
512        Ok(distances
513            .into_iter()
514            .map(|(i, d)| (self.idx_to_id[i].clone(), d))
515            .collect())
516    }
517
518    /// Return a snapshot of current metrics.
519    pub fn metrics(&self) -> &CacheMetrics {
520        &self.metrics
521    }
522
523    /// Access the configuration.
524    pub fn config(&self) -> &QuantizedCacheConfig {
525        &self.config
526    }
527
528    // ── encoding / decoding ────────────────────────────────────────────────
529
530    fn encode_vector(&self, vector: &[f32]) -> Vec<u8> {
531        match self.config.scheme {
532            QuantizationScheme::Scalar => vector
533                .iter()
534                .zip(&self.sq_params)
535                .map(|(&v, params)| params.quantize(v))
536                .collect(),
537            QuantizationScheme::Product => {
538                let n_sub = self.pq_codebooks.len();
539                if n_sub == 0 {
540                    return Vec::new();
541                }
542                let sub_dim = self.dimensions / n_sub;
543                (0..n_sub)
544                    .map(|s| {
545                        let sub = &vector[s * sub_dim..(s + 1) * sub_dim];
546                        self.pq_codebooks[s].encode(sub)
547                    })
548                    .collect()
549            }
550        }
551    }
552
553    fn decode_codes(&self, codes: &[u8]) -> Vec<f32> {
554        match self.config.scheme {
555            QuantizationScheme::Scalar => codes
556                .iter()
557                .zip(&self.sq_params)
558                .map(|(&code, params)| params.dequantize(code))
559                .collect(),
560            QuantizationScheme::Product => {
561                let n_sub = self.pq_codebooks.len();
562                if n_sub == 0 {
563                    return Vec::new();
564                }
565                let mut out = Vec::with_capacity(self.dimensions);
566                for (s, &code) in (0..n_sub).zip(codes.iter()) {
567                    out.extend_from_slice(self.pq_codebooks[s].decode(code));
568                }
569                out
570            }
571        }
572    }
573
574    /// Number of code bytes per vector.
575    fn code_length(&self) -> usize {
576        match self.config.scheme {
577            QuantizationScheme::Scalar => self.dimensions, // 1 byte per dim
578            QuantizationScheme::Product => self.config.pq_n_subspaces,
579        }
580    }
581
582    fn is_untrained(&self) -> bool {
583        match self.config.scheme {
584            QuantizationScheme::Scalar => self.sq_params.is_empty(),
585            QuantizationScheme::Product => self.pq_codebooks.is_empty(),
586        }
587    }
588}
589
590// ── free functions ─────────────────────────────────────────────────────────
591
592/// Squared Euclidean distance between two equal-length slices.
593#[inline]
594fn euclidean_sq_slice(a: &[f32], b: &[f32]) -> f32 {
595    a.iter()
596        .zip(b.iter())
597        .map(|(&x, &y)| {
598            let d = x - y;
599            d * d
600        })
601        .sum()
602}
603
604/// L2-normalize a vector (returns original if norm is near zero).
605fn normalize_vec(v: &[f32]) -> Vec<f32> {
606    let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
607    if norm < 1e-9 {
608        v.to_vec()
609    } else {
610        v.iter().map(|&x| x / norm).collect()
611    }
612}
613
614// ─────────────────────────────────────────────────────────────────────────────
615// Tests
616// ─────────────────────────────────────────────────────────────────────────────
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621
622    fn make_sq_cache(dims: usize) -> QuantizedEmbeddingCache {
623        let config = QuantizedCacheConfig {
624            scheme: QuantizationScheme::Scalar,
625            ..Default::default()
626        };
627        QuantizedEmbeddingCache::new(config, dims)
628    }
629
630    fn make_pq_cache(dims: usize, n_sub: usize) -> QuantizedEmbeddingCache {
631        let config = QuantizedCacheConfig {
632            scheme: QuantizationScheme::Product,
633            pq_n_subspaces: n_sub,
634            pq_n_centroids: 8,
635            pq_max_iters: 5,
636            ..Default::default()
637        };
638        QuantizedEmbeddingCache::new(config, dims)
639    }
640
641    fn training_vecs(n: usize, dims: usize) -> Vec<Vec<f32>> {
642        (0..n)
643            .map(|i| (0..dims).map(|d| (i * dims + d) as f32 * 0.01).collect())
644            .collect()
645    }
646
647    // ── SQ training ────────────────────────────────────────────────────────
648
649    #[test]
650    fn test_sq_train_succeeds() {
651        let mut cache = make_sq_cache(4);
652        let samples = training_vecs(50, 4);
653        assert!(cache.train(&samples).is_ok());
654        assert_eq!(cache.sq_params.len(), 4);
655    }
656
657    #[test]
658    fn test_sq_train_empty_fails() {
659        let mut cache = make_sq_cache(4);
660        assert!(cache.train(&[]).is_err());
661    }
662
663    #[test]
664    fn test_sq_train_wrong_dim_fails() {
665        let mut cache = make_sq_cache(4);
666        let samples = vec![vec![1.0_f32; 8]];
667        assert!(cache.train(&samples).is_err());
668    }
669
670    #[test]
671    fn test_sq_untrained_add_fails() {
672        let mut cache = make_sq_cache(4);
673        let err = cache.add("k".to_string(), vec![0.0; 4]);
674        assert!(err.is_err());
675    }
676
677    // ── SQ add / get ───────────────────────────────────────────────────────
678
679    #[test]
680    fn test_sq_add_and_get() {
681        let mut cache = make_sq_cache(4);
682        let samples = training_vecs(50, 4);
683        cache.train(&samples).unwrap();
684        cache
685            .add("v0".to_string(), vec![0.1, 0.2, 0.3, 0.4])
686            .unwrap();
687        let reconstructed = cache.get("v0");
688        assert!(reconstructed.is_some());
689        let r = reconstructed.unwrap();
690        assert_eq!(r.len(), 4);
691        // Reconstruction should be close to original
692        for (orig, rec) in [0.1_f32, 0.2, 0.3, 0.4].iter().zip(&r) {
693            assert!((orig - rec).abs() < 0.05, "Reconstruction error too large");
694        }
695    }
696
697    #[test]
698    fn test_sq_duplicate_id_fails() {
699        let mut cache = make_sq_cache(4);
700        cache.train(&training_vecs(10, 4)).unwrap();
701        cache.add("k".to_string(), vec![0.0; 4]).unwrap();
702        assert!(cache.add("k".to_string(), vec![1.0; 4]).is_err());
703    }
704
705    #[test]
706    fn test_sq_get_missing_returns_none() {
707        let mut cache = make_sq_cache(4);
708        cache.train(&training_vecs(10, 4)).unwrap();
709        assert!(cache.get("absent").is_none());
710    }
711
712    // ── SQ search ──────────────────────────────────────────────────────────
713
714    #[test]
715    fn test_sq_search_returns_nearest() {
716        let mut cache = make_sq_cache(2);
717        let samples = vec![vec![0.0_f32, 0.0], vec![1.0, 0.0], vec![5.0, 0.0]];
718        cache.train(&samples).unwrap();
719        cache.add("origin".to_string(), vec![0.0, 0.0]).unwrap();
720        cache.add("near".to_string(), vec![0.5, 0.0]).unwrap();
721        cache.add("far".to_string(), vec![5.0, 0.0]).unwrap();
722
723        let results = cache.search(&[0.0, 0.0], 1).unwrap();
724        assert_eq!(results.len(), 1);
725        assert_eq!(results[0].0, "origin");
726    }
727
728    #[test]
729    fn test_sq_search_top_k_ordering() {
730        let mut cache = make_sq_cache(1);
731        let samples: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32]).collect();
732        cache.train(&samples).unwrap();
733        for i in 0..10_u32 {
734            cache.add(format!("v{}", i), vec![i as f32]).unwrap();
735        }
736        let results = cache.search(&[5.0], 3).unwrap();
737        assert!(results.len() <= 3);
738        // Results should be ascending distance
739        for w in results.windows(2) {
740            assert!(w[0].1 <= w[1].1 + 1e-6);
741        }
742    }
743
744    #[test]
745    fn test_sq_search_empty_cache() {
746        let mut cache = make_sq_cache(4);
747        cache.train(&training_vecs(10, 4)).unwrap();
748        let results = cache.search(&[0.0; 4], 5).unwrap();
749        assert!(results.is_empty());
750    }
751
752    // ── SQ metrics ─────────────────────────────────────────────────────────
753
754    #[test]
755    fn test_sq_compression_ratio_greater_than_one() {
756        let mut cache = make_sq_cache(32);
757        cache.train(&training_vecs(100, 32)).unwrap();
758        for i in 0..10 {
759            cache.add(format!("v{}", i), vec![0.5; 32]).unwrap();
760        }
761        let m = cache.metrics();
762        assert!(m.compression_ratio > 1.0);
763        // 32 dims × 4 bytes fp32 vs 32 dims × 1 byte u8 → ratio ≈ 4
764        assert!(
765            (m.compression_ratio - 4.0).abs() < 0.5,
766            "SQ ratio should be ~4"
767        );
768    }
769
770    #[test]
771    fn test_sq_metrics_vector_count() {
772        let mut cache = make_sq_cache(4);
773        cache.train(&training_vecs(10, 4)).unwrap();
774        for i in 0..5 {
775            cache.add(format!("v{}", i), vec![i as f32; 4]).unwrap();
776        }
777        assert_eq!(cache.metrics().vector_count, 5);
778    }
779
780    #[test]
781    fn test_sq_queries_served_increments() {
782        let mut cache = make_sq_cache(4);
783        cache.train(&training_vecs(10, 4)).unwrap();
784        cache.add("a".to_string(), vec![0.0; 4]).unwrap();
785        cache.search(&[0.0; 4], 1).unwrap();
786        cache.search(&[0.0; 4], 1).unwrap();
787        assert_eq!(cache.metrics().queries_served, 2);
788    }
789
790    #[test]
791    fn test_sq_reconstruction_error_reasonable() {
792        let mut cache = make_sq_cache(4);
793        let samples = training_vecs(100, 4);
794        cache.train(&samples).unwrap();
795        // For 8-bit SQ, reconstruction error should be small
796        assert!(cache.metrics().mean_reconstruction_error < 0.1);
797    }
798
799    // ── PQ training ────────────────────────────────────────────────────────
800
801    #[test]
802    fn test_pq_train_succeeds() {
803        let mut cache = make_pq_cache(8, 2);
804        let samples = training_vecs(50, 8);
805        assert!(cache.train(&samples).is_ok());
806        assert_eq!(cache.pq_codebooks.len(), 2);
807    }
808
809    #[test]
810    fn test_pq_train_indivisible_dims_fails() {
811        let mut cache = make_pq_cache(7, 3); // 7 not divisible by 3
812        let samples = training_vecs(30, 7);
813        assert!(cache.train(&samples).is_err());
814    }
815
816    #[test]
817    fn test_pq_add_and_get() {
818        let mut cache = make_pq_cache(8, 2);
819        let samples = training_vecs(50, 8);
820        cache.train(&samples).unwrap();
821        cache.add("v0".to_string(), vec![0.1; 8]).unwrap();
822        let r = cache.get("v0").unwrap();
823        assert_eq!(r.len(), 8);
824    }
825
826    #[test]
827    fn test_pq_compression_ratio() {
828        let mut cache = make_pq_cache(16, 4); // 4 sub-spaces
829        cache.train(&training_vecs(50, 16)).unwrap();
830        for i in 0..8 {
831            cache.add(format!("v{}", i), vec![0.5; 16]).unwrap();
832        }
833        let m = cache.metrics();
834        // 16 dims × 4 bytes = 64 bytes uncompressed; 4 codes × 1 byte = 4 bytes compressed → ratio = 16
835        assert!(m.compression_ratio > 4.0, "PQ ratio should be > 4");
836    }
837
838    #[test]
839    fn test_pq_search() {
840        let mut cache = make_pq_cache(8, 2);
841        let samples = training_vecs(50, 8);
842        cache.train(&samples).unwrap();
843        cache.add("a".to_string(), vec![0.0; 8]).unwrap();
844        cache.add("b".to_string(), vec![10.0; 8]).unwrap();
845        let results = cache.search(&[0.1; 8], 1).unwrap();
846        assert!(!results.is_empty());
847    }
848
849    // ── normalization ──────────────────────────────────────────────────────
850
851    #[test]
852    fn test_normalized_vectors_stored_as_unit_length() {
853        let config = QuantizedCacheConfig {
854            scheme: QuantizationScheme::Scalar,
855            normalize: true,
856            ..Default::default()
857        };
858        let mut cache = QuantizedEmbeddingCache::new(config, 4);
859        let long_vecs: Vec<Vec<f32>> = (0..20)
860            .map(|i| vec![i as f32 + 1.0, i as f32 + 2.0, 0.0, 0.0])
861            .collect();
862        cache.train(&long_vecs).unwrap();
863        cache
864            .add("v".to_string(), vec![3.0, 4.0, 0.0, 0.0])
865            .unwrap();
866        let r = cache.get("v").unwrap();
867        let norm: f32 = r.iter().map(|&x| x * x).sum::<f32>().sqrt();
868        // Reconstructed vector should be approximately unit length (quantization error allowed)
869        assert!((norm - 1.0).abs() < 0.1, "norm={}, expected ~1.0", norm);
870    }
871
872    // ── config accessors ───────────────────────────────────────────────────
873
874    #[test]
875    fn test_config_accessors() {
876        let config = QuantizedCacheConfig {
877            scheme: QuantizationScheme::Product,
878            pq_n_subspaces: 4,
879            pq_n_centroids: 16,
880            ..Default::default()
881        };
882        let cache = QuantizedEmbeddingCache::new(config, 8);
883        assert_eq!(cache.config().pq_n_subspaces, 4);
884        assert_eq!(cache.config().pq_n_centroids, 16);
885    }
886
887    #[test]
888    fn test_is_empty_initially() {
889        let mut cache = make_sq_cache(4);
890        cache.train(&training_vecs(10, 4)).unwrap();
891        assert!(cache.is_empty());
892    }
893
894    #[test]
895    fn test_len_after_adds() {
896        let mut cache = make_sq_cache(4);
897        cache.train(&training_vecs(10, 4)).unwrap();
898        for i in 0..5 {
899            cache.add(format!("v{}", i), vec![0.0; 4]).unwrap();
900        }
901        assert_eq!(cache.len(), 5);
902    }
903
904    // ── add_with_metadata ──────────────────────────────────────────────────
905
906    #[test]
907    fn test_add_with_metadata() {
908        let mut cache = make_sq_cache(4);
909        cache.train(&training_vecs(10, 4)).unwrap();
910        let mut meta = HashMap::new();
911        meta.insert("tag".to_string(), "test".to_string());
912        cache
913            .add_with_metadata("m".to_string(), vec![0.0; 4], meta)
914            .unwrap();
915        assert_eq!(cache.len(), 1);
916    }
917
918    // ── scalar dim params ──────────────────────────────────────────────────
919
920    #[test]
921    fn test_scalar_dim_params_roundtrip() {
922        let params = ScalarDimParams::new(-1.0, 1.0);
923        let q = params.quantize(0.0);
924        let r = params.dequantize(q);
925        assert!((r - 0.0).abs() < 0.02);
926    }
927
928    #[test]
929    fn test_scalar_dim_params_extremes() {
930        let params = ScalarDimParams::new(0.0, 1.0);
931        assert_eq!(params.quantize(0.0), 0);
932        assert_eq!(params.quantize(1.0), 255);
933    }
934}