1use anyhow::{anyhow, Result};
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
31pub enum QuantizationScheme {
32 Scalar,
34 Product,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ScalarDimParams {
43 pub min: f32,
45 pub max: f32,
47 pub scale: f32,
49}
50
51impl ScalarDimParams {
52 fn new(min: f32, max: f32) -> Self {
53 let range = max - min;
54 let scale = if range > 1e-9 { 255.0 / range } else { 1.0 };
55 Self { min, max, scale }
56 }
57
58 #[inline]
60 pub fn quantize(&self, v: f32) -> u8 {
61 ((v - self.min) * self.scale).clamp(0.0, 255.0) as u8
62 }
63
64 #[inline]
66 pub fn dequantize(&self, code: u8) -> f32 {
67 self.min + (code as f32) / self.scale
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct PqCodebook {
76 pub n_centroids: usize,
78 pub sub_dim: usize,
80 pub centroids: Vec<f32>,
82}
83
84impl PqCodebook {
85 fn train(sub_vectors: &[Vec<f32>], n_centroids: usize, max_iters: usize) -> Self {
87 let sub_dim = if sub_vectors.is_empty() {
88 0
89 } else {
90 sub_vectors[0].len()
91 };
92
93 if sub_vectors.is_empty() || n_centroids == 0 || sub_dim == 0 {
94 return Self {
95 n_centroids,
96 sub_dim,
97 centroids: Vec::new(),
98 };
99 }
100
101 let actual_k = n_centroids.min(sub_vectors.len());
102
103 let mut centroids: Vec<Vec<f32>> = sub_vectors.iter().take(actual_k).cloned().collect();
105
106 for _ in 0..max_iters {
107 let mut assignments: Vec<usize> = Vec::with_capacity(sub_vectors.len());
109 for sv in sub_vectors {
110 let best = centroids
111 .iter()
112 .enumerate()
113 .map(|(i, c)| (i, euclidean_sq_slice(sv, c)))
114 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
115 .map(|(i, _)| i)
116 .unwrap_or(0);
117 assignments.push(best);
118 }
119
120 let mut new_centroids = vec![vec![0.0_f32; sub_dim]; actual_k];
122 let mut counts = vec![0usize; actual_k];
123 for (sv, &asgn) in sub_vectors.iter().zip(&assignments) {
124 for (d, &v) in sv.iter().enumerate() {
125 new_centroids[asgn][d] += v;
126 }
127 counts[asgn] += 1;
128 }
129 for (c, count) in new_centroids.iter_mut().zip(&counts) {
130 if *count > 0 {
131 for v in c.iter_mut() {
132 *v /= *count as f32;
133 }
134 }
135 }
136 centroids = new_centroids;
137 }
138
139 let flat: Vec<f32> = centroids.into_iter().flatten().collect();
140 Self {
141 n_centroids: actual_k,
142 sub_dim,
143 centroids: flat,
144 }
145 }
146
147 pub fn encode(&self, sub_vec: &[f32]) -> u8 {
149 let best = (0..self.n_centroids)
150 .map(|i| {
151 let offset = i * self.sub_dim;
152 let centroid = &self.centroids[offset..offset + self.sub_dim];
153 (i, euclidean_sq_slice(sub_vec, centroid))
154 })
155 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
156 .map(|(i, _)| i)
157 .unwrap_or(0);
158 (best & 0xFF) as u8
159 }
160
161 pub fn decode(&self, code: u8) -> &[f32] {
163 let i = (code as usize).min(self.n_centroids.saturating_sub(1));
164 let offset = i * self.sub_dim;
165 &self.centroids[offset..offset + self.sub_dim]
166 }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct QuantizedCacheConfig {
174 pub scheme: QuantizationScheme,
176 pub sq_bits: u8,
178 pub pq_n_subspaces: usize,
180 pub pq_n_centroids: usize,
182 pub pq_max_iters: usize,
184 pub normalize: bool,
186 pub max_training_samples: usize,
188}
189
190impl Default for QuantizedCacheConfig {
191 fn default() -> Self {
192 Self {
193 scheme: QuantizationScheme::Scalar,
194 sq_bits: 8,
195 pq_n_subspaces: 8,
196 pq_n_centroids: 256,
197 pq_max_iters: 25,
198 normalize: false,
199 max_training_samples: 10_000,
200 }
201 }
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct CacheMetrics {
209 pub vector_count: usize,
211 pub dimensions: usize,
213 pub compressed_bytes: usize,
215 pub uncompressed_bytes: usize,
217 pub compression_ratio: f64,
219 pub mean_reconstruction_error: f32,
222 pub queries_served: u64,
224 pub distance_computations: u64,
226}
227
228impl Default for CacheMetrics {
229 fn default() -> Self {
230 Self {
231 vector_count: 0,
232 dimensions: 0,
233 compressed_bytes: 0,
234 uncompressed_bytes: 0,
235 compression_ratio: 0.0,
236 mean_reconstruction_error: 0.0,
237 queries_served: 0,
238 distance_computations: 0,
239 }
240 }
241}
242
243#[derive(Debug, Clone)]
247struct CompressedCode {
248 codes: Vec<u8>,
250 metadata: HashMap<String, String>,
252}
253
254pub struct QuantizedEmbeddingCache {
259 config: QuantizedCacheConfig,
260 dimensions: usize,
261 sq_params: Vec<ScalarDimParams>,
263 pq_codebooks: Vec<PqCodebook>,
265 codes: Vec<CompressedCode>,
267 id_to_idx: HashMap<String, usize>,
268 idx_to_id: Vec<String>,
269 metrics: CacheMetrics,
271}
272
273impl QuantizedEmbeddingCache {
274 pub fn new(config: QuantizedCacheConfig, dimensions: usize) -> Self {
276 Self {
277 config,
278 dimensions,
279 sq_params: Vec::new(),
280 pq_codebooks: Vec::new(),
281 codes: Vec::new(),
282 id_to_idx: HashMap::new(),
283 idx_to_id: Vec::new(),
284 metrics: CacheMetrics {
285 dimensions,
286 ..Default::default()
287 },
288 }
289 }
290
291 pub fn train(&mut self, training_vectors: &[Vec<f32>]) -> Result<()> {
297 if training_vectors.is_empty() {
298 return Err(anyhow!("No training vectors provided"));
299 }
300 let dim = training_vectors[0].len();
301 if dim != self.dimensions {
302 return Err(anyhow!(
303 "Training vector dim {} ≠ cache dim {}",
304 dim,
305 self.dimensions
306 ));
307 }
308
309 let limit = training_vectors.len().min(self.config.max_training_samples);
310 let raw_samples = &training_vectors[..limit];
311
312 let normalized_storage: Vec<Vec<f32>>;
315 let samples: &[Vec<f32>] = if self.config.normalize {
316 normalized_storage = raw_samples.iter().map(|v| normalize_vec(v)).collect();
317 &normalized_storage
318 } else {
319 raw_samples
320 };
321
322 match self.config.scheme {
323 QuantizationScheme::Scalar => self.train_scalar(samples)?,
324 QuantizationScheme::Product => self.train_product(samples)?,
325 }
326
327 let error = self.measure_reconstruction_error(samples);
329 self.metrics.mean_reconstruction_error = error;
330
331 Ok(())
332 }
333
334 fn train_scalar(&mut self, samples: &[Vec<f32>]) -> Result<()> {
335 let mut dim_mins = vec![f32::INFINITY; self.dimensions];
336 let mut dim_maxs = vec![f32::NEG_INFINITY; self.dimensions];
337
338 for v in samples {
339 for (d, &val) in v.iter().enumerate() {
340 dim_mins[d] = dim_mins[d].min(val);
341 dim_maxs[d] = dim_maxs[d].max(val);
342 }
343 }
344
345 self.sq_params = dim_mins
346 .into_iter()
347 .zip(dim_maxs)
348 .map(|(mn, mx)| ScalarDimParams::new(mn, mx))
349 .collect();
350
351 Ok(())
352 }
353
354 fn train_product(&mut self, samples: &[Vec<f32>]) -> Result<()> {
355 let n_sub = self.config.pq_n_subspaces;
356 if self.dimensions % n_sub != 0 {
357 return Err(anyhow!(
358 "dimensions ({}) must be divisible by pq_n_subspaces ({})",
359 self.dimensions,
360 n_sub
361 ));
362 }
363 let sub_dim = self.dimensions / n_sub;
364
365 self.pq_codebooks = (0..n_sub)
366 .map(|s| {
367 let sub_vecs: Vec<Vec<f32>> = samples
368 .iter()
369 .map(|v| v[s * sub_dim..(s + 1) * sub_dim].to_vec())
370 .collect();
371 PqCodebook::train(
372 &sub_vecs,
373 self.config.pq_n_centroids,
374 self.config.pq_max_iters,
375 )
376 })
377 .collect();
378
379 Ok(())
380 }
381
382 fn measure_reconstruction_error(&self, samples: &[Vec<f32>]) -> f32 {
383 let limit = samples.len().min(200);
384 let mut total = 0.0_f32;
385 for v in &samples[..limit] {
386 let normalized = if self.config.normalize {
387 normalize_vec(v)
388 } else {
389 v.clone()
390 };
391 let codes = self.encode_vector(&normalized);
392 let reconstructed = self.decode_codes(&codes);
393 let err: f32 = normalized
394 .iter()
395 .zip(&reconstructed)
396 .map(|(&a, &b)| (a - b).abs())
397 .sum::<f32>()
398 / self.dimensions as f32;
399 total += err;
400 }
401 total / limit as f32
402 }
403
404 pub fn add(&mut self, id: String, vector: Vec<f32>) -> Result<()> {
408 self.add_with_metadata(id, vector, HashMap::new())
409 }
410
411 pub fn add_with_metadata(
413 &mut self,
414 id: String,
415 vector: Vec<f32>,
416 metadata: HashMap<String, String>,
417 ) -> Result<()> {
418 if self.is_untrained() {
419 return Err(anyhow!("Cache not trained; call train() first"));
420 }
421 if vector.len() != self.dimensions {
422 return Err(anyhow!(
423 "Vector dim {} ≠ cache dim {}",
424 vector.len(),
425 self.dimensions
426 ));
427 }
428 if self.id_to_idx.contains_key(&id) {
429 return Err(anyhow!("ID '{}' already in cache", id));
430 }
431
432 let normalized = if self.config.normalize {
433 normalize_vec(&vector)
434 } else {
435 vector
436 };
437 let codes = self.encode_vector(&normalized);
438 let idx = self.codes.len();
439
440 self.codes.push(CompressedCode { codes, metadata });
441 self.id_to_idx.insert(id.clone(), idx);
442 self.idx_to_id.push(id);
443
444 let code_len = self.code_length();
446 self.metrics.vector_count += 1;
447 self.metrics.compressed_bytes += code_len;
448 self.metrics.uncompressed_bytes += self.dimensions * 4;
449 self.metrics.compression_ratio =
450 self.metrics.uncompressed_bytes as f64 / self.metrics.compressed_bytes.max(1) as f64;
451
452 Ok(())
453 }
454
455 pub fn get(&self, id: &str) -> Option<Vec<f32>> {
457 let idx = *self.id_to_idx.get(id)?;
458 Some(self.decode_codes(&self.codes[idx].codes))
459 }
460
461 pub fn len(&self) -> usize {
463 self.codes.len()
464 }
465
466 pub fn is_empty(&self) -> bool {
468 self.codes.is_empty()
469 }
470
471 pub fn search(&mut self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
478 if self.is_untrained() {
479 return Err(anyhow!("Cache not trained"));
480 }
481 if query.len() != self.dimensions {
482 return Err(anyhow!(
483 "Query dim {} ≠ cache dim {}",
484 query.len(),
485 self.dimensions
486 ));
487 }
488
489 let normalized_query = if self.config.normalize {
490 normalize_vec(query)
491 } else {
492 query.to_vec()
493 };
494
495 let mut distances: Vec<(usize, f32)> = self
496 .codes
497 .iter()
498 .enumerate()
499 .map(|(i, code)| {
500 let reconstructed = self.decode_codes(&code.codes);
501 let dist = euclidean_sq_slice(&normalized_query, &reconstructed).sqrt();
502 (i, dist)
503 })
504 .collect();
505
506 self.metrics.distance_computations += self.codes.len() as u64;
507 self.metrics.queries_served += 1;
508
509 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
510 distances.truncate(k);
511
512 Ok(distances
513 .into_iter()
514 .map(|(i, d)| (self.idx_to_id[i].clone(), d))
515 .collect())
516 }
517
518 pub fn metrics(&self) -> &CacheMetrics {
520 &self.metrics
521 }
522
523 pub fn config(&self) -> &QuantizedCacheConfig {
525 &self.config
526 }
527
528 fn encode_vector(&self, vector: &[f32]) -> Vec<u8> {
531 match self.config.scheme {
532 QuantizationScheme::Scalar => vector
533 .iter()
534 .zip(&self.sq_params)
535 .map(|(&v, params)| params.quantize(v))
536 .collect(),
537 QuantizationScheme::Product => {
538 let n_sub = self.pq_codebooks.len();
539 if n_sub == 0 {
540 return Vec::new();
541 }
542 let sub_dim = self.dimensions / n_sub;
543 (0..n_sub)
544 .map(|s| {
545 let sub = &vector[s * sub_dim..(s + 1) * sub_dim];
546 self.pq_codebooks[s].encode(sub)
547 })
548 .collect()
549 }
550 }
551 }
552
553 fn decode_codes(&self, codes: &[u8]) -> Vec<f32> {
554 match self.config.scheme {
555 QuantizationScheme::Scalar => codes
556 .iter()
557 .zip(&self.sq_params)
558 .map(|(&code, params)| params.dequantize(code))
559 .collect(),
560 QuantizationScheme::Product => {
561 let n_sub = self.pq_codebooks.len();
562 if n_sub == 0 {
563 return Vec::new();
564 }
565 let mut out = Vec::with_capacity(self.dimensions);
566 for (s, &code) in (0..n_sub).zip(codes.iter()) {
567 out.extend_from_slice(self.pq_codebooks[s].decode(code));
568 }
569 out
570 }
571 }
572 }
573
574 fn code_length(&self) -> usize {
576 match self.config.scheme {
577 QuantizationScheme::Scalar => self.dimensions, QuantizationScheme::Product => self.config.pq_n_subspaces,
579 }
580 }
581
582 fn is_untrained(&self) -> bool {
583 match self.config.scheme {
584 QuantizationScheme::Scalar => self.sq_params.is_empty(),
585 QuantizationScheme::Product => self.pq_codebooks.is_empty(),
586 }
587 }
588}
589
590#[inline]
594fn euclidean_sq_slice(a: &[f32], b: &[f32]) -> f32 {
595 a.iter()
596 .zip(b.iter())
597 .map(|(&x, &y)| {
598 let d = x - y;
599 d * d
600 })
601 .sum()
602}
603
604fn normalize_vec(v: &[f32]) -> Vec<f32> {
606 let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
607 if norm < 1e-9 {
608 v.to_vec()
609 } else {
610 v.iter().map(|&x| x / norm).collect()
611 }
612}
613
614#[cfg(test)]
619mod tests {
620 use super::*;
621
622 fn make_sq_cache(dims: usize) -> QuantizedEmbeddingCache {
623 let config = QuantizedCacheConfig {
624 scheme: QuantizationScheme::Scalar,
625 ..Default::default()
626 };
627 QuantizedEmbeddingCache::new(config, dims)
628 }
629
630 fn make_pq_cache(dims: usize, n_sub: usize) -> QuantizedEmbeddingCache {
631 let config = QuantizedCacheConfig {
632 scheme: QuantizationScheme::Product,
633 pq_n_subspaces: n_sub,
634 pq_n_centroids: 8,
635 pq_max_iters: 5,
636 ..Default::default()
637 };
638 QuantizedEmbeddingCache::new(config, dims)
639 }
640
641 fn training_vecs(n: usize, dims: usize) -> Vec<Vec<f32>> {
642 (0..n)
643 .map(|i| (0..dims).map(|d| (i * dims + d) as f32 * 0.01).collect())
644 .collect()
645 }
646
647 #[test]
650 fn test_sq_train_succeeds() {
651 let mut cache = make_sq_cache(4);
652 let samples = training_vecs(50, 4);
653 assert!(cache.train(&samples).is_ok());
654 assert_eq!(cache.sq_params.len(), 4);
655 }
656
657 #[test]
658 fn test_sq_train_empty_fails() {
659 let mut cache = make_sq_cache(4);
660 assert!(cache.train(&[]).is_err());
661 }
662
663 #[test]
664 fn test_sq_train_wrong_dim_fails() {
665 let mut cache = make_sq_cache(4);
666 let samples = vec![vec![1.0_f32; 8]];
667 assert!(cache.train(&samples).is_err());
668 }
669
670 #[test]
671 fn test_sq_untrained_add_fails() {
672 let mut cache = make_sq_cache(4);
673 let err = cache.add("k".to_string(), vec![0.0; 4]);
674 assert!(err.is_err());
675 }
676
677 #[test]
680 fn test_sq_add_and_get() {
681 let mut cache = make_sq_cache(4);
682 let samples = training_vecs(50, 4);
683 cache.train(&samples).unwrap();
684 cache
685 .add("v0".to_string(), vec![0.1, 0.2, 0.3, 0.4])
686 .unwrap();
687 let reconstructed = cache.get("v0");
688 assert!(reconstructed.is_some());
689 let r = reconstructed.unwrap();
690 assert_eq!(r.len(), 4);
691 for (orig, rec) in [0.1_f32, 0.2, 0.3, 0.4].iter().zip(&r) {
693 assert!((orig - rec).abs() < 0.05, "Reconstruction error too large");
694 }
695 }
696
697 #[test]
698 fn test_sq_duplicate_id_fails() {
699 let mut cache = make_sq_cache(4);
700 cache.train(&training_vecs(10, 4)).unwrap();
701 cache.add("k".to_string(), vec![0.0; 4]).unwrap();
702 assert!(cache.add("k".to_string(), vec![1.0; 4]).is_err());
703 }
704
705 #[test]
706 fn test_sq_get_missing_returns_none() {
707 let mut cache = make_sq_cache(4);
708 cache.train(&training_vecs(10, 4)).unwrap();
709 assert!(cache.get("absent").is_none());
710 }
711
712 #[test]
715 fn test_sq_search_returns_nearest() {
716 let mut cache = make_sq_cache(2);
717 let samples = vec![vec![0.0_f32, 0.0], vec![1.0, 0.0], vec![5.0, 0.0]];
718 cache.train(&samples).unwrap();
719 cache.add("origin".to_string(), vec![0.0, 0.0]).unwrap();
720 cache.add("near".to_string(), vec![0.5, 0.0]).unwrap();
721 cache.add("far".to_string(), vec![5.0, 0.0]).unwrap();
722
723 let results = cache.search(&[0.0, 0.0], 1).unwrap();
724 assert_eq!(results.len(), 1);
725 assert_eq!(results[0].0, "origin");
726 }
727
728 #[test]
729 fn test_sq_search_top_k_ordering() {
730 let mut cache = make_sq_cache(1);
731 let samples: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32]).collect();
732 cache.train(&samples).unwrap();
733 for i in 0..10_u32 {
734 cache.add(format!("v{}", i), vec![i as f32]).unwrap();
735 }
736 let results = cache.search(&[5.0], 3).unwrap();
737 assert!(results.len() <= 3);
738 for w in results.windows(2) {
740 assert!(w[0].1 <= w[1].1 + 1e-6);
741 }
742 }
743
744 #[test]
745 fn test_sq_search_empty_cache() {
746 let mut cache = make_sq_cache(4);
747 cache.train(&training_vecs(10, 4)).unwrap();
748 let results = cache.search(&[0.0; 4], 5).unwrap();
749 assert!(results.is_empty());
750 }
751
752 #[test]
755 fn test_sq_compression_ratio_greater_than_one() {
756 let mut cache = make_sq_cache(32);
757 cache.train(&training_vecs(100, 32)).unwrap();
758 for i in 0..10 {
759 cache.add(format!("v{}", i), vec![0.5; 32]).unwrap();
760 }
761 let m = cache.metrics();
762 assert!(m.compression_ratio > 1.0);
763 assert!(
765 (m.compression_ratio - 4.0).abs() < 0.5,
766 "SQ ratio should be ~4"
767 );
768 }
769
770 #[test]
771 fn test_sq_metrics_vector_count() {
772 let mut cache = make_sq_cache(4);
773 cache.train(&training_vecs(10, 4)).unwrap();
774 for i in 0..5 {
775 cache.add(format!("v{}", i), vec![i as f32; 4]).unwrap();
776 }
777 assert_eq!(cache.metrics().vector_count, 5);
778 }
779
780 #[test]
781 fn test_sq_queries_served_increments() {
782 let mut cache = make_sq_cache(4);
783 cache.train(&training_vecs(10, 4)).unwrap();
784 cache.add("a".to_string(), vec![0.0; 4]).unwrap();
785 cache.search(&[0.0; 4], 1).unwrap();
786 cache.search(&[0.0; 4], 1).unwrap();
787 assert_eq!(cache.metrics().queries_served, 2);
788 }
789
790 #[test]
791 fn test_sq_reconstruction_error_reasonable() {
792 let mut cache = make_sq_cache(4);
793 let samples = training_vecs(100, 4);
794 cache.train(&samples).unwrap();
795 assert!(cache.metrics().mean_reconstruction_error < 0.1);
797 }
798
799 #[test]
802 fn test_pq_train_succeeds() {
803 let mut cache = make_pq_cache(8, 2);
804 let samples = training_vecs(50, 8);
805 assert!(cache.train(&samples).is_ok());
806 assert_eq!(cache.pq_codebooks.len(), 2);
807 }
808
809 #[test]
810 fn test_pq_train_indivisible_dims_fails() {
811 let mut cache = make_pq_cache(7, 3); let samples = training_vecs(30, 7);
813 assert!(cache.train(&samples).is_err());
814 }
815
816 #[test]
817 fn test_pq_add_and_get() {
818 let mut cache = make_pq_cache(8, 2);
819 let samples = training_vecs(50, 8);
820 cache.train(&samples).unwrap();
821 cache.add("v0".to_string(), vec![0.1; 8]).unwrap();
822 let r = cache.get("v0").unwrap();
823 assert_eq!(r.len(), 8);
824 }
825
826 #[test]
827 fn test_pq_compression_ratio() {
828 let mut cache = make_pq_cache(16, 4); cache.train(&training_vecs(50, 16)).unwrap();
830 for i in 0..8 {
831 cache.add(format!("v{}", i), vec![0.5; 16]).unwrap();
832 }
833 let m = cache.metrics();
834 assert!(m.compression_ratio > 4.0, "PQ ratio should be > 4");
836 }
837
838 #[test]
839 fn test_pq_search() {
840 let mut cache = make_pq_cache(8, 2);
841 let samples = training_vecs(50, 8);
842 cache.train(&samples).unwrap();
843 cache.add("a".to_string(), vec![0.0; 8]).unwrap();
844 cache.add("b".to_string(), vec![10.0; 8]).unwrap();
845 let results = cache.search(&[0.1; 8], 1).unwrap();
846 assert!(!results.is_empty());
847 }
848
849 #[test]
852 fn test_normalized_vectors_stored_as_unit_length() {
853 let config = QuantizedCacheConfig {
854 scheme: QuantizationScheme::Scalar,
855 normalize: true,
856 ..Default::default()
857 };
858 let mut cache = QuantizedEmbeddingCache::new(config, 4);
859 let long_vecs: Vec<Vec<f32>> = (0..20)
860 .map(|i| vec![i as f32 + 1.0, i as f32 + 2.0, 0.0, 0.0])
861 .collect();
862 cache.train(&long_vecs).unwrap();
863 cache
864 .add("v".to_string(), vec![3.0, 4.0, 0.0, 0.0])
865 .unwrap();
866 let r = cache.get("v").unwrap();
867 let norm: f32 = r.iter().map(|&x| x * x).sum::<f32>().sqrt();
868 assert!((norm - 1.0).abs() < 0.1, "norm={}, expected ~1.0", norm);
870 }
871
872 #[test]
875 fn test_config_accessors() {
876 let config = QuantizedCacheConfig {
877 scheme: QuantizationScheme::Product,
878 pq_n_subspaces: 4,
879 pq_n_centroids: 16,
880 ..Default::default()
881 };
882 let cache = QuantizedEmbeddingCache::new(config, 8);
883 assert_eq!(cache.config().pq_n_subspaces, 4);
884 assert_eq!(cache.config().pq_n_centroids, 16);
885 }
886
887 #[test]
888 fn test_is_empty_initially() {
889 let mut cache = make_sq_cache(4);
890 cache.train(&training_vecs(10, 4)).unwrap();
891 assert!(cache.is_empty());
892 }
893
894 #[test]
895 fn test_len_after_adds() {
896 let mut cache = make_sq_cache(4);
897 cache.train(&training_vecs(10, 4)).unwrap();
898 for i in 0..5 {
899 cache.add(format!("v{}", i), vec![0.0; 4]).unwrap();
900 }
901 assert_eq!(cache.len(), 5);
902 }
903
904 #[test]
907 fn test_add_with_metadata() {
908 let mut cache = make_sq_cache(4);
909 cache.train(&training_vecs(10, 4)).unwrap();
910 let mut meta = HashMap::new();
911 meta.insert("tag".to_string(), "test".to_string());
912 cache
913 .add_with_metadata("m".to_string(), vec![0.0; 4], meta)
914 .unwrap();
915 assert_eq!(cache.len(), 1);
916 }
917
918 #[test]
921 fn test_scalar_dim_params_roundtrip() {
922 let params = ScalarDimParams::new(-1.0, 1.0);
923 let q = params.quantize(0.0);
924 let r = params.dequantize(q);
925 assert!((r - 0.0).abs() < 0.02);
926 }
927
928 #[test]
929 fn test_scalar_dim_params_extremes() {
930 let params = ScalarDimParams::new(0.0, 1.0);
931 assert_eq!(params.quantize(0.0), 0);
932 assert_eq!(params.quantize(1.0), 255);
933 }
934}