1use crate::pq::{ProductQuantizer, PQConfig};
31use crate::sq::{F16Quantizer, Int8Quantizer, VectorQuantizer};
32use crate::{beam_search, BeamSearchConfig, GraphIndex, DiskANN, DiskAnnError, DiskAnnParams};
33use anndists::prelude::Distance;
34use rayon::prelude::*;
35use serde::{Deserialize, Serialize};
36use std::fs::File;
37use std::io::{Read, Write};
38
39#[inline]
41pub(crate) fn quantized_distance_from_codes(
42 query: &[f32],
43 idx: usize,
44 codes: &[u8],
45 code_size: usize,
46 quantizer: &QuantizerState,
47 pq_table: Option<&[f32]>,
48) -> f32 {
49 let code_start = idx * code_size;
50 let code = &codes[code_start..code_start + code_size];
51 match quantizer {
52 QuantizerState::PQ(pq) => {
53 if let Some(table) = pq_table {
54 pq.distance_with_table(table, code)
55 } else {
56 pq.asymmetric_distance(query, code)
57 }
58 }
59 QuantizerState::F16(f16q) => f16q.asymmetric_distance(query, code),
60 QuantizerState::Int8(int8q) => int8q.asymmetric_distance(query, code),
61 }
62}
63
64pub(crate) fn quantized_search(
69 graph: &dyn GraphIndex,
70 codes: &[u8],
71 code_size: usize,
72 quantizer: &QuantizerState,
73 start_ids: &[u32],
74 query: &[f32],
75 k: usize,
76 beam_width: usize,
77 rerank_size: usize,
78 filter_fn: impl Fn(u32) -> bool,
79 config: BeamSearchConfig,
80) -> Vec<(u32, f32)> {
81 let pq_table: Option<Vec<f32>> = match quantizer {
83 QuantizerState::PQ(pq) => Some(pq.create_distance_table(query)),
84 _ => None,
85 };
86
87 let search_k = if rerank_size > 0 { rerank_size.max(k) } else { k };
88
89 let mut results = beam_search(
90 start_ids,
91 beam_width,
92 search_k,
93 |id| quantized_distance_from_codes(query, id as usize, codes, code_size, quantizer, pq_table.as_deref()),
94 |id| graph.get_neighbors(id),
95 &filter_fn,
96 config,
97 );
98
99 if rerank_size > 0 {
101 results = results
102 .iter()
103 .map(|&(id, _)| {
104 let exact_dist = graph.distance_to(query, id);
105 (id, exact_dist)
106 })
107 .collect();
108 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
109 results.truncate(k);
110 }
111
112 results
113}
114
115const MAGIC: u32 = 0x51414E4E;
117const VERSION: u32 = 1;
119
120#[derive(Clone, Copy, Debug)]
122pub struct QuantizedConfig {
123 pub rerank_size: usize,
126}
127
128impl Default for QuantizedConfig {
129 fn default() -> Self {
130 Self { rerank_size: 0 }
131 }
132}
133
134#[derive(Serialize, Deserialize, Clone)]
139pub(crate) enum QuantizerState {
140 PQ(ProductQuantizer),
141 F16(F16Quantizer),
142 Int8(Int8Quantizer),
143}
144
145impl QuantizerState {
146 fn quantizer_type_id(&self) -> u8 {
147 match self {
148 QuantizerState::PQ(_) => 0,
149 QuantizerState::F16(_) => 1,
150 QuantizerState::Int8(_) => 2,
151 }
152 }
153
154}
155
156pub struct QuantizedDiskANN<D>
162where
163 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
164{
165 base: DiskANN<D>,
167 codes: Vec<u8>,
170 code_size: usize,
172 quantizer: QuantizerState,
174 rerank_size: usize,
176}
177
178impl<D> QuantizedDiskANN<D>
183where
184 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
185{
186 pub fn from_pq(
190 base: DiskANN<D>,
191 pq: ProductQuantizer,
192 config: QuantizedConfig,
193 ) -> Self {
194 let n = base.num_vectors;
195 let code_size = pq.stats().code_size_bytes;
196 let codes = encode_all_pq(&base, &pq, n);
197 Self {
198 base,
199 codes,
200 code_size,
201 quantizer: QuantizerState::PQ(pq),
202 rerank_size: config.rerank_size,
203 }
204 }
205
206 pub fn from_f16(base: DiskANN<D>, config: QuantizedConfig) -> Self {
208 let dim = base.dim;
209 let n = base.num_vectors;
210 let f16q = F16Quantizer::new(dim);
211 let code_size = dim * 2;
212 let codes = encode_all_generic(&base, &f16q, n, code_size);
213 Self {
214 base,
215 codes,
216 code_size,
217 quantizer: QuantizerState::F16(f16q),
218 rerank_size: config.rerank_size,
219 }
220 }
221
222 pub fn from_int8(
224 base: DiskANN<D>,
225 int8q: Int8Quantizer,
226 config: QuantizedConfig,
227 ) -> Self {
228 let n = base.num_vectors;
229 let code_size = int8q.dim();
230 let codes = encode_all_generic(&base, &int8q, n, code_size);
231 Self {
232 base,
233 codes,
234 code_size,
235 quantizer: QuantizerState::Int8(int8q),
236 rerank_size: config.rerank_size,
237 }
238 }
239
240 pub fn num_vectors(&self) -> usize {
246 self.base.num_vectors
247 }
248
249 pub fn dim(&self) -> usize {
251 self.base.dim
252 }
253
254 pub fn base(&self) -> &DiskANN<D> {
256 &self.base
257 }
258
259 pub fn get_vector(&self, idx: usize) -> Vec<f32> {
261 self.base.get_vector(idx)
262 }
263
264 pub fn search_with_dists(
274 &self,
275 query: &[f32],
276 k: usize,
277 beam_width: usize,
278 ) -> Vec<(u32, f32)> {
279 assert_eq!(
280 query.len(),
281 self.base.dim,
282 "Query dim {} != index dim {}",
283 query.len(),
284 self.base.dim
285 );
286
287 quantized_search(
288 &self.base,
289 &self.codes,
290 self.code_size,
291 &self.quantizer,
292 &[self.base.medoid_id],
293 query,
294 k,
295 beam_width,
296 self.rerank_size,
297 |_| true,
298 BeamSearchConfig::default(),
299 )
300 }
301
302 pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
304 self.search_with_dists(query, k, beam_width)
305 .into_iter()
306 .map(|(id, _)| id)
307 .collect()
308 }
309
310 pub fn search_batch(
312 &self,
313 queries: &[Vec<f32>],
314 k: usize,
315 beam_width: usize,
316 ) -> Vec<Vec<u32>> {
317 queries
318 .par_iter()
319 .map(|q| self.search(q, k, beam_width))
320 .collect()
321 }
322
323 pub fn search_filtered_with_dists(
333 &self,
334 query: &[f32],
335 k: usize,
336 beam_width: usize,
337 labels: &[Vec<u64>],
338 filter: &crate::Filter,
339 ) -> Vec<(u32, f32)> {
340 assert_eq!(
341 query.len(),
342 self.base.dim,
343 "Query dim {} != index dim {}",
344 query.len(),
345 self.base.dim
346 );
347 assert_eq!(
348 labels.len(),
349 self.base.num_vectors,
350 "Labels count {} != index vectors {}",
351 labels.len(),
352 self.base.num_vectors
353 );
354
355 if matches!(filter, crate::Filter::None) {
357 return self.search_with_dists(query, k, beam_width);
358 }
359
360 let expanded_beam = (beam_width * 4).max(k * 10);
361
362 quantized_search(
363 &self.base,
364 &self.codes,
365 self.code_size,
366 &self.quantizer,
367 &[self.base.medoid_id],
368 query,
369 k,
370 beam_width,
371 self.rerank_size,
372 |id| filter.matches(&labels[id as usize]),
373 BeamSearchConfig {
374 expanded_beam: Some(expanded_beam),
375 max_iterations: Some(expanded_beam * 2),
376 early_term_factor: Some(1.5),
377 },
378 )
379 }
380
381 pub fn search_filtered(
383 &self,
384 query: &[f32],
385 k: usize,
386 beam_width: usize,
387 labels: &[Vec<u64>],
388 filter: &crate::Filter,
389 ) -> Vec<u32> {
390 self.search_filtered_with_dists(query, k, beam_width, labels, filter)
391 .into_iter()
392 .map(|(id, _)| id)
393 .collect()
394 }
395
396 pub fn save_quantized(&self, path: &str) -> Result<(), DiskAnnError> {
417 let bytes = self.quantized_to_bytes();
418 let mut file = File::create(path)?;
419 file.write_all(&bytes)?;
420 file.sync_all()?;
421 Ok(())
422 }
423
424 pub fn open(
426 base_path: &str,
427 quantized_path: &str,
428 dist: D,
429 config: QuantizedConfig,
430 ) -> Result<Self, DiskAnnError> {
431 let base = DiskANN::open_index_with(base_path, dist)?;
432 let mut file = File::open(quantized_path)?;
433 let mut bytes = Vec::new();
434 file.read_to_end(&mut bytes)?;
435 Self::from_quantized_bytes(&base, &bytes, config)
436 }
437
438 pub fn to_bytes(&self) -> Vec<u8> {
440 let base_bytes = self.base.to_bytes();
442 let quantized_bytes = self.quantized_to_bytes();
443 let mut out = Vec::with_capacity(8 + base_bytes.len() + quantized_bytes.len());
444 out.extend_from_slice(&(base_bytes.len() as u64).to_le_bytes());
445 out.extend_from_slice(&base_bytes);
446 out.extend_from_slice(&quantized_bytes);
447 out
448 }
449
450 pub fn from_bytes(
452 bytes: &[u8],
453 dist: D,
454 config: QuantizedConfig,
455 ) -> Result<Self, DiskAnnError> {
456 if bytes.len() < 8 {
457 return Err(DiskAnnError::IndexError("Buffer too small".into()));
458 }
459 let base_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
460 if bytes.len() < 8 + base_len {
461 return Err(DiskAnnError::IndexError("Buffer too small for base index".into()));
462 }
463 let base_bytes = bytes[8..8 + base_len].to_vec();
464 let quantized_bytes = &bytes[8 + base_len..];
465
466 let base = DiskANN::from_bytes(base_bytes, dist)?;
467 Self::from_quantized_bytes(&base, quantized_bytes, config)
468 }
469
470 fn quantized_to_bytes(&self) -> Vec<u8> {
475 let quantizer_data = bincode::serialize(&self.quantizer).unwrap();
476 let num_vectors = self.base.num_vectors as u64;
477 let code_size = self.code_size as u64;
478 let quantizer_data_len = quantizer_data.len() as u64;
479
480 let header_size = 4 + 4 + 1 + 8 + 8 + 8; let total = header_size + quantizer_data.len() + self.codes.len();
482 let mut out = Vec::with_capacity(total);
483
484 out.extend_from_slice(&MAGIC.to_le_bytes());
485 out.extend_from_slice(&VERSION.to_le_bytes());
486 out.push(self.quantizer.quantizer_type_id());
487 out.extend_from_slice(&num_vectors.to_le_bytes());
488 out.extend_from_slice(&code_size.to_le_bytes());
489 out.extend_from_slice(&quantizer_data_len.to_le_bytes());
490 out.extend_from_slice(&quantizer_data);
491 out.extend_from_slice(&self.codes);
492 out
493 }
494
495 fn from_quantized_bytes(
496 base: &DiskANN<D>,
497 bytes: &[u8],
498 config: QuantizedConfig,
499 ) -> Result<Self, DiskAnnError> {
500 let header_size = 4 + 4 + 1 + 8 + 8 + 8;
501 if bytes.len() < header_size {
502 return Err(DiskAnnError::IndexError("Quantized data too small".into()));
503 }
504
505 let mut pos = 0;
506
507 let magic = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap());
508 pos += 4;
509 if magic != MAGIC {
510 return Err(DiskAnnError::IndexError(format!(
511 "Invalid magic: expected 0x{:08X}, got 0x{:08X}",
512 MAGIC, magic
513 )));
514 }
515
516 let version = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap());
517 pos += 4;
518 if version != VERSION {
519 return Err(DiskAnnError::IndexError(format!(
520 "Unsupported version: {}",
521 version
522 )));
523 }
524
525 let _quantizer_type = bytes[pos];
526 pos += 1;
527
528 let num_vectors = u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap()) as usize;
529 pos += 8;
530
531 let code_size = u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap()) as usize;
532 pos += 8;
533
534 let quantizer_data_len =
535 u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap()) as usize;
536 pos += 8;
537
538 if bytes.len() < pos + quantizer_data_len {
539 return Err(DiskAnnError::IndexError("Truncated quantizer data".into()));
540 }
541 let quantizer: QuantizerState =
542 bincode::deserialize(&bytes[pos..pos + quantizer_data_len])?;
543 pos += quantizer_data_len;
544
545 let codes_len = num_vectors * code_size;
546 if bytes.len() < pos + codes_len {
547 return Err(DiskAnnError::IndexError("Truncated codes data".into()));
548 }
549 let codes = bytes[pos..pos + codes_len].to_vec();
550
551 let base_bytes = base.to_bytes();
553 let owned_base = DiskANN::from_bytes(base_bytes, base.dist)?;
554
555 Ok(Self {
556 base: owned_base,
557 codes,
558 code_size,
559 quantizer,
560 rerank_size: config.rerank_size,
561 })
562 }
563}
564
565impl<D> QuantizedDiskANN<D>
570where
571 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
572{
573 pub fn build_pq(
575 vectors: &[Vec<f32>],
576 dist: D,
577 file_path: &str,
578 ann_params: DiskAnnParams,
579 pq_config: PQConfig,
580 config: QuantizedConfig,
581 ) -> Result<Self, DiskAnnError> {
582 let base = DiskANN::build_index_with_params(vectors, dist, file_path, ann_params)?;
583 let pq = ProductQuantizer::train(vectors, pq_config)?;
584 Ok(Self::from_pq(base, pq, config))
585 }
586
587 pub fn build_f16(
589 vectors: &[Vec<f32>],
590 dist: D,
591 file_path: &str,
592 ann_params: DiskAnnParams,
593 config: QuantizedConfig,
594 ) -> Result<Self, DiskAnnError> {
595 let base = DiskANN::build_index_with_params(vectors, dist, file_path, ann_params)?;
596 Ok(Self::from_f16(base, config))
597 }
598
599 pub fn build_int8(
601 vectors: &[Vec<f32>],
602 dist: D,
603 file_path: &str,
604 ann_params: DiskAnnParams,
605 config: QuantizedConfig,
606 ) -> Result<Self, DiskAnnError> {
607 let base = DiskANN::build_index_with_params(vectors, dist, file_path, ann_params)?;
608 let int8q = Int8Quantizer::train(vectors)?;
609 Ok(Self::from_int8(base, int8q, config))
610 }
611}
612
613fn encode_all_pq<D>(
618 base: &DiskANN<D>,
619 pq: &ProductQuantizer,
620 n: usize,
621) -> Vec<u8>
622where
623 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
624{
625 let code_size = pq.stats().code_size_bytes;
626 let vectors: Vec<Vec<f32>> = (0..n).map(|i| base.get_vector(i)).collect();
627 let encoded: Vec<Vec<u8>> = vectors.par_iter().map(|v| pq.encode(v)).collect();
628 let mut flat = Vec::with_capacity(n * code_size);
629 for code in &encoded {
630 flat.extend_from_slice(code);
631 }
632 flat
633}
634
635fn encode_all_generic<D, Q>(
636 base: &DiskANN<D>,
637 quantizer: &Q,
638 n: usize,
639 code_size: usize,
640) -> Vec<u8>
641where
642 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
643 Q: VectorQuantizer,
644{
645 let vectors: Vec<Vec<f32>> = (0..n).map(|i| base.get_vector(i)).collect();
646 let encoded: Vec<Vec<u8>> = vectors.par_iter().map(|v| quantizer.encode(v)).collect();
647 let mut flat = Vec::with_capacity(n * code_size);
648 for code in &encoded {
649 flat.extend_from_slice(code);
650 }
651 flat
652}
653
654#[cfg(test)]
659mod tests {
660 use super::*;
661 use anndists::dist::DistL2;
662 use rand::prelude::*;
663 use rand::SeedableRng;
664 use std::collections::HashSet;
665
666 fn random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
667 let mut rng = StdRng::seed_from_u64(seed);
668 (0..n)
669 .map(|_| (0..dim).map(|_| rng.r#gen::<f32>()).collect())
670 .collect()
671 }
672
673 fn brute_force_knn(vectors: &[Vec<f32>], query: &[f32], k: usize) -> Vec<u32> {
674 let mut dists: Vec<(u32, f32)> = vectors
675 .iter()
676 .enumerate()
677 .map(|(i, v)| {
678 let d: f32 = query.iter().zip(v).map(|(a, b)| (a - b) * (a - b)).sum();
679 (i as u32, d)
680 })
681 .collect();
682 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
683 dists.iter().take(k).map(|(i, _)| *i).collect()
684 }
685
686 fn recall_at_k(retrieved: &[u32], ground_truth: &[u32]) -> f32 {
687 let gt_set: HashSet<u32> = ground_truth.iter().copied().collect();
688 let hits = retrieved.iter().filter(|id| gt_set.contains(id)).count();
689 hits as f32 / ground_truth.len() as f32
690 }
691
692 #[test]
694 fn test_quantized_pq_basic() {
695 let path = "test_quantized_pq_basic.db";
696 let _ = std::fs::remove_file(path);
697
698 let dim = 32;
699 let vectors = random_vectors(200, dim, 42);
700 let pq_config = PQConfig {
701 num_subspaces: 4,
702 num_centroids: 64,
703 kmeans_iterations: 10,
704 training_sample_size: 0,
705 };
706 let config = QuantizedConfig { rerank_size: 0 };
707 let ann_params = DiskAnnParams {
708 max_degree: 32,
709 build_beam_width: 128,
710 alpha: 1.2,
711 };
712
713 let index = QuantizedDiskANN::<DistL2>::build_pq(
714 &vectors, DistL2 {}, path, ann_params, pq_config, config,
715 )
716 .unwrap();
717
718 let query = &vectors[0];
719 let results = index.search(query, 10, 64);
720 assert_eq!(results.len(), 10);
721
722 let gt = brute_force_knn(&vectors, query, 10);
723 let recall = recall_at_k(&results, >);
724 assert!(
725 recall >= 0.3,
726 "PQ recall@10 too low: {recall} (expected >= 0.3)"
727 );
728
729 let _ = std::fs::remove_file(path);
730 }
731
732 #[test]
734 fn test_quantized_f16_basic() {
735 let path = "test_quantized_f16_basic.db";
736 let _ = std::fs::remove_file(path);
737
738 let dim = 32;
739 let vectors = random_vectors(200, dim, 43);
740 let config = QuantizedConfig { rerank_size: 0 };
741 let ann_params = DiskAnnParams {
742 max_degree: 32,
743 build_beam_width: 128,
744 alpha: 1.2,
745 };
746
747 let index = QuantizedDiskANN::<DistL2>::build_f16(
748 &vectors, DistL2 {}, path, ann_params, config,
749 )
750 .unwrap();
751
752 let query = &vectors[0];
753 let results = index.search(query, 10, 64);
754 assert_eq!(results.len(), 10);
755
756 let gt = brute_force_knn(&vectors, query, 10);
757 let recall = recall_at_k(&results, >);
758 assert!(
759 recall >= 0.7,
760 "F16 recall@10 too low: {recall} (expected >= 0.7)"
761 );
762
763 let _ = std::fs::remove_file(path);
764 }
765
766 #[test]
768 fn test_quantized_int8_basic() {
769 let path = "test_quantized_int8_basic.db";
770 let _ = std::fs::remove_file(path);
771
772 let dim = 32;
773 let vectors = random_vectors(200, dim, 44);
774 let config = QuantizedConfig { rerank_size: 0 };
775 let ann_params = DiskAnnParams {
776 max_degree: 32,
777 build_beam_width: 128,
778 alpha: 1.2,
779 };
780
781 let index = QuantizedDiskANN::<DistL2>::build_int8(
782 &vectors, DistL2 {}, path, ann_params, config,
783 )
784 .unwrap();
785
786 let query = &vectors[0];
787 let results = index.search(query, 10, 64);
788 assert_eq!(results.len(), 10);
789
790 let gt = brute_force_knn(&vectors, query, 10);
791 let recall = recall_at_k(&results, >);
792 assert!(
793 recall >= 0.7,
794 "Int8 recall@10 too low: {recall} (expected >= 0.7)"
795 );
796
797 let _ = std::fs::remove_file(path);
798 }
799
800 #[test]
802 fn test_reranking_improves_recall() {
803 let path_no_rr = "test_quantized_no_rerank.db";
804 let path_rr = "test_quantized_rerank.db";
805 let _ = std::fs::remove_file(path_no_rr);
806 let _ = std::fs::remove_file(path_rr);
807
808 let dim = 32;
809 let vectors = random_vectors(200, dim, 45);
810 let pq_config = PQConfig {
811 num_subspaces: 4,
812 num_centroids: 64,
813 kmeans_iterations: 10,
814 training_sample_size: 0,
815 };
816 let ann_params = DiskAnnParams {
817 max_degree: 32,
818 build_beam_width: 128,
819 alpha: 1.2,
820 };
821
822 let index_no_rr = QuantizedDiskANN::<DistL2>::build_pq(
823 &vectors,
824 DistL2 {},
825 path_no_rr,
826 ann_params,
827 pq_config,
828 QuantizedConfig { rerank_size: 0 },
829 )
830 .unwrap();
831
832 let index_rr = QuantizedDiskANN::<DistL2>::build_pq(
833 &vectors,
834 DistL2 {},
835 path_rr,
836 ann_params,
837 pq_config,
838 QuantizedConfig { rerank_size: 50 },
839 )
840 .unwrap();
841
842 let num_queries = 20;
844 let mut total_no_rr = 0.0f32;
845 let mut total_rr = 0.0f32;
846
847 for i in 0..num_queries {
848 let query = &vectors[i];
849 let gt = brute_force_knn(&vectors, query, 10);
850 let res_no_rr = index_no_rr.search(query, 10, 64);
851 let res_rr = index_rr.search(query, 10, 64);
852 total_no_rr += recall_at_k(&res_no_rr, >);
853 total_rr += recall_at_k(&res_rr, >);
854 }
855
856 let avg_no_rr = total_no_rr / num_queries as f32;
857 let avg_rr = total_rr / num_queries as f32;
858
859 assert!(
861 avg_rr >= avg_no_rr - 0.05,
862 "Re-ranking should not significantly degrade recall: no_rr={avg_no_rr}, rr={avg_rr}"
863 );
864
865 let _ = std::fs::remove_file(path_no_rr);
866 let _ = std::fs::remove_file(path_rr);
867 }
868
869 #[test]
871 fn test_save_load_roundtrip() {
872 let base_path = "test_quantized_save_base.db";
873 let sidecar_path = "test_quantized_save_sidecar.qann";
874 let _ = std::fs::remove_file(base_path);
875 let _ = std::fs::remove_file(sidecar_path);
876
877 let dim = 32;
878 let vectors = random_vectors(100, dim, 46);
879 let ann_params = DiskAnnParams {
880 max_degree: 32,
881 build_beam_width: 64,
882 alpha: 1.2,
883 };
884 let config = QuantizedConfig { rerank_size: 10 };
885
886 let index = QuantizedDiskANN::<DistL2>::build_f16(
887 &vectors, DistL2 {}, base_path, ann_params, config,
888 )
889 .unwrap();
890
891 let query = &vectors[0];
893 let res_before = index.search(query, 5, 32);
894
895 index.save_quantized(sidecar_path).unwrap();
897
898 let loaded = QuantizedDiskANN::<DistL2>::open(
900 base_path,
901 sidecar_path,
902 DistL2 {},
903 config,
904 )
905 .unwrap();
906
907 assert_eq!(loaded.num_vectors(), index.num_vectors());
908 assert_eq!(loaded.dim(), index.dim());
909
910 let res_after = loaded.search(query, 5, 32);
911 assert_eq!(res_before, res_after);
912
913 let _ = std::fs::remove_file(base_path);
914 let _ = std::fs::remove_file(sidecar_path);
915 }
916
917 #[test]
919 fn test_to_bytes_from_bytes() {
920 let path = "test_quantized_bytes_rt.db";
921 let _ = std::fs::remove_file(path);
922
923 let dim = 32;
924 let vectors = random_vectors(100, dim, 47);
925 let ann_params = DiskAnnParams {
926 max_degree: 32,
927 build_beam_width: 64,
928 alpha: 1.2,
929 };
930 let config = QuantizedConfig { rerank_size: 0 };
931
932 let index = QuantizedDiskANN::<DistL2>::build_int8(
933 &vectors, DistL2 {}, path, ann_params, config,
934 )
935 .unwrap();
936
937 let query = &vectors[0];
938 let res_before = index.search(query, 5, 32);
939
940 let bytes = index.to_bytes();
941 let loaded =
942 QuantizedDiskANN::<DistL2>::from_bytes(&bytes, DistL2 {}, config).unwrap();
943
944 assert_eq!(loaded.num_vectors(), index.num_vectors());
945 let res_after = loaded.search(query, 5, 32);
946 assert_eq!(res_before, res_after);
947
948 let _ = std::fs::remove_file(path);
949 }
950
951 #[test]
953 fn test_pq_distance_table_used() {
954 let path = "test_quantized_pq_table.db";
955 let _ = std::fs::remove_file(path);
956
957 let dim = 32;
958 let vectors = random_vectors(100, dim, 48);
959 let pq_config = PQConfig {
960 num_subspaces: 4,
961 num_centroids: 64,
962 kmeans_iterations: 10,
963 training_sample_size: 0,
964 };
965 let config = QuantizedConfig { rerank_size: 0 };
966 let ann_params = DiskAnnParams {
967 max_degree: 32,
968 build_beam_width: 64,
969 alpha: 1.2,
970 };
971
972 let index = QuantizedDiskANN::<DistL2>::build_pq(
973 &vectors, DistL2 {}, path, ann_params, pq_config, config,
974 )
975 .unwrap();
976
977 let query = &vectors[0];
978 let results = index.search_with_dists(query, 10, 32);
979
980 for (id, dist) in &results {
982 assert!(*dist >= 0.0, "Negative distance for id {id}: {dist}");
983 assert!(*dist < f32::MAX, "MAX distance for id {id}");
984 }
985
986 for pair in results.windows(2) {
988 assert!(
989 pair[0].1 <= pair[1].1 + 1e-6,
990 "Distances not sorted: {} > {}",
991 pair[0].1,
992 pair[1].1
993 );
994 }
995
996 let _ = std::fs::remove_file(path);
997 }
998
999 #[test]
1001 fn test_quantized_search_with_dists() {
1002 let path = "test_quantized_with_dists.db";
1003 let _ = std::fs::remove_file(path);
1004
1005 let dim = 32;
1006 let vectors = random_vectors(100, dim, 49);
1007 let config = QuantizedConfig { rerank_size: 20 };
1008 let ann_params = DiskAnnParams {
1009 max_degree: 32,
1010 build_beam_width: 64,
1011 alpha: 1.2,
1012 };
1013
1014 let index = QuantizedDiskANN::<DistL2>::build_f16(
1015 &vectors, DistL2 {}, path, ann_params, config,
1016 )
1017 .unwrap();
1018
1019 let query = &vectors[0];
1020 let results = index.search_with_dists(query, 5, 32);
1021 assert_eq!(results.len(), 5);
1022
1023 for (id, dist) in &results {
1025 let v = index.get_vector(*id as usize);
1026 let exact: f32 = query
1027 .iter()
1028 .zip(&v)
1029 .map(|(a, b)| (a - b) * (a - b))
1030 .sum::<f32>()
1031 .sqrt();
1032 assert!(
1033 (dist - exact).abs() < 1e-4,
1034 "Distance mismatch for id {id}: returned {dist}, exact {exact}"
1035 );
1036 }
1037
1038 for pair in results.windows(2) {
1040 assert!(pair[0].1 <= pair[1].1 + 1e-6);
1041 }
1042
1043 let _ = std::fs::remove_file(path);
1044 }
1045}