Skip to main content

nodedb_vector/
ivf.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! IVF-PQ index for billion-scale datasets.
4//!
5//! Inverted File with Product Quantization: partition vectors into Voronoi
6//! cells using k-means centroids, PQ-compress within cells.
7
8use crate::distance::{DistanceMetric, distance};
9use crate::hnsw::SearchResult;
10use crate::quantize::pq::PqCodec;
11
12/// IVF-PQ index configuration.
13#[derive(Clone)]
14pub struct IvfPqParams {
15    /// Number of Voronoi cells (partitions). Typical: sqrt(N).
16    pub n_cells: usize,
17    /// Number of PQ subvectors. Must divide dimension evenly.
18    pub pq_m: usize,
19    /// Centroids per PQ subvector (fixed at 256 for u8 encoding).
20    pub pq_k: usize,
21    /// Number of cells to probe at query time. Higher = better recall.
22    pub nprobe: usize,
23    /// Distance metric.
24    pub metric: DistanceMetric,
25}
26
27impl Default for IvfPqParams {
28    fn default() -> Self {
29        Self {
30            n_cells: 256,
31            pq_m: 8,
32            pq_k: 256,
33            nprobe: 16,
34            metric: DistanceMetric::L2,
35        }
36    }
37}
38
39/// IVF-PQ index: inverted file with product quantization.
40pub struct IvfPqIndex {
41    dim: usize,
42    params: IvfPqParams,
43    /// Coarse centroids: `n_cells` × `dim` FP32 vectors.
44    centroids: Vec<Vec<f32>>,
45    /// PQ codec trained on the dataset.
46    pq: Option<PqCodec>,
47    /// Per-cell inverted lists: `cells[cell_id]` = list of (vector_id, pq_code).
48    cells: Vec<Vec<(u32, Vec<u8>)>>,
49    /// Total vectors indexed.
50    count: u32,
51}
52
53impl IvfPqIndex {
54    /// Create an empty IVF-PQ index.
55    pub fn new(dim: usize, params: IvfPqParams) -> Self {
56        Self {
57            dim,
58            params,
59            centroids: Vec::new(),
60            pq: None,
61            cells: Vec::new(),
62            count: 0,
63        }
64    }
65
66    /// Train the index from a set of vectors.
67    pub fn train(&mut self, vectors: &[&[f32]]) {
68        assert!(!vectors.is_empty());
69        assert!(self.dim > 0);
70        assert!(
71            self.dim.is_multiple_of(self.params.pq_m),
72            "dim {} must be divisible by pq_m {}",
73            self.dim,
74            self.params.pq_m
75        );
76
77        let n_cells = self.params.n_cells.min(vectors.len());
78        self.centroids = kmeans_centroids(vectors, self.dim, n_cells, 20);
79        self.cells = vec![Vec::new(); self.centroids.len()];
80
81        let mut residuals: Vec<Vec<f32>> = Vec::with_capacity(vectors.len());
82        for v in vectors {
83            let cell = self.nearest_centroid(v);
84            let res: Vec<f32> = v
85                .iter()
86                .zip(&self.centroids[cell])
87                .map(|(a, b)| a - b)
88                .collect();
89            residuals.push(res);
90        }
91        let res_refs: Vec<&[f32]> = residuals.iter().map(|r| r.as_slice()).collect();
92        self.pq = Some(PqCodec::train(
93            &res_refs,
94            self.dim,
95            self.params.pq_m,
96            self.params.pq_k,
97            20,
98        ));
99    }
100
101    /// Add a vector to the index. Returns the assigned ID.
102    pub fn add(&mut self, vector: &[f32]) -> u32 {
103        assert_eq!(vector.len(), self.dim);
104        let pq = self
105            .pq
106            .as_ref()
107            .expect("index must be trained before add()");
108
109        let cell = self.nearest_centroid(vector);
110        let residual: Vec<f32> = vector
111            .iter()
112            .zip(&self.centroids[cell])
113            .map(|(a, b)| a - b)
114            .collect();
115        let code = pq.encode(&residual);
116        let id = self.count;
117        self.cells[cell].push((id, code));
118        self.count += 1;
119        id
120    }
121
122    /// Batch add vectors.
123    pub fn add_batch(&mut self, vectors: &[&[f32]]) {
124        for v in vectors {
125            self.add(v);
126        }
127    }
128
129    /// Search: find top-k nearest neighbors.
130    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
131        assert_eq!(query.len(), self.dim);
132        if self.centroids.is_empty() || self.count == 0 {
133            return Vec::new();
134        }
135
136        let pq = match &self.pq {
137            Some(p) => p,
138            None => return Vec::new(),
139        };
140
141        let nprobe = self.params.nprobe.min(self.centroids.len());
142        let mut centroid_dists: Vec<(usize, f32)> = self
143            .centroids
144            .iter()
145            .enumerate()
146            .map(|(i, c)| (i, distance(query, c, self.params.metric)))
147            .collect();
148        centroid_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
149
150        let mut candidates: Vec<SearchResult> = Vec::new();
151
152        for &(cell_idx, _) in centroid_dists.iter().take(nprobe) {
153            let residual_query: Vec<f32> = query
154                .iter()
155                .zip(&self.centroids[cell_idx])
156                .map(|(q, c)| q - c)
157                .collect();
158            let table = match pq.build_distance_table(&residual_query) {
159                Ok(t) => t,
160                Err(e) => {
161                    tracing::warn!(error = %e, "IVF PQ build_distance_table budget exhausted; skipping cell");
162                    continue;
163                }
164            };
165
166            for (id, code) in &self.cells[cell_idx] {
167                let dist = pq.asymmetric_distance(&table, code);
168                candidates.push(SearchResult {
169                    id: *id,
170                    distance: dist,
171                });
172            }
173        }
174
175        if candidates.len() > top_k {
176            candidates.select_nth_unstable_by(top_k, |a, b| {
177                a.distance
178                    .partial_cmp(&b.distance)
179                    .unwrap_or(std::cmp::Ordering::Equal)
180            });
181            candidates.truncate(top_k);
182        }
183        candidates.sort_by(|a, b| {
184            a.distance
185                .partial_cmp(&b.distance)
186                .unwrap_or(std::cmp::Ordering::Equal)
187        });
188        candidates
189    }
190
191    fn nearest_centroid(&self, vector: &[f32]) -> usize {
192        let mut best = 0;
193        let mut best_dist = f32::MAX;
194        for (i, c) in self.centroids.iter().enumerate() {
195            let d = distance(vector, c, self.params.metric);
196            if d < best_dist {
197                best_dist = d;
198                best = i;
199            }
200        }
201        best
202    }
203
204    pub fn len(&self) -> usize {
205        self.count as usize
206    }
207
208    pub fn is_empty(&self) -> bool {
209        self.count == 0
210    }
211
212    pub fn dim(&self) -> usize {
213        self.dim
214    }
215
216    pub fn n_cells(&self) -> usize {
217        self.centroids.len()
218    }
219}
220
221fn kmeans_centroids(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> Vec<Vec<f32>> {
222    let n = data.len();
223    let k = k.min(n);
224    if k == 0 {
225        return Vec::new();
226    }
227
228    let mut centroids: Vec<Vec<f32>> = vec![data[0].to_vec()];
229    let mut min_dists = vec![f32::MAX; n];
230
231    // Initialize min_dists against the first centroid.
232    for (i, point) in data.iter().enumerate() {
233        let d = distance(point, &centroids[0], DistanceMetric::L2);
234        if d < min_dists[i] {
235            min_dists[i] = d;
236        }
237    }
238
239    let mut rng = crate::hnsw::Xorshift64::new(0xC0FF_EEDE_ADBE_EF42);
240    for _ in 1..k {
241        let total: f64 = min_dists.iter().map(|&d| d as f64).sum();
242        let next_idx = if total < f64::EPSILON {
243            0
244        } else {
245            let target = rng.next_f64() * total;
246            let mut acc = 0.0f64;
247            let mut chosen = n - 1;
248            for (i, &d) in min_dists.iter().enumerate() {
249                acc += d as f64;
250                if acc >= target {
251                    chosen = i;
252                    break;
253                }
254            }
255            chosen
256        };
257        centroids.push(data[next_idx].to_vec());
258        let last = centroids.last().expect("just pushed");
259        for (i, point) in data.iter().enumerate() {
260            let d = distance(point, last, DistanceMetric::L2);
261            if d < min_dists[i] {
262                min_dists[i] = d;
263            }
264        }
265    }
266
267    let mut assignments = vec![0usize; n];
268    for _ in 0..max_iter {
269        let mut changed = false;
270        for (i, point) in data.iter().enumerate() {
271            let mut best = 0;
272            let mut best_d = f32::MAX;
273            for (c, centroid) in centroids.iter().enumerate() {
274                let d = distance(point, centroid, DistanceMetric::L2);
275                if d < best_d {
276                    best_d = d;
277                    best = c;
278                }
279            }
280            if assignments[i] != best {
281                assignments[i] = best;
282                changed = true;
283            }
284        }
285        if !changed {
286            break;
287        }
288        let mut sums = vec![vec![0.0f32; dim]; k];
289        let mut counts = vec![0usize; k];
290        for (i, point) in data.iter().enumerate() {
291            let c = assignments[i];
292            counts[c] += 1;
293            for d in 0..dim {
294                sums[c][d] += point[d];
295            }
296        }
297        for c in 0..k {
298            if counts[c] > 0 {
299                for d in 0..dim {
300                    centroids[c][d] = sums[c][d] / counts[c] as f32;
301                }
302            }
303        }
304    }
305    centroids
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    fn make_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
313        (0..n)
314            .map(|i| (0..dim).map(|d| ((i * dim + d) as f32) * 0.01).collect())
315            .collect()
316    }
317
318    #[test]
319    fn train_and_search() {
320        let vecs = make_vectors(1000, 16);
321        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
322
323        let mut idx = IvfPqIndex::new(
324            16,
325            IvfPqParams {
326                n_cells: 32,
327                pq_m: 4,
328                pq_k: 32,
329                nprobe: 8,
330                metric: DistanceMetric::L2,
331            },
332        );
333        idx.train(&refs);
334        idx.add_batch(&refs);
335
336        assert_eq!(idx.len(), 1000);
337
338        let query = &vecs[500];
339        let results = idx.search(query, 5);
340        assert_eq!(results.len(), 5);
341        assert!(
342            results.iter().any(|r| r.id == 500),
343            "exact match not found in top-5"
344        );
345    }
346
347    #[test]
348    fn empty_index() {
349        let idx = IvfPqIndex::new(8, IvfPqParams::default());
350        assert!(idx.search(&[0.0; 8], 5).is_empty());
351    }
352}