1use std::collections::HashMap;
34
35use super::distance::{cmp_f32, l2_squared_simd, DistanceResult};
36use super::hnsw::NodeId;
37
38#[derive(Clone, Debug)]
40pub struct IvfConfig {
41 pub n_lists: usize,
43 pub n_probes: usize,
45 pub dimension: usize,
47 pub max_iterations: usize,
49 pub convergence_threshold: f32,
51}
52
53impl Default for IvfConfig {
54 fn default() -> Self {
55 Self {
56 n_lists: 100,
57 n_probes: 10,
58 dimension: 128,
59 max_iterations: 50,
60 convergence_threshold: 1e-4,
61 }
62 }
63}
64
65impl IvfConfig {
66 pub fn new(dimension: usize, n_lists: usize) -> Self {
67 Self {
68 n_lists,
69 n_probes: (n_lists / 10).max(1),
70 dimension,
71 ..Default::default()
72 }
73 }
74
75 pub fn with_probes(mut self, n_probes: usize) -> Self {
76 self.n_probes = n_probes;
77 self
78 }
79}
80
81#[derive(Clone)]
83struct IvfList {
84 centroid: Vec<f32>,
86 ids: Vec<NodeId>,
88 vectors: Vec<Vec<f32>>,
90}
91
92impl IvfList {
93 fn new(centroid: Vec<f32>) -> Self {
94 Self {
95 centroid,
96 ids: Vec::new(),
97 vectors: Vec::new(),
98 }
99 }
100
101 fn add(&mut self, id: NodeId, vector: Vec<f32>) {
102 self.ids.push(id);
103 self.vectors.push(vector);
104 }
105
106 fn len(&self) -> usize {
107 self.ids.len()
108 }
109
110 fn is_empty(&self) -> bool {
111 self.ids.is_empty()
112 }
113}
114
115pub struct IvfIndex {
117 config: IvfConfig,
118 lists: Vec<IvfList>,
120 id_to_list: HashMap<NodeId, usize>,
122 trained: bool,
124 count: usize,
126 next_id: NodeId,
128}
129
130impl IvfIndex {
131 pub fn new(config: IvfConfig) -> Self {
133 Self {
134 config,
135 lists: Vec::new(),
136 id_to_list: HashMap::new(),
137 trained: false,
138 count: 0,
139 next_id: 0,
140 }
141 }
142
143 pub fn with_dimension(dimension: usize) -> Self {
145 Self::new(IvfConfig::new(dimension, 100))
146 }
147
148 pub fn train(&mut self, vectors: &[Vec<f32>]) {
150 if vectors.is_empty() {
151 return;
152 }
153
154 let n_lists = self.config.n_lists.min(vectors.len());
155
156 let centroids = self.kmeans_plusplus_init(vectors, n_lists);
158
159 let final_centroids = self.kmeans(vectors, centroids);
161
162 self.lists = final_centroids.into_iter().map(IvfList::new).collect();
164
165 self.trained = true;
166 }
167
168 fn kmeans_plusplus_init(&self, vectors: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
170 let mut centroids = Vec::with_capacity(k);
171
172 if vectors.is_empty() || k == 0 {
173 return centroids;
174 }
175
176 centroids.push(vectors[vectors.len() / 2].clone());
178
179 for _ in 1..k {
181 let mut distances: Vec<f32> = vectors
182 .iter()
183 .map(|v| {
184 centroids
185 .iter()
186 .map(|c| l2_squared_simd(v, c))
187 .fold(f32::MAX, f32::min)
188 })
189 .collect();
190
191 let total: f32 = distances.iter().sum();
193 if total > 0.0 {
194 for d in &mut distances {
195 *d /= total;
196 }
197 }
198
199 let max_idx = distances
201 .iter()
202 .enumerate()
203 .max_by(|(la, a), (lb, b)| cmp_f32(**a, **b).then_with(|| la.cmp(lb)))
204 .map(|(i, _)| i)
205 .unwrap_or(0);
206
207 centroids.push(vectors[max_idx].clone());
208 }
209
210 centroids
211 }
212
213 fn kmeans(&self, vectors: &[Vec<f32>], mut centroids: Vec<Vec<f32>>) -> Vec<Vec<f32>> {
215 let dim = self.config.dimension;
216 let k = centroids.len();
217
218 for _ in 0..self.config.max_iterations {
219 let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); k];
221 for (i, vector) in vectors.iter().enumerate() {
222 let nearest = self.find_nearest_centroid(vector, ¢roids);
223 assignments[nearest].push(i);
224 }
225
226 let mut new_centroids = Vec::with_capacity(k);
228 let mut max_shift: f32 = 0.0;
229
230 for (cluster_idx, indices) in assignments.iter().enumerate() {
231 if indices.is_empty() {
232 new_centroids.push(centroids[cluster_idx].clone());
234 continue;
235 }
236
237 let mut new_centroid = vec![0.0f32; dim];
239 for &idx in indices {
240 for (j, val) in vectors[idx].iter().enumerate() {
241 if j < dim {
242 new_centroid[j] += val;
243 }
244 }
245 }
246 for val in &mut new_centroid {
247 *val /= indices.len() as f32;
248 }
249
250 let shift = l2_squared_simd(&new_centroid, ¢roids[cluster_idx]).sqrt();
252 max_shift = max_shift.max(shift);
253
254 new_centroids.push(new_centroid);
255 }
256
257 centroids = new_centroids;
258
259 if max_shift < self.config.convergence_threshold {
261 break;
262 }
263 }
264
265 centroids
266 }
267
268 fn find_nearest_centroid(&self, vector: &[f32], centroids: &[Vec<f32>]) -> usize {
270 centroids
271 .iter()
272 .enumerate()
273 .map(|(i, c)| (i, l2_squared_simd(vector, c)))
274 .min_by(|(li, la), (ri, rb)| cmp_f32(*la, *rb).then_with(|| li.cmp(ri)))
275 .map(|(i, _)| i)
276 .unwrap_or(0)
277 }
278
279 fn find_nearest_centroids(&self, vector: &[f32], k: usize) -> Vec<usize> {
281 let mut distances: Vec<(usize, f32)> = self
282 .lists
283 .iter()
284 .enumerate()
285 .map(|(i, list)| (i, l2_squared_simd(vector, &list.centroid)))
286 .collect();
287
288 distances.sort_by(|(li, la), (ri, lb)| cmp_f32(*la, *lb).then_with(|| li.cmp(ri)));
289 distances.into_iter().take(k).map(|(i, _)| i).collect()
290 }
291
292 pub fn add(&mut self, vector: Vec<f32>) -> NodeId {
294 let id = self.next_id;
295 self.next_id += 1;
296 self.add_with_id(id, vector);
297 id
298 }
299
300 pub fn add_with_id(&mut self, id: NodeId, vector: Vec<f32>) {
302 if !self.trained || self.lists.is_empty() {
303 if self.lists.is_empty() {
305 self.lists.push(IvfList::new(vector.clone()));
306 self.trained = true;
307 }
308 }
309
310 let list_idx = self.find_nearest_centroid(
311 &vector,
312 &self
313 .lists
314 .iter()
315 .map(|l| l.centroid.clone())
316 .collect::<Vec<_>>(),
317 );
318
319 self.lists[list_idx].add(id, vector);
320 self.id_to_list.insert(id, list_idx);
321 self.count += 1;
322 }
323
324 pub fn add_batch(&mut self, vectors: Vec<Vec<f32>>) -> Vec<NodeId> {
326 vectors.into_iter().map(|v| self.add(v)).collect()
327 }
328
329 pub fn add_batch_with_ids(&mut self, items: Vec<(NodeId, Vec<f32>)>) {
331 for (id, vector) in items {
332 self.add_with_id(id, vector);
333 }
334 }
335
336 pub fn remove(&mut self, id: NodeId) -> bool {
338 if let Some(list_idx) = self.id_to_list.remove(&id) {
339 let list = &mut self.lists[list_idx];
340 if let Some(pos) = list.ids.iter().position(|&x| x == id) {
341 list.ids.remove(pos);
342 list.vectors.remove(pos);
343 self.count = self.count.saturating_sub(1);
344 return true;
345 }
346 }
347 false
348 }
349
350 pub fn search(&self, query: &[f32], k: usize) -> Vec<DistanceResult> {
352 self.search_with_probes(query, k, self.config.n_probes)
353 }
354
355 pub fn search_with_probes(
357 &self,
358 query: &[f32],
359 k: usize,
360 n_probes: usize,
361 ) -> Vec<DistanceResult> {
362 if self.lists.is_empty() {
363 return Vec::new();
364 }
365
366 let probes = self.find_nearest_centroids(query, n_probes);
367
368 let mut candidates: Vec<DistanceResult> = Vec::new();
370 for list_idx in probes {
371 let list = &self.lists[list_idx];
372 for (i, vector) in list.vectors.iter().enumerate() {
373 let distance = l2_squared_simd(query, vector).sqrt();
374 candidates.push(DistanceResult::new(list.ids[i], distance));
375 }
376 }
377
378 candidates.sort_by(|a, b| cmp_f32(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
380 candidates.truncate(k);
381 candidates
382 }
383
384 pub fn get(&self, id: NodeId) -> Option<&[f32]> {
386 if let Some(&list_idx) = self.id_to_list.get(&id) {
387 let list = &self.lists[list_idx];
388 if let Some(pos) = list.ids.iter().position(|&x| x == id) {
389 return Some(&list.vectors[pos]);
390 }
391 }
392 None
393 }
394
395 pub fn contains(&self, id: NodeId) -> bool {
397 self.id_to_list.contains_key(&id)
398 }
399
400 pub fn len(&self) -> usize {
402 self.count
403 }
404
405 pub fn is_empty(&self) -> bool {
407 self.count == 0
408 }
409
410 pub fn n_lists(&self) -> usize {
412 self.lists.len()
413 }
414
415 pub fn stats(&self) -> IvfStats {
417 let sizes: Vec<usize> = self.lists.iter().map(|l| l.len()).collect();
418 let non_empty = sizes.iter().filter(|&&s| s > 0).count();
419
420 let avg = if non_empty > 0 {
421 sizes.iter().sum::<usize>() as f64 / non_empty as f64
422 } else {
423 0.0
424 };
425
426 let max = sizes.iter().copied().max().unwrap_or(0);
427 let min = sizes.iter().filter(|&&s| s > 0).copied().min().unwrap_or(0);
428
429 IvfStats {
430 total_vectors: self.count,
431 n_lists: self.lists.len(),
432 non_empty_lists: non_empty,
433 avg_list_size: avg,
434 max_list_size: max,
435 min_list_size: min,
436 dimension: self.config.dimension,
437 trained: self.trained,
438 }
439 }
440
441 pub fn to_bytes(&self) -> Vec<u8> {
446 let lists = self
447 .lists
448 .iter()
449 .map(|list| reddb_file::IvfListLayout {
450 centroid: list.centroid.clone(),
451 ids: list.ids.clone(),
452 vectors: list.vectors.clone(),
453 })
454 .collect();
455 let layout = reddb_file::IvfIndexLayout {
456 n_lists: self.config.n_lists,
457 n_probes: self.config.n_probes,
458 dimension: self.config.dimension,
459 max_iterations: self.config.max_iterations,
460 convergence_threshold: self.config.convergence_threshold,
461 trained: self.trained,
462 count: self.count,
463 next_id: self.next_id,
464 lists,
465 };
466 reddb_file::encode_ivf_index(&layout)
467 }
468
469 pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
471 let layout = reddb_file::decode_ivf_index(bytes).map_err(|e| e.to_string())?;
472
473 let config = IvfConfig {
474 n_lists: layout.n_lists,
475 n_probes: layout.n_probes,
476 dimension: layout.dimension,
477 max_iterations: layout.max_iterations,
478 convergence_threshold: layout.convergence_threshold,
479 };
480
481 let mut lists = Vec::with_capacity(layout.lists.len());
482 let mut id_to_list = HashMap::new();
483 for (list_idx, list) in layout.lists.into_iter().enumerate() {
484 for &id in &list.ids {
485 id_to_list.insert(id, list_idx);
486 }
487 lists.push(IvfList {
488 centroid: list.centroid,
489 ids: list.ids,
490 vectors: list.vectors,
491 });
492 }
493
494 Ok(Self {
495 config,
496 lists,
497 id_to_list,
498 trained: layout.trained,
499 count: layout.count,
500 next_id: layout.next_id,
501 })
502 }
503}
504
505#[derive(Debug, Clone)]
507pub struct IvfStats {
508 pub total_vectors: usize,
509 pub n_lists: usize,
510 pub non_empty_lists: usize,
511 pub avg_list_size: f64,
512 pub max_list_size: usize,
513 pub min_list_size: usize,
514 pub dimension: usize,
515 pub trained: bool,
516}
517
518#[cfg(test)]
523mod tests {
524 use super::*;
525
526 fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
527 (0..dim)
529 .map(|i| ((seed * 1103515245 + i as u64 * 12345) % 1000) as f32 / 1000.0)
530 .collect()
531 }
532
533 #[test]
534 fn test_ivf_basic() {
535 let mut ivf = IvfIndex::new(IvfConfig::new(8, 4));
536
537 let training: Vec<Vec<f32>> = (0..100).map(|i| random_vector(8, i)).collect();
539
540 ivf.train(&training);
541 assert!(ivf.trained);
542 assert_eq!(ivf.n_lists(), 4);
543
544 for (i, v) in training.iter().enumerate() {
546 ivf.add_with_id(i as u64, v.clone());
547 }
548
549 assert_eq!(ivf.len(), 100);
550 }
551
552 #[test]
553 fn test_ivf_search() {
554 let dim = 8;
555 let mut ivf = IvfIndex::new(IvfConfig {
556 n_lists: 4,
557 n_probes: 2,
558 dimension: dim,
559 ..Default::default()
560 });
561
562 let mut vectors = Vec::new();
564 for cluster in 0..4 {
565 let base = cluster as f32 * 10.0;
566 for i in 0..25 {
567 let mut v = vec![base; dim];
568 v[0] += i as f32 * 0.01;
569 vectors.push(v);
570 }
571 }
572
573 ivf.train(&vectors);
574
575 for (i, v) in vectors.iter().enumerate() {
576 ivf.add_with_id(i as u64, v.clone());
577 }
578
579 let query = vec![0.05; dim];
581 let results = ivf.search(&query, 5);
582
583 assert!(!results.is_empty());
584 for r in &results {
586 assert!(r.id < 25);
587 }
588 }
589
590 #[test]
591 fn test_ivf_remove() {
592 let mut ivf = IvfIndex::new(IvfConfig::new(4, 2));
593
594 ivf.add_with_id(1, vec![1.0, 0.0, 0.0, 0.0]);
595 ivf.add_with_id(2, vec![0.0, 1.0, 0.0, 0.0]);
596 ivf.add_with_id(3, vec![0.0, 0.0, 1.0, 0.0]);
597
598 assert_eq!(ivf.len(), 3);
599 assert!(ivf.contains(2));
600
601 assert!(ivf.remove(2));
602 assert_eq!(ivf.len(), 2);
603 assert!(!ivf.contains(2));
604 }
605
606 #[test]
607 fn test_ivf_stats() {
608 let mut ivf = IvfIndex::new(IvfConfig::new(4, 3));
609
610 let training: Vec<Vec<f32>> = vec![
611 vec![0.0, 0.0, 0.0, 0.0],
612 vec![1.0, 0.0, 0.0, 0.0],
613 vec![2.0, 0.0, 0.0, 0.0],
614 ];
615
616 ivf.train(&training);
617
618 for (i, v) in training.iter().enumerate() {
619 ivf.add_with_id(i as u64, v.clone());
620 }
621
622 let stats = ivf.stats();
623 assert_eq!(stats.total_vectors, 3);
624 assert_eq!(stats.n_lists, 3);
625 assert!(stats.trained);
626 }
627}