Skip to main content

diskann_rs/
quantized.rs

1//! # Integrated Quantized Search for DiskANN
2//!
3//! Fuses quantization into graph traversal so beam search uses compressed
4//! in-memory codes for distance computation instead of reading full f32
5//! vectors from disk. Optionally re-ranks top candidates with exact vectors.
6//!
7//! ## Architecture
8//!
9//! `QuantizedDiskANN<D>` wraps an existing `DiskANN<D>`, storing compressed
10//! codes in a flat `Vec<u8>` buffer for cache-friendly access. An enum
11//! `QuantizerState` dispatches between PQ, F16, and Int8 without dynamic
12//! dispatch overhead on the hot path.
13//!
14//! ## Example
15//!
16//! ```ignore
17//! use anndists::dist::DistL2;
18//! use diskann_rs::{QuantizedDiskANN, QuantizedConfig};
19//! use diskann_rs::pq::PQConfig;
20//!
21//! let vectors = vec![vec![0.0f32; 64]; 1000];
22//! let config = QuantizedConfig { rerank_size: 50 };
23//! let index = QuantizedDiskANN::<DistL2>::build_pq(
24//!     &vectors, DistL2{}, "index.db", Default::default(), PQConfig::default(), config,
25//! ).unwrap();
26//!
27//! let results = index.search(&vec![0.0; 64], 10, 64);
28//! ```
29
30use crate::pq::{ProductQuantizer, PQConfig};
31use crate::sq::{F16Quantizer, Int8Quantizer, VectorQuantizer};
32use crate::{beam_search, BeamSearchConfig, GraphIndex, DiskANN, DiskAnnError, DiskAnnParams};
33use anndists::prelude::Distance;
34use rayon::prelude::*;
35use serde::{Deserialize, Serialize};
36use std::fs::File;
37use std::io::{Read, Write};
38
39/// Compute quantized distance for a single candidate from flat code buffer.
40#[inline]
41pub(crate) fn quantized_distance_from_codes(
42    query: &[f32],
43    idx: usize,
44    codes: &[u8],
45    code_size: usize,
46    quantizer: &QuantizerState,
47    pq_table: Option<&[f32]>,
48) -> f32 {
49    let code_start = idx * code_size;
50    let code = &codes[code_start..code_start + code_size];
51    match quantizer {
52        QuantizerState::PQ(pq) => {
53            if let Some(table) = pq_table {
54                pq.distance_with_table(table, code)
55            } else {
56                pq.asymmetric_distance(query, code)
57            }
58        }
59        QuantizerState::F16(f16q) => f16q.asymmetric_distance(query, code),
60        QuantizerState::Int8(int8q) => int8q.asymmetric_distance(query, code),
61    }
62}
63
64/// Shared quantized search implementation usable with any `GraphIndex`.
65///
66/// Performs beam search using quantized distances, with optional re-ranking
67/// using exact distances from the graph and optional label-based filtering.
68pub(crate) fn quantized_search(
69    graph: &dyn GraphIndex,
70    codes: &[u8],
71    code_size: usize,
72    quantizer: &QuantizerState,
73    start_ids: &[u32],
74    query: &[f32],
75    k: usize,
76    beam_width: usize,
77    rerank_size: usize,
78    filter_fn: impl Fn(u32) -> bool,
79    config: BeamSearchConfig,
80) -> Vec<(u32, f32)> {
81    // Pre-query: build PQ distance table if applicable
82    let pq_table: Option<Vec<f32>> = match quantizer {
83        QuantizerState::PQ(pq) => Some(pq.create_distance_table(query)),
84        _ => None,
85    };
86
87    let search_k = if rerank_size > 0 { rerank_size.max(k) } else { k };
88
89    let mut results = beam_search(
90        start_ids,
91        beam_width,
92        search_k,
93        |id| quantized_distance_from_codes(query, id as usize, codes, code_size, quantizer, pq_table.as_deref()),
94        |id| graph.get_neighbors(id),
95        &filter_fn,
96        config,
97    );
98
99    // Re-ranking phase
100    if rerank_size > 0 {
101        results = results
102            .iter()
103            .map(|&(id, _)| {
104                let exact_dist = graph.distance_to(query, id);
105                (id, exact_dist)
106            })
107            .collect();
108        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
109        results.truncate(k);
110    }
111
112    results
113}
114
115/// Magic number for quantized sidecar files: "QANN"
116const MAGIC: u32 = 0x51414E4E;
117/// Current sidecar file format version
118const VERSION: u32 = 1;
119
120/// Configuration for quantized search.
121#[derive(Clone, Copy, Debug)]
122pub struct QuantizedConfig {
123    /// Number of candidates to re-rank with exact vectors after quantized search.
124    /// Set to 0 to disable re-ranking (faster, lower recall).
125    pub rerank_size: usize,
126}
127
128impl Default for QuantizedConfig {
129    fn default() -> Self {
130        Self { rerank_size: 0 }
131    }
132}
133
134/// Monomorphic quantizer state — avoids dynamic dispatch on the hot path.
135///
136/// PQ uses `create_distance_table()` once per query then `distance_with_table()`
137/// per candidate (O(M) lookups). F16/Int8 use `asymmetric_distance()` directly.
138#[derive(Serialize, Deserialize, Clone)]
139pub(crate) enum QuantizerState {
140    PQ(ProductQuantizer),
141    F16(F16Quantizer),
142    Int8(Int8Quantizer),
143}
144
145impl QuantizerState {
146    fn quantizer_type_id(&self) -> u8 {
147        match self {
148            QuantizerState::PQ(_) => 0,
149            QuantizerState::F16(_) => 1,
150            QuantizerState::Int8(_) => 2,
151        }
152    }
153
154}
155
156/// Quantized DiskANN index — wraps a `DiskANN<D>` with compressed in-memory codes.
157///
158/// During beam search, distance computations use the compressed codes in RAM
159/// instead of reading full f32 vectors from the backing store. After search,
160/// an optional re-ranking phase re-scores top candidates with exact vectors.
161pub struct QuantizedDiskANN<D>
162where
163    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
164{
165    /// The underlying DiskANN index (graph + full vectors on disk/mmap)
166    base: DiskANN<D>,
167    /// Compressed codes in RAM — contiguous flat buffer.
168    /// `codes[i * code_size .. (i+1) * code_size]` is the code for vector i.
169    codes: Vec<u8>,
170    /// Bytes per code
171    code_size: usize,
172    /// Quantizer state (PQ, F16, or Int8)
173    quantizer: QuantizerState,
174    /// Number of candidates to re-rank with exact vectors (0 = no re-ranking)
175    rerank_size: usize,
176}
177
178// ---------------------------------------------------------------------------
179// Construction from existing components
180// ---------------------------------------------------------------------------
181
182impl<D> QuantizedDiskANN<D>
183where
184    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
185{
186    /// Create from an existing index and a pre-trained ProductQuantizer.
187    ///
188    /// Encodes all vectors from `base` into the in-memory codes buffer.
189    pub fn from_pq(
190        base: DiskANN<D>,
191        pq: ProductQuantizer,
192        config: QuantizedConfig,
193    ) -> Self {
194        let n = base.num_vectors;
195        let code_size = pq.stats().code_size_bytes;
196        let codes = encode_all_pq(&base, &pq, n);
197        Self {
198            base,
199            codes,
200            code_size,
201            quantizer: QuantizerState::PQ(pq),
202            rerank_size: config.rerank_size,
203        }
204    }
205
206    /// Create from an existing index using F16 quantization.
207    pub fn from_f16(base: DiskANN<D>, config: QuantizedConfig) -> Self {
208        let dim = base.dim;
209        let n = base.num_vectors;
210        let f16q = F16Quantizer::new(dim);
211        let code_size = dim * 2;
212        let codes = encode_all_generic(&base, &f16q, n, code_size);
213        Self {
214            base,
215            codes,
216            code_size,
217            quantizer: QuantizerState::F16(f16q),
218            rerank_size: config.rerank_size,
219        }
220    }
221
222    /// Create from an existing index and a pre-trained Int8Quantizer.
223    pub fn from_int8(
224        base: DiskANN<D>,
225        int8q: Int8Quantizer,
226        config: QuantizedConfig,
227    ) -> Self {
228        let n = base.num_vectors;
229        let code_size = int8q.dim();
230        let codes = encode_all_generic(&base, &int8q, n, code_size);
231        Self {
232            base,
233            codes,
234            code_size,
235            quantizer: QuantizerState::Int8(int8q),
236            rerank_size: config.rerank_size,
237        }
238    }
239
240    // -----------------------------------------------------------------------
241    // Accessors
242    // -----------------------------------------------------------------------
243
244    /// Number of vectors in the index.
245    pub fn num_vectors(&self) -> usize {
246        self.base.num_vectors
247    }
248
249    /// Vector dimensionality.
250    pub fn dim(&self) -> usize {
251        self.base.dim
252    }
253
254    /// Reference to the underlying base index.
255    pub fn base(&self) -> &DiskANN<D> {
256        &self.base
257    }
258
259    /// Get a full-precision vector from the base index.
260    pub fn get_vector(&self, idx: usize) -> Vec<f32> {
261        self.base.get_vector(idx)
262    }
263
264    // -----------------------------------------------------------------------
265    // Search
266    // -----------------------------------------------------------------------
267
268    /// Search with quantized beam search, returning `(id, distance)` pairs.
269    ///
270    /// Distances in the returned results are:
271    /// - If `rerank_size > 0`: exact distances from the base index
272    /// - Otherwise: approximate distances from the quantizer
273    pub fn search_with_dists(
274        &self,
275        query: &[f32],
276        k: usize,
277        beam_width: usize,
278    ) -> Vec<(u32, f32)> {
279        assert_eq!(
280            query.len(),
281            self.base.dim,
282            "Query dim {} != index dim {}",
283            query.len(),
284            self.base.dim
285        );
286
287        quantized_search(
288            &self.base,
289            &self.codes,
290            self.code_size,
291            &self.quantizer,
292            &[self.base.medoid_id],
293            query,
294            k,
295            beam_width,
296            self.rerank_size,
297            |_| true,
298            BeamSearchConfig::default(),
299        )
300    }
301
302    /// Search returning only neighbor IDs.
303    pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
304        self.search_with_dists(query, k, beam_width)
305            .into_iter()
306            .map(|(id, _)| id)
307            .collect()
308    }
309
310    /// Batch search (parallel over queries).
311    pub fn search_batch(
312        &self,
313        queries: &[Vec<f32>],
314        k: usize,
315        beam_width: usize,
316    ) -> Vec<Vec<u32>> {
317        queries
318            .par_iter()
319            .map(|q| self.search(q, k, beam_width))
320            .collect()
321    }
322
323    /// Search with quantized beam search and a metadata filter.
324    ///
325    /// Composes quantized distance computation with label-based filtering.
326    /// `labels` must have one entry per vector in the index, with the same
327    /// ordering as the base index vectors.
328    ///
329    /// Distances in the returned results are:
330    /// - If `rerank_size > 0`: exact distances from the base index
331    /// - Otherwise: approximate distances from the quantizer
332    pub fn search_filtered_with_dists(
333        &self,
334        query: &[f32],
335        k: usize,
336        beam_width: usize,
337        labels: &[Vec<u64>],
338        filter: &crate::Filter,
339    ) -> Vec<(u32, f32)> {
340        assert_eq!(
341            query.len(),
342            self.base.dim,
343            "Query dim {} != index dim {}",
344            query.len(),
345            self.base.dim
346        );
347        assert_eq!(
348            labels.len(),
349            self.base.num_vectors,
350            "Labels count {} != index vectors {}",
351            labels.len(),
352            self.base.num_vectors
353        );
354
355        // For unfiltered search, use the fast path
356        if matches!(filter, crate::Filter::None) {
357            return self.search_with_dists(query, k, beam_width);
358        }
359
360        let expanded_beam = (beam_width * 4).max(k * 10);
361
362        quantized_search(
363            &self.base,
364            &self.codes,
365            self.code_size,
366            &self.quantizer,
367            &[self.base.medoid_id],
368            query,
369            k,
370            beam_width,
371            self.rerank_size,
372            |id| filter.matches(&labels[id as usize]),
373            BeamSearchConfig {
374                expanded_beam: Some(expanded_beam),
375                max_iterations: Some(expanded_beam * 2),
376                early_term_factor: Some(1.5),
377            },
378        )
379    }
380
381    /// Search with filter, returning only neighbor IDs.
382    pub fn search_filtered(
383        &self,
384        query: &[f32],
385        k: usize,
386        beam_width: usize,
387        labels: &[Vec<u64>],
388        filter: &crate::Filter,
389    ) -> Vec<u32> {
390        self.search_filtered_with_dists(query, k, beam_width, labels, filter)
391            .into_iter()
392            .map(|(id, _)| id)
393            .collect()
394    }
395
396    // -----------------------------------------------------------------------
397    // Persistence — sidecar file
398    // -----------------------------------------------------------------------
399
400    /// Save the quantizer state and codes to a sidecar file.
401    ///
402    /// The base index should be saved separately via `DiskANN::build_index` or
403    /// already persisted on disk.
404    ///
405    /// ## Sidecar Format
406    /// ```text
407    /// [magic: u32 = 0x51414E4E]
408    /// [version: u32 = 1]
409    /// [quantizer_type: u8]        // 0=PQ, 1=F16, 2=Int8
410    /// [num_vectors: u64]
411    /// [code_size: u64]
412    /// [quantizer_data_len: u64]
413    /// [quantizer_data: bincode]
414    /// [codes: num_vectors * code_size bytes]
415    /// ```
416    pub fn save_quantized(&self, path: &str) -> Result<(), DiskAnnError> {
417        let bytes = self.quantized_to_bytes();
418        let mut file = File::create(path)?;
419        file.write_all(&bytes)?;
420        file.sync_all()?;
421        Ok(())
422    }
423
424    /// Open a QuantizedDiskANN from a base index path and a sidecar path.
425    pub fn open(
426        base_path: &str,
427        quantized_path: &str,
428        dist: D,
429        config: QuantizedConfig,
430    ) -> Result<Self, DiskAnnError> {
431        let base = DiskANN::open_index_with(base_path, dist)?;
432        let mut file = File::open(quantized_path)?;
433        let mut bytes = Vec::new();
434        file.read_to_end(&mut bytes)?;
435        Self::from_quantized_bytes(&base, &bytes, config)
436    }
437
438    /// Serialize the quantizer + codes to bytes (without the base index).
439    pub fn to_bytes(&self) -> Vec<u8> {
440        // Format: [base_len:u64][base_bytes][quantized_bytes]
441        let base_bytes = self.base.to_bytes();
442        let quantized_bytes = self.quantized_to_bytes();
443        let mut out = Vec::with_capacity(8 + base_bytes.len() + quantized_bytes.len());
444        out.extend_from_slice(&(base_bytes.len() as u64).to_le_bytes());
445        out.extend_from_slice(&base_bytes);
446        out.extend_from_slice(&quantized_bytes);
447        out
448    }
449
450    /// Deserialize from bytes (base index + quantized data).
451    pub fn from_bytes(
452        bytes: &[u8],
453        dist: D,
454        config: QuantizedConfig,
455    ) -> Result<Self, DiskAnnError> {
456        if bytes.len() < 8 {
457            return Err(DiskAnnError::IndexError("Buffer too small".into()));
458        }
459        let base_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
460        if bytes.len() < 8 + base_len {
461            return Err(DiskAnnError::IndexError("Buffer too small for base index".into()));
462        }
463        let base_bytes = bytes[8..8 + base_len].to_vec();
464        let quantized_bytes = &bytes[8 + base_len..];
465
466        let base = DiskANN::from_bytes(base_bytes, dist)?;
467        Self::from_quantized_bytes(&base, quantized_bytes, config)
468    }
469
470    // -----------------------------------------------------------------------
471    // Internal serialization helpers
472    // -----------------------------------------------------------------------
473
474    fn quantized_to_bytes(&self) -> Vec<u8> {
475        let quantizer_data = bincode::serialize(&self.quantizer).unwrap();
476        let num_vectors = self.base.num_vectors as u64;
477        let code_size = self.code_size as u64;
478        let quantizer_data_len = quantizer_data.len() as u64;
479
480        let header_size = 4 + 4 + 1 + 8 + 8 + 8; // magic + version + type + num_vectors + code_size + qdata_len
481        let total = header_size + quantizer_data.len() + self.codes.len();
482        let mut out = Vec::with_capacity(total);
483
484        out.extend_from_slice(&MAGIC.to_le_bytes());
485        out.extend_from_slice(&VERSION.to_le_bytes());
486        out.push(self.quantizer.quantizer_type_id());
487        out.extend_from_slice(&num_vectors.to_le_bytes());
488        out.extend_from_slice(&code_size.to_le_bytes());
489        out.extend_from_slice(&quantizer_data_len.to_le_bytes());
490        out.extend_from_slice(&quantizer_data);
491        out.extend_from_slice(&self.codes);
492        out
493    }
494
495    fn from_quantized_bytes(
496        base: &DiskANN<D>,
497        bytes: &[u8],
498        config: QuantizedConfig,
499    ) -> Result<Self, DiskAnnError> {
500        let header_size = 4 + 4 + 1 + 8 + 8 + 8;
501        if bytes.len() < header_size {
502            return Err(DiskAnnError::IndexError("Quantized data too small".into()));
503        }
504
505        let mut pos = 0;
506
507        let magic = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap());
508        pos += 4;
509        if magic != MAGIC {
510            return Err(DiskAnnError::IndexError(format!(
511                "Invalid magic: expected 0x{:08X}, got 0x{:08X}",
512                MAGIC, magic
513            )));
514        }
515
516        let version = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap());
517        pos += 4;
518        if version != VERSION {
519            return Err(DiskAnnError::IndexError(format!(
520                "Unsupported version: {}",
521                version
522            )));
523        }
524
525        let _quantizer_type = bytes[pos];
526        pos += 1;
527
528        let num_vectors = u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap()) as usize;
529        pos += 8;
530
531        let code_size = u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap()) as usize;
532        pos += 8;
533
534        let quantizer_data_len =
535            u64::from_le_bytes(bytes[pos..pos + 8].try_into().unwrap()) as usize;
536        pos += 8;
537
538        if bytes.len() < pos + quantizer_data_len {
539            return Err(DiskAnnError::IndexError("Truncated quantizer data".into()));
540        }
541        let quantizer: QuantizerState =
542            bincode::deserialize(&bytes[pos..pos + quantizer_data_len])?;
543        pos += quantizer_data_len;
544
545        let codes_len = num_vectors * code_size;
546        if bytes.len() < pos + codes_len {
547            return Err(DiskAnnError::IndexError("Truncated codes data".into()));
548        }
549        let codes = bytes[pos..pos + codes_len].to_vec();
550
551        // Clone the base index data — we need ownership
552        let base_bytes = base.to_bytes();
553        let owned_base = DiskANN::from_bytes(base_bytes, base.dist)?;
554
555        Ok(Self {
556            base: owned_base,
557            codes,
558            code_size,
559            quantizer,
560            rerank_size: config.rerank_size,
561        })
562    }
563}
564
565// ---------------------------------------------------------------------------
566// Combined builders (vectors -> graph + quantizer + codes)
567// ---------------------------------------------------------------------------
568
569impl<D> QuantizedDiskANN<D>
570where
571    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
572{
573    /// Build a PQ-quantized index from scratch.
574    pub fn build_pq(
575        vectors: &[Vec<f32>],
576        dist: D,
577        file_path: &str,
578        ann_params: DiskAnnParams,
579        pq_config: PQConfig,
580        config: QuantizedConfig,
581    ) -> Result<Self, DiskAnnError> {
582        let base = DiskANN::build_index_with_params(vectors, dist, file_path, ann_params)?;
583        let pq = ProductQuantizer::train(vectors, pq_config)?;
584        Ok(Self::from_pq(base, pq, config))
585    }
586
587    /// Build an F16-quantized index from scratch.
588    pub fn build_f16(
589        vectors: &[Vec<f32>],
590        dist: D,
591        file_path: &str,
592        ann_params: DiskAnnParams,
593        config: QuantizedConfig,
594    ) -> Result<Self, DiskAnnError> {
595        let base = DiskANN::build_index_with_params(vectors, dist, file_path, ann_params)?;
596        Ok(Self::from_f16(base, config))
597    }
598
599    /// Build an Int8-quantized index from scratch.
600    pub fn build_int8(
601        vectors: &[Vec<f32>],
602        dist: D,
603        file_path: &str,
604        ann_params: DiskAnnParams,
605        config: QuantizedConfig,
606    ) -> Result<Self, DiskAnnError> {
607        let base = DiskANN::build_index_with_params(vectors, dist, file_path, ann_params)?;
608        let int8q = Int8Quantizer::train(vectors)?;
609        Ok(Self::from_int8(base, int8q, config))
610    }
611}
612
613// ---------------------------------------------------------------------------
614// Encoding helpers — parallel encoding of all vectors
615// ---------------------------------------------------------------------------
616
617fn encode_all_pq<D>(
618    base: &DiskANN<D>,
619    pq: &ProductQuantizer,
620    n: usize,
621) -> Vec<u8>
622where
623    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
624{
625    let code_size = pq.stats().code_size_bytes;
626    let vectors: Vec<Vec<f32>> = (0..n).map(|i| base.get_vector(i)).collect();
627    let encoded: Vec<Vec<u8>> = vectors.par_iter().map(|v| pq.encode(v)).collect();
628    let mut flat = Vec::with_capacity(n * code_size);
629    for code in &encoded {
630        flat.extend_from_slice(code);
631    }
632    flat
633}
634
635fn encode_all_generic<D, Q>(
636    base: &DiskANN<D>,
637    quantizer: &Q,
638    n: usize,
639    code_size: usize,
640) -> Vec<u8>
641where
642    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
643    Q: VectorQuantizer,
644{
645    let vectors: Vec<Vec<f32>> = (0..n).map(|i| base.get_vector(i)).collect();
646    let encoded: Vec<Vec<u8>> = vectors.par_iter().map(|v| quantizer.encode(v)).collect();
647    let mut flat = Vec::with_capacity(n * code_size);
648    for code in &encoded {
649        flat.extend_from_slice(code);
650    }
651    flat
652}
653
654// ---------------------------------------------------------------------------
655// Tests
656// ---------------------------------------------------------------------------
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661    use anndists::dist::DistL2;
662    use rand::prelude::*;
663    use rand::SeedableRng;
664    use std::collections::HashSet;
665
666    fn random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
667        let mut rng = StdRng::seed_from_u64(seed);
668        (0..n)
669            .map(|_| (0..dim).map(|_| rng.r#gen::<f32>()).collect())
670            .collect()
671    }
672
673    fn brute_force_knn(vectors: &[Vec<f32>], query: &[f32], k: usize) -> Vec<u32> {
674        let mut dists: Vec<(u32, f32)> = vectors
675            .iter()
676            .enumerate()
677            .map(|(i, v)| {
678                let d: f32 = query.iter().zip(v).map(|(a, b)| (a - b) * (a - b)).sum();
679                (i as u32, d)
680            })
681            .collect();
682        dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
683        dists.iter().take(k).map(|(i, _)| *i).collect()
684    }
685
686    fn recall_at_k(retrieved: &[u32], ground_truth: &[u32]) -> f32 {
687        let gt_set: HashSet<u32> = ground_truth.iter().copied().collect();
688        let hits = retrieved.iter().filter(|id| gt_set.contains(id)).count();
689        hits as f32 / ground_truth.len() as f32
690    }
691
692    // Test 1: PQ basic search
693    #[test]
694    fn test_quantized_pq_basic() {
695        let path = "test_quantized_pq_basic.db";
696        let _ = std::fs::remove_file(path);
697
698        let dim = 32;
699        let vectors = random_vectors(200, dim, 42);
700        let pq_config = PQConfig {
701            num_subspaces: 4,
702            num_centroids: 64,
703            kmeans_iterations: 10,
704            training_sample_size: 0,
705        };
706        let config = QuantizedConfig { rerank_size: 0 };
707        let ann_params = DiskAnnParams {
708            max_degree: 32,
709            build_beam_width: 128,
710            alpha: 1.2,
711        };
712
713        let index = QuantizedDiskANN::<DistL2>::build_pq(
714            &vectors, DistL2 {}, path, ann_params, pq_config, config,
715        )
716        .unwrap();
717
718        let query = &vectors[0];
719        let results = index.search(query, 10, 64);
720        assert_eq!(results.len(), 10);
721
722        let gt = brute_force_knn(&vectors, query, 10);
723        let recall = recall_at_k(&results, &gt);
724        assert!(
725            recall >= 0.3,
726            "PQ recall@10 too low: {recall} (expected >= 0.3)"
727        );
728
729        let _ = std::fs::remove_file(path);
730    }
731
732    // Test 2: F16 basic search
733    #[test]
734    fn test_quantized_f16_basic() {
735        let path = "test_quantized_f16_basic.db";
736        let _ = std::fs::remove_file(path);
737
738        let dim = 32;
739        let vectors = random_vectors(200, dim, 43);
740        let config = QuantizedConfig { rerank_size: 0 };
741        let ann_params = DiskAnnParams {
742            max_degree: 32,
743            build_beam_width: 128,
744            alpha: 1.2,
745        };
746
747        let index = QuantizedDiskANN::<DistL2>::build_f16(
748            &vectors, DistL2 {}, path, ann_params, config,
749        )
750        .unwrap();
751
752        let query = &vectors[0];
753        let results = index.search(query, 10, 64);
754        assert_eq!(results.len(), 10);
755
756        let gt = brute_force_knn(&vectors, query, 10);
757        let recall = recall_at_k(&results, &gt);
758        assert!(
759            recall >= 0.7,
760            "F16 recall@10 too low: {recall} (expected >= 0.7)"
761        );
762
763        let _ = std::fs::remove_file(path);
764    }
765
766    // Test 3: Int8 basic search
767    #[test]
768    fn test_quantized_int8_basic() {
769        let path = "test_quantized_int8_basic.db";
770        let _ = std::fs::remove_file(path);
771
772        let dim = 32;
773        let vectors = random_vectors(200, dim, 44);
774        let config = QuantizedConfig { rerank_size: 0 };
775        let ann_params = DiskAnnParams {
776            max_degree: 32,
777            build_beam_width: 128,
778            alpha: 1.2,
779        };
780
781        let index = QuantizedDiskANN::<DistL2>::build_int8(
782            &vectors, DistL2 {}, path, ann_params, config,
783        )
784        .unwrap();
785
786        let query = &vectors[0];
787        let results = index.search(query, 10, 64);
788        assert_eq!(results.len(), 10);
789
790        let gt = brute_force_knn(&vectors, query, 10);
791        let recall = recall_at_k(&results, &gt);
792        assert!(
793            recall >= 0.7,
794            "Int8 recall@10 too low: {recall} (expected >= 0.7)"
795        );
796
797        let _ = std::fs::remove_file(path);
798    }
799
800    // Test 4: Re-ranking improves recall
801    #[test]
802    fn test_reranking_improves_recall() {
803        let path_no_rr = "test_quantized_no_rerank.db";
804        let path_rr = "test_quantized_rerank.db";
805        let _ = std::fs::remove_file(path_no_rr);
806        let _ = std::fs::remove_file(path_rr);
807
808        let dim = 32;
809        let vectors = random_vectors(200, dim, 45);
810        let pq_config = PQConfig {
811            num_subspaces: 4,
812            num_centroids: 64,
813            kmeans_iterations: 10,
814            training_sample_size: 0,
815        };
816        let ann_params = DiskAnnParams {
817            max_degree: 32,
818            build_beam_width: 128,
819            alpha: 1.2,
820        };
821
822        let index_no_rr = QuantizedDiskANN::<DistL2>::build_pq(
823            &vectors,
824            DistL2 {},
825            path_no_rr,
826            ann_params,
827            pq_config,
828            QuantizedConfig { rerank_size: 0 },
829        )
830        .unwrap();
831
832        let index_rr = QuantizedDiskANN::<DistL2>::build_pq(
833            &vectors,
834            DistL2 {},
835            path_rr,
836            ann_params,
837            pq_config,
838            QuantizedConfig { rerank_size: 50 },
839        )
840        .unwrap();
841
842        // Average recall over several queries
843        let num_queries = 20;
844        let mut total_no_rr = 0.0f32;
845        let mut total_rr = 0.0f32;
846
847        for i in 0..num_queries {
848            let query = &vectors[i];
849            let gt = brute_force_knn(&vectors, query, 10);
850            let res_no_rr = index_no_rr.search(query, 10, 64);
851            let res_rr = index_rr.search(query, 10, 64);
852            total_no_rr += recall_at_k(&res_no_rr, &gt);
853            total_rr += recall_at_k(&res_rr, &gt);
854        }
855
856        let avg_no_rr = total_no_rr / num_queries as f32;
857        let avg_rr = total_rr / num_queries as f32;
858
859        // Re-ranking should generally improve or at least match recall
860        assert!(
861            avg_rr >= avg_no_rr - 0.05,
862            "Re-ranking should not significantly degrade recall: no_rr={avg_no_rr}, rr={avg_rr}"
863        );
864
865        let _ = std::fs::remove_file(path_no_rr);
866        let _ = std::fs::remove_file(path_rr);
867    }
868
869    // Test 5: Save/load sidecar round-trip
870    #[test]
871    fn test_save_load_roundtrip() {
872        let base_path = "test_quantized_save_base.db";
873        let sidecar_path = "test_quantized_save_sidecar.qann";
874        let _ = std::fs::remove_file(base_path);
875        let _ = std::fs::remove_file(sidecar_path);
876
877        let dim = 32;
878        let vectors = random_vectors(100, dim, 46);
879        let ann_params = DiskAnnParams {
880            max_degree: 32,
881            build_beam_width: 64,
882            alpha: 1.2,
883        };
884        let config = QuantizedConfig { rerank_size: 10 };
885
886        let index = QuantizedDiskANN::<DistL2>::build_f16(
887            &vectors, DistL2 {}, base_path, ann_params, config,
888        )
889        .unwrap();
890
891        // Search before save
892        let query = &vectors[0];
893        let res_before = index.search(query, 5, 32);
894
895        // Save sidecar
896        index.save_quantized(sidecar_path).unwrap();
897
898        // Reload
899        let loaded = QuantizedDiskANN::<DistL2>::open(
900            base_path,
901            sidecar_path,
902            DistL2 {},
903            config,
904        )
905        .unwrap();
906
907        assert_eq!(loaded.num_vectors(), index.num_vectors());
908        assert_eq!(loaded.dim(), index.dim());
909
910        let res_after = loaded.search(query, 5, 32);
911        assert_eq!(res_before, res_after);
912
913        let _ = std::fs::remove_file(base_path);
914        let _ = std::fs::remove_file(sidecar_path);
915    }
916
917    // Test 6: to_bytes / from_bytes round-trip
918    #[test]
919    fn test_to_bytes_from_bytes() {
920        let path = "test_quantized_bytes_rt.db";
921        let _ = std::fs::remove_file(path);
922
923        let dim = 32;
924        let vectors = random_vectors(100, dim, 47);
925        let ann_params = DiskAnnParams {
926            max_degree: 32,
927            build_beam_width: 64,
928            alpha: 1.2,
929        };
930        let config = QuantizedConfig { rerank_size: 0 };
931
932        let index = QuantizedDiskANN::<DistL2>::build_int8(
933            &vectors, DistL2 {}, path, ann_params, config,
934        )
935        .unwrap();
936
937        let query = &vectors[0];
938        let res_before = index.search(query, 5, 32);
939
940        let bytes = index.to_bytes();
941        let loaded =
942            QuantizedDiskANN::<DistL2>::from_bytes(&bytes, DistL2 {}, config).unwrap();
943
944        assert_eq!(loaded.num_vectors(), index.num_vectors());
945        let res_after = loaded.search(query, 5, 32);
946        assert_eq!(res_before, res_after);
947
948        let _ = std::fs::remove_file(path);
949    }
950
951    // Test 7: PQ distance table produces valid results
952    #[test]
953    fn test_pq_distance_table_used() {
954        let path = "test_quantized_pq_table.db";
955        let _ = std::fs::remove_file(path);
956
957        let dim = 32;
958        let vectors = random_vectors(100, dim, 48);
959        let pq_config = PQConfig {
960            num_subspaces: 4,
961            num_centroids: 64,
962            kmeans_iterations: 10,
963            training_sample_size: 0,
964        };
965        let config = QuantizedConfig { rerank_size: 0 };
966        let ann_params = DiskAnnParams {
967            max_degree: 32,
968            build_beam_width: 64,
969            alpha: 1.2,
970        };
971
972        let index = QuantizedDiskANN::<DistL2>::build_pq(
973            &vectors, DistL2 {}, path, ann_params, pq_config, config,
974        )
975        .unwrap();
976
977        let query = &vectors[0];
978        let results = index.search_with_dists(query, 10, 32);
979
980        // All returned distances should be non-negative
981        for (id, dist) in &results {
982            assert!(*dist >= 0.0, "Negative distance for id {id}: {dist}");
983            assert!(*dist < f32::MAX, "MAX distance for id {id}");
984        }
985
986        // Distances should be non-decreasing
987        for pair in results.windows(2) {
988            assert!(
989                pair[0].1 <= pair[1].1 + 1e-6,
990                "Distances not sorted: {} > {}",
991                pair[0].1,
992                pair[1].1
993            );
994        }
995
996        let _ = std::fs::remove_file(path);
997    }
998
999    // Test 8: search_with_dists returns correct distances
1000    #[test]
1001    fn test_quantized_search_with_dists() {
1002        let path = "test_quantized_with_dists.db";
1003        let _ = std::fs::remove_file(path);
1004
1005        let dim = 32;
1006        let vectors = random_vectors(100, dim, 49);
1007        let config = QuantizedConfig { rerank_size: 20 };
1008        let ann_params = DiskAnnParams {
1009            max_degree: 32,
1010            build_beam_width: 64,
1011            alpha: 1.2,
1012        };
1013
1014        let index = QuantizedDiskANN::<DistL2>::build_f16(
1015            &vectors, DistL2 {}, path, ann_params, config,
1016        )
1017        .unwrap();
1018
1019        let query = &vectors[0];
1020        let results = index.search_with_dists(query, 5, 32);
1021        assert_eq!(results.len(), 5);
1022
1023        // With reranking, distances should be exact L2 (DistL2 returns sqrt of sum-of-squares)
1024        for (id, dist) in &results {
1025            let v = index.get_vector(*id as usize);
1026            let exact: f32 = query
1027                .iter()
1028                .zip(&v)
1029                .map(|(a, b)| (a - b) * (a - b))
1030                .sum::<f32>()
1031                .sqrt();
1032            assert!(
1033                (dist - exact).abs() < 1e-4,
1034                "Distance mismatch for id {id}: returned {dist}, exact {exact}"
1035            );
1036        }
1037
1038        // Distances should be non-decreasing
1039        for pair in results.windows(2) {
1040            assert!(pair[0].1 <= pair[1].1 + 1e-6);
1041        }
1042
1043        let _ = std::fs::remove_file(path);
1044    }
1045}