1use ipfrs_core::{Error, Result};
7use nalgebra::{DMatrix, DVector};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ScalarQuantizerConfig {
13 pub bits: u8,
15 pub signed: bool,
17 pub min_values: Vec<f32>,
19 pub max_values: Vec<f32>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ScalarQuantizer {
29 config: ScalarQuantizerConfig,
31 dimension: usize,
33 trained: bool,
35}
36
37impl ScalarQuantizer {
38 pub fn new(dimension: usize, signed: bool) -> Self {
44 Self {
45 config: ScalarQuantizerConfig {
46 bits: 8,
47 signed,
48 min_values: vec![f32::MAX; dimension],
49 max_values: vec![f32::MIN; dimension],
50 },
51 dimension,
52 trained: false,
53 }
54 }
55
56 pub fn uint8(dimension: usize) -> Self {
58 Self::new(dimension, false)
59 }
60
61 pub fn int8(dimension: usize) -> Self {
63 Self::new(dimension, true)
64 }
65
66 pub fn train(&mut self, vectors: &[Vec<f32>]) -> Result<()> {
74 if vectors.is_empty() {
75 return Err(Error::InvalidInput(
76 "Cannot train on empty vector set".to_string(),
77 ));
78 }
79
80 for (i, vec) in vectors.iter().enumerate() {
82 if vec.len() != self.dimension {
83 return Err(Error::InvalidInput(format!(
84 "Vector {} has dimension {}, expected {}",
85 i,
86 vec.len(),
87 self.dimension
88 )));
89 }
90 }
91
92 self.config.min_values = vec![f32::MAX; self.dimension];
94 self.config.max_values = vec![f32::MIN; self.dimension];
95
96 for vec in vectors {
98 for (i, &val) in vec.iter().enumerate() {
99 if val < self.config.min_values[i] {
100 self.config.min_values[i] = val;
101 }
102 if val > self.config.max_values[i] {
103 self.config.max_values[i] = val;
104 }
105 }
106 }
107
108 for i in 0..self.dimension {
110 let range = self.config.max_values[i] - self.config.min_values[i];
111 if range < 1e-6 {
112 self.config.min_values[i] -= 0.5;
114 self.config.max_values[i] += 0.5;
115 } else {
116 let margin = range * 0.01;
117 self.config.min_values[i] -= margin;
118 self.config.max_values[i] += margin;
119 }
120 }
121
122 self.trained = true;
123 Ok(())
124 }
125
126 pub fn train_incremental(&mut self, vector: &[f32]) -> Result<()> {
128 if vector.len() != self.dimension {
129 return Err(Error::InvalidInput(format!(
130 "Vector has dimension {}, expected {}",
131 vector.len(),
132 self.dimension
133 )));
134 }
135
136 for (i, &val) in vector.iter().enumerate() {
137 if val < self.config.min_values[i] {
138 self.config.min_values[i] = val;
139 }
140 if val > self.config.max_values[i] {
141 self.config.max_values[i] = val;
142 }
143 }
144
145 self.trained = true;
146 Ok(())
147 }
148
149 pub fn is_trained(&self) -> bool {
151 self.trained
152 }
153
154 pub fn quantize(&self, vector: &[f32]) -> Result<QuantizedVector> {
156 if !self.trained {
157 return Err(Error::InvalidInput(
158 "Quantizer must be trained before use".to_string(),
159 ));
160 }
161
162 if vector.len() != self.dimension {
163 return Err(Error::InvalidInput(format!(
164 "Vector has dimension {}, expected {}",
165 vector.len(),
166 self.dimension
167 )));
168 }
169
170 let mut quantized = Vec::with_capacity(self.dimension);
171
172 for (i, &val) in vector.iter().enumerate() {
173 let min = self.config.min_values[i];
174 let max = self.config.max_values[i];
175 let range = max - min;
176
177 let normalized = if range > 1e-6 {
179 ((val - min) / range).clamp(0.0, 1.0)
180 } else {
181 0.5
182 };
183
184 let q = if self.config.signed {
186 ((normalized * 255.0 - 128.0).round() as i8) as u8
188 } else {
189 (normalized * 255.0).round() as u8
191 };
192
193 quantized.push(q);
194 }
195
196 Ok(QuantizedVector {
197 data: quantized,
198 signed: self.config.signed,
199 })
200 }
201
202 pub fn dequantize(&self, quantized: &QuantizedVector) -> Result<Vec<f32>> {
204 if !self.trained {
205 return Err(Error::InvalidInput(
206 "Quantizer must be trained before use".to_string(),
207 ));
208 }
209
210 if quantized.data.len() != self.dimension {
211 return Err(Error::InvalidInput(format!(
212 "Quantized vector has dimension {}, expected {}",
213 quantized.data.len(),
214 self.dimension
215 )));
216 }
217
218 let mut result = Vec::with_capacity(self.dimension);
219
220 for (i, &q) in quantized.data.iter().enumerate() {
221 let min = self.config.min_values[i];
222 let max = self.config.max_values[i];
223 let range = max - min;
224
225 let normalized = if self.config.signed {
227 ((q as i8) as f32 + 128.0) / 255.0
229 } else {
230 q as f32 / 255.0
232 };
233
234 let val = min + normalized * range;
235 result.push(val);
236 }
237
238 Ok(result)
239 }
240
241 pub fn distance_l2_quantized(&self, a: &QuantizedVector, b: &QuantizedVector) -> Result<f32> {
245 if a.data.len() != b.data.len() {
246 return Err(Error::InvalidInput(
247 "Vectors must have same dimension".to_string(),
248 ));
249 }
250
251 let mut sum_sq: i64 = 0;
252
253 for (qa, qb) in a.data.iter().zip(b.data.iter()) {
254 let diff = if a.signed {
255 (*qa as i8 as i64) - (*qb as i8 as i64)
256 } else {
257 (*qa as i64) - (*qb as i64)
258 };
259 sum_sq += diff * diff;
260 }
261
262 Ok((sum_sq as f32).sqrt() / 255.0)
265 }
266
267 pub fn dot_product_quantized(&self, a: &QuantizedVector, b: &QuantizedVector) -> Result<f32> {
269 if a.data.len() != b.data.len() {
270 return Err(Error::InvalidInput(
271 "Vectors must have same dimension".to_string(),
272 ));
273 }
274
275 let mut sum: i64 = 0;
276
277 for (qa, qb) in a.data.iter().zip(b.data.iter()) {
278 if a.signed {
279 sum += (*qa as i8 as i64) * (*qb as i8 as i64);
280 } else {
281 sum += (*qa as i64) * (*qb as i64);
282 }
283 }
284
285 Ok(sum as f32 / (255.0 * 255.0))
287 }
288
289 pub fn dimension(&self) -> usize {
291 self.dimension
292 }
293
294 pub fn compression_ratio(&self) -> f32 {
296 4.0 }
298
299 pub fn memory_estimate(&self, num_vectors: usize) -> usize {
301 num_vectors * self.dimension + 2 * self.dimension * 4
304 }
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
309pub struct QuantizedVector {
310 pub data: Vec<u8>,
312 pub signed: bool,
314}
315
316impl QuantizedVector {
317 pub fn new(data: Vec<u8>, signed: bool) -> Self {
319 Self { data, signed }
320 }
321
322 pub fn dimension(&self) -> usize {
324 self.data.len()
325 }
326
327 pub fn as_bytes(&self) -> &[u8] {
329 &self.data
330 }
331
332 pub fn size_bytes(&self) -> usize {
334 self.data.len()
335 }
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct ProductQuantizerConfig {
341 pub dimension: usize,
343 pub num_subquantizers: usize,
345 pub bits_per_subquantizer: u8,
347 pub codebooks: Vec<Vec<Vec<f32>>>,
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct ProductQuantizer {
357 config: ProductQuantizerConfig,
359 subdimension: usize,
361 num_centroids: usize,
363 trained: bool,
365}
366
367impl ProductQuantizer {
368 pub fn new(dimension: usize, num_subquantizers: usize, bits: u8) -> Result<Self> {
375 if !dimension.is_multiple_of(num_subquantizers) {
376 return Err(Error::InvalidInput(format!(
377 "Dimension {} must be divisible by num_subquantizers {}",
378 dimension, num_subquantizers
379 )));
380 }
381
382 if bits > 16 {
383 return Err(Error::InvalidInput(
384 "Bits per subquantizer must be <= 16".to_string(),
385 ));
386 }
387
388 let subdimension = dimension / num_subquantizers;
389 let num_centroids = 1 << bits;
390
391 Ok(Self {
392 config: ProductQuantizerConfig {
393 dimension,
394 num_subquantizers,
395 bits_per_subquantizer: bits,
396 codebooks: Vec::new(),
397 },
398 subdimension,
399 num_centroids,
400 trained: false,
401 })
402 }
403
404 pub fn standard(dimension: usize) -> Result<Self> {
406 Self::new(dimension, 8, 8)
407 }
408
409 pub fn train(&mut self, vectors: &[Vec<f32>], max_iterations: usize) -> Result<()> {
415 if vectors.is_empty() {
416 return Err(Error::InvalidInput(
417 "Cannot train on empty vector set".to_string(),
418 ));
419 }
420
421 for (i, vec) in vectors.iter().enumerate() {
423 if vec.len() != self.config.dimension {
424 return Err(Error::InvalidInput(format!(
425 "Vector {} has dimension {}, expected {}",
426 i,
427 vec.len(),
428 self.config.dimension
429 )));
430 }
431 }
432
433 self.config.codebooks = Vec::with_capacity(self.config.num_subquantizers);
435
436 for sq in 0..self.config.num_subquantizers {
437 let start = sq * self.subdimension;
438 let end = start + self.subdimension;
439
440 let subvectors: Vec<Vec<f32>> =
442 vectors.iter().map(|v| v[start..end].to_vec()).collect();
443
444 let centroids = self.kmeans(&subvectors, self.num_centroids, max_iterations)?;
446 self.config.codebooks.push(centroids);
447 }
448
449 self.trained = true;
450 Ok(())
451 }
452
453 fn kmeans(&self, data: &[Vec<f32>], k: usize, max_iterations: usize) -> Result<Vec<Vec<f32>>> {
455 if data.is_empty() {
456 return Err(Error::InvalidInput("Empty data for k-means".to_string()));
457 }
458
459 let dim = data[0].len();
460 let n = data.len();
461 let actual_k = k.min(n); let mut centroids = Vec::with_capacity(actual_k);
465
466 centroids.push(data[0].clone());
468
469 for _ in 1..actual_k {
471 let mut best_idx = 0;
472 let mut best_dist = 0.0f32;
473
474 for (i, vec) in data.iter().enumerate() {
475 let min_dist = centroids
476 .iter()
477 .map(|c| self.l2_distance(vec, c))
478 .fold(f32::MAX, |a, b| a.min(b));
479
480 if min_dist > best_dist {
481 best_dist = min_dist;
482 best_idx = i;
483 }
484 }
485
486 centroids.push(data[best_idx].clone());
487 }
488
489 let mut assignments = vec![0usize; n];
491
492 for _iter in 0..max_iterations {
493 let mut changed = false;
495 for (i, vec) in data.iter().enumerate() {
496 let nearest = centroids
497 .iter()
498 .enumerate()
499 .map(|(j, c)| (j, self.l2_distance(vec, c)))
500 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
501 .map(|(j, _)| j)
502 .unwrap_or(0);
503
504 if assignments[i] != nearest {
505 assignments[i] = nearest;
506 changed = true;
507 }
508 }
509
510 if !changed {
511 break;
512 }
513
514 let mut new_centroids = vec![vec![0.0f32; dim]; actual_k];
516 let mut counts = vec![0usize; actual_k];
517
518 for (i, vec) in data.iter().enumerate() {
519 let cluster = assignments[i];
520 counts[cluster] += 1;
521 for (j, &val) in vec.iter().enumerate() {
522 new_centroids[cluster][j] += val;
523 }
524 }
525
526 for (i, centroid) in new_centroids.iter_mut().enumerate() {
527 if counts[i] > 0 {
528 for val in centroid.iter_mut() {
529 *val /= counts[i] as f32;
530 }
531 } else {
532 *centroid = centroids[i].clone();
534 }
535 }
536
537 centroids = new_centroids;
538 }
539
540 while centroids.len() < k {
542 centroids.push(centroids[centroids.len() - 1].clone());
543 }
544
545 Ok(centroids)
546 }
547
548 fn l2_distance(&self, a: &[f32], b: &[f32]) -> f32 {
549 a.iter()
550 .zip(b.iter())
551 .map(|(x, y)| (x - y) * (x - y))
552 .sum::<f32>()
553 .sqrt()
554 }
555
556 pub fn quantize(&self, vector: &[f32]) -> Result<PQCode> {
558 if !self.trained {
559 return Err(Error::InvalidInput(
560 "Product quantizer must be trained before use".to_string(),
561 ));
562 }
563
564 if vector.len() != self.config.dimension {
565 return Err(Error::InvalidInput(format!(
566 "Vector has dimension {}, expected {}",
567 vector.len(),
568 self.config.dimension
569 )));
570 }
571
572 let mut codes = Vec::with_capacity(self.config.num_subquantizers);
573
574 for sq in 0..self.config.num_subquantizers {
575 let start = sq * self.subdimension;
576 let end = start + self.subdimension;
577 let subvector = &vector[start..end];
578
579 let codebook = &self.config.codebooks[sq];
581 let (best_idx, _) = codebook
582 .iter()
583 .enumerate()
584 .map(|(i, c)| (i, self.l2_distance(subvector, c)))
585 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
586 .unwrap_or((0, 0.0));
587
588 codes.push(best_idx as u8);
589 }
590
591 Ok(PQCode { codes })
592 }
593
594 pub fn dequantize(&self, code: &PQCode) -> Result<Vec<f32>> {
596 if !self.trained {
597 return Err(Error::InvalidInput(
598 "Product quantizer must be trained before use".to_string(),
599 ));
600 }
601
602 if code.codes.len() != self.config.num_subquantizers {
603 return Err(Error::InvalidInput(format!(
604 "PQ code has {} elements, expected {}",
605 code.codes.len(),
606 self.config.num_subquantizers
607 )));
608 }
609
610 let mut result = Vec::with_capacity(self.config.dimension);
611
612 for (sq, &idx) in code.codes.iter().enumerate() {
613 let centroid = &self.config.codebooks[sq][idx as usize];
614 result.extend_from_slice(centroid);
615 }
616
617 Ok(result)
618 }
619
620 pub fn asymmetric_distance(&self, query: &[f32], code: &PQCode) -> Result<f32> {
624 if !self.trained {
625 return Err(Error::InvalidInput(
626 "Product quantizer must be trained".to_string(),
627 ));
628 }
629
630 let mut total_dist_sq = 0.0f32;
631
632 for sq in 0..self.config.num_subquantizers {
633 let start = sq * self.subdimension;
634 let end = start + self.subdimension;
635 let subquery = &query[start..end];
636 let centroid = &self.config.codebooks[sq][code.codes[sq] as usize];
637
638 for (q, c) in subquery.iter().zip(centroid.iter()) {
639 let diff = q - c;
640 total_dist_sq += diff * diff;
641 }
642 }
643
644 Ok(total_dist_sq.sqrt())
645 }
646
647 pub fn compute_distance_table(&self, query: &[f32]) -> Result<Vec<Vec<f32>>> {
651 if !self.trained {
652 return Err(Error::InvalidInput(
653 "Product quantizer must be trained".to_string(),
654 ));
655 }
656
657 let mut table = Vec::with_capacity(self.config.num_subquantizers);
658
659 for sq in 0..self.config.num_subquantizers {
660 let start = sq * self.subdimension;
661 let end = start + self.subdimension;
662 let subquery = &query[start..end];
663
664 let distances: Vec<f32> = self.config.codebooks[sq]
665 .iter()
666 .map(|c| {
667 subquery
668 .iter()
669 .zip(c.iter())
670 .map(|(q, c)| (q - c) * (q - c))
671 .sum::<f32>()
672 })
673 .collect();
674
675 table.push(distances);
676 }
677
678 Ok(table)
679 }
680
681 pub fn distance_from_table(&self, table: &[Vec<f32>], code: &PQCode) -> f32 {
683 let mut total = 0.0f32;
684 for (sq, &idx) in code.codes.iter().enumerate() {
685 total += table[sq][idx as usize];
686 }
687 total.sqrt()
688 }
689
690 pub fn compression_ratio(&self) -> f32 {
692 (self.config.dimension * 4) as f32 / self.config.num_subquantizers as f32
695 }
696
697 pub fn is_trained(&self) -> bool {
699 self.trained
700 }
701}
702
703#[derive(Debug, Clone, Serialize, Deserialize)]
705pub struct PQCode {
706 pub codes: Vec<u8>,
708}
709
710impl PQCode {
711 pub fn size_bytes(&self) -> usize {
713 self.codes.len()
714 }
715}
716
717#[derive(Debug, Clone, Serialize, Deserialize)]
722pub struct OptimizedProductQuantizer {
723 pq: ProductQuantizer,
725 #[serde(with = "rotation_matrix_serde")]
727 rotation: Option<DMatrix<f32>>,
728 rotation_trained: bool,
730}
731
732mod rotation_matrix_serde {
734 use super::DMatrix;
735 use serde::{Deserialize, Deserializer, Serialize, Serializer};
736
737 #[derive(Serialize, Deserialize)]
738 struct MatrixData {
739 nrows: usize,
740 ncols: usize,
741 data: Vec<f32>,
742 }
743
744 pub fn serialize<S>(
745 matrix: &Option<DMatrix<f32>>,
746 serializer: S,
747 ) -> std::result::Result<S::Ok, S::Error>
748 where
749 S: Serializer,
750 {
751 let opt_data = matrix.as_ref().map(|m| MatrixData {
752 nrows: m.nrows(),
753 ncols: m.ncols(),
754 data: m.as_slice().to_vec(),
755 });
756 opt_data.serialize(serializer)
757 }
758
759 pub fn deserialize<'de, D>(
760 deserializer: D,
761 ) -> std::result::Result<Option<DMatrix<f32>>, D::Error>
762 where
763 D: Deserializer<'de>,
764 {
765 let opt: Option<MatrixData> = Option::deserialize(deserializer)?;
766 Ok(opt.map(|data| DMatrix::from_vec(data.nrows, data.ncols, data.data)))
767 }
768}
769
770impl OptimizedProductQuantizer {
771 pub fn new(dimension: usize, num_subquantizers: usize, bits: u8) -> Result<Self> {
778 let pq = ProductQuantizer::new(dimension, num_subquantizers, bits)?;
779 Ok(Self {
780 pq,
781 rotation: None,
782 rotation_trained: false,
783 })
784 }
785
786 pub fn standard(dimension: usize) -> Result<Self> {
788 Self::new(dimension, 8, 8)
789 }
790
791 #[allow(clippy::too_many_arguments)]
800 pub fn train(
801 &mut self,
802 vectors: &[Vec<f32>],
803 max_iterations: usize,
804 rotation_iterations: usize,
805 ) -> Result<()> {
806 if vectors.is_empty() {
807 return Err(Error::InvalidInput(
808 "Cannot train on empty vector set".to_string(),
809 ));
810 }
811
812 let dim = self.pq.config.dimension;
813
814 for (i, vec) in vectors.iter().enumerate() {
816 if vec.len() != dim {
817 return Err(Error::InvalidInput(format!(
818 "Vector {} has dimension {}, expected {}",
819 i,
820 vec.len(),
821 dim
822 )));
823 }
824 }
825
826 let mut rotation = DMatrix::<f32>::identity(dim, dim);
828
829 for iteration in 0..rotation_iterations {
831 let rotated = self.apply_rotation_batch(vectors, &rotation);
833
834 self.pq.train(&rotated, max_iterations)?;
836
837 if iteration < rotation_iterations - 1 {
839 rotation = self.learn_rotation(vectors, &self.pq)?;
840 }
841 }
842
843 self.rotation = Some(rotation);
844 self.rotation_trained = true;
845
846 Ok(())
847 }
848
849 #[allow(dead_code)]
853 fn learn_rotation(&self, vectors: &[Vec<f32>], pq: &ProductQuantizer) -> Result<DMatrix<f32>> {
854 let dim = pq.config.dimension;
855 let n = vectors.len();
856
857 let mut cov = DMatrix::<f32>::zeros(dim, dim);
859
860 for vec in vectors {
861 let code = pq.quantize(vec)?;
863 let reconstructed = pq.dequantize(&code)?;
864
865 let v = DVector::from_vec(vec.clone());
867 let r = DVector::from_vec(reconstructed);
868 cov += v * r.transpose();
869 }
870
871 cov /= n as f32;
872
873 let svd = cov.svd(true, true);
875
876 match (svd.u, svd.v_t) {
878 (Some(u), Some(vt)) => Ok(u * vt),
879 _ => {
880 Ok(DMatrix::identity(dim, dim))
882 }
883 }
884 }
885
886 fn apply_rotation_batch(&self, vectors: &[Vec<f32>], rotation: &DMatrix<f32>) -> Vec<Vec<f32>> {
888 vectors
889 .iter()
890 .map(|v| self.apply_rotation(v, rotation))
891 .collect()
892 }
893
894 fn apply_rotation(&self, vector: &[f32], rotation: &DMatrix<f32>) -> Vec<f32> {
896 let v = DVector::from_vec(vector.to_vec());
897 let rotated = rotation * v;
898 rotated.as_slice().to_vec()
899 }
900
901 pub fn quantize(&self, vector: &[f32]) -> Result<PQCode> {
903 if !self.is_trained() {
904 return Err(Error::InvalidInput("OPQ must be trained".to_string()));
905 }
906
907 let rotated = match &self.rotation {
909 Some(r) => self.apply_rotation(vector, r),
910 None => vector.to_vec(),
911 };
912
913 self.pq.quantize(&rotated)
914 }
915
916 pub fn dequantize(&self, code: &PQCode) -> Result<Vec<f32>> {
918 if !self.is_trained() {
919 return Err(Error::InvalidInput("OPQ must be trained".to_string()));
920 }
921
922 let rotated = self.pq.dequantize(code)?;
924
925 match &self.rotation {
927 Some(r) => {
928 let r_inv = r.transpose();
930 Ok(self.apply_rotation(&rotated, &r_inv))
931 }
932 None => Ok(rotated),
933 }
934 }
935
936 pub fn asymmetric_distance(&self, query: &[f32], code: &PQCode) -> Result<f32> {
938 if !self.is_trained() {
939 return Err(Error::InvalidInput("OPQ must be trained".to_string()));
940 }
941
942 let rotated_query = match &self.rotation {
944 Some(r) => self.apply_rotation(query, r),
945 None => query.to_vec(),
946 };
947
948 self.pq.asymmetric_distance(&rotated_query, code)
949 }
950
951 pub fn compute_distance_table(&self, query: &[f32]) -> Result<Vec<Vec<f32>>> {
953 if !self.is_trained() {
954 return Err(Error::InvalidInput("OPQ must be trained".to_string()));
955 }
956
957 let rotated_query = match &self.rotation {
959 Some(r) => self.apply_rotation(query, r),
960 None => query.to_vec(),
961 };
962
963 self.pq.compute_distance_table(&rotated_query)
964 }
965
966 pub fn distance_from_table(&self, table: &[Vec<f32>], code: &PQCode) -> f32 {
968 self.pq.distance_from_table(table, code)
969 }
970
971 pub fn compression_ratio(&self) -> f32 {
973 self.pq.compression_ratio()
974 }
975
976 pub fn is_trained(&self) -> bool {
978 self.pq.is_trained() && self.rotation_trained
979 }
980
981 #[allow(dead_code)]
983 pub fn inner_pq(&self) -> &ProductQuantizer {
984 &self.pq
985 }
986}
987
988#[derive(Debug, Clone, Serialize, Deserialize)]
990pub struct QuantizationBenchmark {
991 pub recall_at_k: Vec<(usize, f32)>,
993 pub compression_ratio: f32,
995 pub avg_quantization_error: f32,
997 pub max_quantization_error: f32,
999 pub memory_savings: usize,
1001 pub speed_factor: f32,
1003}
1004
1005impl QuantizationBenchmark {
1006 pub fn summary(&self) -> String {
1008 let recall_str: Vec<String> = self
1009 .recall_at_k
1010 .iter()
1011 .map(|(k, r)| format!("R@{}: {:.2}%", k, r * 100.0))
1012 .collect();
1013
1014 format!(
1015 "Compression: {:.1}x, Avg Error: {:.4}, {}, Memory Saved: {} bytes",
1016 self.compression_ratio,
1017 self.avg_quantization_error,
1018 recall_str.join(", "),
1019 self.memory_savings
1020 )
1021 }
1022}
1023
1024pub struct QuantizationBenchmarker;
1026
1027impl QuantizationBenchmarker {
1028 pub fn benchmark_scalar(
1037 quantizer: &ScalarQuantizer,
1038 vectors: &[Vec<f32>],
1039 queries: &[Vec<f32>],
1040 ground_truth: &[Vec<usize>],
1041 k_values: &[usize],
1042 ) -> Result<QuantizationBenchmark> {
1043 if !quantizer.is_trained() {
1044 return Err(Error::InvalidInput("Quantizer must be trained".to_string()));
1045 }
1046
1047 let quantized: Vec<QuantizedVector> = vectors
1049 .iter()
1050 .map(|v| quantizer.quantize(v))
1051 .collect::<Result<Vec<_>>>()?;
1052
1053 let mut total_error = 0.0f32;
1055 let mut max_error = 0.0f32;
1056
1057 for (i, qv) in quantized.iter().enumerate() {
1058 let restored = quantizer.dequantize(qv)?;
1059 let error: f32 = vectors[i]
1060 .iter()
1061 .zip(restored.iter())
1062 .map(|(a, b)| (a - b).powi(2))
1063 .sum::<f32>()
1064 .sqrt();
1065 total_error += error;
1066 max_error = max_error.max(error);
1067 }
1068
1069 let avg_error = total_error / vectors.len() as f32;
1070
1071 let mut recall_at_k = Vec::new();
1073
1074 for &k in k_values {
1075 let mut total_recall = 0.0f32;
1076
1077 for (qi, query) in queries.iter().enumerate() {
1078 let query_quantized = quantizer.quantize(query)?;
1079
1080 let mut distances: Vec<(usize, f32)> = quantized
1082 .iter()
1083 .enumerate()
1084 .map(|(i, qv)| {
1085 let dist = quantizer
1086 .distance_l2_quantized(&query_quantized, qv)
1087 .unwrap_or(f32::MAX);
1088 (i, dist)
1089 })
1090 .collect();
1091
1092 distances
1093 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1094
1095 let found: std::collections::HashSet<usize> =
1096 distances.iter().take(k).map(|(i, _)| *i).collect();
1097
1098 let gt: std::collections::HashSet<usize> =
1099 ground_truth[qi].iter().take(k).copied().collect();
1100
1101 let intersection = found.intersection(>).count();
1102 total_recall += intersection as f32 / k.min(gt.len()) as f32;
1103 }
1104
1105 let recall = total_recall / queries.len() as f32;
1106 recall_at_k.push((k, recall));
1107 }
1108
1109 let original_size = vectors.len() * vectors[0].len() * 4; let quantized_size = vectors.len() * vectors[0].len(); let memory_savings = original_size - quantized_size;
1113
1114 Ok(QuantizationBenchmark {
1115 recall_at_k,
1116 compression_ratio: quantizer.compression_ratio(),
1117 avg_quantization_error: avg_error,
1118 max_quantization_error: max_error,
1119 memory_savings,
1120 speed_factor: 2.0, })
1122 }
1123
1124 pub fn benchmark_pq(
1126 pq: &ProductQuantizer,
1127 vectors: &[Vec<f32>],
1128 queries: &[Vec<f32>],
1129 ground_truth: &[Vec<usize>],
1130 k_values: &[usize],
1131 ) -> Result<QuantizationBenchmark> {
1132 if !pq.is_trained() {
1133 return Err(Error::InvalidInput("PQ must be trained".to_string()));
1134 }
1135
1136 let codes: Vec<PQCode> = vectors
1138 .iter()
1139 .map(|v| pq.quantize(v))
1140 .collect::<Result<Vec<_>>>()?;
1141
1142 let mut total_error = 0.0f32;
1144 let mut max_error = 0.0f32;
1145
1146 for (i, code) in codes.iter().enumerate() {
1147 let restored = pq.dequantize(code)?;
1148 let error: f32 = vectors[i]
1149 .iter()
1150 .zip(restored.iter())
1151 .map(|(a, b)| (a - b).powi(2))
1152 .sum::<f32>()
1153 .sqrt();
1154 total_error += error;
1155 max_error = max_error.max(error);
1156 }
1157
1158 let avg_error = total_error / vectors.len() as f32;
1159
1160 let mut recall_at_k = Vec::new();
1162
1163 for &k in k_values {
1164 let mut total_recall = 0.0f32;
1165
1166 for (qi, query) in queries.iter().enumerate() {
1167 let table = pq.compute_distance_table(query)?;
1169
1170 let mut distances: Vec<(usize, f32)> = codes
1171 .iter()
1172 .enumerate()
1173 .map(|(i, code)| (i, pq.distance_from_table(&table, code)))
1174 .collect();
1175
1176 distances
1177 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1178
1179 let found: std::collections::HashSet<usize> =
1180 distances.iter().take(k).map(|(i, _)| *i).collect();
1181
1182 let gt: std::collections::HashSet<usize> =
1183 ground_truth[qi].iter().take(k).copied().collect();
1184
1185 let intersection = found.intersection(>).count();
1186 total_recall += intersection as f32 / k.min(gt.len()) as f32;
1187 }
1188
1189 let recall = total_recall / queries.len() as f32;
1190 recall_at_k.push((k, recall));
1191 }
1192
1193 let original_size = vectors.len() * vectors[0].len() * 4;
1195 let quantized_size = vectors.len() * codes[0].size_bytes();
1196 let memory_savings = original_size.saturating_sub(quantized_size);
1197
1198 Ok(QuantizationBenchmark {
1199 recall_at_k,
1200 compression_ratio: pq.compression_ratio(),
1201 avg_quantization_error: avg_error,
1202 max_quantization_error: max_error,
1203 memory_savings,
1204 speed_factor: 4.0, })
1206 }
1207
1208 pub fn compute_ground_truth(
1210 vectors: &[Vec<f32>],
1211 queries: &[Vec<f32>],
1212 k: usize,
1213 ) -> Vec<Vec<usize>> {
1214 queries
1215 .iter()
1216 .map(|query| {
1217 let mut distances: Vec<(usize, f32)> = vectors
1218 .iter()
1219 .enumerate()
1220 .map(|(i, v)| {
1221 let dist: f32 = query
1222 .iter()
1223 .zip(v.iter())
1224 .map(|(a, b)| (a - b).powi(2))
1225 .sum();
1226 (i, dist)
1227 })
1228 .collect();
1229
1230 distances
1231 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1232
1233 distances.iter().take(k).map(|(i, _)| *i).collect()
1234 })
1235 .collect()
1236 }
1237
1238 pub fn compare_methods(
1240 vectors: &[Vec<f32>],
1241 queries: &[Vec<f32>],
1242 k_values: &[usize],
1243 ) -> Result<QuantizationComparison> {
1244 let max_k = *k_values.iter().max().unwrap_or(&10);
1245 let ground_truth = Self::compute_ground_truth(vectors, queries, max_k);
1246
1247 let mut sq = ScalarQuantizer::uint8(vectors[0].len());
1249 sq.train(vectors)?;
1250 let scalar_results =
1251 Self::benchmark_scalar(&sq, vectors, queries, &ground_truth, k_values)?;
1252
1253 let dim = vectors[0].len();
1255 let pq_results = if dim >= 8 && dim.is_multiple_of(8) {
1256 let mut pq = ProductQuantizer::new(dim, 8, 8)?;
1257 pq.train(vectors, 20)?;
1258 Some(Self::benchmark_pq(
1259 &pq,
1260 vectors,
1261 queries,
1262 &ground_truth,
1263 k_values,
1264 )?)
1265 } else {
1266 None
1267 };
1268
1269 Ok(QuantizationComparison {
1270 scalar: scalar_results,
1271 product: pq_results,
1272 dataset_size: vectors.len(),
1273 dimension: dim,
1274 })
1275 }
1276}
1277
1278#[derive(Debug, Clone, Serialize, Deserialize)]
1280pub struct QuantizationComparison {
1281 pub scalar: QuantizationBenchmark,
1283 pub product: Option<QuantizationBenchmark>,
1285 pub dataset_size: usize,
1287 pub dimension: usize,
1289}
1290
1291impl QuantizationComparison {
1292 pub fn summary(&self) -> String {
1294 let mut result = format!(
1295 "Dataset: {} vectors, {} dimensions\n\nScalar Quantization:\n {}\n",
1296 self.dataset_size,
1297 self.dimension,
1298 self.scalar.summary()
1299 );
1300
1301 if let Some(ref pq) = self.product {
1302 result.push_str(&format!("\nProduct Quantization:\n {}\n", pq.summary()));
1303 }
1304
1305 result
1306 }
1307
1308 pub fn best_method_for_k(&self, k: usize) -> (&str, f32) {
1310 let scalar_recall = self
1311 .scalar
1312 .recall_at_k
1313 .iter()
1314 .find(|(kv, _)| *kv == k)
1315 .map(|(_, r)| *r)
1316 .unwrap_or(0.0);
1317
1318 if let Some(ref pq) = self.product {
1319 let pq_recall = pq
1320 .recall_at_k
1321 .iter()
1322 .find(|(kv, _)| *kv == k)
1323 .map(|(_, r)| *r)
1324 .unwrap_or(0.0);
1325
1326 if pq_recall > scalar_recall {
1327 return ("ProductQuantization", pq_recall);
1328 }
1329 }
1330
1331 ("ScalarQuantization", scalar_recall)
1332 }
1333}
1334
1335#[cfg(test)]
1336mod tests {
1337 use super::*;
1338
1339 #[test]
1340 fn test_scalar_quantizer_uint8() {
1341 let mut quantizer = ScalarQuantizer::uint8(4);
1342
1343 let vectors = vec![
1345 vec![0.0, 0.5, 1.0, -0.5],
1346 vec![1.0, 0.0, 0.5, 0.5],
1347 vec![0.5, 0.5, 0.0, 0.0],
1348 ];
1349
1350 quantizer.train(&vectors).unwrap();
1351 assert!(quantizer.is_trained());
1352
1353 let original = vec![0.5, 0.25, 0.75, 0.0];
1355 let quantized = quantizer.quantize(&original).unwrap();
1356 let restored = quantizer.dequantize(&quantized).unwrap();
1357
1358 for (o, r) in original.iter().zip(restored.iter()) {
1360 assert!((o - r).abs() < 0.05, "Expected {} ~= {}", o, r);
1361 }
1362
1363 assert_eq!(quantizer.compression_ratio(), 4.0);
1364 }
1365
1366 #[test]
1367 fn test_scalar_quantizer_int8() {
1368 let mut quantizer = ScalarQuantizer::int8(4);
1369
1370 let vectors = vec![vec![-1.0, 0.0, 1.0, -0.5], vec![1.0, -1.0, 0.5, 0.5]];
1371
1372 quantizer.train(&vectors).unwrap();
1373
1374 let original = vec![0.0, -0.5, 0.5, 0.25];
1375 let quantized = quantizer.quantize(&original).unwrap();
1376 let restored = quantizer.dequantize(&quantized).unwrap();
1377
1378 for (o, r) in original.iter().zip(restored.iter()) {
1379 assert!((o - r).abs() < 0.1, "Expected {} ~= {}", o, r);
1380 }
1381 }
1382
1383 #[test]
1384 fn test_scalar_distance() {
1385 let mut quantizer = ScalarQuantizer::uint8(4);
1386
1387 let vectors = vec![vec![0.0, 0.0, 0.0, 0.0], vec![1.0, 1.0, 1.0, 1.0]];
1388
1389 quantizer.train(&vectors).unwrap();
1390
1391 let a = quantizer.quantize(&[0.0, 0.0, 0.0, 0.0]).unwrap();
1392 let b = quantizer.quantize(&[1.0, 1.0, 1.0, 1.0]).unwrap();
1393
1394 let dist = quantizer.distance_l2_quantized(&a, &b).unwrap();
1395 assert!(dist > 0.0, "Distance should be positive");
1396 }
1397
1398 #[test]
1399 fn test_product_quantizer() {
1400 let mut pq = ProductQuantizer::new(8, 2, 4).unwrap(); let vectors: Vec<Vec<f32>> = (0..100)
1404 .map(|i| (0..8).map(|j| i as f32 * 0.01 + j as f32 * 0.1).collect())
1405 .collect();
1406
1407 pq.train(&vectors, 10).unwrap();
1408 assert!(pq.is_trained());
1409
1410 let original = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
1412 let code = pq.quantize(&original).unwrap();
1413 let restored = pq.dequantize(&code).unwrap();
1414
1415 assert_eq!(code.codes.len(), 2);
1417 assert_eq!(restored.len(), 8);
1418
1419 assert_eq!(pq.compression_ratio(), 16.0);
1421 }
1422
1423 #[test]
1424 fn test_pq_distance_table() {
1425 let mut pq = ProductQuantizer::new(8, 2, 4).unwrap();
1426
1427 let vectors: Vec<Vec<f32>> = (0..50)
1428 .map(|i| (0..8).map(|j| i as f32 * 0.02 + j as f32 * 0.1).collect())
1429 .collect();
1430
1431 pq.train(&vectors, 5).unwrap();
1432
1433 let query = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
1434 let code = pq.quantize(&vectors[0]).unwrap();
1435
1436 let direct_dist = pq.asymmetric_distance(&query, &code).unwrap();
1438 let table = pq.compute_distance_table(&query).unwrap();
1439 let table_dist = pq.distance_from_table(&table, &code);
1440
1441 assert!((direct_dist - table_dist).abs() < 1e-5);
1443 }
1444
1445 #[test]
1446 fn test_quantization_benchmarker() {
1447 let dim = 8;
1449 let n_vectors = 50;
1450 let n_queries = 5;
1451
1452 let vectors: Vec<Vec<f32>> = (0..n_vectors)
1453 .map(|i| (0..dim).map(|j| (i as f32 + j as f32) * 0.1).collect())
1454 .collect();
1455
1456 let queries: Vec<Vec<f32>> = (0..n_queries)
1457 .map(|i| {
1458 (0..dim)
1459 .map(|j| (i as f32 * 2.0 + j as f32) * 0.1)
1460 .collect()
1461 })
1462 .collect();
1463
1464 let gt = QuantizationBenchmarker::compute_ground_truth(&vectors, &queries, 10);
1466 assert_eq!(gt.len(), n_queries);
1467 assert_eq!(gt[0].len(), 10);
1468
1469 let mut sq = ScalarQuantizer::uint8(dim);
1471 sq.train(&vectors).unwrap();
1472
1473 let sq_benchmark =
1474 QuantizationBenchmarker::benchmark_scalar(&sq, &vectors, &queries, >, &[1, 5, 10])
1475 .unwrap();
1476
1477 assert_eq!(sq_benchmark.recall_at_k.len(), 3);
1478 assert!(sq_benchmark.compression_ratio > 1.0);
1479 assert!(sq_benchmark.memory_savings > 0);
1480 }
1481
1482 #[test]
1483 fn test_quantization_comparison() {
1484 let dim = 8;
1485 let n_vectors = 100;
1486 let n_queries = 10;
1487
1488 let vectors: Vec<Vec<f32>> = (0..n_vectors)
1489 .map(|i| (0..dim).map(|j| (i as f32 + j as f32) * 0.05).collect())
1490 .collect();
1491
1492 let queries: Vec<Vec<f32>> = (0..n_queries)
1493 .map(|i| {
1494 (0..dim)
1495 .map(|j| (i as f32 * 3.0 + j as f32) * 0.05)
1496 .collect()
1497 })
1498 .collect();
1499
1500 let comparison =
1501 QuantizationBenchmarker::compare_methods(&vectors, &queries, &[1, 5, 10]).unwrap();
1502
1503 assert_eq!(comparison.dataset_size, n_vectors);
1504 assert_eq!(comparison.dimension, dim);
1505 assert!(comparison.product.is_some()); let (method, recall) = comparison.best_method_for_k(10);
1509 assert!(!method.is_empty());
1510 assert!((0.0..=1.0).contains(&recall));
1511 }
1512
1513 #[test]
1514 fn test_opq_basic() {
1515 let mut opq = OptimizedProductQuantizer::new(8, 2, 4).unwrap();
1517
1518 let vectors: Vec<Vec<f32>> = (0..100)
1519 .map(|i| (0..8).map(|j| i as f32 * 0.01 + j as f32 * 0.1).collect())
1520 .collect();
1521
1522 opq.train(&vectors, 10, 5).unwrap();
1524 assert!(opq.is_trained());
1525
1526 let original = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
1528 let code = opq.quantize(&original).unwrap();
1529 let restored = opq.dequantize(&code).unwrap();
1530
1531 assert_eq!(code.codes.len(), 2);
1532 assert_eq!(restored.len(), 8);
1533
1534 assert_eq!(opq.compression_ratio(), 16.0);
1536
1537 let query = vec![0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85];
1539 let dist = opq.asymmetric_distance(&query, &code).unwrap();
1540 assert!(dist >= 0.0);
1541 }
1542
1543 #[test]
1544 fn test_opq_distance_table() {
1545 let mut opq = OptimizedProductQuantizer::new(8, 2, 4).unwrap();
1546
1547 let vectors: Vec<Vec<f32>> = (0..50)
1548 .map(|i| (0..8).map(|j| i as f32 * 0.02 + j as f32 * 0.1).collect())
1549 .collect();
1550
1551 opq.train(&vectors, 5, 3).unwrap();
1552
1553 let query = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
1554 let code = opq.quantize(&vectors[0]).unwrap();
1555
1556 let direct_dist = opq.asymmetric_distance(&query, &code).unwrap();
1558 let table = opq.compute_distance_table(&query).unwrap();
1559 let table_dist = opq.distance_from_table(&table, &code);
1560
1561 assert!((direct_dist - table_dist).abs() < 1e-4);
1563 }
1564
1565 #[test]
1566 fn test_opq_serialization() {
1567 let mut opq = OptimizedProductQuantizer::new(8, 2, 4).unwrap();
1568
1569 let vectors: Vec<Vec<f32>> = (0..50)
1570 .map(|i| (0..8).map(|j| i as f32 * 0.02 + j as f32 * 0.1).collect())
1571 .collect();
1572
1573 opq.train(&vectors, 5, 3).unwrap();
1574
1575 let serialized = oxicode::serde::encode_to_vec(&opq, oxicode::config::standard()).unwrap();
1577
1578 let deserialized: OptimizedProductQuantizer =
1580 oxicode::serde::decode_owned_from_slice(&serialized, oxicode::config::standard())
1581 .map(|(v, _)| v)
1582 .unwrap();
1583
1584 assert!(deserialized.is_trained());
1585 assert_eq!(deserialized.compression_ratio(), opq.compression_ratio());
1586
1587 let original = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
1589 let code1 = opq.quantize(&original).unwrap();
1590 let code2 = deserialized.quantize(&original).unwrap();
1591
1592 assert_eq!(code1.codes, code2.codes);
1594 }
1595}