Skip to main content

ailake_index/
ivf_pq.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// IVF-PQ index: Inverted File Index with Product Quantization.
3//
4// vs HNSW tradeoffs:
5//   - Index size: ~100x smaller (PQ codes vs raw vectors + graph pointers)
6//   - S3 reads: sequential inverted-list scan vs random graph traversal
7//   - Recall: slightly lower at same memory, tunable via nprobe
8//   - Build: O(n * nlist) k-means vs O(n log n) HNSW insertions
9//
10// Non-residual variant: global PQ codebook trained on all vectors.
11// Simpler than per-cluster residual PQ, adequate for dim >= 64.
12
13use serde::{Deserialize, Serialize};
14
15use ailake_core::{AilakeError, AilakeResult, RowId, VectorMetric};
16use ailake_vec::{kmeans_centroids, PQCodebook};
17
18/// K-means dispatch: NVIDIA CUDA → AMD ROCm → CPU rayon fallback.
19fn kmeans_dispatch(vecs: &[Vec<f32>], k: usize, max_iter: usize) -> Vec<Vec<f32>> {
20    if let Some(result) = crate::gpu::try_nvidia_kmeans(vecs, k, max_iter) {
21        return result;
22    }
23    if let Some(result) = crate::gpu::try_rocm_kmeans(vecs, k, max_iter) {
24        return result;
25    }
26    kmeans_centroids(vecs, k, max_iter)
27}
28
29/// Configuration for IVF-PQ index construction and search.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct IvfPqConfig {
32    /// Number of coarse Voronoi cells (inverted lists).
33    /// Rule of thumb: sqrt(n) for balanced recall/speed. Default 256.
34    pub nlist: usize,
35    /// Cells probed per query. Higher = better recall, more compute.
36    /// nprobe=1 is ANN; nprobe=nlist is exact. Default 8.
37    pub nprobe: usize,
38    /// PQ sub-vector count M. Must divide dim. Default 8.
39    pub pq_m: usize,
40    /// PQ centroids per sub-space K. Must be ≤ 256 (u8 codes). Default 256.
41    pub pq_k: usize,
42    /// k-means max iterations for both coarse and PQ training. Default 25.
43    pub max_iter: usize,
44}
45
46impl Default for IvfPqConfig {
47    fn default() -> Self {
48        Self {
49            nlist: 256,
50            nprobe: 8,
51            pq_m: 8,
52            pq_k: 256,
53            max_iter: 25,
54        }
55    }
56}
57
58impl IvfPqConfig {
59    /// Derive sensible defaults from vector dimensionality.
60    pub fn for_dim(dim: usize) -> Self {
61        let pq_m = (dim / 16).clamp(4, 64);
62        Self {
63            pq_m: find_valid_pq_m(pq_m, dim),
64            ..Self::default()
65        }
66    }
67
68    /// Derive sensible defaults from both dimensionality and dataset size.
69    ///
70    /// `nlist` scales with sqrt(n_vectors) so each cluster gets ~4 vectors on average.
71    /// Clamped to [16, 1024] to avoid degenerate configs on tiny or huge datasets.
72    pub fn for_dataset(dim: usize, n_vectors: usize) -> Self {
73        let nlist = ((n_vectors as f64).sqrt() as usize).clamp(16, 1024);
74        let nprobe = (nlist / 8).max(1);
75        let pq_m_hint = (dim / 16).clamp(4, 64);
76        Self {
77            nlist,
78            nprobe,
79            pq_m: find_valid_pq_m(pq_m_hint, dim),
80            pq_k: 256,
81            max_iter: 25,
82        }
83    }
84}
85
86pub struct IvfPqIndex {
87    pub config: IvfPqConfig,
88    pub metric: VectorMetric,
89    pub dim: usize,
90    /// Coarse cluster centroids: [nlist × dim]
91    coarse_centroids: Vec<Vec<f32>>,
92    /// Global PQ codebook trained on all vectors
93    pq: PQCodebook,
94    /// Inverted lists: row IDs per cluster
95    inv_row_ids: Vec<Vec<u64>>,
96    /// PQ codes per cluster, flat: inv_codes[i].len() == inv_row_ids[i].len() * pq_m
97    inv_codes: Vec<Vec<u8>>,
98}
99
100impl IvfPqIndex {
101    /// Train IVF-PQ index.
102    pub fn train(
103        row_ids: &[RowId],
104        vectors: &[Vec<f32>],
105        metric: VectorMetric,
106        config: IvfPqConfig,
107    ) -> AilakeResult<Self> {
108        let n = vectors.len();
109        if n == 0 {
110            return Err(AilakeError::Catalog(
111                "IVF-PQ training requires at least 1 vector".into(),
112            ));
113        }
114        let dim = vectors[0].len();
115
116        let normed_storage: Vec<Vec<f32>>;
117        let vecs: &[Vec<f32>] = if metric == VectorMetric::Cosine {
118            normed_storage = vectors.iter().map(|v| l2_normalize(v)).collect();
119            &normed_storage
120        } else {
121            vectors
122        };
123
124        let nlist = config.nlist.min(n);
125        let nprobe = config.nprobe.min(nlist);
126        let pq_m = find_valid_pq_m(config.pq_m, dim);
127
128        // Train coarse centroids + PQ codebook, using GPU k-means when available.
129        let coarse_centroids = kmeans_dispatch(vecs, nlist, config.max_iter);
130
131        // Assign each vector to its nearest coarse centroid
132        let assignments: Vec<usize> = vecs
133            .iter()
134            .map(|v| nearest_idx(v, &coarse_centroids))
135            .collect();
136
137        // Train global PQ on all vectors
138        let pq = PQCodebook::train_with_kmeans(
139            vecs,
140            pq_m,
141            config.pq_k.min(256),
142            config.max_iter,
143            kmeans_dispatch,
144        )
145        .map_err(|e| AilakeError::Catalog(format!("PQ training failed: {e}")))?;
146
147        // Build inverted lists
148        let mut inv_row_ids = vec![Vec::new(); nlist];
149        let mut inv_codes = vec![Vec::new(); nlist];
150
151        for (i, (v, &list_idx)) in vecs.iter().zip(assignments.iter()).enumerate() {
152            let codes = pq.encode(v);
153            inv_row_ids[list_idx].push(row_ids[i].0);
154            inv_codes[list_idx].extend_from_slice(&codes);
155        }
156
157        Ok(IvfPqIndex {
158            config: IvfPqConfig {
159                nlist,
160                nprobe,
161                pq_m,
162                ..config
163            },
164            metric,
165            dim,
166            coarse_centroids,
167            pq,
168            inv_row_ids,
169            inv_codes,
170        })
171    }
172
173    /// Search for approximate nearest neighbors.
174    ///
175    /// `nprobe` overrides `config.nprobe` when `Some`. `ef` is ignored (HNSW compat shim).
176    pub fn search(&self, query: &[f32], top_k: usize, nprobe: Option<usize>) -> Vec<(RowId, f32)> {
177        let nprobe = nprobe.unwrap_or(self.config.nprobe).min(self.config.nlist);
178
179        let q_normed: Vec<f32>;
180        let q: &[f32] = if self.metric == VectorMetric::Cosine {
181            q_normed = l2_normalize(query);
182            &q_normed
183        } else {
184            query
185        };
186
187        // Select nprobe nearest coarse centroids
188        let mut c_dists: Vec<(usize, f32)> = self
189            .coarse_centroids
190            .iter()
191            .enumerate()
192            .map(|(i, c)| (i, l2_sq(q, c)))
193            .collect();
194        c_dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
195        c_dists.truncate(nprobe);
196
197        // Precompute ADC table once for the full query
198        let adc_table = self.pq.compute_adc_table(q);
199
200        // Scan selected inverted lists
201        let pq_m = self.config.pq_m;
202        let mut candidates: Vec<(RowId, f32)> = Vec::new();
203
204        for (list_idx, _) in &c_dists {
205            let row_ids = &self.inv_row_ids[*list_idx];
206            let codes_flat = &self.inv_codes[*list_idx];
207
208            for (j, &rid) in row_ids.iter().enumerate() {
209                let codes = &codes_flat[j * pq_m..(j + 1) * pq_m];
210                let dist = self.pq.adc_distance(codes, &adc_table);
211                candidates.push((RowId(rid), dist));
212            }
213        }
214
215        candidates.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
216        candidates.truncate(top_k);
217        candidates
218    }
219
220    pub fn node_count(&self) -> u64 {
221        self.inv_row_ids.iter().map(|l| l.len() as u64).sum()
222    }
223
224    pub fn dim(&self) -> usize {
225        self.dim
226    }
227}
228
229// ── Serialization ──────────────────────────────────────────────────────────────
230
231#[derive(Serialize, Deserialize)]
232struct IvfPqSnapshot {
233    nlist: usize,
234    nprobe: usize,
235    pq_m: usize,
236    pq_k: usize,
237    max_iter: usize,
238    dim: usize,
239    metric: u8,
240    coarse_flat: Vec<f32>, // [nlist * dim]
241    pq: PQCodebook,
242    inv_row_ids: Vec<Vec<u64>>,
243    inv_codes: Vec<Vec<u8>>, // flat per list: len == inv_row_ids[i].len() * pq_m
244}
245
246pub struct IvfPqSerializer;
247
248impl IvfPqSerializer {
249    pub fn to_bytes(index: &IvfPqIndex) -> AilakeResult<Vec<u8>> {
250        let coarse_flat: Vec<f32> = index
251            .coarse_centroids
252            .iter()
253            .flat_map(|c| c.iter().copied())
254            .collect();
255        let snap = IvfPqSnapshot {
256            nlist: index.config.nlist,
257            nprobe: index.config.nprobe,
258            pq_m: index.config.pq_m,
259            pq_k: index.config.pq_k,
260            max_iter: index.config.max_iter,
261            dim: index.dim,
262            metric: metric_to_u8(index.metric),
263            coarse_flat,
264            pq: index.pq.clone(),
265            inv_row_ids: index.inv_row_ids.clone(),
266            inv_codes: index.inv_codes.clone(),
267        };
268        bincode::serialize(&snap).map_err(|e| AilakeError::Bincode(e.to_string()))
269    }
270
271    pub fn from_bytes(bytes: &[u8]) -> AilakeResult<IvfPqIndex> {
272        let snap: IvfPqSnapshot =
273            bincode::deserialize(bytes).map_err(|e| AilakeError::Bincode(e.to_string()))?;
274        let metric = u8_to_metric(snap.metric)?;
275        let coarse_centroids: Vec<Vec<f32>> = snap
276            .coarse_flat
277            .chunks_exact(snap.dim)
278            .map(|c| c.to_vec())
279            .collect();
280        Ok(IvfPqIndex {
281            config: IvfPqConfig {
282                nlist: snap.nlist,
283                nprobe: snap.nprobe,
284                pq_m: snap.pq_m,
285                pq_k: snap.pq_k,
286                max_iter: snap.max_iter,
287            },
288            metric,
289            dim: snap.dim,
290            coarse_centroids,
291            pq: snap.pq,
292            inv_row_ids: snap.inv_row_ids,
293            inv_codes: snap.inv_codes,
294        })
295    }
296}
297
298// ── Helpers ────────────────────────────────────────────────────────────────────
299
300fn l2_normalize(v: &[f32]) -> Vec<f32> {
301    let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
302    if norm < 1e-9 {
303        v.to_vec()
304    } else {
305        v.iter().map(|x| x / norm).collect()
306    }
307}
308
309fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
310    a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
311}
312
313fn nearest_idx(v: &[f32], centroids: &[Vec<f32>]) -> usize {
314    centroids
315        .iter()
316        .enumerate()
317        .map(|(i, c)| (i, l2_sq(v, c)))
318        .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
319        .map(|(i, _)| i)
320        .unwrap_or(0)
321}
322
323/// Find largest M <= requested that divides dim.
324pub fn find_valid_pq_m(requested: usize, dim: usize) -> usize {
325    for m in (1..=requested).rev() {
326        if dim.is_multiple_of(m) {
327            return m;
328        }
329    }
330    1
331}
332
333fn metric_to_u8(m: VectorMetric) -> u8 {
334    match m {
335        VectorMetric::Cosine => 0,
336        VectorMetric::Euclidean => 1,
337        VectorMetric::DotProduct => 2,
338    }
339}
340
341fn u8_to_metric(v: u8) -> AilakeResult<VectorMetric> {
342    match v {
343        0 => Ok(VectorMetric::Cosine),
344        1 => Ok(VectorMetric::Euclidean),
345        2 => Ok(VectorMetric::DotProduct),
346        _ => Err(AilakeError::Catalog(format!("unknown metric byte: {v}"))),
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    fn make_vecs(n: usize, dim: usize) -> (Vec<RowId>, Vec<Vec<f32>>) {
355        let row_ids: Vec<RowId> = (0..n).map(|i| RowId(i as u64)).collect();
356        let vecs: Vec<Vec<f32>> = (0..n)
357            .map(|i| {
358                let mut v = vec![0.0f32; dim];
359                v[i % dim] = 1.0;
360                v
361            })
362            .collect();
363        (row_ids, vecs)
364    }
365
366    #[test]
367    fn train_and_search_basic() {
368        let dim = 8;
369        let (ids, vecs) = make_vecs(64, dim);
370        let config = IvfPqConfig {
371            nlist: 4,
372            nprobe: 2,
373            pq_m: 2,
374            pq_k: 4,
375            max_iter: 10,
376        };
377        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
378        assert_eq!(idx.node_count(), 64);
379
380        let query = vecs[0].clone();
381        let results = idx.search(&query, 5, None);
382        assert!(!results.is_empty());
383        // Top result should be close to query
384        assert!(results[0].1 < 0.1, "nearest should be approximate self");
385    }
386
387    #[test]
388    fn train_cosine_normalizes() {
389        let dim = 4;
390        let (ids, vecs) = make_vecs(32, dim);
391        let config = IvfPqConfig {
392            nlist: 4,
393            nprobe: 2,
394            pq_m: 2,
395            pq_k: 4,
396            max_iter: 10,
397        };
398        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Cosine, config).unwrap();
399        let results = idx.search(&vecs[0], 1, None);
400        assert!(!results.is_empty());
401    }
402
403    #[test]
404    fn serialize_roundtrip() {
405        let dim = 8;
406        let (ids, vecs) = make_vecs(32, dim);
407        let config = IvfPqConfig {
408            nlist: 4,
409            nprobe: 2,
410            pq_m: 2,
411            pq_k: 4,
412            max_iter: 10,
413        };
414        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
415        let bytes = IvfPqSerializer::to_bytes(&idx).unwrap();
416        let idx2 = IvfPqSerializer::from_bytes(&bytes).unwrap();
417
418        assert_eq!(idx2.node_count(), idx.node_count());
419        assert_eq!(idx2.dim(), idx.dim());
420
421        let q = vecs[0].clone();
422        let r1 = idx.search(&q, 5, None);
423        let r2 = idx2.search(&q, 5, None);
424        assert_eq!(r1.len(), r2.len());
425        for (a, b) in r1.iter().zip(r2.iter()) {
426            assert_eq!(a.0, b.0, "row_ids should match after roundtrip");
427        }
428    }
429
430    #[test]
431    fn nlist_clamped_to_n() {
432        let dim = 4;
433        let (ids, vecs) = make_vecs(10, dim); // fewer vectors than default nlist
434        let config = IvfPqConfig {
435            nlist: 256, // will be clamped to 10
436            nprobe: 8,
437            pq_m: 2,
438            pq_k: 4,
439            max_iter: 5,
440        };
441        let idx = IvfPqIndex::train(&ids, &vecs, VectorMetric::Euclidean, config).unwrap();
442        assert!(idx.config.nlist <= 10);
443        assert_eq!(idx.node_count(), 10);
444    }
445}