1use crate::distance::{DistanceMetric, distance};
9use crate::hnsw::SearchResult;
10use crate::quantize::pq::PqCodec;
11
12#[derive(Clone)]
14pub struct IvfPqParams {
15 pub n_cells: usize,
17 pub pq_m: usize,
19 pub pq_k: usize,
21 pub nprobe: usize,
23 pub metric: DistanceMetric,
25}
26
27impl Default for IvfPqParams {
28 fn default() -> Self {
29 Self {
30 n_cells: 256,
31 pq_m: 8,
32 pq_k: 256,
33 nprobe: 16,
34 metric: DistanceMetric::L2,
35 }
36 }
37}
38
39pub struct IvfPqIndex {
41 dim: usize,
42 params: IvfPqParams,
43 centroids: Vec<Vec<f32>>,
45 pq: Option<PqCodec>,
47 cells: Vec<Vec<(u32, Vec<u8>)>>,
49 count: u32,
51}
52
53impl IvfPqIndex {
54 pub fn new(dim: usize, params: IvfPqParams) -> Self {
56 Self {
57 dim,
58 params,
59 centroids: Vec::new(),
60 pq: None,
61 cells: Vec::new(),
62 count: 0,
63 }
64 }
65
66 pub fn train(&mut self, vectors: &[&[f32]]) {
68 assert!(!vectors.is_empty());
69 assert!(self.dim > 0);
70 assert!(
71 self.dim.is_multiple_of(self.params.pq_m),
72 "dim {} must be divisible by pq_m {}",
73 self.dim,
74 self.params.pq_m
75 );
76
77 let n_cells = self.params.n_cells.min(vectors.len());
78 self.centroids = kmeans_centroids(vectors, self.dim, n_cells, 20);
79 self.cells = vec![Vec::new(); self.centroids.len()];
80
81 let mut residuals: Vec<Vec<f32>> = Vec::with_capacity(vectors.len());
82 for v in vectors {
83 let cell = self.nearest_centroid(v);
84 let res: Vec<f32> = v
85 .iter()
86 .zip(&self.centroids[cell])
87 .map(|(a, b)| a - b)
88 .collect();
89 residuals.push(res);
90 }
91 let res_refs: Vec<&[f32]> = residuals.iter().map(|r| r.as_slice()).collect();
92 self.pq = Some(PqCodec::train(
93 &res_refs,
94 self.dim,
95 self.params.pq_m,
96 self.params.pq_k,
97 20,
98 ));
99 }
100
101 pub fn add(&mut self, vector: &[f32]) -> u32 {
103 assert_eq!(vector.len(), self.dim);
104 let pq = self
105 .pq
106 .as_ref()
107 .expect("index must be trained before add()");
108
109 let cell = self.nearest_centroid(vector);
110 let residual: Vec<f32> = vector
111 .iter()
112 .zip(&self.centroids[cell])
113 .map(|(a, b)| a - b)
114 .collect();
115 let code = pq.encode(&residual);
116 let id = self.count;
117 self.cells[cell].push((id, code));
118 self.count += 1;
119 id
120 }
121
122 pub fn add_batch(&mut self, vectors: &[&[f32]]) {
124 for v in vectors {
125 self.add(v);
126 }
127 }
128
129 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
131 assert_eq!(query.len(), self.dim);
132 if self.centroids.is_empty() || self.count == 0 {
133 return Vec::new();
134 }
135
136 let pq = match &self.pq {
137 Some(p) => p,
138 None => return Vec::new(),
139 };
140
141 let nprobe = self.params.nprobe.min(self.centroids.len());
142 let mut centroid_dists: Vec<(usize, f32)> = self
143 .centroids
144 .iter()
145 .enumerate()
146 .map(|(i, c)| (i, distance(query, c, self.params.metric)))
147 .collect();
148 centroid_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
149
150 let mut candidates: Vec<SearchResult> = Vec::new();
151
152 for &(cell_idx, _) in centroid_dists.iter().take(nprobe) {
153 let residual_query: Vec<f32> = query
154 .iter()
155 .zip(&self.centroids[cell_idx])
156 .map(|(q, c)| q - c)
157 .collect();
158 let table = match pq.build_distance_table(&residual_query) {
159 Ok(t) => t,
160 Err(e) => {
161 tracing::warn!(error = %e, "IVF PQ build_distance_table budget exhausted; skipping cell");
162 continue;
163 }
164 };
165
166 for (id, code) in &self.cells[cell_idx] {
167 let dist = pq.asymmetric_distance(&table, code);
168 candidates.push(SearchResult {
169 id: *id,
170 distance: dist,
171 });
172 }
173 }
174
175 if candidates.len() > top_k {
176 candidates.select_nth_unstable_by(top_k, |a, b| {
177 a.distance
178 .partial_cmp(&b.distance)
179 .unwrap_or(std::cmp::Ordering::Equal)
180 });
181 candidates.truncate(top_k);
182 }
183 candidates.sort_by(|a, b| {
184 a.distance
185 .partial_cmp(&b.distance)
186 .unwrap_or(std::cmp::Ordering::Equal)
187 });
188 candidates
189 }
190
191 fn nearest_centroid(&self, vector: &[f32]) -> usize {
192 let mut best = 0;
193 let mut best_dist = f32::MAX;
194 for (i, c) in self.centroids.iter().enumerate() {
195 let d = distance(vector, c, self.params.metric);
196 if d < best_dist {
197 best_dist = d;
198 best = i;
199 }
200 }
201 best
202 }
203
204 pub fn len(&self) -> usize {
205 self.count as usize
206 }
207
208 pub fn is_empty(&self) -> bool {
209 self.count == 0
210 }
211
212 pub fn dim(&self) -> usize {
213 self.dim
214 }
215
216 pub fn n_cells(&self) -> usize {
217 self.centroids.len()
218 }
219}
220
221fn kmeans_centroids(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> Vec<Vec<f32>> {
222 let n = data.len();
223 let k = k.min(n);
224 if k == 0 {
225 return Vec::new();
226 }
227
228 let mut centroids: Vec<Vec<f32>> = vec![data[0].to_vec()];
229 let mut min_dists = vec![f32::MAX; n];
230
231 for (i, point) in data.iter().enumerate() {
233 let d = distance(point, ¢roids[0], DistanceMetric::L2);
234 if d < min_dists[i] {
235 min_dists[i] = d;
236 }
237 }
238
239 let mut rng = crate::hnsw::Xorshift64::new(0xC0FF_EEDE_ADBE_EF42);
240 for _ in 1..k {
241 let total: f64 = min_dists.iter().map(|&d| d as f64).sum();
242 let next_idx = if total < f64::EPSILON {
243 0
244 } else {
245 let target = rng.next_f64() * total;
246 let mut acc = 0.0f64;
247 let mut chosen = n - 1;
248 for (i, &d) in min_dists.iter().enumerate() {
249 acc += d as f64;
250 if acc >= target {
251 chosen = i;
252 break;
253 }
254 }
255 chosen
256 };
257 centroids.push(data[next_idx].to_vec());
258 let last = centroids.last().expect("just pushed");
259 for (i, point) in data.iter().enumerate() {
260 let d = distance(point, last, DistanceMetric::L2);
261 if d < min_dists[i] {
262 min_dists[i] = d;
263 }
264 }
265 }
266
267 let mut assignments = vec![0usize; n];
268 for _ in 0..max_iter {
269 let mut changed = false;
270 for (i, point) in data.iter().enumerate() {
271 let mut best = 0;
272 let mut best_d = f32::MAX;
273 for (c, centroid) in centroids.iter().enumerate() {
274 let d = distance(point, centroid, DistanceMetric::L2);
275 if d < best_d {
276 best_d = d;
277 best = c;
278 }
279 }
280 if assignments[i] != best {
281 assignments[i] = best;
282 changed = true;
283 }
284 }
285 if !changed {
286 break;
287 }
288 let mut sums = vec![vec![0.0f32; dim]; k];
289 let mut counts = vec![0usize; k];
290 for (i, point) in data.iter().enumerate() {
291 let c = assignments[i];
292 counts[c] += 1;
293 for d in 0..dim {
294 sums[c][d] += point[d];
295 }
296 }
297 for c in 0..k {
298 if counts[c] > 0 {
299 for d in 0..dim {
300 centroids[c][d] = sums[c][d] / counts[c] as f32;
301 }
302 }
303 }
304 }
305 centroids
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 fn make_vectors(n: usize, dim: usize) -> Vec<Vec<f32>> {
313 (0..n)
314 .map(|i| (0..dim).map(|d| ((i * dim + d) as f32) * 0.01).collect())
315 .collect()
316 }
317
318 #[test]
319 fn train_and_search() {
320 let vecs = make_vectors(1000, 16);
321 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
322
323 let mut idx = IvfPqIndex::new(
324 16,
325 IvfPqParams {
326 n_cells: 32,
327 pq_m: 4,
328 pq_k: 32,
329 nprobe: 8,
330 metric: DistanceMetric::L2,
331 },
332 );
333 idx.train(&refs);
334 idx.add_batch(&refs);
335
336 assert_eq!(idx.len(), 1000);
337
338 let query = &vecs[500];
339 let results = idx.search(query, 5);
340 assert_eq!(results.len(), 5);
341 assert!(
342 results.iter().any(|r| r.id == 500),
343 "exact match not found in top-5"
344 );
345 }
346
347 #[test]
348 fn empty_index() {
349 let idx = IvfPqIndex::new(8, IvfPqParams::default());
350 assert!(idx.search(&[0.0; 8], 5).is_empty());
351 }
352}