1use std::collections::HashMap;
34
35use super::distance::{cmp_f32, l2_squared_simd, DistanceResult};
36use super::hnsw::NodeId;
37
38#[derive(Clone, Debug)]
40pub struct IvfConfig {
41 pub n_lists: usize,
43 pub n_probes: usize,
45 pub dimension: usize,
47 pub max_iterations: usize,
49 pub convergence_threshold: f32,
51}
52
53impl Default for IvfConfig {
54 fn default() -> Self {
55 Self {
56 n_lists: 100,
57 n_probes: 10,
58 dimension: 128,
59 max_iterations: 50,
60 convergence_threshold: 1e-4,
61 }
62 }
63}
64
65impl IvfConfig {
66 pub fn new(dimension: usize, n_lists: usize) -> Self {
67 Self {
68 n_lists,
69 n_probes: (n_lists / 10).max(1),
70 dimension,
71 ..Default::default()
72 }
73 }
74
75 pub fn with_probes(mut self, n_probes: usize) -> Self {
76 self.n_probes = n_probes;
77 self
78 }
79}
80
81#[derive(Clone)]
83struct IvfList {
84 centroid: Vec<f32>,
86 ids: Vec<NodeId>,
88 vectors: Vec<Vec<f32>>,
90}
91
92impl IvfList {
93 fn new(centroid: Vec<f32>) -> Self {
94 Self {
95 centroid,
96 ids: Vec::new(),
97 vectors: Vec::new(),
98 }
99 }
100
101 fn add(&mut self, id: NodeId, vector: Vec<f32>) {
102 self.ids.push(id);
103 self.vectors.push(vector);
104 }
105
106 fn len(&self) -> usize {
107 self.ids.len()
108 }
109
110 fn is_empty(&self) -> bool {
111 self.ids.is_empty()
112 }
113}
114
115pub struct IvfIndex {
117 config: IvfConfig,
118 lists: Vec<IvfList>,
120 id_to_list: HashMap<NodeId, usize>,
122 trained: bool,
124 count: usize,
126 next_id: NodeId,
128}
129
130impl IvfIndex {
131 pub fn new(config: IvfConfig) -> Self {
133 Self {
134 config,
135 lists: Vec::new(),
136 id_to_list: HashMap::new(),
137 trained: false,
138 count: 0,
139 next_id: 0,
140 }
141 }
142
143 pub fn with_dimension(dimension: usize) -> Self {
145 Self::new(IvfConfig::new(dimension, 100))
146 }
147
148 pub fn train(&mut self, vectors: &[Vec<f32>]) {
150 if vectors.is_empty() {
151 return;
152 }
153
154 let n_lists = self.config.n_lists.min(vectors.len());
155
156 let centroids = self.kmeans_plusplus_init(vectors, n_lists);
158
159 let final_centroids = self.kmeans(vectors, centroids);
161
162 self.lists = final_centroids.into_iter().map(IvfList::new).collect();
164
165 self.trained = true;
166 }
167
168 fn kmeans_plusplus_init(&self, vectors: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
170 let mut centroids = Vec::with_capacity(k);
171
172 if vectors.is_empty() || k == 0 {
173 return centroids;
174 }
175
176 centroids.push(vectors[vectors.len() / 2].clone());
178
179 for _ in 1..k {
181 let mut distances: Vec<f32> = vectors
182 .iter()
183 .map(|v| {
184 centroids
185 .iter()
186 .map(|c| l2_squared_simd(v, c))
187 .fold(f32::MAX, f32::min)
188 })
189 .collect();
190
191 let total: f32 = distances.iter().sum();
193 if total > 0.0 {
194 for d in &mut distances {
195 *d /= total;
196 }
197 }
198
199 let max_idx = distances
201 .iter()
202 .enumerate()
203 .max_by(|(la, a), (lb, b)| cmp_f32(**a, **b).then_with(|| la.cmp(lb)))
204 .map(|(i, _)| i)
205 .unwrap_or(0);
206
207 centroids.push(vectors[max_idx].clone());
208 }
209
210 centroids
211 }
212
213 fn kmeans(&self, vectors: &[Vec<f32>], mut centroids: Vec<Vec<f32>>) -> Vec<Vec<f32>> {
215 let dim = self.config.dimension;
216 let k = centroids.len();
217
218 for _ in 0..self.config.max_iterations {
219 let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); k];
221 for (i, vector) in vectors.iter().enumerate() {
222 let nearest = self.find_nearest_centroid(vector, ¢roids);
223 assignments[nearest].push(i);
224 }
225
226 let mut new_centroids = Vec::with_capacity(k);
228 let mut max_shift: f32 = 0.0;
229
230 for (cluster_idx, indices) in assignments.iter().enumerate() {
231 if indices.is_empty() {
232 new_centroids.push(centroids[cluster_idx].clone());
234 continue;
235 }
236
237 let mut new_centroid = vec![0.0f32; dim];
239 for &idx in indices {
240 for (j, val) in vectors[idx].iter().enumerate() {
241 if j < dim {
242 new_centroid[j] += val;
243 }
244 }
245 }
246 for val in &mut new_centroid {
247 *val /= indices.len() as f32;
248 }
249
250 let shift = l2_squared_simd(&new_centroid, ¢roids[cluster_idx]).sqrt();
252 max_shift = max_shift.max(shift);
253
254 new_centroids.push(new_centroid);
255 }
256
257 centroids = new_centroids;
258
259 if max_shift < self.config.convergence_threshold {
261 break;
262 }
263 }
264
265 centroids
266 }
267
268 fn find_nearest_centroid(&self, vector: &[f32], centroids: &[Vec<f32>]) -> usize {
270 centroids
271 .iter()
272 .enumerate()
273 .map(|(i, c)| (i, l2_squared_simd(vector, c)))
274 .min_by(|(li, la), (ri, rb)| cmp_f32(*la, *rb).then_with(|| li.cmp(ri)))
275 .map(|(i, _)| i)
276 .unwrap_or(0)
277 }
278
279 fn find_nearest_centroids(&self, vector: &[f32], k: usize) -> Vec<usize> {
281 let mut distances: Vec<(usize, f32)> = self
282 .lists
283 .iter()
284 .enumerate()
285 .map(|(i, list)| (i, l2_squared_simd(vector, &list.centroid)))
286 .collect();
287
288 distances.sort_by(|(li, la), (ri, lb)| cmp_f32(*la, *lb).then_with(|| li.cmp(ri)));
289 distances.into_iter().take(k).map(|(i, _)| i).collect()
290 }
291
292 pub fn add(&mut self, vector: Vec<f32>) -> NodeId {
294 let id = self.next_id;
295 self.next_id += 1;
296 self.add_with_id(id, vector);
297 id
298 }
299
300 pub fn add_with_id(&mut self, id: NodeId, vector: Vec<f32>) {
302 if !self.trained || self.lists.is_empty() {
303 if self.lists.is_empty() {
305 self.lists.push(IvfList::new(vector.clone()));
306 self.trained = true;
307 }
308 }
309
310 let list_idx = self.find_nearest_centroid(
311 &vector,
312 &self
313 .lists
314 .iter()
315 .map(|l| l.centroid.clone())
316 .collect::<Vec<_>>(),
317 );
318
319 self.lists[list_idx].add(id, vector);
320 self.id_to_list.insert(id, list_idx);
321 self.count += 1;
322 }
323
324 pub fn add_batch(&mut self, vectors: Vec<Vec<f32>>) -> Vec<NodeId> {
326 vectors.into_iter().map(|v| self.add(v)).collect()
327 }
328
329 pub fn add_batch_with_ids(&mut self, items: Vec<(NodeId, Vec<f32>)>) {
331 for (id, vector) in items {
332 self.add_with_id(id, vector);
333 }
334 }
335
336 pub fn remove(&mut self, id: NodeId) -> bool {
338 if let Some(list_idx) = self.id_to_list.remove(&id) {
339 let list = &mut self.lists[list_idx];
340 if let Some(pos) = list.ids.iter().position(|&x| x == id) {
341 list.ids.remove(pos);
342 list.vectors.remove(pos);
343 self.count = self.count.saturating_sub(1);
344 return true;
345 }
346 }
347 false
348 }
349
350 pub fn search(&self, query: &[f32], k: usize) -> Vec<DistanceResult> {
352 self.search_with_probes(query, k, self.config.n_probes)
353 }
354
355 pub fn search_with_probes(
357 &self,
358 query: &[f32],
359 k: usize,
360 n_probes: usize,
361 ) -> Vec<DistanceResult> {
362 if self.lists.is_empty() {
363 return Vec::new();
364 }
365
366 let probes = self.find_nearest_centroids(query, n_probes);
367
368 let mut candidates: Vec<DistanceResult> = Vec::new();
370 for list_idx in probes {
371 let list = &self.lists[list_idx];
372 for (i, vector) in list.vectors.iter().enumerate() {
373 let distance = l2_squared_simd(query, vector).sqrt();
374 candidates.push(DistanceResult::new(list.ids[i], distance));
375 }
376 }
377
378 candidates.sort_by(|a, b| cmp_f32(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
380 candidates.truncate(k);
381 candidates
382 }
383
384 pub fn get(&self, id: NodeId) -> Option<&[f32]> {
386 if let Some(&list_idx) = self.id_to_list.get(&id) {
387 let list = &self.lists[list_idx];
388 if let Some(pos) = list.ids.iter().position(|&x| x == id) {
389 return Some(&list.vectors[pos]);
390 }
391 }
392 None
393 }
394
395 pub fn contains(&self, id: NodeId) -> bool {
397 self.id_to_list.contains_key(&id)
398 }
399
400 pub fn len(&self) -> usize {
402 self.count
403 }
404
405 pub fn is_empty(&self) -> bool {
407 self.count == 0
408 }
409
410 pub fn n_lists(&self) -> usize {
412 self.lists.len()
413 }
414
415 pub fn stats(&self) -> IvfStats {
417 let sizes: Vec<usize> = self.lists.iter().map(|l| l.len()).collect();
418 let non_empty = sizes.iter().filter(|&&s| s > 0).count();
419
420 let avg = if non_empty > 0 {
421 sizes.iter().sum::<usize>() as f64 / non_empty as f64
422 } else {
423 0.0
424 };
425
426 let max = sizes.iter().copied().max().unwrap_or(0);
427 let min = sizes.iter().filter(|&&s| s > 0).copied().min().unwrap_or(0);
428
429 IvfStats {
430 total_vectors: self.count,
431 n_lists: self.lists.len(),
432 non_empty_lists: non_empty,
433 avg_list_size: avg,
434 max_list_size: max,
435 min_list_size: min,
436 dimension: self.config.dimension,
437 trained: self.trained,
438 }
439 }
440
441 pub fn to_bytes(&self) -> Vec<u8> {
443 let mut bytes = Vec::new();
444 bytes.extend_from_slice(b"IVF1");
445 bytes.extend_from_slice(&(self.config.n_lists as u32).to_le_bytes());
446 bytes.extend_from_slice(&(self.config.n_probes as u32).to_le_bytes());
447 bytes.extend_from_slice(&(self.config.dimension as u32).to_le_bytes());
448 bytes.extend_from_slice(&(self.config.max_iterations as u32).to_le_bytes());
449 bytes.extend_from_slice(&self.config.convergence_threshold.to_le_bytes());
450 bytes.push(if self.trained { 1 } else { 0 });
451 bytes.extend_from_slice(&(self.count as u64).to_le_bytes());
452 bytes.extend_from_slice(&self.next_id.to_le_bytes());
453 bytes.extend_from_slice(&(self.lists.len() as u32).to_le_bytes());
454
455 for list in &self.lists {
456 bytes.extend_from_slice(&(list.centroid.len() as u32).to_le_bytes());
457 for value in &list.centroid {
458 bytes.extend_from_slice(&value.to_le_bytes());
459 }
460
461 bytes.extend_from_slice(&(list.ids.len() as u32).to_le_bytes());
462 for id in &list.ids {
463 bytes.extend_from_slice(&id.to_le_bytes());
464 }
465
466 bytes.extend_from_slice(&(list.vectors.len() as u32).to_le_bytes());
467 for vector in &list.vectors {
468 bytes.extend_from_slice(&(vector.len() as u32).to_le_bytes());
469 for value in vector {
470 bytes.extend_from_slice(&value.to_le_bytes());
471 }
472 }
473 }
474
475 bytes
476 }
477
478 pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
480 if bytes.len() < 41 {
481 return Err("data too short".to_string());
482 }
483 if &bytes[0..4] != b"IVF1" {
484 return Err("invalid IVF magic".to_string());
485 }
486
487 let mut pos = 4usize;
488 let read_u32 = |buf: &[u8], pos: &mut usize| -> Result<u32, String> {
489 if *pos + 4 > buf.len() {
490 return Err("truncated IVF payload".to_string());
491 }
492 let value =
493 u32::from_le_bytes([buf[*pos], buf[*pos + 1], buf[*pos + 2], buf[*pos + 3]]);
494 *pos += 4;
495 Ok(value)
496 };
497 let read_u64 = |buf: &[u8], pos: &mut usize| -> Result<u64, String> {
498 if *pos + 8 > buf.len() {
499 return Err("truncated IVF payload".to_string());
500 }
501 let value = u64::from_le_bytes([
502 buf[*pos],
503 buf[*pos + 1],
504 buf[*pos + 2],
505 buf[*pos + 3],
506 buf[*pos + 4],
507 buf[*pos + 5],
508 buf[*pos + 6],
509 buf[*pos + 7],
510 ]);
511 *pos += 8;
512 Ok(value)
513 };
514 let read_f32 = |buf: &[u8], pos: &mut usize| -> Result<f32, String> {
515 if *pos + 4 > buf.len() {
516 return Err("truncated IVF payload".to_string());
517 }
518 let value =
519 f32::from_le_bytes([buf[*pos], buf[*pos + 1], buf[*pos + 2], buf[*pos + 3]]);
520 *pos += 4;
521 Ok(value)
522 };
523
524 let config = IvfConfig {
525 n_lists: read_u32(bytes, &mut pos)? as usize,
526 n_probes: read_u32(bytes, &mut pos)? as usize,
527 dimension: read_u32(bytes, &mut pos)? as usize,
528 max_iterations: read_u32(bytes, &mut pos)? as usize,
529 convergence_threshold: read_f32(bytes, &mut pos)?,
530 };
531 if pos >= bytes.len() {
532 return Err("truncated IVF payload".to_string());
533 }
534 let trained = bytes[pos] == 1;
535 pos += 1;
536 let count = read_u64(bytes, &mut pos)? as usize;
537 let next_id = read_u64(bytes, &mut pos)?;
538 let list_count = read_u32(bytes, &mut pos)? as usize;
539
540 let mut lists = Vec::with_capacity(list_count);
541 let mut id_to_list = HashMap::new();
542 for list_idx in 0..list_count {
543 let centroid_len = read_u32(bytes, &mut pos)? as usize;
544 let mut centroid = Vec::with_capacity(centroid_len);
545 for _ in 0..centroid_len {
546 centroid.push(read_f32(bytes, &mut pos)?);
547 }
548
549 let id_count = read_u32(bytes, &mut pos)? as usize;
550 let mut ids = Vec::with_capacity(id_count);
551 for _ in 0..id_count {
552 let id = read_u64(bytes, &mut pos)?;
553 id_to_list.insert(id, list_idx);
554 ids.push(id);
555 }
556
557 let vector_count = read_u32(bytes, &mut pos)? as usize;
558 let mut vectors = Vec::with_capacity(vector_count);
559 for _ in 0..vector_count {
560 let vector_len = read_u32(bytes, &mut pos)? as usize;
561 let mut vector = Vec::with_capacity(vector_len);
562 for _ in 0..vector_len {
563 vector.push(read_f32(bytes, &mut pos)?);
564 }
565 vectors.push(vector);
566 }
567
568 lists.push(IvfList {
569 centroid,
570 ids,
571 vectors,
572 });
573 }
574
575 Ok(Self {
576 config,
577 lists,
578 id_to_list,
579 trained,
580 count,
581 next_id,
582 })
583 }
584}
585
586#[derive(Debug, Clone)]
588pub struct IvfStats {
589 pub total_vectors: usize,
590 pub n_lists: usize,
591 pub non_empty_lists: usize,
592 pub avg_list_size: f64,
593 pub max_list_size: usize,
594 pub min_list_size: usize,
595 pub dimension: usize,
596 pub trained: bool,
597}
598
599#[cfg(test)]
604mod tests {
605 use super::*;
606
607 fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
608 (0..dim)
610 .map(|i| ((seed * 1103515245 + i as u64 * 12345) % 1000) as f32 / 1000.0)
611 .collect()
612 }
613
614 #[test]
615 fn test_ivf_basic() {
616 let mut ivf = IvfIndex::new(IvfConfig::new(8, 4));
617
618 let training: Vec<Vec<f32>> = (0..100).map(|i| random_vector(8, i)).collect();
620
621 ivf.train(&training);
622 assert!(ivf.trained);
623 assert_eq!(ivf.n_lists(), 4);
624
625 for (i, v) in training.iter().enumerate() {
627 ivf.add_with_id(i as u64, v.clone());
628 }
629
630 assert_eq!(ivf.len(), 100);
631 }
632
633 #[test]
634 fn test_ivf_search() {
635 let dim = 8;
636 let mut ivf = IvfIndex::new(IvfConfig {
637 n_lists: 4,
638 n_probes: 2,
639 dimension: dim,
640 ..Default::default()
641 });
642
643 let mut vectors = Vec::new();
645 for cluster in 0..4 {
646 let base = cluster as f32 * 10.0;
647 for i in 0..25 {
648 let mut v = vec![base; dim];
649 v[0] += i as f32 * 0.01;
650 vectors.push(v);
651 }
652 }
653
654 ivf.train(&vectors);
655
656 for (i, v) in vectors.iter().enumerate() {
657 ivf.add_with_id(i as u64, v.clone());
658 }
659
660 let query = vec![0.05; dim];
662 let results = ivf.search(&query, 5);
663
664 assert!(!results.is_empty());
665 for r in &results {
667 assert!(r.id < 25);
668 }
669 }
670
671 #[test]
672 fn test_ivf_remove() {
673 let mut ivf = IvfIndex::new(IvfConfig::new(4, 2));
674
675 ivf.add_with_id(1, vec![1.0, 0.0, 0.0, 0.0]);
676 ivf.add_with_id(2, vec![0.0, 1.0, 0.0, 0.0]);
677 ivf.add_with_id(3, vec![0.0, 0.0, 1.0, 0.0]);
678
679 assert_eq!(ivf.len(), 3);
680 assert!(ivf.contains(2));
681
682 assert!(ivf.remove(2));
683 assert_eq!(ivf.len(), 2);
684 assert!(!ivf.contains(2));
685 }
686
687 #[test]
688 fn test_ivf_stats() {
689 let mut ivf = IvfIndex::new(IvfConfig::new(4, 3));
690
691 let training: Vec<Vec<f32>> = vec![
692 vec![0.0, 0.0, 0.0, 0.0],
693 vec![1.0, 0.0, 0.0, 0.0],
694 vec![2.0, 0.0, 0.0, 0.0],
695 ];
696
697 ivf.train(&training);
698
699 for (i, v) in training.iter().enumerate() {
700 ivf.add_with_id(i as u64, v.clone());
701 }
702
703 let stats = ivf.stats();
704 assert_eq!(stats.total_vectors, 3);
705 assert_eq!(stats.n_lists, 3);
706 assert!(stats.trained);
707 }
708}