1use std::collections::HashMap;
11use std::io::{self, Cursor, Read, Write};
12use std::path::Path;
13
14use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
15use rand::prelude::*;
16use serde::{Deserialize, Serialize};
17
18use super::rabitq::QuantizedVector;
19
20const CENTROIDS_MAGIC: u32 = 0x48435643; #[allow(dead_code)]
25const IVF_MAGIC: u32 = 0x49565651; #[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct CoarseCentroids {
30 pub num_clusters: u32,
32 pub dim: usize,
34 pub centroids: Vec<f32>,
36 pub version: u64,
38}
39
40impl CoarseCentroids {
41 #[cfg(feature = "native")]
45 pub fn train(vectors: &[Vec<f32>], num_clusters: usize, max_iters: usize, _seed: u64) -> Self {
46 use kmeans::{EuclideanDistance, KMeans, KMeansConfig};
47
48 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
49 assert!(num_clusters > 0, "Need at least 1 cluster");
50
51 let actual_clusters = num_clusters.min(vectors.len());
52 let dim = vectors[0].len();
53
54 let samples: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
56
57 let kmean: KMeans<f32, 8, _> = KMeans::new(&samples, vectors.len(), dim, EuclideanDistance);
60 let result = kmean.kmeans_lloyd(
61 actual_clusters,
62 max_iters,
63 KMeans::init_kmeanplusplus,
64 &KMeansConfig::default(),
65 );
66
67 let centroids: Vec<f32> = result
69 .centroids
70 .iter()
71 .flat_map(|c| c.iter().copied())
72 .collect();
73
74 let version = std::time::SystemTime::now()
75 .duration_since(std::time::UNIX_EPOCH)
76 .unwrap_or_default()
77 .as_millis() as u64;
78
79 Self {
80 num_clusters: actual_clusters as u32,
81 dim,
82 centroids,
83 version,
84 }
85 }
86
87 #[cfg(not(feature = "native"))]
89 pub fn train(vectors: &[Vec<f32>], num_clusters: usize, max_iters: usize, seed: u64) -> Self {
90 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
91 assert!(num_clusters > 0, "Need at least 1 cluster");
92
93 let actual_clusters = num_clusters.min(vectors.len());
94 let dim = vectors[0].len();
95 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
96
97 let mut indices: Vec<usize> = (0..vectors.len()).collect();
99 indices.shuffle(&mut rng);
100
101 let mut centroids: Vec<f32> = indices[..actual_clusters]
102 .iter()
103 .flat_map(|&i| vectors[i].iter().copied())
104 .collect();
105
106 for _ in 0..max_iters {
108 let assignments: Vec<usize> = vectors
109 .iter()
110 .map(|v| Self::find_nearest_centroid_idx(v, ¢roids, dim))
111 .collect();
112
113 let mut new_centroids = vec![0.0f32; actual_clusters * dim];
114 let mut counts = vec![0usize; actual_clusters];
115
116 for (vec_idx, &cluster_id) in assignments.iter().enumerate() {
117 counts[cluster_id] += 1;
118 let offset = cluster_id * dim;
119 for (i, &val) in vectors[vec_idx].iter().enumerate() {
120 new_centroids[offset + i] += val;
121 }
122 }
123
124 for cluster_id in 0..actual_clusters {
125 if counts[cluster_id] > 0 {
126 let offset = cluster_id * dim;
127 for i in 0..dim {
128 new_centroids[offset + i] /= counts[cluster_id] as f32;
129 }
130 }
131 }
132
133 centroids = new_centroids;
134 }
135
136 let version = std::time::SystemTime::now()
137 .duration_since(std::time::UNIX_EPOCH)
138 .unwrap_or_default()
139 .as_millis() as u64;
140
141 Self {
142 num_clusters: actual_clusters as u32,
143 dim,
144 centroids,
145 version,
146 }
147 }
148
149 #[allow(dead_code)]
151 fn kmeans_plusplus_init(
152 vectors: &[Vec<f32>],
153 num_clusters: usize,
154 rng: &mut impl Rng,
155 ) -> Vec<f32> {
156 let dim = vectors[0].len();
157 let mut centroids = Vec::with_capacity(num_clusters * dim);
158
159 let first_idx = rng.random_range(0..vectors.len());
161 centroids.extend_from_slice(&vectors[first_idx]);
162
163 for _ in 1..num_clusters {
165 let mut distances: Vec<f32> = vectors
166 .iter()
167 .map(|v| {
168 let mut min_dist = f32::MAX;
169 for c in 0..(centroids.len() / dim) {
170 let offset = c * dim;
171 let dist: f32 = v
172 .iter()
173 .zip(¢roids[offset..offset + dim])
174 .map(|(&a, &b)| (a - b) * (a - b))
175 .sum();
176 min_dist = min_dist.min(dist);
177 }
178 min_dist
179 })
180 .collect();
181
182 let total: f32 = distances.iter().sum();
184 if total > 0.0 {
185 for d in &mut distances {
186 *d /= total;
187 }
188 }
189
190 let r: f32 = rng.random();
192 let mut cumsum = 0.0;
193 let mut chosen_idx = 0;
194 for (i, &d) in distances.iter().enumerate() {
195 cumsum += d;
196 if cumsum >= r {
197 chosen_idx = i;
198 break;
199 }
200 }
201
202 centroids.extend_from_slice(&vectors[chosen_idx]);
203 }
204
205 centroids
206 }
207
208 fn find_nearest_centroid_idx(vector: &[f32], centroids: &[f32], dim: usize) -> usize {
210 let num_clusters = centroids.len() / dim;
211 let mut best_idx = 0;
212 let mut best_dist = f32::MAX;
213
214 for c in 0..num_clusters {
215 let offset = c * dim;
216 let dist: f32 = vector
217 .iter()
218 .zip(¢roids[offset..offset + dim])
219 .map(|(&a, &b)| (a - b) * (a - b))
220 .sum();
221
222 if dist < best_dist {
223 best_dist = dist;
224 best_idx = c;
225 }
226 }
227
228 best_idx
229 }
230
231 pub fn find_nearest(&self, vector: &[f32]) -> u32 {
233 Self::find_nearest_centroid_idx(vector, &self.centroids, self.dim) as u32
234 }
235
236 pub fn find_k_nearest(&self, vector: &[f32], k: usize) -> Vec<u32> {
238 let mut distances: Vec<(u32, f32)> = (0..self.num_clusters)
239 .map(|c| {
240 let offset = c as usize * self.dim;
241 let dist: f32 = vector
242 .iter()
243 .zip(&self.centroids[offset..offset + self.dim])
244 .map(|(&a, &b)| (a - b) * (a - b))
245 .sum();
246 (c, dist)
247 })
248 .collect();
249
250 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
251 distances.truncate(k);
252 distances.into_iter().map(|(c, _)| c).collect()
253 }
254
255 pub fn get_centroid(&self, cluster_id: u32) -> &[f32] {
257 let offset = cluster_id as usize * self.dim;
258 &self.centroids[offset..offset + self.dim]
259 }
260
261 pub fn save(&self, path: &Path) -> io::Result<()> {
263 let mut file = std::fs::File::create(path)?;
264 self.write_to(&mut file)
265 }
266
267 pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
269 writer.write_u32::<LittleEndian>(CENTROIDS_MAGIC)?;
270 writer.write_u32::<LittleEndian>(1)?; writer.write_u64::<LittleEndian>(self.version)?;
272 writer.write_u32::<LittleEndian>(self.num_clusters)?;
273 writer.write_u32::<LittleEndian>(self.dim as u32)?;
274
275 for &val in &self.centroids {
276 writer.write_f32::<LittleEndian>(val)?;
277 }
278
279 Ok(())
280 }
281
282 pub fn load(path: &Path) -> io::Result<Self> {
284 let data = std::fs::read(path)?;
285 Self::read_from(&mut Cursor::new(data))
286 }
287
288 pub fn read_from<R: Read>(reader: &mut R) -> io::Result<Self> {
290 let magic = reader.read_u32::<LittleEndian>()?;
291 if magic != CENTROIDS_MAGIC {
292 return Err(io::Error::new(
293 io::ErrorKind::InvalidData,
294 "Invalid centroids file magic",
295 ));
296 }
297
298 let _file_version = reader.read_u32::<LittleEndian>()?;
299 let version = reader.read_u64::<LittleEndian>()?;
300 let num_clusters = reader.read_u32::<LittleEndian>()?;
301 let dim = reader.read_u32::<LittleEndian>()? as usize;
302
303 let mut centroids = vec![0.0f32; num_clusters as usize * dim];
304 for val in &mut centroids {
305 *val = reader.read_f32::<LittleEndian>()?;
306 }
307
308 Ok(Self {
309 num_clusters,
310 dim,
311 centroids,
312 version,
313 })
314 }
315
316 pub fn to_bytes(&self) -> io::Result<Vec<u8>> {
318 let mut buf = Vec::new();
319 self.write_to(&mut buf)?;
320 Ok(buf)
321 }
322
323 pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
325 Self::read_from(&mut Cursor::new(data))
326 }
327}
328
329#[derive(Debug, Clone, Default, Serialize, Deserialize)]
331pub struct ClusterData {
332 pub doc_ids: Vec<u32>,
334 pub binary_codes: Vec<QuantizedVector>,
336 pub raw_vectors: Option<Vec<Vec<f32>>>,
338}
339
340impl ClusterData {
341 pub fn new() -> Self {
342 Self::default()
343 }
344
345 pub fn len(&self) -> usize {
346 self.doc_ids.len()
347 }
348
349 pub fn is_empty(&self) -> bool {
350 self.doc_ids.is_empty()
351 }
352
353 pub fn append(&mut self, other: &ClusterData, doc_id_offset: u32) {
355 for &doc_id in &other.doc_ids {
356 self.doc_ids.push(doc_id + doc_id_offset);
357 }
358 self.binary_codes.extend(other.binary_codes.iter().cloned());
359
360 if let Some(ref other_raw) = other.raw_vectors {
361 let raw = self.raw_vectors.get_or_insert_with(Vec::new);
362 raw.extend(other_raw.iter().cloned());
363 }
364 }
365}
366
367#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct IVFConfig {
370 pub dim: usize,
372 pub seed: u64,
374 pub query_bits: u8,
376 pub store_raw: bool,
378 pub default_nprobe: usize,
380}
381
382impl IVFConfig {
383 pub fn new(dim: usize) -> Self {
384 Self {
385 dim,
386 seed: 42,
387 query_bits: 4,
388 store_raw: true,
389 default_nprobe: 32,
390 }
391 }
392}
393
394#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct IVFRaBitQIndex {
397 pub config: IVFConfig,
399 pub centroids_version: u64,
401 pub random_signs: Vec<i8>,
403 pub random_perm: Vec<u32>,
405 pub clusters: HashMap<u32, ClusterData>,
407 pub num_vectors: usize,
409}
410
411impl IVFRaBitQIndex {
412 pub fn new(config: IVFConfig, centroids_version: u64) -> Self {
414 let dim = config.dim;
415 let mut rng = rand::rngs::StdRng::seed_from_u64(config.seed);
416
417 let random_signs: Vec<i8> = (0..dim)
419 .map(|_| if rng.random::<bool>() { 1 } else { -1 })
420 .collect();
421
422 let mut random_perm: Vec<u32> = (0..dim as u32).collect();
424 for i in (1..dim).rev() {
425 let j = rng.random_range(0..=i);
426 random_perm.swap(i, j);
427 }
428
429 Self {
430 config,
431 centroids_version,
432 random_signs,
433 random_perm,
434 clusters: HashMap::new(),
435 num_vectors: 0,
436 }
437 }
438
439 pub fn build(
441 config: IVFConfig,
442 coarse_centroids: &CoarseCentroids,
443 vectors: &[Vec<f32>],
444 doc_ids: Option<&[u32]>,
445 ) -> Self {
446 let mut index = Self::new(config.clone(), coarse_centroids.version);
447
448 for (i, vector) in vectors.iter().enumerate() {
449 let doc_id = doc_ids.map(|ids| ids[i]).unwrap_or(i as u32);
450 index.add_vector(coarse_centroids, doc_id, vector);
451 }
452
453 index
454 }
455
456 pub fn add_vector(&mut self, coarse_centroids: &CoarseCentroids, doc_id: u32, vector: &[f32]) {
458 let cluster_id = coarse_centroids.find_nearest(vector);
460
461 let centroid = coarse_centroids.get_centroid(cluster_id);
463
464 let binary_code = self.quantize_vector(vector, centroid);
466
467 let cluster = self.clusters.entry(cluster_id).or_default();
469 cluster.doc_ids.push(doc_id);
470 cluster.binary_codes.push(binary_code);
471
472 if self.config.store_raw {
473 cluster
474 .raw_vectors
475 .get_or_insert_with(Vec::new)
476 .push(vector.to_vec());
477 }
478
479 self.num_vectors += 1;
480 }
481
482 fn quantize_vector(&self, raw: &[f32], centroid: &[f32]) -> QuantizedVector {
484 let dim = self.config.dim;
485
486 let mut centered: Vec<f32> = raw.iter().zip(centroid).map(|(&v, &c)| v - c).collect();
488
489 let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
490 let dist_to_centroid = norm;
491
492 if norm > 1e-10 {
494 for x in &mut centered {
495 *x /= norm;
496 }
497 }
498
499 let transformed: Vec<f32> = (0..dim)
501 .map(|i| {
502 let src_idx = self.random_perm[i] as usize;
503 centered[src_idx] * self.random_signs[src_idx] as f32
504 })
505 .collect();
506
507 let num_bytes = dim.div_ceil(8);
509 let mut bits = vec![0u8; num_bytes];
510 let mut popcount = 0u32;
511
512 for i in 0..dim {
513 if transformed[i] >= 0.0 {
514 bits[i / 8] |= 1 << (i % 8);
515 popcount += 1;
516 }
517 }
518
519 let scale = 1.0 / (dim as f32).sqrt();
521 let mut self_dot = 0.0f32;
522 for i in 0..dim {
523 let o_bar_i = if (bits[i / 8] >> (i % 8)) & 1 == 1 {
524 scale
525 } else {
526 -scale
527 };
528 self_dot += transformed[i] * o_bar_i;
529 }
530
531 QuantizedVector {
532 bits,
533 dist_to_centroid,
534 self_dot,
535 popcount,
536 }
537 }
538
539 pub fn search(
541 &self,
542 coarse_centroids: &CoarseCentroids,
543 query: &[f32],
544 k: usize,
545 nprobe: usize,
546 ) -> Vec<(u32, f32)> {
547 let nearest_clusters = coarse_centroids.find_k_nearest(query, nprobe);
549
550 let mut candidates: Vec<(u32, f32)> = Vec::new();
552
553 for cluster_id in nearest_clusters {
554 if let Some(cluster) = self.clusters.get(&cluster_id) {
555 let centroid = coarse_centroids.get_centroid(cluster_id);
556 let prepared = self.prepare_query(query, centroid);
557
558 for (i, binary_code) in cluster.binary_codes.iter().enumerate() {
559 let dist = self.estimate_distance(&prepared, binary_code);
560 candidates.push((cluster.doc_ids[i], dist));
561 }
562 }
563 }
564
565 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
567
568 let rerank_count = (k * 3).min(candidates.len());
570 if rerank_count > 0 {
571 let mut reranked: Vec<(u32, f32)> = Vec::with_capacity(rerank_count);
572
573 for &(doc_id, _) in candidates.iter().take(rerank_count) {
574 for cluster in self.clusters.values() {
576 if let Some(pos) = cluster.doc_ids.iter().position(|&d| d == doc_id) {
577 if let Some(ref raw_vecs) = cluster.raw_vectors {
578 let raw_vec = &raw_vecs[pos];
579 let dist: f32 = query
580 .iter()
581 .zip(raw_vec.iter())
582 .map(|(&a, &b)| (a - b).powi(2))
583 .sum();
584 reranked.push((doc_id, dist));
585 }
586 break;
587 }
588 }
589 }
590
591 if !reranked.is_empty() {
592 reranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
593 reranked.truncate(k);
594 return reranked;
595 }
596 }
597
598 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
600 candidates.truncate(k);
601 candidates
602 }
603
604 fn prepare_query(&self, raw_query: &[f32], centroid: &[f32]) -> PreparedQuery {
606 let dim = self.config.dim;
607
608 let mut centered: Vec<f32> = raw_query
610 .iter()
611 .zip(centroid)
612 .map(|(&v, &c)| v - c)
613 .collect();
614
615 let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
616 let dist_to_centroid = norm;
617
618 if norm > 1e-10 {
620 for x in &mut centered {
621 *x /= norm;
622 }
623 }
624
625 let transformed: Vec<f32> = (0..dim)
627 .map(|i| {
628 let src_idx = self.random_perm[i] as usize;
629 centered[src_idx] * self.random_signs[src_idx] as f32
630 })
631 .collect();
632
633 let min_val = transformed.iter().cloned().fold(f32::INFINITY, f32::min);
635 let max_val = transformed
636 .iter()
637 .cloned()
638 .fold(f32::NEG_INFINITY, f32::max);
639 let lower = min_val;
640 let width = if max_val > min_val {
641 max_val - min_val
642 } else {
643 1.0
644 };
645
646 let quantized_vals: Vec<u8> = transformed
648 .iter()
649 .map(|&x| {
650 let normalized = (x - lower) / width;
651 (normalized * 15.0).round().clamp(0.0, 15.0) as u8
652 })
653 .collect();
654
655 let sum: u32 = quantized_vals.iter().map(|&x| x as u32).sum();
657
658 let num_luts = dim.div_ceil(4);
660 let mut luts = vec![[0u16; 16]; num_luts];
661
662 for (lut_idx, lut) in luts.iter_mut().enumerate() {
663 let base_dim = lut_idx * 4;
664 for pattern in 0u8..16 {
665 let mut dot = 0u16;
666 for bit in 0..4 {
667 let dim_idx = base_dim + bit;
668 if dim_idx < dim && (pattern >> bit) & 1 == 1 {
669 dot += quantized_vals[dim_idx] as u16;
670 }
671 }
672 lut[pattern as usize] = dot;
673 }
674 }
675
676 PreparedQuery {
677 dist_to_centroid,
678 lower,
679 width,
680 sum,
681 luts,
682 }
683 }
684
685 fn estimate_distance(&self, query: &PreparedQuery, vec: &QuantizedVector) -> f32 {
687 let dim = self.config.dim;
688
689 let mut dot_sum = 0u32;
691 for (lut_idx, lut) in query.luts.iter().enumerate() {
692 let base_bit = lut_idx * 4;
693 let byte_idx = base_bit / 8;
694 let bit_offset = base_bit % 8;
695
696 let byte = vec.bits.get(byte_idx).copied().unwrap_or(0);
697 let next_byte = vec.bits.get(byte_idx + 1).copied().unwrap_or(0);
698
699 let pattern = if bit_offset <= 4 {
700 (byte >> bit_offset) & 0x0F
701 } else {
702 ((byte >> bit_offset) | (next_byte << (8 - bit_offset))) & 0x0F
703 };
704
705 dot_sum += lut[pattern as usize] as u32;
706 }
707
708 let scale = 1.0 / (dim as f32).sqrt();
710
711 let sum_positive = vec.popcount as f32 * query.lower + dot_sum as f32 * query.width / 15.0;
714
715 let sum_all = dim as f32 * query.lower + query.sum as f32 * query.width / 15.0;
717
718 let q_obar_dot = scale * (2.0 * sum_positive - sum_all);
720
721 let q_o_estimate = if vec.self_dot.abs() > 1e-6 {
723 q_obar_dot / vec.self_dot
724 } else {
725 q_obar_dot
726 };
727
728 let q_o_clamped = q_o_estimate.clamp(-1.0, 1.0);
730
731 let dist_sq = vec.dist_to_centroid * vec.dist_to_centroid
733 + query.dist_to_centroid * query.dist_to_centroid
734 - 2.0 * vec.dist_to_centroid * query.dist_to_centroid * q_o_clamped;
735
736 dist_sq.max(0.0)
737 }
738
739 pub fn merge(
741 indexes: &[&IVFRaBitQIndex],
742 doc_id_offsets: &[u32],
743 ) -> Result<Self, &'static str> {
744 if indexes.is_empty() {
745 return Err("No indexes to merge");
746 }
747
748 let version = indexes[0].centroids_version;
750 for idx in indexes.iter().skip(1) {
751 if idx.centroids_version != version {
752 return Err("Cannot merge indexes with different centroid versions");
753 }
754 }
755
756 let config = indexes[0].config.clone();
757 let mut merged = Self::new(config, version);
758
759 for (seg_idx, index) in indexes.iter().enumerate() {
761 let offset = doc_id_offsets[seg_idx];
762
763 for (&cluster_id, cluster_data) in &index.clusters {
764 let merged_cluster = merged.clusters.entry(cluster_id).or_default();
765
766 merged_cluster.append(cluster_data, offset);
767 }
768
769 merged.num_vectors += index.num_vectors;
770 }
771
772 Ok(merged)
773 }
774
775 pub fn num_clusters(&self) -> usize {
777 self.clusters.len()
778 }
779
780 pub fn len(&self) -> usize {
782 self.num_vectors
783 }
784
785 pub fn is_empty(&self) -> bool {
786 self.num_vectors == 0
787 }
788}
789
790struct PreparedQuery {
792 dist_to_centroid: f32,
793 lower: f32,
794 width: f32,
795 #[allow(dead_code)]
796 sum: u32,
797 luts: Vec<[u16; 16]>,
798}
799
800#[cfg(test)]
801mod tests {
802 use super::*;
803
804 #[test]
805 fn test_coarse_centroids_train() {
806 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
808 let vectors: Vec<Vec<f32>> = (0..1000)
809 .map(|_| (0..64).map(|_| rng.random::<f32>()).collect())
810 .collect();
811
812 let centroids = CoarseCentroids::train(&vectors, 16, 10, 42);
813
814 assert_eq!(centroids.num_clusters, 16);
815 assert_eq!(centroids.dim, 64);
816 assert_eq!(centroids.centroids.len(), 16 * 64);
817 }
818
819 #[test]
820 fn test_coarse_centroids_save_load() {
821 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
822 let vectors: Vec<Vec<f32>> = (0..100)
823 .map(|_| (0..32).map(|_| rng.random::<f32>()).collect())
824 .collect();
825
826 let centroids = CoarseCentroids::train(&vectors, 8, 5, 42);
827 let bytes = centroids.to_bytes().unwrap();
828 let loaded = CoarseCentroids::from_bytes(&bytes).unwrap();
829
830 assert_eq!(centroids.num_clusters, loaded.num_clusters);
831 assert_eq!(centroids.dim, loaded.dim);
832 assert_eq!(centroids.centroids, loaded.centroids);
833 }
834
835 #[test]
836 fn test_ivf_build_and_search() {
837 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
838 let dim = 64;
839
840 let vectors: Vec<Vec<f32>> = (0..1000)
842 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
843 .collect();
844
845 let centroids = CoarseCentroids::train(&vectors, 16, 10, 42);
847
848 let config = IVFConfig::new(dim);
850 let index = IVFRaBitQIndex::build(config, ¢roids, &vectors, None);
851
852 assert_eq!(index.len(), 1000);
853 assert!(index.num_clusters() <= 16);
854
855 let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
857 let results = index.search(¢roids, &query, 10, 4);
858
859 assert_eq!(results.len(), 10);
860 }
861
862 #[test]
863 fn test_ivf_merge() {
864 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
865 let dim = 32;
866
867 let vectors1: Vec<Vec<f32>> = (0..500)
869 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
870 .collect();
871 let vectors2: Vec<Vec<f32>> = (0..500)
872 .map(|_| (0..dim).map(|_| rng.random::<f32>()).collect())
873 .collect();
874
875 let all_vectors: Vec<Vec<f32>> = vectors1.iter().chain(vectors2.iter()).cloned().collect();
877 let centroids = CoarseCentroids::train(&all_vectors, 8, 10, 42);
878
879 let config = IVFConfig::new(dim);
881 let index1 = IVFRaBitQIndex::build(config.clone(), ¢roids, &vectors1, None);
882 let index2 = IVFRaBitQIndex::build(config, ¢roids, &vectors2, None);
883
884 let merged = IVFRaBitQIndex::merge(&[&index1, &index2], &[0, 500]).unwrap();
886
887 assert_eq!(merged.len(), 1000);
888
889 let query: Vec<f32> = (0..dim).map(|_| rng.random::<f32>()).collect();
891 let results = merged.search(¢roids, &query, 10, 4);
892
893 assert_eq!(results.len(), 10);
894 }
895}