Skip to main content

ailake_index/
ivf_pq.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// IVF-PQ index: Inverted File Index with Product Quantization.
3//
4// vs HNSW tradeoffs:
5//   - Index size: ~100x smaller (PQ codes vs raw vectors + graph pointers)
6//   - S3 reads: sequential inverted-list scan vs random graph traversal
7//   - Recall: slightly lower at same memory, tunable via nprobe
8//   - Build: O(n * nlist) k-means vs O(n log n) HNSW insertions
9//
10// Two variants selected by IvfPqConfig::residual:
11//   false (default) — global PQ codebook trained on raw vectors. Simpler, backward-compatible.
12//   true (residual) — PQ trained on per-cluster residuals (vec - coarse_centroid).
13//     Same bytes/vector, ~2-4pp better recall@10 because residuals have lower variance
14//     per sub-space than raw vectors. Search uses per-cluster ADC tables (O(nprobe*M*K)
15//     extra precomputation, negligible vs scan cost).
16
17use serde::{Deserialize, Serialize};
18use tracing::{debug, info, warn};
19
20use ailake_core::{AilakeError, AilakeResult, RowId, VectorMetric};
21use ailake_vec::{kmeans_centroids, PQCodebook};
22
23/// K-means dispatch: NVIDIA CUDA → AMD ROCm → CPU rayon fallback.
24fn kmeans_dispatch(vecs: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
25    if let Some(result) = crate::gpu::try_nvidia_kmeans(vecs, k, max_iter) {
26        debug!(
27            "ailake: IVF-PQ k-means used NVIDIA CUDA (n={} k={} max_iter={})",
28            vecs.len(),
29            k,
30            max_iter
31        );
32        return result;
33    }
34    if let Some(result) = crate::gpu::try_rocm_kmeans(vecs, k, max_iter) {
35        debug!(
36            "ailake: IVF-PQ k-means used AMD ROCm (n={} k={} max_iter={})",
37            vecs.len(),
38            k,
39            max_iter
40        );
41        return result;
42    }
43    debug!(
44        "ailake: IVF-PQ k-means using CPU rayon (n={} k={} max_iter={})",
45        vecs.len(),
46        k,
47        max_iter
48    );
49    kmeans_centroids(vecs, k, max_iter)
50}
51
52/// Configuration for IVF-PQ index construction and search.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct IvfPqConfig {
55    /// Number of coarse Voronoi cells (inverted lists).
56    /// Rule of thumb: sqrt(n) for balanced recall/speed. Default 256.
57    pub nlist: usize,
58    /// Cells probed per query. Higher = better recall, more compute.
59    /// nprobe=1 is ANN; nprobe=nlist is exact. Default 8.
60    pub nprobe: usize,
61    /// PQ sub-vector count M. Must divide dim. Default 8.
62    pub pq_m: usize,
63    /// PQ centroids per sub-space K. Must be ≤ 256 (u8 codes). Default 256.
64    pub pq_k: usize,
65    /// k-means max iterations for both coarse and PQ training. Default 25.
66    pub max_iter: usize,
67    /// Residual PQ: train PQ on per-cluster residuals (vec - coarse_centroid) instead of raw
68    /// vectors. Same storage cost, ~2-4pp better recall@10. Default false (backward-compatible).
69    #[serde(default)]
70    pub residual: bool,
71}
72
73impl Default for IvfPqConfig {
74    fn default() -> Self {
75        Self {
76            nlist: 256,
77            nprobe: 8,
78            pq_m: 8,
79            pq_k: 256,
80            max_iter: 25,
81            residual: false,
82        }
83    }
84}
85
86impl IvfPqConfig {
87    /// Derive sensible defaults from vector dimensionality.
88    pub fn for_dim(dim: usize) -> Self {
89        let pq_m = (dim / 8).clamp(4, 96);
90        Self {
91            pq_m: find_valid_pq_m(pq_m, dim),
92            ..Self::default()
93        }
94    }
95
96    /// Derive sensible defaults from both dimensionality and dataset size.
97    ///
98    /// `nlist` scales with sqrt(n_vectors) so each cluster gets ~4 vectors on average.
99    /// Clamped to [16, 1024] to avoid degenerate configs on tiny or huge datasets.
100    pub fn for_dataset(dim: usize, n_vectors: usize) -> Self {
101        let nlist = ((n_vectors as f64).sqrt() as usize).clamp(16, 1024);
102        let nprobe = (nlist / 4).max(1); // 25% coverage — better candidate quality than nlist/8
103        let pq_m_hint = (dim / 8).clamp(4, 96);
104        Self {
105            nlist,
106            nprobe,
107            pq_m: find_valid_pq_m(pq_m_hint, dim),
108            pq_k: 256,
109            max_iter: 25,
110            residual: false,
111        }
112    }
113
114    /// Enable residual PQ — train on per-cluster residuals for better recall at same storage.
115    pub fn with_residual(mut self) -> Self {
116        self.residual = true;
117        self
118    }
119}
120
121pub struct IvfPqIndex {
122    pub config: IvfPqConfig,
123    pub metric: VectorMetric,
124    pub dim: usize,
125    /// Coarse cluster centroids: [nlist × dim]
126    coarse_centroids: Vec<Vec<f32>>,
127    /// PQ codebook — trained on raw vectors (residual=false) or residuals (residual=true)
128    pq: PQCodebook,
129    /// Inverted lists: row IDs per cluster
130    inv_row_ids: Vec<Vec<u64>>,
131    /// PQ codes per cluster, flat: inv_codes[i].len() == inv_row_ids[i].len() * pq_m
132    inv_codes: Vec<Vec<u8>>,
133    /// Whether codes are residual-encoded (vec - coarse_centroid)
134    residual: bool,
135}
136
137/// Shared codebook trained once and reused across all shards.
138/// When multiple shards share the same codebook, their ADC distances are
139/// numerically comparable — cross-shard merge is correct without reranking.
140#[derive(Clone)]
141pub struct IvfPqCodebook {
142    pub coarse_centroids: Vec<Vec<f32>>,
143    pub pq: PQCodebook,
144    pub nlist: usize,
145    pub nprobe: usize,
146    pub pq_m: usize,
147    pub dim: usize,
148    pub metric: VectorMetric,
149    pub residual: bool,
150}
151
152impl IvfPqIndex {
153    /// Train IVF-PQ index (trains its own coarse + PQ codebook).
154    pub fn train(
155        row_ids: &[RowId],
156        vectors: &[Vec<f32>],
157        metric: VectorMetric,
158        config: IvfPqConfig,
159    ) -> AilakeResult<Self> {
160        let codebook = Self::train_codebook(vectors, metric, &config)?;
161        Self::build_with_codebook(row_ids, vectors, &codebook)
162    }
163
164    /// Train only the coarse quantizer + PQ codebook (no inverted lists).
165    /// Call once, then reuse across shards via `build_with_codebook`.
166    pub fn train_codebook(
167        vectors: &[Vec<f32>],
168        metric: VectorMetric,
169        config: &IvfPqConfig,
170    ) -> AilakeResult<IvfPqCodebook> {
171        let n = vectors.len();
172        if n == 0 {
173            return Err(AilakeError::Catalog(
174                "IVF-PQ training requires at least 1 vector".into(),
175            ));
176        }
177        let dim = vectors[0].len();
178
179        let normed_storage: Vec<Vec<f32>>;
180        let vecs: &[Vec<f32>] = if metric == VectorMetric::Cosine {
181            normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
182            &normed_storage
183        } else {
184            vectors
185        };
186
187        let nlist = config.nlist.min(n);
188        if nlist < config.nlist {
189            warn!(
190                "ailake: IVF-PQ nlist clamped from {} to {} (n={} vectors); \
191                 consider using HNSW for small datasets",
192                config.nlist, nlist, n
193            );
194        }
195        let nprobe = config.nprobe.min(nlist);
196        let pq_m = find_valid_pq_m(config.pq_m, dim);
197
198        info!(
199            "ailake: training IVF-PQ codebook — n={} dim={} nlist={} nprobe={} pq_m={}",
200            n, dim, nlist, nprobe, pq_m
201        );
202
203        let coarse_centroids = kmeans_dispatch(vecs, nlist, config.max_iter);
204
205        let pq_train_vecs: Vec<Vec<f32>>;
206        let pq_input: &[Vec<f32>] = if config.residual {
207            // Compute per-cluster residuals and train PQ on them.
208            // Lower variance per sub-space → better codebook quality at same K.
209            let assignments: Vec<usize> = vecs
210                .iter()
211                .map(|v| nearest_idx(v, &coarse_centroids))
212                .collect();
213            pq_train_vecs = vecs
214                .iter()
215                .zip(assignments.iter())
216                .map(|(v, &c)| {
217                    v.iter()
218                        .zip(coarse_centroids[c].iter())
219                        .map(|(a, b)| a - b)
220                        .collect()
221                })
222                .collect();
223            &pq_train_vecs
224        } else {
225            vecs
226        };
227
228        let pq = PQCodebook::train_with_kmeans(
229            pq_input,
230            pq_m,
231            config.pq_k.min(256),
232            config.max_iter,
233            kmeans_dispatch,
234        )
235        .map_err(|e| AilakeError::Catalog(format!("PQ training failed: {e}")))?;
236
237        Ok(IvfPqCodebook {
238            coarse_centroids,
239            pq,
240            nlist,
241            nprobe,
242            pq_m,
243            dim,
244            metric,
245            residual: config.residual,
246        })
247    }
248
249    /// Build inverted lists using a pre-trained codebook. No k-means training.
250    /// All shards built from the same codebook produce comparable ADC distances.
251    pub fn build_with_codebook(
252        row_ids: &[RowId],
253        vectors: &[Vec<f32>],
254        codebook: &IvfPqCodebook,
255    ) -> AilakeResult<Self> {
256        let n = vectors.len();
257        if n == 0 {
258            return Err(AilakeError::Catalog(
259                "IVF-PQ build requires at least 1 vector".into(),
260            ));
261        }
262
263        let normed_storage: Vec<Vec<f32>>;
264        let vecs: &[Vec<f32>] = if codebook.metric == VectorMetric::Cosine {
265            normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
266            &normed_storage
267        } else {
268            vectors
269        };
270
271        let nlist = codebook.nlist;
272        let assignments: Vec<usize> = vecs
273            .iter()
274            .map(|v| nearest_idx(v, &codebook.coarse_centroids))
275            .collect();
276
277        let mut inv_row_ids = vec![Vec::new(); nlist];
278        let mut inv_codes = vec![Vec::new(); nlist];
279
280        for (i, (v, &list_idx)) in vecs.iter().zip(assignments.iter()).enumerate() {
281            let codes = if codebook.residual {
282                let centroid = &codebook.coarse_centroids[list_idx];
283                let residual: Vec<f32> =
284                    v.iter().zip(centroid.iter()).map(|(a, b)| a - b).collect();
285                codebook.pq.encode(&residual)
286            } else {
287                codebook.pq.encode(v)
288            };
289            inv_row_ids[list_idx].push(row_ids[i].0);
290            inv_codes[list_idx].extend_from_slice(&codes);
291        }
292
293        Ok(IvfPqIndex {
294            config: IvfPqConfig {
295                nlist: codebook.nlist,
296                nprobe: codebook.nprobe,
297                pq_m: codebook.pq_m,
298                pq_k: codebook.pq.num_centroids,
299                max_iter: 0,
300                residual: codebook.residual,
301            },
302            metric: codebook.metric,
303            dim: codebook.dim,
304            coarse_centroids: codebook.coarse_centroids.clone(),
305            pq: codebook.pq.clone(),
306            inv_row_ids,
307            inv_codes,
308            residual: codebook.residual,
309        })
310    }
311
312    /// Search for approximate nearest neighbors.
313    ///
314    /// `nprobe` overrides `config.nprobe` when `Some`. `ef` is ignored (HNSW compat shim).
315    pub fn search(&self, query: &[f32], top_k: usize, nprobe: Option<usize>) -> Vec<(RowId, f32)> {
316        let nprobe = nprobe.unwrap_or(self.config.nprobe).min(self.config.nlist);
317
318        let q_normed: Vec<f32>;
319        let q: &[f32] = if self.metric == VectorMetric::Cosine {
320            q_normed = l2_normalize(query);
321            &q_normed
322        } else {
323            query
324        };
325
326        // Select nprobe nearest coarse centroids
327        let mut c_dists: Vec<(usize, f32)> = self
328            .coarse_centroids
329            .iter()
330            .enumerate()
331            .map(|(i, c)| (i, l2_sq(q, c)))
332            .collect();
333        c_dists.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
334        c_dists.truncate(nprobe);
335
336        // Non-residual: precompute one global ADC table for the full query.
337        // Residual: per-cluster ADC table — query residual differs per centroid.
338        let global_adc = if !self.residual {
339            Some(self.pq.compute_adc_table(q))
340        } else {
341            None
342        };
343
344        // Scan selected inverted lists
345        let pq_m = self.config.pq_m;
346        let mut candidates: Vec<(RowId, f32)> = Vec::new();
347
348        for (list_idx, _) in &c_dists {
349            let row_ids = &self.inv_row_ids[*list_idx];
350            let codes_flat = &self.inv_codes[*list_idx];
351
352            // For residual mode, subtract coarse centroid from query before computing ADC.
353            let cluster_adc;
354            let adc_table = if self.residual {
355                let centroid = &self.coarse_centroids[*list_idx];
356                let q_res: Vec<f32> = q.iter().zip(centroid.iter()).map(|(a, b)| a - b).collect();
357                cluster_adc = self.pq.compute_adc_table(&q_res);
358                &cluster_adc
359            } else {
360                // SAFETY: global_adc is Some when !self.residual (set above).
361                global_adc
362                    .as_ref()
363                    .expect("global_adc must be Some for non-residual path")
364            };
365
366            for (j, &rid) in row_ids.iter().enumerate() {
367                let codes = &codes_flat[j * pq_m..(j + 1) * pq_m];
368                let dist = self.pq.adc_distance(codes, adc_table);
369                candidates.push((RowId(rid), dist));
370            }
371        }
372
373        candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
374        candidates.truncate(top_k);
375        candidates
376    }
377
378    pub fn node_count(&self) -> u64 {
379        self.inv_row_ids.iter().map(|l| l.len() as u64).sum()
380    }
381
382    pub fn dim(&self) -> usize {
383        self.dim
384    }
385}
386
387// ── Serialization ──────────────────────────────────────────────────────────────
388//
389// Format: IvfPqSnapshotCore (bincode v1) + optional trailing byte for residual flag.
390//
391// The `residual` flag is appended as a single byte AFTER the bincode payload rather
392// than as a struct field.  bincode v1 is purely positional — #[serde(default)] has
393// no effect on missing bytes at EOF.  Appending the flag as a trailing byte gives us
394// clean backward compatibility:
395//
396//   Legacy file (no trailing byte) → cursor at EOF after core → residual = false
397//   New file (trailing byte 0x01)  → cursor has 1 byte left  → residual = true
398//   Go / C++ readers               → stop after inv_codes, trailing byte ignored
399//
400// This is the only safe way to extend the format without a versioned header rewrite.
401
402#[derive(Serialize, Deserialize)]
403struct IvfPqSnapshotCore {
404    nlist: usize,
405    nprobe: usize,
406    pq_m: usize,
407    pq_k: usize,
408    max_iter: usize,
409    dim: usize,
410    metric: u8,
411    coarse_flat: Vec<f32>, // [nlist * dim]
412    pq: PQCodebook,
413    inv_row_ids: Vec<Vec<u64>>,
414    inv_codes: Vec<Vec<u8>>, // flat per list: len == inv_row_ids[i].len() * pq_m
415}
416
417pub struct IvfPqSerializer;
418
419impl IvfPqSerializer {
420    pub fn to_bytes(index: &IvfPqIndex) -> AilakeResult<Vec<u8>> {
421        let coarse_flat: Vec<f32> = index
422            .coarse_centroids
423            .iter()
424            .flat_map(|c| c.iter().copied())
425            .collect();
426        let core = IvfPqSnapshotCore {
427            nlist: index.config.nlist,
428            nprobe: index.config.nprobe,
429            pq_m: index.config.pq_m,
430            pq_k: index.config.pq_k,
431            max_iter: index.config.max_iter,
432            dim: index.dim,
433            metric: metric_to_u8(index.metric),
434            coarse_flat,
435            pq: index.pq.clone(),
436            inv_row_ids: index.inv_row_ids.clone(),
437            inv_codes: index.inv_codes.clone(),
438        };
439        let mut bytes =
440            bincode::serialize(&core).map_err(|e| AilakeError::Bincode(e.to_string()))?;
441        // Append residual flag as trailing byte (0x00 = false, 0x01 = true).
442        // Legacy readers (Go, C++, old Rust) stop at this boundary and ignore it.
443        bytes.push(u8::from(index.residual));
444        Ok(bytes)
445    }
446
447    pub fn from_bytes(bytes: &[u8]) -> AilakeResult<IvfPqIndex> {
448        // Use a cursor so we know exactly how many bytes were consumed by the core.
449        // Any remaining byte after deserialization is the residual flag.
450        let mut cursor = std::io::Cursor::new(bytes);
451        let core: IvfPqSnapshotCore = bincode::deserialize_from(&mut cursor)
452            .map_err(|e| AilakeError::Bincode(e.to_string()))?;
453
454        let residual = if (cursor.position() as usize) < bytes.len() {
455            bytes[cursor.position() as usize] != 0
456        } else {
457            false // legacy file — no trailing byte
458        };
459
460        let metric = u8_to_metric(core.metric)?;
461        let coarse_centroids: Vec<Vec<f32>> = core
462            .coarse_flat
463            .chunks_exact(core.dim)
464            .map(|c| c.to_vec())
465            .collect();
466        Ok(IvfPqIndex {
467            config: IvfPqConfig {
468                nlist: core.nlist,
469                nprobe: core.nprobe,
470                pq_m: core.pq_m,
471                pq_k: core.pq_k,
472                max_iter: core.max_iter,
473                residual,
474            },
475            metric,
476            dim: core.dim,
477            coarse_centroids,
478            pq: core.pq,
479            inv_row_ids: core.inv_row_ids,
480            inv_codes: core.inv_codes,
481            residual,
482        })
483    }
484}
485
486// ── Helpers ────────────────────────────────────────────────────────────────────
487
488fn l2_normalize(v: &[f32]) -> Vec<f32> {
489    let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
490    if norm < 1e-9 {
491        v.to_vec()
492    } else {
493        v.iter().map(|x| x / norm).collect()
494    }
495}
496
497fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
498    a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
499}
500
501fn nearest_idx(v: &[f32], centroids: &[Vec<f32>]) -> usize {
502    centroids
503        .iter()
504        .enumerate()
505        .map(|(i, c)| (i, l2_sq(v, c)))
506        .min_by(|a, b| a.1.total_cmp(&b.1))
507        .map(|(i, _)| i)
508        .unwrap_or(0)
509}
510
511/// Find largest M <= requested that divides dim.
512pub fn find_valid_pq_m(requested: usize, dim: usize) -> usize {
513    for m in (1..=requested).rev() {
514        if dim.is_multiple_of(m) {
515            return m;
516        }
517    }
518    1
519}
520
521fn metric_to_u8(m: VectorMetric) -> u8 {
522    match m {
523        VectorMetric::Cosine => 0,
524        VectorMetric::Euclidean => 1,
525        VectorMetric::DotProduct => 2,
526        VectorMetric::NormalizedCosine => 3,
527    }
528}
529
530fn u8_to_metric(v: u8) -> AilakeResult<VectorMetric> {
531    match v {
532        0 => Ok(VectorMetric::Cosine),
533        1 => Ok(VectorMetric::Euclidean),
534        2 => Ok(VectorMetric::DotProduct),
535        3 => Ok(VectorMetric::NormalizedCosine),
536        _ => Err(AilakeError::Catalog(format!("unknown metric byte: {v}"))),
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543
544    fn make_vecs(n: usize, dim: usize) -> (Vec<RowId>, Vec<Vec<f32>>) {
545        let row_ids: Vec<RowId> = (0..n).map(|i| RowId(i as u64)).collect();
546        let vecs: Vec<Vec<f32>> = (0..n)
547            .map(|i| {
548                let mut v = vec![0.0f32; dim];
549                v[i % dim] = 1.0;
550                v
551            })
552            .collect();
553        (row_ids, vecs)
554    }
555
556    #[test]
557    fn train_and_search_basic() {
558        let dim = 8;
559        let (ids, vecs) = make_vecs(64, dim);
560        let config = IvfPqConfig {
561            nlist: 4,
562            nprobe: 2,
563            pq_m: 2,
564            pq_k: 4,
565            max_iter: 10,
566            residual: false,
567        };
568        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
569        assert_eq!(idx.node_count(), 64);
570
571        let query = vecs[0].clone();
572        let results = idx.search(&query, 5, None);
573        assert!(!results.is_empty());
574        // Top result should be close to query
575        assert!(results[0].1 < 0.1, "nearest should be approximate self");
576    }
577
578    #[test]
579    fn train_cosine_normalizes() {
580        let dim = 4;
581        let (ids, vecs) = make_vecs(32, dim);
582        let config = IvfPqConfig {
583            nlist: 4,
584            nprobe: 2,
585            pq_m: 2,
586            pq_k: 4,
587            max_iter: 10,
588            residual: false,
589        };
590        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Cosine, config).unwrap();
591        let results = idx.search(&vecs[0], 1, None);
592        assert!(!results.is_empty());
593    }
594
595    #[test]
596    fn serialize_roundtrip() {
597        let dim = 8;
598        let (ids, vecs) = make_vecs(32, dim);
599        let config = IvfPqConfig {
600            nlist: 4,
601            nprobe: 2,
602            pq_m: 2,
603            pq_k: 4,
604            max_iter: 10,
605            residual: false,
606        };
607        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
608        let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
609        let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
610
611        assert_eq!(idx2.node_count(), idx.node_count());
612        assert_eq!(idx2.dim(), idx.dim());
613
614        let q = vecs[0].clone();
615        let r1 = idx.search(&q, 5, None);
616        let r2 = idx2.search(&q, 5, None);
617        assert_eq!(r1.len(), r2.len());
618        for (a, b) in r1.iter().zip(r2.iter()) {
619            assert_eq!(a.0, b.0, "row_ids should match after roundtrip");
620        }
621    }
622
623    #[test]
624    fn nlist_clamped_to_n() {
625        let dim = 4;
626        let (ids, vecs) = make_vecs(10, dim); // fewer vectors than default nlist
627        let config = IvfPqConfig {
628            nlist: 256, // will be clamped to 10
629            nprobe: 8,
630            pq_m: 2,
631            pq_k: 4,
632            max_iter: 5,
633            residual: false,
634        };
635        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
636        assert!(idx.config.nlist <= 10);
637        assert_eq!(idx.node_count(), 10);
638    }
639
640    #[test]
641    fn residual_pq_search_finds_nearest() {
642        let dim = 8;
643        let (ids, vecs) = make_vecs(64, dim);
644        let config = IvfPqConfig {
645            nlist: 4,
646            nprobe: 4,
647            pq_m: 2,
648            pq_k: 4,
649            max_iter: 10,
650            residual: true,
651        };
652        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
653        assert_eq!(idx.node_count(), 64);
654        assert!(idx.residual);
655
656        let query = vecs[0].clone();
657        let results = idx.search(&query, 5, None);
658        assert!(!results.is_empty());
659        assert!(
660            results[0].1 < 0.1,
661            "nearest residual-PQ result should be close to query"
662        );
663    }
664
665    #[test]
666    fn residual_pq_serialize_roundtrip() {
667        let dim = 8;
668        let (ids, vecs) = make_vecs(32, dim);
669        let config = IvfPqConfig {
670            nlist: 4,
671            nprobe: 2,
672            pq_m: 2,
673            pq_k: 4,
674            max_iter: 10,
675            residual: true,
676        };
677        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
678        let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
679        let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
680
681        assert_eq!(idx2.node_count(), idx.node_count());
682        assert!(idx2.residual, "residual flag must survive roundtrip");
683
684        let q = vecs[0].clone();
685        let r1 = idx.search(&q, 5, None);
686        let r2 = idx2.search(&q, 5, None);
687        assert_eq!(r1.len(), r2.len());
688        for (a, b) in r1.iter().zip(r2.iter()) {
689            assert_eq!(a.0, b.0, "row_ids should match after roundtrip");
690        }
691    }
692
693    #[test]
694    fn non_residual_snapshot_deserializes_as_false() {
695        // Simulate a legacy snapshot (no `residual` field) deserializing to residual=false.
696        let dim = 8;
697        let (ids, vecs) = make_vecs(16, dim);
698        let config = IvfPqConfig {
699            nlist: 2,
700            nprobe: 1,
701            pq_m: 2,
702            pq_k: 4,
703            max_iter: 5,
704            residual: false,
705        };
706        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
707        let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
708        let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
709        assert!(
710            !idx2.residual,
711            "non-residual index must deserialize as residual=false"
712        );
713    }
714}