Skip to main content

ailake_index/
ivf_pq.rs

1// IVF-PQ index: Inverted File Index with Product Quantization.
2//
3// vs HNSW tradeoffs:
4//   - Index size: ~100x smaller (PQ codes vs raw vectors + graph pointers)
5//   - S3 reads: sequential inverted-list scan vs random graph traversal
6//   - Recall: slightly lower at same memory, tunable via nprobe
7//   - Build: O(n * nlist) k-means vs O(n log n) HNSW insertions
8//
9// Non-residual variant: global PQ codebook trained on all vectors.
10// Simpler than per-cluster residual PQ, adequate for dim >= 64.
11
12use serde::{Deserialize, Serialize};
13
14use ailake_core::{AilakeError, AilakeResult, RowId, VectorMetric};
15use ailake_vec::{kmeans_centroids, PQCodebook};
16
17/// K-means dispatch: NVIDIA CUDA → AMD ROCm → CPU rayon fallback.
18fn kmeans_dispatch(vecs: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
19    if let Some(result) = crate::gpu::try_nvidia_kmeans(vecs, k, max_iter) {
20        return result;
21    }
22    if let Some(result) = crate::gpu::try_rocm_kmeans(vecs, k, max_iter) {
23        return result;
24    }
25    kmeans_centroids(vecs, k, max_iter)
26}
27
28/// Configuration for IVF-PQ index construction and search.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct IvfPqConfig {
31    /// Number of coarse Voronoi cells (inverted lists).
32    /// Rule of thumb: sqrt(n) for balanced recall/speed. Default 256.
33    pub nlist: usize,
34    /// Cells probed per query. Higher = better recall, more compute.
35    /// nprobe=1 is ANN; nprobe=nlist is exact. Default 8.
36    pub nprobe: usize,
37    /// PQ sub-vector count M. Must divide dim. Default 8.
38    pub pq_m: usize,
39    /// PQ centroids per sub-space K. Must be ≤ 256 (u8 codes). Default 256.
40    pub pq_k: usize,
41    /// k-means max iterations for both coarse and PQ training. Default 25.
42    pub max_iter: usize,
43}
44
45impl Default for IvfPqConfig {
46    fn default() -> Self {
47        Self {
48            nlist: 256,
49            nprobe: 8,
50            pq_m: 8,
51            pq_k: 256,
52            max_iter: 25,
53        }
54    }
55}
56
57impl IvfPqConfig {
58    /// Derive sensible defaults from vector dimensionality.
59    pub fn for_dim(dim: usize) -> Self {
60        let pq_m = (dim / 16).clamp(4, 64);
61        Self {
62            pq_m: find_valid_pq_m(pq_m, dim),
63            ..Self::default()
64        }
65    }
66
67    /// Derive sensible defaults from both dimensionality and dataset size.
68    ///
69    /// `nlist` scales with sqrt(n_vectors) so each cluster gets ~4 vectors on average.
70    /// Clamped to [16, 1024] to avoid degenerate configs on tiny or huge datasets.
71    pub fn for_dataset(dim: usize, n_vectors: usize) -> Self {
72        let nlist = ((n_vectors as f64).sqrt() as usize).clamp(16, 1024);
73        let nprobe = (nlist / 8).max(1);
74        let pq_m_hint = (dim / 16).clamp(4, 64);
75        Self {
76            nlist,
77            nprobe,
78            pq_m: find_valid_pq_m(pq_m_hint, dim),
79            pq_k: 256,
80            max_iter: 25,
81        }
82    }
83}
84
85pub struct IvfPqIndex {
86    pub config: IvfPqConfig,
87    pub metric: VectorMetric,
88    pub dim: usize,
89    /// Coarse cluster centroids: [nlist × dim]
90    coarse_centroids: Vec<Vec<f32>>,
91    /// Global PQ codebook trained on all vectors
92    pq: PQCodebook,
93    /// Inverted lists: row IDs per cluster
94    inv_row_ids: Vec<Vec<u64>>,
95    /// PQ codes per cluster, flat: inv_codes[i].len() == inv_row_ids[i].len() * pq_m
96    inv_codes: Vec<Vec<u8>>,
97}
98
99impl IvfPqIndex {
100    /// Train IVF-PQ index.
101    pub fn train(
102        row_ids: &[RowId],
103        vectors: &[Vec<f32>],
104        metric: VectorMetric,
105        config: IvfPqConfig,
106    ) -> AilakeResult<Self> {
107        let n = vectors.len();
108        if n == 0 {
109            return Err(AilakeError::Catalog(
110                "IVF-PQ training requires at least 1 vector".into(),
111            ));
112        }
113        let dim = vectors[0].len();
114
115        let normed_storage: Vec<Vec<f32>>;
116        let vecs: &[Vec<f32>] = if metric == VectorMetric::Cosine {
117            normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
118            &normed_storage
119        } else {
120            vectors
121        };
122
123        let nlist = config.nlist.min(n);
124        let nprobe = config.nprobe.min(nlist);
125        let pq_m = find_valid_pq_m(config.pq_m, dim);
126
127        // Train coarse centroids + PQ codebook, using GPU k-means when available.
128        let coarse_centroids = kmeans_dispatch(vecs, nlist, config.max_iter);
129
130        // Assign each vector to its nearest coarse centroid
131        let assignments: Vec<usize> = vecs
132            .iter()
133            .map(|v| nearest_idx(v, &coarse_centroids))
134            .collect();
135
136        // Train global PQ on all vectors
137        let pq = PQCodebook::train_with_kmeans(
138            vecs,
139            pq_m,
140            config.pq_k.min(256),
141            config.max_iter,
142            kmeans_dispatch,
143        )
144        .map_err(|e| AilakeError::Catalog(format!("PQ training failed: {e}")))?;
145
146        // Build inverted lists
147        let mut inv_row_ids = vec![Vec::new(); nlist];
148        let mut inv_codes = vec![Vec::new(); nlist];
149
150        for (i, (v, &list_idx)) in vecs.iter().zip(assignments.iter()).enumerate() {
151            let codes = pq.encode(v);
152            inv_row_ids[list_idx].push(row_ids[i].0);
153            inv_codes[list_idx].extend_from_slice(&codes);
154        }
155
156        Ok(IvfPqIndex {
157            config: IvfPqConfig {
158                nlist,
159                nprobe,
160                pq_m,
161                ..config
162            },
163            metric,
164            dim,
165            coarse_centroids,
166            pq,
167            inv_row_ids,
168            inv_codes,
169        })
170    }
171
172    /// Search for approximate nearest neighbors.
173    ///
174    /// `nprobe` overrides `config.nprobe` when `Some`. `ef` is ignored (HNSW compat shim).
175    pub fn search(&self, query: &[f32], top_k: usize, nprobe: Option<usize>) -> Vec<(RowId, f32)> {
176        let nprobe = nprobe.unwrap_or(self.config.nprobe).min(self.config.nlist);
177
178        let q_normed: Vec<f32>;
179        let q: &[f32] = if self.metric == VectorMetric::Cosine {
180            q_normed = l2_normalize(query);
181            &q_normed
182        } else {
183            query
184        };
185
186        // Select nprobe nearest coarse centroids
187        let mut c_dists: Vec<(usize, f32)> = self
188            .coarse_centroids
189            .iter()
190            .enumerate()
191            .map(|(i, c)| (i, l2_sq(q, c)))
192            .collect();
193        c_dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
194        c_dists.truncate(nprobe);
195
196        // Precompute ADC table once for the full query
197        let adc_table = self.pq.compute_adc_table(q);
198
199        // Scan selected inverted lists
200        let pq_m = self.config.pq_m;
201        let mut candidates: Vec<(RowId, f32)> = Vec::new();
202
203        for (list_idx, _) in &c_dists {
204            let row_ids = &self.inv_row_ids[*list_idx];
205            let codes_flat = &self.inv_codes[*list_idx];
206
207            for (j, &rid) in row_ids.iter().enumerate() {
208                let codes = &codes_flat[j * pq_m..(j + 1) * pq_m];
209                let dist = self.pq.adc_distance(codes, &adc_table);
210                candidates.push((RowId(rid), dist));
211            }
212        }
213
214        candidates.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
215        candidates.truncate(top_k);
216        candidates
217    }
218
219    pub fn node_count(&self) -> u64 {
220        self.inv_row_ids.iter().map(|l| l.len() as u64).sum()
221    }
222
223    pub fn dim(&self) -> usize {
224        self.dim
225    }
226}
227
228// ── Serialization ──────────────────────────────────────────────────────────────
229
230#[derive(Serialize, Deserialize)]
231struct IvfPqSnapshot {
232    nlist: usize,
233    nprobe: usize,
234    pq_m: usize,
235    pq_k: usize,
236    max_iter: usize,
237    dim: usize,
238    metric: u8,
239    coarse_flat: Vec<f32>, // [nlist * dim]
240    pq: PQCodebook,
241    inv_row_ids: Vec<Vec<u64>>,
242    inv_codes: Vec<Vec<u8>>, // flat per list: len == inv_row_ids[i].len() * pq_m
243}
244
245pub struct IvfPqSerializer;
246
247impl IvfPqSerializer {
248    pub fn to_bytes(index: &IvfPqIndex) -> AilakeResult<Vec<u8>> {
249        let coarse_flat: Vec<f32> = index
250            .coarse_centroids
251            .iter()
252            .flat_map(|c| c.iter().copied())
253            .collect();
254        let snap = IvfPqSnapshot {
255            nlist: index.config.nlist,
256            nprobe: index.config.nprobe,
257            pq_m: index.config.pq_m,
258            pq_k: index.config.pq_k,
259            max_iter: index.config.max_iter,
260            dim: index.dim,
261            metric: metric_to_u8(index.metric),
262            coarse_flat,
263            pq: index.pq.clone(),
264            inv_row_ids: index.inv_row_ids.clone(),
265            inv_codes: index.inv_codes.clone(),
266        };
267        bincode::serialize(&snap).map_err(|e| AilakeError::Bincode(e.to_string()))
268    }
269
270    pub fn from_bytes(bytes: &[u8]) -> AilakeResult<IvfPqIndex> {
271        let snap: IvfPqSnapshot =
272            bincode::deserialize(bytes).map_err(|e| AilakeError::Bincode(e.to_string()))?;
273        let metric = u8_to_metric(snap.metric)?;
274        let coarse_centroids: Vec<Vec<f32>> = snap
275            .coarse_flat
276            .chunks_exact(snap.dim)
277            .map(|c| c.to_vec())
278            .collect();
279        Ok(IvfPqIndex {
280            config: IvfPqConfig {
281                nlist: snap.nlist,
282                nprobe: snap.nprobe,
283                pq_m: snap.pq_m,
284                pq_k: snap.pq_k,
285                max_iter: snap.max_iter,
286            },
287            metric,
288            dim: snap.dim,
289            coarse_centroids,
290            pq: snap.pq,
291            inv_row_ids: snap.inv_row_ids,
292            inv_codes: snap.inv_codes,
293        })
294    }
295}
296
297// ── Helpers ────────────────────────────────────────────────────────────────────
298
299fn l2_normalize(v: &[f32]) -> Vec<f32> {
300    let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
301    if norm < 1e-9 {
302        v.to_vec()
303    } else {
304        v.iter().map(|x| x / norm).collect()
305    }
306}
307
308fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
309    a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
310}
311
312fn nearest_idx(v: &[f32], centroids: &[Vec<f32>]) -> usize {
313    centroids
314        .iter()
315        .enumerate()
316        .map(|(i, c)| (i, l2_sq(v, c)))
317        .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
318        .map(|(i, _)| i)
319        .unwrap_or(0)
320}
321
322/// Find largest M <= requested that divides dim.
323pub fn find_valid_pq_m(requested: usize, dim: usize) -> usize {
324    for m in (1..=requested).rev() {
325        if dim.is_multiple_of(m) {
326            return m;
327        }
328    }
329    1
330}
331
332fn metric_to_u8(m: VectorMetric) -> u8 {
333    match m {
334        VectorMetric::Cosine => 0,
335        VectorMetric::Euclidean => 1,
336        VectorMetric::DotProduct => 2,
337    }
338}
339
340fn u8_to_metric(v: u8) -> AilakeResult<VectorMetric> {
341    match v {
342        0 => Ok(VectorMetric::Cosine),
343        1 => Ok(VectorMetric::Euclidean),
344        2 => Ok(VectorMetric::DotProduct),
345        _ => Err(AilakeError::Catalog(format!("unknown metric byte: {v}"))),
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    fn make_vecs(n: usize, dim: usize) -> (Vec<RowId>, Vec<Vec<f32>>) {
354        let row_ids: Vec<RowId> = (0..n).map(|i| RowId(i as u64)).collect();
355        let vecs: Vec<Vec<f32>> = (0..n)
356            .map(|i| {
357                let mut v = vec![0.0f32; dim];
358                v[i % dim] = 1.0;
359                v
360            })
361            .collect();
362        (row_ids, vecs)
363    }
364
365    #[test]
366    fn train_and_search_basic() {
367        let dim = 8;
368        let (ids, vecs) = make_vecs(64, dim);
369        let config = IvfPqConfig {
370            nlist: 4,
371            nprobe: 2,
372            pq_m: 2,
373            pq_k: 4,
374            max_iter: 10,
375        };
376        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
377        assert_eq!(idx.node_count(), 64);
378
379        let query = vecs[0].clone();
380        let results = idx.search(&query, 5, None);
381        assert!(!results.is_empty());
382        // Top result should be close to query
383        assert!(results[0].1 < 0.1, "nearest should be approximate self");
384    }
385
386    #[test]
387    fn train_cosine_normalizes() {
388        let dim = 4;
389        let (ids, vecs) = make_vecs(32, dim);
390        let config = IvfPqConfig {
391            nlist: 4,
392            nprobe: 2,
393            pq_m: 2,
394            pq_k: 4,
395            max_iter: 10,
396        };
397        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Cosine, config).unwrap();
398        let results = idx.search(&vecs[0], 1, None);
399        assert!(!results.is_empty());
400    }
401
402    #[test]
403    fn serialize_roundtrip() {
404        let dim = 8;
405        let (ids, vecs) = make_vecs(32, dim);
406        let config = IvfPqConfig {
407            nlist: 4,
408            nprobe: 2,
409            pq_m: 2,
410            pq_k: 4,
411            max_iter: 10,
412        };
413        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
414        let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
415        let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
416
417        assert_eq!(idx2.node_count(), idx.node_count());
418        assert_eq!(idx2.dim(), idx.dim());
419
420        let q = vecs[0].clone();
421        let r1 = idx.search(&q, 5, None);
422        let r2 = idx2.search(&q, 5, None);
423        assert_eq!(r1.len(), r2.len());
424        for (a, b) in r1.iter().zip(r2.iter()) {
425            assert_eq!(a.0, b.0, "row_ids should match after roundtrip");
426        }
427    }
428
429    #[test]
430    fn nlist_clamped_to_n() {
431        let dim = 4;
432        let (ids, vecs) = make_vecs(10, dim); // fewer vectors than default nlist
433        let config = IvfPqConfig {
434            nlist: 256, // will be clamped to 10
435            nprobe: 8,
436            pq_m: 2,
437            pq_k: 4,
438            max_iter: 5,
439        };
440        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
441        assert!(idx.config.nlist <= 10);
442        assert_eq!(idx.node_count(), 10);
443    }
444}