1use crate::distance::{DistanceMetric, distance};
9use crate::hnsw::SearchResult;
10use crate::quantize::pq::PqCodec;
11
12#[derive(Clone)]
14pub struct IvfPqParams {
15 pub n_cells: usize,
17 pub pq_m: usize,
19 pub pq_k: usize,
21 pub nprobe: usize,
23 pub metric: DistanceMetric,
25}
26
27impl Default for IvfPqParams {
28 fn default() -> Self {
29 Self {
30 n_cells: 256,
31 pq_m: 8,
32 pq_k: 256,
33 nprobe: 16,
34 metric: DistanceMetric::L2,
35 }
36 }
37}
38
39pub struct IvfPqIndex {
41 dim: usize,
42 params: IvfPqParams,
43 centroids: Vec<Vec<f32>>,
45 pq: Option<PqCodec>,
47 cells: Vec<Vec<(u32, Vec<u8>)>>,
49 count: u32,
51}
52
53impl IvfPqIndex {
54 pub fn new(dim: usize, params: IvfPqParams) -> Self {
56 Self {
57 dim,
58 params,
59 centroids: Vec::new(),
60 pq: None,
61 cells: Vec::new(),
62 count: 0,
63 }
64 }
65
66 pub fn train(&mut self, vectors: &[&[f32]]) {
68 assert!(!vectors.is_empty());
69 assert!(self.dim > 0);
70 assert!(
71 self.dim.is_multiple_of(self.params.pq_m),
72 "dim {} must be divisible by pq_m {}",
73 self.dim,
74 self.params.pq_m
75 );
76
77 let n_cells = self.params.n_cells.min(vectors.len());
78 self.centroids = kmeans_centroids(vectors, self.dim, n_cells, 20);
79 self.cells = vec![Vec::new(); self.centroids.len()];
80
81 let mut residuals: Vec<Vec<f32>> = Vec::with_capacity(vectors.len());
83 for v in vectors {
84 let cell = self.nearest_centroid(v);
85 let res: Vec<f32> = v
86 .iter()
87 .zip(&self.centroids[cell])
88 .map(|(a, b)| a - b)
89 .collect();
90 residuals.push(res);
91 }
92 let res_refs: Vec<&[f32]> = residuals.iter().map(|r| r.as_slice()).collect();
93 self.pq = Some(PqCodec::train(
94 &res_refs,
95 self.dim,
96 self.params.pq_m,
97 self.params.pq_k,
98 20,
99 ));
100 }
101
102 pub fn add(&mut self, vector: &[f32]) -> u32 {
104 assert_eq!(vector.len(), self.dim);
105 let pq = self
106 .pq
107 .as_ref()
108 .expect("index must be trained before add()");
109
110 let cell = self.nearest_centroid(vector);
111 let residual: Vec<f32> = vector
112 .iter()
113 .zip(&self.centroids[cell])
114 .map(|(a, b)| a - b)
115 .collect();
116 let code = pq.encode(&residual);
117 let id = self.count;
118 self.cells[cell].push((id, code));
119 self.count += 1;
120 id
121 }
122
123 pub fn add_batch(&mut self, vectors: &[&[f32]]) {
125 for v in vectors {
126 self.add(v);
127 }
128 }
129
130 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
132 assert_eq!(query.len(), self.dim);
133 if self.centroids.is_empty() || self.count == 0 {
134 return Vec::new();
135 }
136
137 let pq = match &self.pq {
138 Some(p) => p,
139 None => return Vec::new(),
140 };
141
142 let nprobe = self.params.nprobe.min(self.centroids.len());
143 let mut centroid_dists: Vec<(usize, f32)> = self
144 .centroids
145 .iter()
146 .enumerate()
147 .map(|(i, c)| (i, distance(query, c, self.params.metric)))
148 .collect();
149 centroid_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
150
151 let mut candidates: Vec<SearchResult> = Vec::new();
152
153 for &(cell_idx, _) in centroid_dists.iter().take(nprobe) {
154 let residual_query: Vec<f32> = query
155 .iter()
156 .zip(&self.centroids[cell_idx])
157 .map(|(q, c)| q - c)
158 .collect();
159 let table = match pq.build_distance_table(&residual_query) {
160 Ok(t) => t,
161 Err(e) => {
162 tracing::warn!(error = %e, "IVF PQ build_distance_table budget exhausted; skipping cell");
163 continue;
164 }
165 };
166
167 for (id, code) in &self.cells[cell_idx] {
168 let dist = pq.asymmetric_distance(&table, code);
169 candidates.push(SearchResult {
170 id: *id,
171 distance: dist,
172 });
173 }
174 }
175
176 if candidates.len() > top_k {
177 candidates.select_nth_unstable_by(top_k, |a, b| {
178 a.distance
179 .partial_cmp(&b.distance)
180 .unwrap_or(std::cmp::Ordering::Equal)
181 });
182 candidates.truncate(top_k);
183 }
184 candidates.sort_by(|a, b| {
185 a.distance
186 .partial_cmp(&b.distance)
187 .unwrap_or(std::cmp::Ordering::Equal)
188 });
189 candidates
190 }
191
192 fn nearest_centroid(&self, vector: &[f32]) -> usize {
193 let mut best = 0;
194 let mut best_dist = f32::MAX;
195 for (i, c) in self.centroids.iter().enumerate() {
196 let d = distance(vector, c, self.params.metric);
197 if d < best_dist {
198 best_dist = d;
199 best = i;
200 }
201 }
202 best
203 }
204
205 pub fn len(&self) -> usize {
206 self.count as usize
207 }
208
209 pub fn is_empty(&self) -> bool {
210 self.count == 0
211 }
212
213 pub fn dim(&self) -> usize {
214 self.dim
215 }
216
217 pub fn n_cells(&self) -> usize {
218 self.centroids.len()
219 }
220}
221
222fn kmeans_centroids(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> Vec<Vec<f32>> {
223 let n = data.len();
224 let k = k.min(n);
225 if k == 0 {
226 return Vec::new();
227 }
228
229 let mut centroids: Vec<Vec<f32>> = vec![data[0].to_vec()];
230 let mut min_dists = vec![f32::MAX; n];
231
232 for (i, point) in data.iter().enumerate() {
234 let d = distance(point, ¢roids[0], DistanceMetric::L2);
235 if d < min_dists[i] {
236 min_dists[i] = d;
237 }
238 }
239
240 let mut rng = crate::hnsw::Xorshift64::new(0xC0FF_EEDE_ADBE_EF42);
241 for _ in 1..k {
242 let total: f64 = min_dists.iter().map(|&d| d as f64).sum();
243 let next_idx = if total < f64::EPSILON {
244 0
245 } else {
246 let target = rng.next_f64() * total;
247 let mut acc = 0.0f64;
248 let mut chosen = n - 1;
249 for (i, &d) in min_dists.iter().enumerate() {
250 acc += d as f64;
251 if acc >= target {
252 chosen = i;
253 break;
254 }
255 }
256 chosen
257 };
258 centroids.push(data[next_idx].to_vec());
259 let last = centroids.last().expect("just pushed");
260 for (i, point) in data.iter().enumerate() {
261 let d = distance(point, last, DistanceMetric::L2);
262 if d < min_dists[i] {
263 min_dists[i] = d;
264 }
265 }
266 }
267
268 let mut assignments = vec![0usize; n];
269 for _ in 0..max_iter {
270 let mut changed = false;
271 for (i, point) in data.iter().enumerate() {
272 let mut best = 0;
273 let mut best_d = f32::MAX;
274 for (c, centroid) in centroids.iter().enumerate() {
275 let d = distance(point, centroid, DistanceMetric::L2);
276 if d < best_d {
277 best_d = d;
278 best = c;
279 }
280 }
281 if assignments[i] != best {
282 assignments[i] = best;
283 changed = true;
284 }
285 }
286 if !changed {
287 break;
288 }
289 let mut sums = vec![vec![0.0f32; dim]; k];
290 let mut counts = vec![0usize; k];
291 for (i, point) in data.iter().enumerate() {
292 let c = assignments[i];
293 counts[c] += 1;
294 for d in 0..dim {
295 sums[c][d] += point[d];
296 }
297 }
298 for c in 0..k {
299 if counts[c] > 0 {
300 for d in 0..dim {
301 centroids[c][d] = sums[c][d] / counts[c] as f32;
302 }
303 }
304 }
305 }
306 centroids
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 fn make_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
314 (0..n)
315 .map(|i| (0..dim).map(|d| ((i * dim + d) as f32) * 0.01).collect())
316 .collect()
317 }
318
319 #[test]
320 fn train_and_search() {
321 let vecs = make_vectors(1000, 16);
322 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
323
324 let mut idx = IvfPqIndex::new(
325 16,
326 IvfPqParams {
327 n_cells: 32,
328 pq_m: 4,
329 pq_k: 32,
330 nprobe: 8,
331 metric: DistanceMetric::L2,
332 },
333 );
334 idx.train(&refs);
335 idx.add_batch(&refs);
336
337 assert_eq!(idx.len(), 1000);
338
339 let query = &vecs[500];
340 let results = idx.search(query, 5);
341 assert_eq!(results.len(), 5);
342 assert!(
343 results.iter().any(|r| r.id == 500),
344 "exact match not found in top-5"
345 );
346 }
347
348 #[test]
349 fn empty_index() {
350 let idx = IvfPqIndex::new(8, IvfPqParams::default());
351 assert!(idx.search(&[0.0; 8], 5).is_empty());
352 }
353}