Skip to main content

nodedb_vector/
ivf.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! IVF-PQ index for billion-scale datasets.
4//!
5//! Inverted File with Product Quantization: partition vectors into Voronoi
6//! cells using k-means centroids, PQ-compress within cells.
7
8use crate::distance::{DistanceMetric, distance};
9use crate::hnsw::SearchResult;
10use crate::quantize::pq::PqCodec;
11
12/// IVF-PQ index configuration.
13#[derive(Clone)]
14pub struct IvfPqParams {
15    /// Number of Voronoi cells (partitions). Typical: sqrt(N).
16    pub n_cells: usize,
17    /// Number of PQ subvectors. Must divide dimension evenly.
18    pub pq_m: usize,
19    /// Centroids per PQ subvector (fixed at 256 for u8 encoding).
20    pub pq_k: usize,
21    /// Number of cells to probe at query time. Higher = better recall.
22    pub nprobe: usize,
23    /// Distance metric.
24    pub metric: DistanceMetric,
25}
26
27impl Default for IvfPqParams {
28    fn default() -> Self {
29        Self {
30            n_cells: 256,
31            pq_m: 8,
32            pq_k: 256,
33            nprobe: 16,
34            metric: DistanceMetric::L2,
35        }
36    }
37}
38
39/// IVF-PQ index: inverted file with product quantization.
40pub struct IvfPqIndex {
41    dim: usize,
42    params: IvfPqParams,
43    /// Coarse centroids: `n_cells` × `dim` FP32 vectors.
44    centroids: Vec<Vec<f32>>,
45    /// PQ codec trained on the dataset.
46    pq: Option<PqCodec>,
47    /// Per-cell inverted lists: `cells[cell_id]` = list of (vector_id, pq_code).
48    cells: Vec<Vec<(u32, Vec<u8>)>>,
49    /// Total vectors indexed.
50    count: u32,
51}
52
53impl IvfPqIndex {
54    /// Create an empty IVF-PQ index.
55    pub fn new(dim: usize, params: IvfPqParams) -> Self {
56        Self {
57            dim,
58            params,
59            centroids: Vec::new(),
60            pq: None,
61            cells: Vec::new(),
62            count: 0,
63        }
64    }
65
66    /// Train the index from a set of vectors.
67    pub fn train(&mut self, vectors: &[&[f32]]) {
68        assert!(!vectors.is_empty());
69        assert!(self.dim > 0);
70        assert!(
71            self.dim.is_multiple_of(self.params.pq_m),
72            "dim {} must be divisible by pq_m {}",
73            self.dim,
74            self.params.pq_m
75        );
76
77        let n_cells = self.params.n_cells.min(vectors.len());
78        self.centroids = kmeans_centroids(vectors, self.dim, n_cells, 20);
79        self.cells = vec![Vec::new(); self.centroids.len()];
80
81        // no-governor: cold IVF training; residuals built once during index build, governed at call site
82        let mut residuals: Vec<Vec<f32>> = Vec::with_capacity(vectors.len());
83        for v in vectors {
84            let cell = self.nearest_centroid(v);
85            let res: Vec<f32> = v
86                .iter()
87                .zip(&self.centroids[cell])
88                .map(|(a, b)| a - b)
89                .collect();
90            residuals.push(res);
91        }
92        let res_refs: Vec<&[f32]> = residuals.iter().map(|r| r.as_slice()).collect();
93        self.pq = Some(PqCodec::train(
94            &res_refs,
95            self.dim,
96            self.params.pq_m,
97            self.params.pq_k,
98            20,
99        ));
100    }
101
102    /// Add a vector to the index. Returns the assigned ID.
103    pub fn add(&mut self, vector: &[f32]) -> u32 {
104        assert_eq!(vector.len(), self.dim);
105        let pq = self
106            .pq
107            .as_ref()
108            .expect("index must be trained before add()");
109
110        let cell = self.nearest_centroid(vector);
111        let residual: Vec<f32> = vector
112            .iter()
113            .zip(&self.centroids[cell])
114            .map(|(a, b)| a - b)
115            .collect();
116        let code = pq.encode(&residual);
117        let id = self.count;
118        self.cells[cell].push((id, code));
119        self.count += 1;
120        id
121    }
122
123    /// Batch add vectors.
124    pub fn add_batch(&mut self, vectors: &[&[f32]]) {
125        for v in vectors {
126            self.add(v);
127        }
128    }
129
130    /// Search: find top-k nearest neighbors.
131    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
132        assert_eq!(query.len(), self.dim);
133        if self.centroids.is_empty() || self.count == 0 {
134            return Vec::new();
135        }
136
137        let pq = match &self.pq {
138            Some(p) => p,
139            None => return Vec::new(),
140        };
141
142        let nprobe = self.params.nprobe.min(self.centroids.len());
143        let mut centroid_dists: Vec<(usize, f32)> = self
144            .centroids
145            .iter()
146            .enumerate()
147            .map(|(i, c)| (i, distance(query, c, self.params.metric)))
148            .collect();
149        centroid_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
150
151        let mut candidates: Vec<SearchResult> = Vec::new();
152
153        for &(cell_idx, _) in centroid_dists.iter().take(nprobe) {
154            let residual_query: Vec<f32> = query
155                .iter()
156                .zip(&self.centroids[cell_idx])
157                .map(|(q, c)| q - c)
158                .collect();
159            let table = match pq.build_distance_table(&residual_query) {
160                Ok(t) => t,
161                Err(e) => {
162                    tracing::warn!(error = %e, "IVF PQ build_distance_table budget exhausted; skipping cell");
163                    continue;
164                }
165            };
166
167            for (id, code) in &self.cells[cell_idx] {
168                let dist = pq.asymmetric_distance(&table, code);
169                candidates.push(SearchResult {
170                    id: *id,
171                    distance: dist,
172                });
173            }
174        }
175
176        if candidates.len() > top_k {
177            candidates.select_nth_unstable_by(top_k, |a, b| {
178                a.distance
179                    .partial_cmp(&b.distance)
180                    .unwrap_or(std::cmp::Ordering::Equal)
181            });
182            candidates.truncate(top_k);
183        }
184        candidates.sort_by(|a, b| {
185            a.distance
186                .partial_cmp(&b.distance)
187                .unwrap_or(std::cmp::Ordering::Equal)
188        });
189        candidates
190    }
191
192    fn nearest_centroid(&self, vector: &[f32]) -> usize {
193        let mut best = 0;
194        let mut best_dist = f32::MAX;
195        for (i, c) in self.centroids.iter().enumerate() {
196            let d = distance(vector, c, self.params.metric);
197            if d < best_dist {
198                best_dist = d;
199                best = i;
200            }
201        }
202        best
203    }
204
205    pub fn len(&self) -> usize {
206        self.count as usize
207    }
208
209    pub fn is_empty(&self) -> bool {
210        self.count == 0
211    }
212
213    pub fn dim(&self) -> usize {
214        self.dim
215    }
216
217    pub fn n_cells(&self) -> usize {
218        self.centroids.len()
219    }
220}
221
222fn kmeans_centroids(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> Vec<Vec<f32>> {
223    let n = data.len();
224    let k = k.min(n);
225    if k == 0 {
226        return Vec::new();
227    }
228
229    let mut centroids: Vec<Vec<f32>> = vec![data[0].to_vec()];
230    let mut min_dists = vec![f32::MAX; n];
231
232    // Initialize min_dists against the first centroid.
233    for (i, point) in data.iter().enumerate() {
234        let d = distance(point, &centroids[0], DistanceMetric::L2);
235        if d < min_dists[i] {
236            min_dists[i] = d;
237        }
238    }
239
240    let mut rng = crate::hnsw::Xorshift64::new(0xC0FF_EEDE_ADBE_EF42);
241    for _ in 1..k {
242        let total: f64 = min_dists.iter().map(|&d| d as f64).sum();
243        let next_idx = if total < f64::EPSILON {
244            0
245        } else {
246            let target = rng.next_f64() * total;
247            let mut acc = 0.0f64;
248            let mut chosen = n - 1;
249            for (i, &d) in min_dists.iter().enumerate() {
250                acc += d as f64;
251                if acc >= target {
252                    chosen = i;
253                    break;
254                }
255            }
256            chosen
257        };
258        centroids.push(data[next_idx].to_vec());
259        let last = centroids.last().expect("just pushed");
260        for (i, point) in data.iter().enumerate() {
261            let d = distance(point, last, DistanceMetric::L2);
262            if d < min_dists[i] {
263                min_dists[i] = d;
264            }
265        }
266    }
267
268    let mut assignments = vec![0usize; n];
269    for _ in 0..max_iter {
270        let mut changed = false;
271        for (i, point) in data.iter().enumerate() {
272            let mut best = 0;
273            let mut best_d = f32::MAX;
274            for (c, centroid) in centroids.iter().enumerate() {
275                let d = distance(point, centroid, DistanceMetric::L2);
276                if d < best_d {
277                    best_d = d;
278                    best = c;
279                }
280            }
281            if assignments[i] != best {
282                assignments[i] = best;
283                changed = true;
284            }
285        }
286        if !changed {
287            break;
288        }
289        let mut sums = vec![vec![0.0f32; dim]; k];
290        let mut counts = vec![0usize; k];
291        for (i, point) in data.iter().enumerate() {
292            let c = assignments[i];
293            counts[c] += 1;
294            for d in 0..dim {
295                sums[c][d] += point[d];
296            }
297        }
298        for c in 0..k {
299            if counts[c] > 0 {
300                for d in 0..dim {
301                    centroids[c][d] = sums[c][d] / counts[c] as f32;
302                }
303            }
304        }
305    }
306    centroids
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    fn make_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
314        (0..n)
315            .map(|i| (0..dim).map(|d| ((i * dim + d) as f32) * 0.01).collect())
316            .collect()
317    }
318
319    #[test]
320    fn train_and_search() {
321        let vecs = make_vectors(1000, 16);
322        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
323
324        let mut idx = IvfPqIndex::new(
325            16,
326            IvfPqParams {
327                n_cells: 32,
328                pq_m: 4,
329                pq_k: 32,
330                nprobe: 8,
331                metric: DistanceMetric::L2,
332            },
333        );
334        idx.train(&refs);
335        idx.add_batch(&refs);
336
337        assert_eq!(idx.len(), 1000);
338
339        let query = &vecs[500];
340        let results = idx.search(query, 5);
341        assert_eq!(results.len(), 5);
342        assert!(
343            results.iter().any(|r| r.id == 500),
344            "exact match not found in top-5"
345        );
346    }
347
348    #[test]
349    fn empty_index() {
350        let idx = IvfPqIndex::new(8, IvfPqParams::default());
351        assert!(idx.search(&[0.0; 8], 5).is_empty());
352    }
353}