Skip to main content

ailake_index/
ivf_pq.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// IVF-PQ index: Inverted File Index with Product Quantization.
3//
4// vs HNSW tradeoffs:
5//   - Index size: ~100x smaller (PQ codes vs raw vectors + graph pointers)
6//   - S3 reads: sequential inverted-list scan vs random graph traversal
7//   - Recall: slightly lower at same memory, tunable via nprobe
8//   - Build: O(n * nlist) k-means vs O(n log n) HNSW insertions
9//
10// Non-residual variant: global PQ codebook trained on all vectors.
11// Simpler than per-cluster residual PQ, adequate for dim >= 64.
12
13use serde::{Deserialize, Serialize};
14use tracing::{debug, info, warn};
15
16use ailake_core::{AilakeError, AilakeResult, RowId, VectorMetric};
17use ailake_vec::{kmeans_centroids, PQCodebook};
18
19/// K-means dispatch: NVIDIA CUDA → AMD ROCm → CPU rayon fallback.
20fn kmeans_dispatch(vecs: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
21    if let Some(result) = crate::gpu::try_nvidia_kmeans(vecs, k, max_iter) {
22        debug!(
23            "ailake: IVF-PQ k-means used NVIDIA CUDA (n={} k={} max_iter={})",
24            vecs.len(),
25            k,
26            max_iter
27        );
28        return result;
29    }
30    if let Some(result) = crate::gpu::try_rocm_kmeans(vecs, k, max_iter) {
31        debug!(
32            "ailake: IVF-PQ k-means used AMD ROCm (n={} k={} max_iter={})",
33            vecs.len(),
34            k,
35            max_iter
36        );
37        return result;
38    }
39    debug!(
40        "ailake: IVF-PQ k-means using CPU rayon (n={} k={} max_iter={})",
41        vecs.len(),
42        k,
43        max_iter
44    );
45    kmeans_centroids(vecs, k, max_iter)
46}
47
48/// Configuration for IVF-PQ index construction and search.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct IvfPqConfig {
51    /// Number of coarse Voronoi cells (inverted lists).
52    /// Rule of thumb: sqrt(n) for balanced recall/speed. Default 256.
53    pub nlist: usize,
54    /// Cells probed per query. Higher = better recall, more compute.
55    /// nprobe=1 is ANN; nprobe=nlist is exact. Default 8.
56    pub nprobe: usize,
57    /// PQ sub-vector count M. Must divide dim. Default 8.
58    pub pq_m: usize,
59    /// PQ centroids per sub-space K. Must be ≤ 256 (u8 codes). Default 256.
60    pub pq_k: usize,
61    /// k-means max iterations for both coarse and PQ training. Default 25.
62    pub max_iter: usize,
63}
64
65impl Default for IvfPqConfig {
66    fn default() -> Self {
67        Self {
68            nlist: 256,
69            nprobe: 8,
70            pq_m: 8,
71            pq_k: 256,
72            max_iter: 25,
73        }
74    }
75}
76
77impl IvfPqConfig {
78    /// Derive sensible defaults from vector dimensionality.
79    pub fn for_dim(dim: usize) -> Self {
80        let pq_m = (dim / 16).clamp(4, 64);
81        Self {
82            pq_m: find_valid_pq_m(pq_m, dim),
83            ..Self::default()
84        }
85    }
86
87    /// Derive sensible defaults from both dimensionality and dataset size.
88    ///
89    /// `nlist` scales with sqrt(n_vectors) so each cluster gets ~4 vectors on average.
90    /// Clamped to [16, 1024] to avoid degenerate configs on tiny or huge datasets.
91    pub fn for_dataset(dim: usize, n_vectors: usize) -> Self {
92        let nlist = ((n_vectors as f64).sqrt() as usize).clamp(16, 1024);
93        let nprobe = (nlist / 4).max(1); // 25% coverage — better candidate quality than nlist/8
94        let pq_m_hint = (dim / 16).clamp(4, 64);
95        Self {
96            nlist,
97            nprobe,
98            pq_m: find_valid_pq_m(pq_m_hint, dim),
99            pq_k: 256,
100            max_iter: 25,
101        }
102    }
103}
104
105pub struct IvfPqIndex {
106    pub config: IvfPqConfig,
107    pub metric: VectorMetric,
108    pub dim: usize,
109    /// Coarse cluster centroids: [nlist × dim]
110    coarse_centroids: Vec<Vec<f32>>,
111    /// Global PQ codebook trained on all vectors
112    pq: PQCodebook,
113    /// Inverted lists: row IDs per cluster
114    inv_row_ids: Vec<Vec<u64>>,
115    /// PQ codes per cluster, flat: inv_codes[i].len() == inv_row_ids[i].len() * pq_m
116    inv_codes: Vec<Vec<u8>>,
117}
118
119/// Shared codebook trained once and reused across all shards.
120/// When multiple shards share the same codebook, their ADC distances are
121/// numerically comparable — cross-shard merge is correct without reranking.
122#[derive(Clone)]
123pub struct IvfPqCodebook {
124    pub coarse_centroids: Vec<Vec<f32>>,
125    pub pq: PQCodebook,
126    pub nlist: usize,
127    pub nprobe: usize,
128    pub pq_m: usize,
129    pub dim: usize,
130    pub metric: VectorMetric,
131}
132
133impl IvfPqIndex {
134    /// Train IVF-PQ index (trains its own coarse + PQ codebook).
135    pub fn train(
136        row_ids: &[RowId],
137        vectors: &[Vec<f32>],
138        metric: VectorMetric,
139        config: IvfPqConfig,
140    ) -> AilakeResult<Self> {
141        let codebook = Self::train_codebook(vectors, metric, &config)?;
142        Self::build_with_codebook(row_ids, vectors, &codebook)
143    }
144
145    /// Train only the coarse quantizer + PQ codebook (no inverted lists).
146    /// Call once, then reuse across shards via `build_with_codebook`.
147    pub fn train_codebook(
148        vectors: &[Vec<f32>],
149        metric: VectorMetric,
150        config: &IvfPqConfig,
151    ) -> AilakeResult<IvfPqCodebook> {
152        let n = vectors.len();
153        if n == 0 {
154            return Err(AilakeError::Catalog(
155                "IVF-PQ training requires at least 1 vector".into(),
156            ));
157        }
158        let dim = vectors[0].len();
159
160        let normed_storage: Vec<Vec<f32>>;
161        let vecs: &[Vec<f32>] = if metric == VectorMetric::Cosine {
162            normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
163            &normed_storage
164        } else {
165            vectors
166        };
167
168        let nlist = config.nlist.min(n);
169        if nlist < config.nlist {
170            warn!(
171                "ailake: IVF-PQ nlist clamped from {} to {} (n={} vectors); \
172                 consider using HNSW for small datasets",
173                config.nlist, nlist, n
174            );
175        }
176        let nprobe = config.nprobe.min(nlist);
177        let pq_m = find_valid_pq_m(config.pq_m, dim);
178
179        info!(
180            "ailake: training IVF-PQ codebook — n={} dim={} nlist={} nprobe={} pq_m={}",
181            n, dim, nlist, nprobe, pq_m
182        );
183
184        let coarse_centroids = kmeans_dispatch(vecs, nlist, config.max_iter);
185        let pq = PQCodebook::train_with_kmeans(
186            vecs,
187            pq_m,
188            config.pq_k.min(256),
189            config.max_iter,
190            kmeans_dispatch,
191        )
192        .map_err(|e| AilakeError::Catalog(format!("PQ training failed: {e}")))?;
193
194        Ok(IvfPqCodebook {
195            coarse_centroids,
196            pq,
197            nlist,
198            nprobe,
199            pq_m,
200            dim,
201            metric,
202        })
203    }
204
205    /// Build inverted lists using a pre-trained codebook. No k-means training.
206    /// All shards built from the same codebook produce comparable ADC distances.
207    pub fn build_with_codebook(
208        row_ids: &[RowId],
209        vectors: &[Vec<f32>],
210        codebook: &IvfPqCodebook,
211    ) -> AilakeResult<Self> {
212        let n = vectors.len();
213        if n == 0 {
214            return Err(AilakeError::Catalog(
215                "IVF-PQ build requires at least 1 vector".into(),
216            ));
217        }
218
219        let normed_storage: Vec<Vec<f32>>;
220        let vecs: &[Vec<f32>] = if codebook.metric == VectorMetric::Cosine {
221            normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
222            &normed_storage
223        } else {
224            vectors
225        };
226
227        let nlist = codebook.nlist;
228        let assignments: Vec<usize> = vecs
229            .iter()
230            .map(|v| nearest_idx(v, &codebook.coarse_centroids))
231            .collect();
232
233        let mut inv_row_ids = vec![Vec::new(); nlist];
234        let mut inv_codes = vec![Vec::new(); nlist];
235
236        for (i, (v, &list_idx)) in vecs.iter().zip(assignments.iter()).enumerate() {
237            let codes = codebook.pq.encode(v);
238            inv_row_ids[list_idx].push(row_ids[i].0);
239            inv_codes[list_idx].extend_from_slice(&codes);
240        }
241
242        Ok(IvfPqIndex {
243            config: IvfPqConfig {
244                nlist: codebook.nlist,
245                nprobe: codebook.nprobe,
246                pq_m: codebook.pq_m,
247                pq_k: codebook.pq.num_centroids,
248                max_iter: 0,
249            },
250            metric: codebook.metric,
251            dim: codebook.dim,
252            coarse_centroids: codebook.coarse_centroids.clone(),
253            pq: codebook.pq.clone(),
254            inv_row_ids,
255            inv_codes,
256        })
257    }
258
259    /// Search for approximate nearest neighbors.
260    ///
261    /// `nprobe` overrides `config.nprobe` when `Some`. `ef` is ignored (HNSW compat shim).
262    pub fn search(&self, query: &[f32], top_k: usize, nprobe: Option<usize>) -> Vec<(RowId, f32)> {
263        let nprobe = nprobe.unwrap_or(self.config.nprobe).min(self.config.nlist);
264
265        let q_normed: Vec<f32>;
266        let q: &[f32] = if self.metric == VectorMetric::Cosine {
267            q_normed = l2_normalize(query);
268            &q_normed
269        } else {
270            query
271        };
272
273        // Select nprobe nearest coarse centroids
274        let mut c_dists: Vec<(usize, f32)> = self
275            .coarse_centroids
276            .iter()
277            .enumerate()
278            .map(|(i, c)| (i, l2_sq(q, c)))
279            .collect();
280        c_dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
281        c_dists.truncate(nprobe);
282
283        // Precompute ADC table once for the full query
284        let adc_table = self.pq.compute_adc_table(q);
285
286        // Scan selected inverted lists
287        let pq_m = self.config.pq_m;
288        let mut candidates: Vec<(RowId, f32)> = Vec::new();
289
290        for (list_idx, _) in &c_dists {
291            let row_ids = &self.inv_row_ids[*list_idx];
292            let codes_flat = &self.inv_codes[*list_idx];
293
294            for (j, &rid) in row_ids.iter().enumerate() {
295                let codes = &codes_flat[j * pq_m..(j + 1) * pq_m];
296                let dist = self.pq.adc_distance(codes, &adc_table);
297                candidates.push((RowId(rid), dist));
298            }
299        }
300
301        candidates.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
302        candidates.truncate(top_k);
303        candidates
304    }
305
306    pub fn node_count(&self) -> u64 {
307        self.inv_row_ids.iter().map(|l| l.len() as u64).sum()
308    }
309
310    pub fn dim(&self) -> usize {
311        self.dim
312    }
313}
314
315// ── Serialization ──────────────────────────────────────────────────────────────
316
317#[derive(Serialize, Deserialize)]
318struct IvfPqSnapshot {
319    nlist: usize,
320    nprobe: usize,
321    pq_m: usize,
322    pq_k: usize,
323    max_iter: usize,
324    dim: usize,
325    metric: u8,
326    coarse_flat: Vec<f32>, // [nlist * dim]
327    pq: PQCodebook,
328    inv_row_ids: Vec<Vec<u64>>,
329    inv_codes: Vec<Vec<u8>>, // flat per list: len == inv_row_ids[i].len() * pq_m
330}
331
332pub struct IvfPqSerializer;
333
334impl IvfPqSerializer {
335    pub fn to_bytes(index: &IvfPqIndex) -> AilakeResult<Vec<u8>> {
336        let coarse_flat: Vec<f32> = index
337            .coarse_centroids
338            .iter()
339            .flat_map(|c| c.iter().copied())
340            .collect();
341        let snap = IvfPqSnapshot {
342            nlist: index.config.nlist,
343            nprobe: index.config.nprobe,
344            pq_m: index.config.pq_m,
345            pq_k: index.config.pq_k,
346            max_iter: index.config.max_iter,
347            dim: index.dim,
348            metric: metric_to_u8(index.metric),
349            coarse_flat,
350            pq: index.pq.clone(),
351            inv_row_ids: index.inv_row_ids.clone(),
352            inv_codes: index.inv_codes.clone(),
353        };
354        bincode::serialize(&snap).map_err(|e| AilakeError::Bincode(e.to_string()))
355    }
356
357    pub fn from_bytes(bytes: &[u8]) -> AilakeResult<IvfPqIndex> {
358        let snap: IvfPqSnapshot =
359            bincode::deserialize(bytes).map_err(|e| AilakeError::Bincode(e.to_string()))?;
360        let metric = u8_to_metric(snap.metric)?;
361        let coarse_centroids: Vec<Vec<f32>> = snap
362            .coarse_flat
363            .chunks_exact(snap.dim)
364            .map(|c| c.to_vec())
365            .collect();
366        Ok(IvfPqIndex {
367            config: IvfPqConfig {
368                nlist: snap.nlist,
369                nprobe: snap.nprobe,
370                pq_m: snap.pq_m,
371                pq_k: snap.pq_k,
372                max_iter: snap.max_iter,
373            },
374            metric,
375            dim: snap.dim,
376            coarse_centroids,
377            pq: snap.pq,
378            inv_row_ids: snap.inv_row_ids,
379            inv_codes: snap.inv_codes,
380        })
381    }
382}
383
384// ── Helpers ────────────────────────────────────────────────────────────────────
385
386fn l2_normalize(v: &[f32]) -> Vec<f32> {
387    let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
388    if norm < 1e-9 {
389        v.to_vec()
390    } else {
391        v.iter().map(|x| x / norm).collect()
392    }
393}
394
395fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
396    a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
397}
398
399fn nearest_idx(v: &[f32], centroids: &[Vec<f32>]) -> usize {
400    centroids
401        .iter()
402        .enumerate()
403        .map(|(i, c)| (i, l2_sq(v, c)))
404        .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
405        .map(|(i, _)| i)
406        .unwrap_or(0)
407}
408
409/// Find largest M <= requested that divides dim.
410pub fn find_valid_pq_m(requested: usize, dim: usize) -> usize {
411    for m in (1..=requested).rev() {
412        if dim.is_multiple_of(m) {
413            return m;
414        }
415    }
416    1
417}
418
419fn metric_to_u8(m: VectorMetric) -> u8 {
420    match m {
421        VectorMetric::Cosine => 0,
422        VectorMetric::Euclidean => 1,
423        VectorMetric::DotProduct => 2,
424        VectorMetric::NormalizedCosine => 3,
425    }
426}
427
428fn u8_to_metric(v: u8) -> AilakeResult<VectorMetric> {
429    match v {
430        0 => Ok(VectorMetric::Cosine),
431        1 => Ok(VectorMetric::Euclidean),
432        2 => Ok(VectorMetric::DotProduct),
433        3 => Ok(VectorMetric::NormalizedCosine),
434        _ => Err(AilakeError::Catalog(format!("unknown metric byte: {v}"))),
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    fn make_vecs(n: usize, dim: usize) -> (Vec<RowId>, Vec<Vec<f32>>) {
443        let row_ids: Vec<RowId> = (0..n).map(|i| RowId(i as u64)).collect();
444        let vecs: Vec<Vec<f32>> = (0..n)
445            .map(|i| {
446                let mut v = vec![0.0f32; dim];
447                v[i % dim] = 1.0;
448                v
449            })
450            .collect();
451        (row_ids, vecs)
452    }
453
454    #[test]
455    fn train_and_search_basic() {
456        let dim = 8;
457        let (ids, vecs) = make_vecs(64, dim);
458        let config = IvfPqConfig {
459            nlist: 4,
460            nprobe: 2,
461            pq_m: 2,
462            pq_k: 4,
463            max_iter: 10,
464        };
465        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
466        assert_eq!(idx.node_count(), 64);
467
468        let query = vecs[0].clone();
469        let results = idx.search(&query, 5, None);
470        assert!(!results.is_empty());
471        // Top result should be close to query
472        assert!(results[0].1 < 0.1, "nearest should be approximate self");
473    }
474
475    #[test]
476    fn train_cosine_normalizes() {
477        let dim = 4;
478        let (ids, vecs) = make_vecs(32, dim);
479        let config = IvfPqConfig {
480            nlist: 4,
481            nprobe: 2,
482            pq_m: 2,
483            pq_k: 4,
484            max_iter: 10,
485        };
486        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Cosine, config).unwrap();
487        let results = idx.search(&vecs[0], 1, None);
488        assert!(!results.is_empty());
489    }
490
491    #[test]
492    fn serialize_roundtrip() {
493        let dim = 8;
494        let (ids, vecs) = make_vecs(32, dim);
495        let config = IvfPqConfig {
496            nlist: 4,
497            nprobe: 2,
498            pq_m: 2,
499            pq_k: 4,
500            max_iter: 10,
501        };
502        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
503        let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
504        let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
505
506        assert_eq!(idx2.node_count(), idx.node_count());
507        assert_eq!(idx2.dim(), idx.dim());
508
509        let q = vecs[0].clone();
510        let r1 = idx.search(&q, 5, None);
511        let r2 = idx2.search(&q, 5, None);
512        assert_eq!(r1.len(), r2.len());
513        for (a, b) in r1.iter().zip(r2.iter()) {
514            assert_eq!(a.0, b.0, "row_ids should match after roundtrip");
515        }
516    }
517
518    #[test]
519    fn nlist_clamped_to_n() {
520        let dim = 4;
521        let (ids, vecs) = make_vecs(10, dim); // fewer vectors than default nlist
522        let config = IvfPqConfig {
523            nlist: 256, // will be clamped to 10
524            nprobe: 8,
525            pq_m: 2,
526            pq_k: 4,
527            max_iter: 5,
528        };
529        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
530        assert!(idx.config.nlist <= 10);
531        assert_eq!(idx.node_count(), 10);
532    }
533}