Skip to main content

ailake_index/
ivf_pq.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// IVF-PQ index: Inverted File Index with Product Quantization.
3//
4// vs HNSW tradeoffs:
5//   - Index size: ~100x smaller (PQ codes vs raw vectors + graph pointers)
6//   - S3 reads: sequential inverted-list scan vs random graph traversal
7//   - Recall: slightly lower at same memory, tunable via nprobe
8//   - Build: O(n * nlist) k-means vs O(n log n) HNSW insertions
9//
10// Two variants selected by IvfPqConfig::residual:
11//   false (default) — global PQ codebook trained on raw vectors. Simpler, backward-compatible.
12//   true (residual) — PQ trained on per-cluster residuals (vec - coarse_centroid).
13//     Same bytes/vector, ~2-4pp better recall@10 because residuals have lower variance
14//     per sub-space than raw vectors. Search uses per-cluster ADC tables (O(nprobe*M*K)
15//     extra precomputation, negligible vs scan cost).
16
17use serde::{Deserialize, Serialize};
18use tracing::{debug, info, warn};
19
20use ailake_core::{AilakeError, AilakeResult, RowId, VectorMetric};
21use ailake_vec::{kmeans_centroids, PQCodebook};
22
23/// K-means dispatch: NVIDIA CUDA → AMD ROCm → CPU rayon fallback.
24fn kmeans_dispatch(vecs: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
25    if let Some(result) = crate::gpu::try_nvidia_kmeans(vecs, k, max_iter) {
26        debug!(
27            "ailake: IVF-PQ k-means used NVIDIA CUDA (n={} k={} max_iter={})",
28            vecs.len(),
29            k,
30            max_iter
31        );
32        return result;
33    }
34    if let Some(result) = crate::gpu::try_rocm_kmeans(vecs, k, max_iter) {
35        debug!(
36            "ailake: IVF-PQ k-means used AMD ROCm (n={} k={} max_iter={})",
37            vecs.len(),
38            k,
39            max_iter
40        );
41        return result;
42    }
43    debug!(
44        "ailake: IVF-PQ k-means using CPU rayon (n={} k={} max_iter={})",
45        vecs.len(),
46        k,
47        max_iter
48    );
49    kmeans_centroids(vecs, k, max_iter)
50}
51
52/// Configuration for IVF-PQ index construction and search.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct IvfPqConfig {
55    /// Number of coarse Voronoi cells (inverted lists).
56    /// Rule of thumb: sqrt(n) for balanced recall/speed. Default 256.
57    pub nlist: usize,
58    /// Cells probed per query. Higher = better recall, more compute.
59    /// nprobe=1 is ANN; nprobe=nlist is exact. Default 8.
60    pub nprobe: usize,
61    /// PQ sub-vector count M. Must divide dim. Default 8.
62    pub pq_m: usize,
63    /// PQ centroids per sub-space K. Must be ≤ 256 (u8 codes). Default 256.
64    pub pq_k: usize,
65    /// k-means max iterations for both coarse and PQ training. Default 25.
66    pub max_iter: usize,
67    /// Residual PQ: train PQ on per-cluster residuals (vec - coarse_centroid) instead of raw
68    /// vectors. Same storage cost, ~2-4pp better recall@10. Default false (backward-compatible).
69    #[serde(default)]
70    pub residual: bool,
71}
72
73impl Default for IvfPqConfig {
74    fn default() -> Self {
75        Self {
76            nlist: 256,
77            nprobe: 8,
78            pq_m: 8,
79            pq_k: 256,
80            max_iter: 25,
81            residual: false,
82        }
83    }
84}
85
86impl IvfPqConfig {
87    /// Derive sensible defaults from vector dimensionality.
88    pub fn for_dim(dim: usize) -> Self {
89        let pq_m = (dim / 8).clamp(4, 96);
90        Self {
91            pq_m: find_valid_pq_m(pq_m, dim),
92            ..Self::default()
93        }
94    }
95
96    /// Derive sensible defaults from both dimensionality and dataset size.
97    ///
98    /// `nlist` scales with sqrt(n_vectors) so each cluster gets ~4 vectors on average.
99    /// Clamped to [16, 1024] to avoid degenerate configs on tiny or huge datasets.
100    pub fn for_dataset(dim: usize, n_vectors: usize) -> Self {
101        let nlist = ((n_vectors as f64).sqrt() as usize).clamp(16, 1024);
102        let nprobe = (nlist / 4).max(1); // 25% coverage — better candidate quality than nlist/8
103        let pq_m_hint = (dim / 8).clamp(4, 96);
104        Self {
105            nlist,
106            nprobe,
107            pq_m: find_valid_pq_m(pq_m_hint, dim),
108            pq_k: 256,
109            max_iter: 25,
110            residual: false,
111        }
112    }
113
114    /// Enable residual PQ — train on per-cluster residuals for better recall at same storage.
115    pub fn with_residual(mut self) -> Self {
116        self.residual = true;
117        self
118    }
119}
120
121pub struct IvfPqIndex {
122    pub config: IvfPqConfig,
123    pub metric: VectorMetric,
124    pub dim: usize,
125    /// Coarse cluster centroids: [nlist × dim]
126    coarse_centroids: Vec<Vec<f32>>,
127    /// PQ codebook — trained on raw vectors (residual=false) or residuals (residual=true)
128    pq: PQCodebook,
129    /// Inverted lists: row IDs per cluster
130    inv_row_ids: Vec<Vec<u64>>,
131    /// PQ codes per cluster, flat: inv_codes[i].len() == inv_row_ids[i].len() * pq_m
132    inv_codes: Vec<Vec<u8>>,
133    /// Whether codes are residual-encoded (vec - coarse_centroid)
134    residual: bool,
135}
136
137/// Shared codebook trained once and reused across all shards.
138/// When multiple shards share the same codebook, their ADC distances are
139/// numerically comparable — cross-shard merge is correct without reranking.
140#[derive(Clone)]
141pub struct IvfPqCodebook {
142    pub coarse_centroids: Vec<Vec<f32>>,
143    pub pq: PQCodebook,
144    pub nlist: usize,
145    pub nprobe: usize,
146    pub pq_m: usize,
147    pub dim: usize,
148    pub metric: VectorMetric,
149    pub residual: bool,
150}
151
152impl IvfPqIndex {
153    /// Train IVF-PQ index (trains its own coarse + PQ codebook).
154    pub fn train(
155        row_ids: &[RowId],
156        vectors: &[Vec<f32>],
157        metric: VectorMetric,
158        config: IvfPqConfig,
159    ) -> AilakeResult<Self> {
160        let codebook = Self::train_codebook(vectors, metric, &config)?;
161        Self::build_with_codebook(row_ids, vectors, &codebook)
162    }
163
164    /// Train only the coarse quantizer + PQ codebook (no inverted lists).
165    /// Call once, then reuse across shards via `build_with_codebook`.
166    pub fn train_codebook(
167        vectors: &[Vec<f32>],
168        metric: VectorMetric,
169        config: &IvfPqConfig,
170    ) -> AilakeResult<IvfPqCodebook> {
171        let n = vectors.len();
172        if n == 0 {
173            return Err(AilakeError::Catalog(
174                "IVF-PQ training requires at least 1 vector".into(),
175            ));
176        }
177        let dim = vectors[0].len();
178
179        let normed_storage: Vec<Vec<f32>>;
180        let vecs: &[Vec<f32>] = if metric == VectorMetric::Cosine {
181            normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
182            &normed_storage
183        } else {
184            vectors
185        };
186
187        let nlist = config.nlist.min(n);
188        if nlist < config.nlist {
189            warn!(
190                "ailake: IVF-PQ nlist clamped from {} to {} (n={} vectors); \
191                 consider using HNSW for small datasets",
192                config.nlist, nlist, n
193            );
194        }
195        let nprobe = config.nprobe.min(nlist);
196        let pq_m = find_valid_pq_m(config.pq_m, dim);
197
198        info!(
199            "ailake: training IVF-PQ codebook — n={} dim={} nlist={} nprobe={} pq_m={}",
200            n, dim, nlist, nprobe, pq_m
201        );
202
203        let coarse_centroids = kmeans_dispatch(vecs, nlist, config.max_iter);
204
205        let pq_train_vecs: Vec<Vec<f32>>;
206        let pq_input: &[Vec<f32>] = if config.residual {
207            // Compute per-cluster residuals and train PQ on them.
208            // Lower variance per sub-space → better codebook quality at same K.
209            let assignments: Vec<usize> = vecs
210                .iter()
211                .map(|v| nearest_idx(v, &coarse_centroids))
212                .collect();
213            pq_train_vecs = vecs
214                .iter()
215                .zip(assignments.iter())
216                .map(|(v, &c)| {
217                    v.iter()
218                        .zip(coarse_centroids[c].iter())
219                        .map(|(a, b)| a - b)
220                        .collect()
221                })
222                .collect();
223            &pq_train_vecs
224        } else {
225            vecs
226        };
227
228        let pq = PQCodebook::train_with_kmeans(
229            pq_input,
230            pq_m,
231            config.pq_k.min(256),
232            config.max_iter,
233            kmeans_dispatch,
234        )
235        .map_err(|e| AilakeError::Catalog(format!("PQ training failed: {e}")))?;
236
237        Ok(IvfPqCodebook {
238            coarse_centroids,
239            pq,
240            nlist,
241            nprobe,
242            pq_m,
243            dim,
244            metric,
245            residual: config.residual,
246        })
247    }
248
249    /// Build inverted lists using a pre-trained codebook. No k-means training.
250    /// All shards built from the same codebook produce comparable ADC distances.
251    pub fn build_with_codebook(
252        row_ids: &[RowId],
253        vectors: &[Vec<f32>],
254        codebook: &IvfPqCodebook,
255    ) -> AilakeResult<Self> {
256        let n = vectors.len();
257        if n == 0 {
258            return Err(AilakeError::Catalog(
259                "IVF-PQ build requires at least 1 vector".into(),
260            ));
261        }
262
263        let normed_storage: Vec<Vec<f32>>;
264        let vecs: &[Vec<f32>] = if codebook.metric == VectorMetric::Cosine {
265            normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
266            &normed_storage
267        } else {
268            vectors
269        };
270
271        let nlist = codebook.nlist;
272        let assignments: Vec<usize> = vecs
273            .iter()
274            .map(|v| nearest_idx(v, &codebook.coarse_centroids))
275            .collect();
276
277        let mut inv_row_ids = vec![Vec::new(); nlist];
278        let mut inv_codes = vec![Vec::new(); nlist];
279
280        for (i, (v, &list_idx)) in vecs.iter().zip(assignments.iter()).enumerate() {
281            let codes = if codebook.residual {
282                let centroid = &codebook.coarse_centroids[list_idx];
283                let residual: Vec<f32> =
284                    v.iter().zip(centroid.iter()).map(|(a, b)| a - b).collect();
285                codebook.pq.encode(&residual)
286            } else {
287                codebook.pq.encode(v)
288            };
289            inv_row_ids[list_idx].push(row_ids[i].0);
290            inv_codes[list_idx].extend_from_slice(&codes);
291        }
292
293        Ok(IvfPqIndex {
294            config: IvfPqConfig {
295                nlist: codebook.nlist,
296                nprobe: codebook.nprobe,
297                pq_m: codebook.pq_m,
298                pq_k: codebook.pq.num_centroids,
299                max_iter: 0,
300                residual: codebook.residual,
301            },
302            metric: codebook.metric,
303            dim: codebook.dim,
304            coarse_centroids: codebook.coarse_centroids.clone(),
305            pq: codebook.pq.clone(),
306            inv_row_ids,
307            inv_codes,
308            residual: codebook.residual,
309        })
310    }
311
312    /// Search for approximate nearest neighbors.
313    ///
314    /// `nprobe` overrides `config.nprobe` when `Some`. `ef` is ignored (HNSW compat shim).
315    pub fn search(&self, query: &[f32], top_k: usize, nprobe: Option<usize>) -> Vec<(RowId, f32)> {
316        let nprobe = nprobe.unwrap_or(self.config.nprobe).min(self.config.nlist);
317
318        let q_normed: Vec<f32>;
319        let q: &[f32] = if self.metric == VectorMetric::Cosine {
320            q_normed = l2_normalize(query);
321            &q_normed
322        } else {
323            query
324        };
325
326        // Select nprobe nearest coarse centroids
327        let mut c_dists: Vec<(usize, f32)> = self
328            .coarse_centroids
329            .iter()
330            .enumerate()
331            .map(|(i, c)| (i, l2_sq(q, c)))
332            .collect();
333        c_dists.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
334        c_dists.truncate(nprobe);
335
336        // Non-residual: precompute one global ADC table for the full query.
337        // Residual: per-cluster ADC table — query residual differs per centroid.
338        let global_adc = if !self.residual {
339            Some(self.pq.compute_adc_table(q))
340        } else {
341            None
342        };
343
344        // Scan selected inverted lists
345        let pq_m = self.config.pq_m;
346        let mut candidates: Vec<(RowId, f32)> = Vec::new();
347
348        for (list_idx, _) in &c_dists {
349            let row_ids = &self.inv_row_ids[*list_idx];
350            let codes_flat = &self.inv_codes[*list_idx];
351
352            // For residual mode, subtract coarse centroid from query before computing ADC.
353            let cluster_adc;
354            let adc_table = if self.residual {
355                let centroid = &self.coarse_centroids[*list_idx];
356                let q_res: Vec<f32> = q.iter().zip(centroid.iter()).map(|(a, b)| a - b).collect();
357                cluster_adc = self.pq.compute_adc_table(&q_res);
358                &cluster_adc
359            } else {
360                // SAFETY: global_adc is Some when !self.residual (set above).
361                global_adc
362                    .as_ref()
363                    .expect("global_adc must be Some for non-residual path")
364            };
365
366            for (j, &rid) in row_ids.iter().enumerate() {
367                let codes = &codes_flat[j * pq_m..(j + 1) * pq_m];
368                let dist = self.pq.adc_distance(codes, adc_table);
369                candidates.push((RowId(rid), dist));
370            }
371        }
372
373        candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
374        candidates.truncate(top_k);
375        candidates
376    }
377
378    pub fn node_count(&self) -> u64 {
379        self.inv_row_ids.iter().map(|l| l.len() as u64).sum()
380    }
381
382    pub fn dim(&self) -> usize {
383        self.dim
384    }
385}
386
387// ── Serialization ──────────────────────────────────────────────────────────────
388//
389// Format: IvfPqSnapshotCore (bincode v1) + optional trailing byte for residual flag.
390//
391// The `residual` flag is appended as a single byte AFTER the bincode payload rather
392// than as a struct field.  bincode v1 is purely positional — #[serde(default)] has
393// no effect on missing bytes at EOF.  Appending the flag as a trailing byte gives us
394// clean backward compatibility:
395//
396//   Legacy file (no trailing byte) → cursor at EOF after core → residual = false
397//   New file (trailing byte 0x01)  → cursor has 1 byte left  → residual = true
398//   Go / C++ readers               → stop after inv_codes, trailing byte ignored
399//
400// This is the only safe way to extend the format without a versioned header rewrite.
401
402#[derive(Serialize, Deserialize)]
403struct IvfPqSnapshotCore {
404    nlist: usize,
405    nprobe: usize,
406    pq_m: usize,
407    pq_k: usize,
408    max_iter: usize,
409    dim: usize,
410    metric: u8,
411    coarse_flat: Vec<f32>, // [nlist * dim]
412    pq: PQCodebook,
413    inv_row_ids: Vec<Vec<u64>>,
414    inv_codes: Vec<Vec<u8>>, // flat per list: len == inv_row_ids[i].len() * pq_m
415}
416
417pub struct IvfPqSerializer;
418
419impl IvfPqSerializer {
420    pub fn to_bytes(index: &IvfPqIndex) -> AilakeResult<Vec<u8>> {
421        let coarse_flat: Vec<f32> = index
422            .coarse_centroids
423            .iter()
424            .flat_map(|c| c.iter().copied())
425            .collect();
426        let core = IvfPqSnapshotCore {
427            nlist: index.config.nlist,
428            nprobe: index.config.nprobe,
429            pq_m: index.config.pq_m,
430            pq_k: index.config.pq_k,
431            max_iter: index.config.max_iter,
432            dim: index.dim,
433            metric: metric_to_u8(index.metric),
434            coarse_flat,
435            pq: index.pq.clone(),
436            inv_row_ids: index.inv_row_ids.clone(),
437            inv_codes: index.inv_codes.clone(),
438        };
439        let mut bytes =
440            bincode::serialize(&core).map_err(|e| AilakeError::Bincode(e.to_string()))?;
441        // Append residual flag as trailing byte (0x00 = false, 0x01 = true).
442        // Legacy readers (Go, C++, old Rust) stop at this boundary and ignore it.
443        bytes.push(u8::from(index.residual));
444        Ok(bytes)
445    }
446
447    pub fn from_bytes(bytes: &[u8]) -> AilakeResult<IvfPqIndex> {
448        // Use a cursor so we know exactly how many bytes were consumed by the core.
449        // Any remaining byte after deserialization is the residual flag.
450        let mut cursor = std::io::Cursor::new(bytes);
451        let core: IvfPqSnapshotCore = bincode::deserialize_from(&mut cursor)
452            .map_err(|e| AilakeError::Bincode(e.to_string()))?;
453
454        let residual = if (cursor.position() as usize) < bytes.len() {
455            bytes[cursor.position() as usize] != 0
456        } else {
457            false // legacy file — no trailing byte
458        };
459
460        let metric = u8_to_metric(core.metric)?;
461        let coarse_centroids: Vec<Vec<f32>> = core
462            .coarse_flat
463            .chunks_exact(core.dim)
464            .map(|c| c.to_vec())
465            .collect();
466        Ok(IvfPqIndex {
467            config: IvfPqConfig {
468                nlist: core.nlist,
469                nprobe: core.nprobe,
470                pq_m: core.pq_m,
471                pq_k: core.pq_k,
472                max_iter: core.max_iter,
473                residual,
474            },
475            metric,
476            dim: core.dim,
477            coarse_centroids,
478            pq: core.pq,
479            inv_row_ids: core.inv_row_ids,
480            inv_codes: core.inv_codes,
481            residual,
482        })
483    }
484}
485
486// ── Helpers ────────────────────────────────────────────────────────────────────
487
488fn l2_normalize(v: &[f32]) -> Vec<f32> {
489    let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
490    if norm < 1e-9 {
491        v.to_vec()
492    } else {
493        v.iter().map(|x| x / norm).collect()
494    }
495}
496
497fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
498    a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
499}
500
501fn nearest_idx(v: &[f32], centroids: &[Vec<f32>]) -> usize {
502    centroids
503        .iter()
504        .enumerate()
505        .map(|(i, c)| (i, l2_sq(v, c)))
506        .min_by(|a, b| a.1.total_cmp(&b.1))
507        .map(|(i, _)| i)
508        .unwrap_or(0)
509}
510
511/// Find largest M <= requested that divides dim.
512pub fn find_valid_pq_m(requested: usize, dim: usize) -> usize {
513    for m in (1..=requested).rev() {
514        if dim.is_multiple_of(m) {
515            return m;
516        }
517    }
518    1
519}
520
521fn metric_to_u8(m: VectorMetric) -> u8 {
522    match m {
523        VectorMetric::Cosine => 0,
524        VectorMetric::Euclidean => 1,
525        VectorMetric::DotProduct => 2,
526        VectorMetric::NormalizedCosine => 3,
527    }
528}
529
530fn u8_to_metric(v: u8) -> AilakeResult<VectorMetric> {
531    match v {
532        0 => Ok(VectorMetric::Cosine),
533        1 => Ok(VectorMetric::Euclidean),
534        2 => Ok(VectorMetric::DotProduct),
535        3 => Ok(VectorMetric::NormalizedCosine),
536        _ => Err(AilakeError::Catalog(format!(
537            "IVF-PQ codebook deserialization: unknown metric byte {v} (valid: 0=Cosine, 1=Euclidean, 2=DotProduct, 3=NormalizedCosine)"
538        ))),
539    }
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545
546    fn make_vecs(n: usize, dim: usize) -> (Vec<RowId>, Vec<Vec<f32>>) {
547        let row_ids: Vec<RowId> = (0..n).map(|i| RowId(i as u64)).collect();
548        let vecs: Vec<Vec<f32>> = (0..n)
549            .map(|i| {
550                let mut v = vec![0.0f32; dim];
551                v[i % dim] = 1.0;
552                v
553            })
554            .collect();
555        (row_ids, vecs)
556    }
557
558    #[test]
559    fn train_and_search_basic() {
560        let dim = 8;
561        let (ids, vecs) = make_vecs(64, dim);
562        let config = IvfPqConfig {
563            nlist: 4,
564            nprobe: 2,
565            pq_m: 2,
566            pq_k: 4,
567            max_iter: 10,
568            residual: false,
569        };
570        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
571        assert_eq!(idx.node_count(), 64);
572
573        let query = vecs[0].clone();
574        let results = idx.search(&query, 5, None);
575        assert!(!results.is_empty());
576        // Top result should be close to query
577        assert!(results[0].1 < 0.1, "nearest should be approximate self");
578    }
579
580    #[test]
581    fn train_cosine_normalizes() {
582        let dim = 4;
583        let (ids, vecs) = make_vecs(32, dim);
584        let config = IvfPqConfig {
585            nlist: 4,
586            nprobe: 2,
587            pq_m: 2,
588            pq_k: 4,
589            max_iter: 10,
590            residual: false,
591        };
592        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Cosine, config).unwrap();
593        let results = idx.search(&vecs[0], 1, None);
594        assert!(!results.is_empty());
595    }
596
597    #[test]
598    fn serialize_roundtrip() {
599        let dim = 8;
600        let (ids, vecs) = make_vecs(32, dim);
601        let config = IvfPqConfig {
602            nlist: 4,
603            nprobe: 2,
604            pq_m: 2,
605            pq_k: 4,
606            max_iter: 10,
607            residual: false,
608        };
609        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
610        let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
611        let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
612
613        assert_eq!(idx2.node_count(), idx.node_count());
614        assert_eq!(idx2.dim(), idx.dim());
615
616        let q = vecs[0].clone();
617        let r1 = idx.search(&q, 5, None);
618        let r2 = idx2.search(&q, 5, None);
619        assert_eq!(r1.len(), r2.len());
620        for (a, b) in r1.iter().zip(r2.iter()) {
621            assert_eq!(a.0, b.0, "row_ids should match after roundtrip");
622        }
623    }
624
625    #[test]
626    fn nlist_clamped_to_n() {
627        let dim = 4;
628        let (ids, vecs) = make_vecs(10, dim); // fewer vectors than default nlist
629        let config = IvfPqConfig {
630            nlist: 256, // will be clamped to 10
631            nprobe: 8,
632            pq_m: 2,
633            pq_k: 4,
634            max_iter: 5,
635            residual: false,
636        };
637        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
638        assert!(idx.config.nlist <= 10);
639        assert_eq!(idx.node_count(), 10);
640    }
641
642    #[test]
643    fn residual_pq_search_finds_nearest() {
644        let dim = 8;
645        let (ids, vecs) = make_vecs(64, dim);
646        let config = IvfPqConfig {
647            nlist: 4,
648            nprobe: 4,
649            pq_m: 2,
650            pq_k: 4,
651            max_iter: 10,
652            residual: true,
653        };
654        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
655        assert_eq!(idx.node_count(), 64);
656        assert!(idx.residual);
657
658        let query = vecs[0].clone();
659        let results = idx.search(&query, 5, None);
660        assert!(!results.is_empty());
661        assert!(
662            results[0].1 < 0.1,
663            "nearest residual-PQ result should be close to query"
664        );
665    }
666
667    #[test]
668    fn residual_pq_serialize_roundtrip() {
669        let dim = 8;
670        let (ids, vecs) = make_vecs(32, dim);
671        let config = IvfPqConfig {
672            nlist: 4,
673            nprobe: 2,
674            pq_m: 2,
675            pq_k: 4,
676            max_iter: 10,
677            residual: true,
678        };
679        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
680        let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
681        let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
682
683        assert_eq!(idx2.node_count(), idx.node_count());
684        assert!(idx2.residual, "residual flag must survive roundtrip");
685
686        let q = vecs[0].clone();
687        let r1 = idx.search(&q, 5, None);
688        let r2 = idx2.search(&q, 5, None);
689        assert_eq!(r1.len(), r2.len());
690        for (a, b) in r1.iter().zip(r2.iter()) {
691            assert_eq!(a.0, b.0, "row_ids should match after roundtrip");
692        }
693    }
694
695    #[test]
696    fn non_residual_snapshot_deserializes_as_false() {
697        // Simulate a legacy snapshot (no `residual` field) deserializing to residual=false.
698        let dim = 8;
699        let (ids, vecs) = make_vecs(16, dim);
700        let config = IvfPqConfig {
701            nlist: 2,
702            nprobe: 1,
703            pq_m: 2,
704            pq_k: 4,
705            max_iter: 5,
706            residual: false,
707        };
708        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
709        let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
710        let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
711        assert!(
712            !idx2.residual,
713            "non-residual index must deserialize as residual=false"
714        );
715    }
716}