1use super::binary::BinaryStore;
9use super::distance;
10
11use rand::prelude::*;
12use rayon::prelude::*;
13use std::cell::UnsafeCell;
14use std::collections::BinaryHeap;
15
16pub struct SpMat {
18 pub rows: usize,
19 pub cols: usize,
20 pub indptr: Vec<i64>,
21 pub indices: Vec<i32>,
22}
23
24pub enum VecStore {
26 U8(Vec<u8>),
27 F32(Vec<f32>),
28}
29
30pub enum QueryStore<'a> {
32 U8(&'a [u8]),
33 F32(&'a [f32]),
34}
35
36enum QueryVec<'a> {
38 U8(&'a [u8]),
39 F32(&'a [f32]),
40}
41
42#[inline]
45fn compute_dist(store: &VecStore, gid: usize, query: &QueryVec, dim: usize) -> u32 {
46 match (store, query) {
47 (VecStore::U8(v), QueryVec::U8(q)) => distance::l2_sq8(q, &v[gid * dim..(gid + 1) * dim]),
48 (VecStore::F32(v), QueryVec::F32(q)) => {
49 distance::l2_squared(q, &v[gid * dim..(gid + 1) * dim]).to_bits()
50 }
51 _ => unreachable!("mismatched vector/query types"),
52 }
53}
54
55pub struct IvfIndex {
57 pub vectors: VecStore,
59 pub original_ids: Vec<u32>,
61 pub cluster_starts: Vec<u32>,
63 tag_offsets: Vec<u32>,
65 tag_index: Vec<(u32, u32, u32)>,
67 posting_ids: Vec<u32>,
69 pub tag_clusters: Vec<Vec<u16>>,
71 pub dim: usize,
73 pub n_clusters: usize,
75}
76
77impl IvfIndex {
78 pub fn build(
83 base: &VecStore,
84 base_meta: &SpMat,
85 assignments: &[u16],
86 n: usize,
87 dim: usize,
88 n_clusters: usize,
89 ) -> Self {
90 let mut cluster_sizes = vec![0u32; n_clusters];
91 for &a in assignments {
92 cluster_sizes[a as usize] += 1;
93 }
94 let mut cluster_starts = vec![0u32; n_clusters + 1];
95 for i in 0..n_clusters {
96 cluster_starts[i + 1] = cluster_starts[i] + cluster_sizes[i];
97 }
98
99 let mut position = cluster_starts[..n_clusters].to_vec();
100 let mut new_order = vec![0u32; n];
101 for (i, &ci_raw) in assignments.iter().enumerate().take(n) {
102 let ci = ci_raw as usize;
103 let new_id = position[ci] as usize;
104 new_order[new_id] = i as u32;
105 position[ci] += 1;
106 }
107
108 macro_rules! reorder_and_sort {
109 ($base_data:expr, $zero:expr, $T:ty) => {{
110 let mut vecs = vec![$zero; n * dim];
111 for (new_id, &old_id) in new_order.iter().enumerate() {
112 let src = &$base_data[old_id as usize * dim..(old_id as usize + 1) * dim];
113 vecs[new_id * dim..(new_id + 1) * dim].copy_from_slice(src);
114 }
115
116 let mut tag_freq = vec![0u32; base_meta.cols + 1];
118 for &tag in &base_meta.indices {
119 tag_freq[tag as usize] += 1;
120 }
121 for ci in 0..n_clusters {
122 let cs = cluster_starts[ci] as usize;
123 let ce = cluster_starts[ci + 1] as usize;
124 if ce - cs <= 1 {
125 continue;
126 }
127
128 let mut sort_keys: Vec<(u32, usize)> = (0..ce - cs)
129 .map(|local| {
130 let old_id = new_order[cs + local] as usize;
131 let ms = base_meta.indptr[old_id] as usize;
132 let me = base_meta.indptr[old_id + 1] as usize;
133 let tag = base_meta.indices[ms..me]
134 .iter()
135 .max_by_key(|&&t| tag_freq[t as usize])
136 .map(|&t| t as u32)
137 .unwrap_or(u32::MAX);
138 (tag, local)
139 })
140 .collect();
141 sort_keys.sort_unstable_by_key(|&(tag, _)| tag);
142
143 let old_vecs: Vec<$T> = vecs[cs * dim..ce * dim].to_vec();
144 let old_ids: Vec<u32> = new_order[cs..ce].to_vec();
145 for (new_local, &(_, old_local)) in sort_keys.iter().enumerate() {
146 vecs[(cs + new_local) * dim..(cs + new_local + 1) * dim]
147 .copy_from_slice(&old_vecs[old_local * dim..(old_local + 1) * dim]);
148 new_order[cs + new_local] = old_ids[old_local];
149 }
150 }
151 vecs
152 }};
153 }
154
155 let vectors = match base {
156 VecStore::U8(data) => VecStore::U8(reorder_and_sort!(data, 0u8, u8)),
157 VecStore::F32(data) => VecStore::F32(reorder_and_sort!(data, 0.0f32, f32)),
158 };
159
160 let mut old_to_new = vec![0u32; n];
162 for (new_id, &old_id) in new_order.iter().enumerate() {
163 old_to_new[old_id as usize] = new_id as u32;
164 }
165
166 let mut all_tag_entries: Vec<Vec<(u32, u32, u32)>> = Vec::with_capacity(n_clusters);
167 let mut all_posting_ids: Vec<u32> = Vec::new();
168
169 let mut cluster_maps: Vec<std::collections::HashMap<u32, Vec<u32>>> = (0..n_clusters)
170 .map(|_| std::collections::HashMap::new())
171 .collect();
172
173 for old_id in 0..n {
174 let new_id = old_to_new[old_id] as usize;
175 let ci = assignments[old_id] as usize;
176 let local_id = new_id - cluster_starts[ci] as usize;
177
178 let start = base_meta.indptr[old_id] as usize;
179 let end = base_meta.indptr[old_id + 1] as usize;
180 for &tag in &base_meta.indices[start..end] {
181 cluster_maps[ci]
182 .entry(tag as u32)
183 .or_default()
184 .push(local_id as u32);
185 }
186 }
187
188 for cluster_map in cluster_maps.iter_mut().take(n_clusters) {
189 let mut entries: Vec<(u32, Vec<u32>)> = cluster_map.drain().collect();
190 entries.sort_unstable_by_key(|&(tag, _)| tag);
191
192 let mut cluster_entries = Vec::with_capacity(entries.len());
193 for (tag, mut ids) in entries {
194 ids.sort_unstable();
195 let posting_start = all_posting_ids.len() as u32;
196 let posting_len = ids.len() as u32;
197 all_posting_ids.extend_from_slice(&ids);
198 cluster_entries.push((tag, posting_start, posting_len));
199 }
200 all_tag_entries.push(cluster_entries);
201 }
202
203 let mut tag_offsets = Vec::with_capacity(n_clusters + 1);
204 let mut tag_index = Vec::new();
205 let mut offset = 0u32;
206 for entries in &all_tag_entries {
207 tag_offsets.push(offset);
208 tag_index.extend_from_slice(entries);
209 offset += entries.len() as u32;
210 }
211 tag_offsets.push(offset);
212
213 let max_tag = tag_index.iter().map(|&(t, _, _)| t).max().unwrap_or(0) as usize;
215 let mut tag_clusters: Vec<Vec<u16>> = vec![vec![]; max_tag + 1];
216 for ci in 0..n_clusters {
217 let start = tag_offsets[ci] as usize;
218 let end = tag_offsets[ci + 1] as usize;
219 for &(tag, _, _) in &tag_index[start..end] {
220 tag_clusters[tag as usize].push(ci as u16);
221 }
222 }
223
224 Self {
225 vectors,
226 original_ids: new_order,
227 cluster_starts,
228 tag_offsets,
229 tag_index,
230 posting_ids: all_posting_ids,
231 tag_clusters,
232 dim,
233 n_clusters,
234 }
235 }
236
237 #[inline]
239 fn lookup_tag(&self, cluster: usize, tag: u32) -> &[u32] {
240 let start = self.tag_offsets[cluster] as usize;
241 let end = self.tag_offsets[cluster + 1] as usize;
242 let entries = &self.tag_index[start..end];
243 match entries.binary_search_by_key(&tag, |&(t, _, _)| t) {
244 Ok(idx) => {
245 let (_, ps, pl) = entries[idx];
246 &self.posting_ids[ps as usize..(ps + pl) as usize]
247 }
248 Err(_) => &[],
249 }
250 }
251
252 #[allow(clippy::too_many_arguments)]
255 fn scan_cluster(
256 &self,
257 ci: usize,
258 lids: impl ExactSizeIterator<Item = u32>,
259 query: &QueryVec,
260 q_binary: &[u64],
261 binary: &BinaryStore,
262 ef: usize,
263 binary_rerank: usize,
264 heap: &mut BinaryHeap<(u32, u32)>,
265 ) {
266 let dim = self.dim;
267 let cluster_base = self.cluster_starts[ci] as usize;
268 let rerank_budget = binary_rerank * ef;
269
270 if binary_rerank > 0 && lids.len() > rerank_budget {
271 let mut candidates: Vec<(u32, u32)> = lids
272 .map(|lid| {
273 let gid = (cluster_base + lid as usize) as u32;
274 (distance::hamming(q_binary, binary.code(gid)), lid)
275 })
276 .collect();
277 let budget = rerank_budget.min(candidates.len());
278 candidates.select_nth_unstable_by_key(budget - 1, |&(d, _)| d);
279 candidates.truncate(budget);
280 for &(_, lid) in &candidates {
281 let gid = (cluster_base + lid as usize) as u32;
282 let dist = compute_dist(&self.vectors, gid as usize, query, dim);
283 let orig_id = self.original_ids[gid as usize];
284 heap_insert(heap, dist, orig_id, ef);
285 }
286 } else {
287 for lid in lids {
288 let gid = (cluster_base + lid as usize) as u32;
289 let dist = compute_dist(&self.vectors, gid as usize, query, dim);
290 let orig_id = self.original_ids[gid as usize];
291 heap_insert(heap, dist, orig_id, ef);
292 }
293 }
294 }
295
296 #[allow(clippy::too_many_arguments)]
298 pub fn batch_search_mqcb(
299 &self,
300 queries: &QueryStore,
301 nq: usize,
302 query_tags: &[Vec<usize>],
303 query_binary: &[Vec<u64>],
304 query_top_clusters: &[Vec<usize>],
305 binary: &BinaryStore,
306 k: usize,
307 ef: usize,
308 n_probe: usize,
309 binary_rerank: usize,
310 ) -> Vec<Vec<u32>> {
311 let dim = self.dim;
312
313 let mut cluster_queries: Vec<Vec<usize>> = vec![vec![]; self.n_clusters];
315 for (qi, top_clusters) in query_top_clusters.iter().enumerate().take(nq) {
316 let np = n_probe.min(top_clusters.len());
317 for &ci in &top_clusters[..np] {
318 cluster_queries[ci].push(qi);
319 }
320 }
321
322 struct HeapArray(Vec<UnsafeCell<BinaryHeap<(u32, u32)>>>);
325 unsafe impl Sync for HeapArray {}
326 impl HeapArray {
327 #[inline]
328 #[allow(clippy::mut_from_ref)]
329 unsafe fn get(&self, idx: usize) -> &mut BinaryHeap<(u32, u32)> {
330 &mut *self.0[idx].get()
331 }
332 }
333 let heaps = HeapArray(
334 (0..nq)
335 .map(|_| UnsafeCell::new(BinaryHeap::with_capacity(ef + 1)))
336 .collect(),
337 );
338
339 for (ci, qi_list) in cluster_queries.iter().enumerate() {
341 if qi_list.is_empty() {
342 continue;
343 }
344
345 qi_list.par_iter().for_each(|&qi| {
346 let query = match queries {
347 QueryStore::U8(data) => QueryVec::U8(&data[qi * dim..(qi + 1) * dim]),
348 QueryStore::F32(data) => QueryVec::F32(&data[qi * dim..(qi + 1) * dim]),
349 };
350 let tags = &query_tags[qi];
351 let heap = unsafe { heaps.get(qi) };
352
353 match tags.len() {
354 0 => {
356 let len = self.cluster_starts[ci + 1] - self.cluster_starts[ci];
357 self.scan_cluster(
358 ci,
359 0..len,
360 &query,
361 &query_binary[qi],
362 binary,
363 ef,
364 binary_rerank,
365 heap,
366 );
367 }
368 1 => {
369 let matching = self.lookup_tag(ci, tags[0] as u32);
370 self.scan_cluster(
371 ci,
372 matching.iter().copied(),
373 &query,
374 &query_binary[qi],
375 binary,
376 ef,
377 binary_rerank,
378 heap,
379 );
380 }
381 _ => {
383 let lists: Vec<&[u32]> = tags
384 .iter()
385 .map(|&t| self.lookup_tag(ci, t as u32))
386 .collect();
387 let matching = intersect_postings(lists);
388 self.scan_cluster(
389 ci,
390 matching.iter().copied(),
391 &query,
392 &query_binary[qi],
393 binary,
394 ef,
395 binary_rerank,
396 heap,
397 );
398 }
399 }
400 });
401 }
402
403 heaps
404 .0
405 .into_par_iter()
406 .map(|cell| {
407 let heap = cell.into_inner();
408 let mut results: Vec<(u32, u32)> = heap.into_vec();
409 results.sort_unstable_by_key(|&(d, _)| d);
410 results.iter().take(k).map(|&(_, id)| id).collect()
411 })
412 .collect()
413 }
414}
415
416#[inline]
418fn heap_insert(heap: &mut BinaryHeap<(u32, u32)>, dist: u32, id: u32, cap: usize) {
419 if heap.len() < cap {
420 heap.push((dist, id));
421 } else if let Some(mut top) = heap.peek_mut() {
422 if dist < top.0 {
423 *top = (dist, id);
424 }
425 }
426}
427
428fn intersect_postings(mut lists: Vec<&[u32]>) -> Vec<u32> {
431 lists.sort_unstable_by_key(|l| l.len());
432 let mut acc: Vec<u32> = lists[0].to_vec();
433 for list in &lists[1..] {
434 if acc.is_empty() {
435 break;
436 }
437 let mut out = Vec::with_capacity(acc.len().min(list.len()));
438 let (mut i, mut j) = (0, 0);
439 while i < acc.len() && j < list.len() {
440 match acc[i].cmp(&list[j]) {
441 std::cmp::Ordering::Less => i += 1,
442 std::cmp::Ordering::Greater => j += 1,
443 std::cmp::Ordering::Equal => {
444 out.push(acc[i]);
445 i += 1;
446 j += 1;
447 }
448 }
449 }
450 acc = out;
451 }
452 acc
453}
454
455pub fn sorted_intersect_u16(a: &[u16], b: &[u16]) -> Vec<u16> {
457 let mut result = Vec::new();
458 let (mut i, mut j) = (0, 0);
459 while i < a.len() && j < b.len() {
460 match a[i].cmp(&b[j]) {
461 std::cmp::Ordering::Less => i += 1,
462 std::cmp::Ordering::Greater => j += 1,
463 std::cmp::Ordering::Equal => {
464 result.push(a[i]);
465 i += 1;
466 j += 1;
467 }
468 }
469 }
470 result
471}
472
473pub fn kmeans(
475 base: &VecStore,
476 n: usize,
477 dim: usize,
478 c: usize,
479 iters: usize,
480) -> (Vec<u16>, VecStore) {
481 let mut rng = StdRng::seed_from_u64(42);
482 let mut centroid_ids: Vec<usize> = (0..n).collect();
483 centroid_ids.shuffle(&mut rng);
484 centroid_ids.truncate(c);
485
486 let mut centroids_f32 = vec![0.0f32; c * dim];
487 match base {
488 VecStore::U8(data) => {
489 for (ci, &vid) in centroid_ids.iter().enumerate() {
490 for d in 0..dim {
491 centroids_f32[ci * dim + d] = data[vid * dim + d] as f32;
492 }
493 }
494 }
495 VecStore::F32(data) => {
496 for (ci, &vid) in centroid_ids.iter().enumerate() {
497 centroids_f32[ci * dim..(ci + 1) * dim]
498 .copy_from_slice(&data[vid * dim..(vid + 1) * dim]);
499 }
500 }
501 }
502
503 let mut assignments = vec![0u16; n];
504
505 for _ in 0..iters {
506 let new_assignments: Vec<u16> = match base {
507 VecStore::U8(data) => {
508 let centroids_u8: Vec<u8> = centroids_f32
509 .iter()
510 .map(|&x| x.round().clamp(0.0, 255.0) as u8)
511 .collect();
512 (0..n)
513 .into_par_iter()
514 .map(|i| {
515 let v = &data[i * dim..(i + 1) * dim];
516 let mut best_c = 0u16;
517 let mut best_d = u32::MAX;
518 for ci in 0..c {
519 let cent = ¢roids_u8[ci * dim..(ci + 1) * dim];
520 let d = distance::l2_sq8(v, cent);
521 if d < best_d {
522 best_d = d;
523 best_c = ci as u16;
524 }
525 }
526 best_c
527 })
528 .collect()
529 }
530 VecStore::F32(data) => (0..n)
531 .into_par_iter()
532 .map(|i| {
533 let v = &data[i * dim..(i + 1) * dim];
534 let mut best_c = 0u16;
535 let mut best_d = f32::INFINITY;
536 for ci in 0..c {
537 let cent = ¢roids_f32[ci * dim..(ci + 1) * dim];
538 let d = distance::l2_squared(v, cent);
539 if d < best_d {
540 best_d = d;
541 best_c = ci as u16;
542 }
543 }
544 best_c
545 })
546 .collect(),
547 };
548 assignments = new_assignments;
549
550 let mut sums = vec![0.0f64; c * dim];
552 let mut counts = vec![0u32; c];
553 match base {
554 VecStore::U8(data) => {
555 for i in 0..n {
556 let ci = assignments[i] as usize;
557 counts[ci] += 1;
558 for d in 0..dim {
559 sums[ci * dim + d] += data[i * dim + d] as f64;
560 }
561 }
562 }
563 VecStore::F32(data) => {
564 for i in 0..n {
565 let ci = assignments[i] as usize;
566 counts[ci] += 1;
567 for d in 0..dim {
568 sums[ci * dim + d] += data[i * dim + d] as f64;
569 }
570 }
571 }
572 }
573 for ci in 0..c {
574 if counts[ci] > 0 {
575 let inv = 1.0 / counts[ci] as f64;
576 for d in 0..dim {
577 centroids_f32[ci * dim + d] = (sums[ci * dim + d] * inv) as f32;
578 }
579 }
580 }
581
582 for ci in 0..c {
586 if counts[ci] > 0 {
587 continue;
588 }
589 let donor = (0..c).max_by_key(|&d| counts[d]).unwrap();
590 if counts[donor] <= 1 {
591 break;
592 }
593 let members: Vec<usize> = (0..n)
594 .filter(|&i| assignments[i] as usize == donor)
595 .collect();
596 let p = members[rng.gen_range(0..members.len())];
597 match base {
598 VecStore::U8(data) => {
599 for d in 0..dim {
600 centroids_f32[ci * dim + d] = data[p * dim + d] as f32;
601 }
602 }
603 VecStore::F32(data) => {
604 centroids_f32[ci * dim..(ci + 1) * dim]
605 .copy_from_slice(&data[p * dim..(p + 1) * dim]);
606 }
607 }
608 assignments[p] = ci as u16;
609 counts[donor] -= 1;
610 counts[ci] = 1;
611 }
612 }
613
614 let centroids = match base {
615 VecStore::U8(_) => VecStore::U8(
616 centroids_f32
617 .iter()
618 .map(|&x| x.round().clamp(0.0, 255.0) as u8)
619 .collect(),
620 ),
621 VecStore::F32(_) => VecStore::F32(centroids_f32),
622 };
623
624 (assignments, centroids)
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630 use crate::prism::point::PointStore;
631
632 fn fixture() -> (IvfIndex, BinaryStore) {
635 let points: Vec<Vec<f32>> = vec![
636 vec![0.0, 0.0],
637 vec![0.1, 0.0],
638 vec![0.2, 0.0],
639 vec![0.3, 0.0],
640 vec![5.0, 5.0],
641 vec![5.1, 5.0],
642 ];
643 let tag_sets: Vec<Vec<i32>> = vec![
644 vec![0, 1, 2],
645 vec![0, 1],
646 vec![0, 2],
647 vec![1, 2],
648 vec![0, 1, 2],
649 vec![3],
650 ];
651 let flat: Vec<f32> = points.iter().flatten().copied().collect();
652 let mut indptr = vec![0i64];
653 let mut indices = Vec::new();
654 for tags in &tag_sets {
655 indices.extend_from_slice(tags);
656 indptr.push(indices.len() as i64);
657 }
658 let meta = SpMat {
659 rows: points.len(),
660 cols: 4,
661 indptr,
662 indices,
663 };
664 let assignments: Vec<u16> = vec![0, 0, 0, 0, 1, 1];
665 let base = VecStore::F32(flat.clone());
666 let index = IvfIndex::build(&base, &meta, &assignments, points.len(), 2, 2);
667 let store = PointStore::from_parts(flat, 2, vec![vec![0; points.len()]]);
668 let binary = BinaryStore::build(&store);
669 (index, binary)
670 }
671
672 fn run_query(
673 index: &IvfIndex,
674 binary: &BinaryStore,
675 query: &[f32],
676 tags: Vec<usize>,
677 k: usize,
678 ) -> Vec<u32> {
679 let qb = binary.encode_query(query);
680 let mut results = index.batch_search_mqcb(
681 &QueryStore::F32(query),
682 1,
683 &[tags],
684 &[qb],
685 &[vec![0, 1]],
686 binary,
687 k,
688 10,
689 2,
690 0,
691 );
692 results.pop().unwrap()
693 }
694
695 #[test]
696 fn batch_zero_tags_scans_whole_clusters() {
697 let (index, binary) = fixture();
698 let mut ids = run_query(&index, &binary, &[5.05, 5.0], Vec::new(), 2);
699 ids.sort_unstable();
700 assert_eq!(ids, vec![4, 5]);
701 }
702
703 #[test]
704 fn batch_three_tags_enforces_full_conjunction() {
705 let (index, binary) = fixture();
706 let ids = run_query(&index, &binary, &[0.05, 0.0], vec![0, 1, 2], 4);
707 let mut sorted = ids.clone();
708 sorted.sort_unstable();
709 assert_eq!(sorted, vec![0, 4]);
712 }
713
714 #[test]
715 fn kmeans_reseeds_empty_clusters() {
716 let n = 64;
719 let mut flat = vec![0.0f32; n * 2];
720 for (i, off) in [(60, 50.0f32), (61, -50.0), (62, 100.0), (63, -100.0)] {
721 flat[i * 2] = off;
722 flat[i * 2 + 1] = off;
723 }
724 let (assignments, _) = kmeans(&VecStore::F32(flat), n, 2, 8, 3);
725 let mut seen = [false; 8];
726 for &a in &assignments {
727 seen[a as usize] = true;
728 }
729 assert!(
730 seen.iter().all(|&s| s),
731 "every cluster must keep at least one member, got {assignments:?}"
732 );
733 }
734}