1use super::binary::BinaryStore;
9use super::distance;
10
11use rand::prelude::*;
12use rayon::prelude::*;
13use std::cell::UnsafeCell;
14use std::collections::BinaryHeap;
15
16pub struct SpMat {
18 pub rows: usize,
19 pub cols: usize,
20 pub indptr: Vec<i64>,
21 pub indices: Vec<i32>,
22}
23
24pub enum VecStore {
26 U8(Vec<u8>),
27 F32(Vec<f32>),
28}
29
30pub enum QueryStore<'a> {
32 U8(&'a [u8]),
33 F32(&'a [f32]),
34}
35
36enum QueryVec<'a> {
38 U8(&'a [u8]),
39 F32(&'a [f32]),
40}
41
42#[inline]
45fn compute_dist(store: &VecStore, gid: usize, query: &QueryVec, dim: usize) -> u32 {
46 match (store, query) {
47 (VecStore::U8(v), QueryVec::U8(q)) => distance::l2_sq8(q, &v[gid * dim..(gid + 1) * dim]),
48 (VecStore::F32(v), QueryVec::F32(q)) => {
49 distance::l2_squared(q, &v[gid * dim..(gid + 1) * dim]).to_bits()
50 }
51 _ => unreachable!("mismatched vector/query types"),
52 }
53}
54
55pub struct IvfIndex {
57 pub vectors: VecStore,
59 pub original_ids: Vec<u32>,
61 pub cluster_starts: Vec<u32>,
63 tag_offsets: Vec<u32>,
65 tag_index: Vec<(u32, u32, u32)>,
67 posting_ids: Vec<u32>,
69 pub tag_clusters: Vec<Vec<u16>>,
71 pub dim: usize,
73 pub n_clusters: usize,
75}
76
77impl IvfIndex {
78 pub fn build(
83 base: &VecStore,
84 base_meta: &SpMat,
85 assignments: &[u16],
86 n: usize,
87 dim: usize,
88 n_clusters: usize,
89 ) -> Self {
90 let mut cluster_sizes = vec![0u32; n_clusters];
92 for &a in assignments {
93 cluster_sizes[a as usize] += 1;
94 }
95 let mut cluster_starts = vec![0u32; n_clusters + 1];
96 for i in 0..n_clusters {
97 cluster_starts[i + 1] = cluster_starts[i] + cluster_sizes[i];
98 }
99
100 let mut position = cluster_starts[..n_clusters].to_vec();
102 let mut new_order = vec![0u32; n];
103 for (i, &ci_raw) in assignments.iter().enumerate().take(n) {
104 let ci = ci_raw as usize;
105 let new_id = position[ci] as usize;
106 new_order[new_id] = i as u32;
107 position[ci] += 1;
108 }
109
110 macro_rules! reorder_and_sort {
112 ($base_data:expr, $zero:expr, $T:ty) => {{
113 let mut vecs = vec![$zero; n * dim];
114 for (new_id, &old_id) in new_order.iter().enumerate() {
115 let src = &$base_data[old_id as usize * dim..(old_id as usize + 1) * dim];
116 vecs[new_id * dim..(new_id + 1) * dim].copy_from_slice(src);
117 }
118
119 let mut tag_freq = vec![0u32; base_meta.cols + 1];
121 for &tag in &base_meta.indices {
122 tag_freq[tag as usize] += 1;
123 }
124 for ci in 0..n_clusters {
125 let cs = cluster_starts[ci] as usize;
126 let ce = cluster_starts[ci + 1] as usize;
127 if ce - cs <= 1 {
128 continue;
129 }
130
131 let mut sort_keys: Vec<(u32, usize)> = (0..ce - cs)
132 .map(|local| {
133 let old_id = new_order[cs + local] as usize;
134 let ms = base_meta.indptr[old_id] as usize;
135 let me = base_meta.indptr[old_id + 1] as usize;
136 let tag = base_meta.indices[ms..me]
137 .iter()
138 .max_by_key(|&&t| tag_freq[t as usize])
139 .map(|&t| t as u32)
140 .unwrap_or(u32::MAX);
141 (tag, local)
142 })
143 .collect();
144 sort_keys.sort_unstable_by_key(|&(tag, _)| tag);
145
146 let old_vecs: Vec<$T> = vecs[cs * dim..ce * dim].to_vec();
147 let old_ids: Vec<u32> = new_order[cs..ce].to_vec();
148 for (new_local, &(_, old_local)) in sort_keys.iter().enumerate() {
149 vecs[(cs + new_local) * dim..(cs + new_local + 1) * dim]
150 .copy_from_slice(&old_vecs[old_local * dim..(old_local + 1) * dim]);
151 new_order[cs + new_local] = old_ids[old_local];
152 }
153 }
154 vecs
155 }};
156 }
157
158 let vectors = match base {
159 VecStore::U8(data) => VecStore::U8(reorder_and_sort!(data, 0u8, u8)),
160 VecStore::F32(data) => VecStore::F32(reorder_and_sort!(data, 0.0f32, f32)),
161 };
162
163 let mut old_to_new = vec![0u32; n];
165 for (new_id, &old_id) in new_order.iter().enumerate() {
166 old_to_new[old_id as usize] = new_id as u32;
167 }
168
169 let mut all_tag_entries: Vec<Vec<(u32, u32, u32)>> = Vec::with_capacity(n_clusters);
171 let mut all_posting_ids: Vec<u32> = Vec::new();
172
173 let mut cluster_maps: Vec<std::collections::HashMap<u32, Vec<u32>>> = (0..n_clusters)
174 .map(|_| std::collections::HashMap::new())
175 .collect();
176
177 for old_id in 0..n {
178 let new_id = old_to_new[old_id] as usize;
179 let ci = assignments[old_id] as usize;
180 let local_id = new_id - cluster_starts[ci] as usize;
181
182 let start = base_meta.indptr[old_id] as usize;
183 let end = base_meta.indptr[old_id + 1] as usize;
184 for &tag in &base_meta.indices[start..end] {
185 cluster_maps[ci]
186 .entry(tag as u32)
187 .or_default()
188 .push(local_id as u32);
189 }
190 }
191
192 for cluster_map in cluster_maps.iter_mut().take(n_clusters) {
194 let mut entries: Vec<(u32, Vec<u32>)> = cluster_map.drain().collect();
195 entries.sort_unstable_by_key(|&(tag, _)| tag);
196
197 let mut cluster_entries = Vec::with_capacity(entries.len());
198 for (tag, mut ids) in entries {
199 ids.sort_unstable();
200 let posting_start = all_posting_ids.len() as u32;
201 let posting_len = ids.len() as u32;
202 all_posting_ids.extend_from_slice(&ids);
203 cluster_entries.push((tag, posting_start, posting_len));
204 }
205 all_tag_entries.push(cluster_entries);
206 }
207
208 let mut tag_offsets = Vec::with_capacity(n_clusters + 1);
210 let mut tag_index = Vec::new();
211 let mut offset = 0u32;
212 for entries in &all_tag_entries {
213 tag_offsets.push(offset);
214 tag_index.extend_from_slice(entries);
215 offset += entries.len() as u32;
216 }
217 tag_offsets.push(offset);
218
219 let total_posting = all_posting_ids.len();
220 let total_entries = tag_index.len();
221 eprintln!(
222 " IVF: {n_clusters} clusters, {total_entries} tag entries, {total_posting} posting IDs"
223 );
224
225 let max_tag = tag_index.iter().map(|&(t, _, _)| t).max().unwrap_or(0) as usize;
227 let mut tag_clusters: Vec<Vec<u16>> = vec![vec![]; max_tag + 1];
228 for ci in 0..n_clusters {
229 let start = tag_offsets[ci] as usize;
230 let end = tag_offsets[ci + 1] as usize;
231 for &(tag, _, _) in &tag_index[start..end] {
232 tag_clusters[tag as usize].push(ci as u16);
233 }
234 }
235
236 Self {
237 vectors,
238 original_ids: new_order,
239 cluster_starts,
240 tag_offsets,
241 tag_index,
242 posting_ids: all_posting_ids,
243 tag_clusters,
244 dim,
245 n_clusters,
246 }
247 }
248
249 #[inline]
251 fn lookup_tag(&self, cluster: usize, tag: u32) -> &[u32] {
252 let start = self.tag_offsets[cluster] as usize;
253 let end = self.tag_offsets[cluster + 1] as usize;
254 let entries = &self.tag_index[start..end];
255 match entries.binary_search_by_key(&tag, |&(t, _, _)| t) {
256 Ok(idx) => {
257 let (_, ps, pl) = entries[idx];
258 &self.posting_ids[ps as usize..(ps + pl) as usize]
259 }
260 Err(_) => &[],
261 }
262 }
263
264 #[allow(clippy::too_many_arguments)]
266 fn scan_cluster(
267 &self,
268 ci: usize,
269 matching: &[u32],
270 query: &QueryVec,
271 q_binary: &[u64],
272 binary: &BinaryStore,
273 ef: usize,
274 binary_rerank: usize,
275 heap: &mut BinaryHeap<(u32, u32)>,
276 ) {
277 if matching.is_empty() {
278 return;
279 }
280 let dim = self.dim;
281 let cluster_base = self.cluster_starts[ci] as usize;
282 let rerank_budget = binary_rerank * ef;
283
284 if binary_rerank > 0 && matching.len() > rerank_budget {
285 let mut candidates: Vec<(u32, u32)> = matching
286 .iter()
287 .map(|&lid| {
288 let gid = (cluster_base + lid as usize) as u32;
289 (distance::hamming(q_binary, binary.code(gid)), lid)
290 })
291 .collect();
292 let budget = rerank_budget.min(candidates.len());
293 candidates.select_nth_unstable_by_key(budget - 1, |&(d, _)| d);
294 candidates.truncate(budget);
295 for &(_, lid) in &candidates {
296 let gid = (cluster_base + lid as usize) as u32;
297 let dist = compute_dist(&self.vectors, gid as usize, query, dim);
298 let orig_id = self.original_ids[gid as usize];
299 heap_insert(heap, dist, orig_id, ef);
300 }
301 } else {
302 for &lid in matching {
303 let gid = (cluster_base + lid as usize) as u32;
304 let dist = compute_dist(&self.vectors, gid as usize, query, dim);
305 let orig_id = self.original_ids[gid as usize];
306 heap_insert(heap, dist, orig_id, ef);
307 }
308 }
309 }
310
311 #[allow(clippy::too_many_arguments)]
313 fn scan_cluster_intersect(
314 &self,
315 ci: usize,
316 list_a: &[u32],
317 list_b: &[u32],
318 query: &QueryVec,
319 q_binary: &[u64],
320 binary: &BinaryStore,
321 ef: usize,
322 binary_rerank: usize,
323 heap: &mut BinaryHeap<(u32, u32)>,
324 ) {
325 let dim = self.dim;
326 let cluster_base = self.cluster_starts[ci] as usize;
327 let rerank_budget = binary_rerank * ef;
328
329 let est = list_a.len().min(list_b.len());
330
331 if binary_rerank > 0 && est > rerank_budget {
332 let mut candidates: Vec<(u32, u32)> = Vec::new();
333 let (mut i, mut j) = (0, 0);
334 while i < list_a.len() && j < list_b.len() {
335 let a = list_a[i];
336 let b = list_b[j];
337 if a < b {
338 i += 1;
339 } else if a > b {
340 j += 1;
341 } else {
342 let gid = (cluster_base + a as usize) as u32;
343 let hd = distance::hamming(q_binary, binary.code(gid));
344 candidates.push((hd, gid));
345 i += 1;
346 j += 1;
347 }
348 }
349 if candidates.len() > rerank_budget {
350 candidates.select_nth_unstable_by_key(rerank_budget - 1, |&(d, _)| d);
351 candidates.truncate(rerank_budget);
352 }
353 for &(_, gid) in &candidates {
354 let dist = compute_dist(&self.vectors, gid as usize, query, dim);
355 let orig_id = self.original_ids[gid as usize];
356 heap_insert(heap, dist, orig_id, ef);
357 }
358 } else {
359 let (mut i, mut j) = (0, 0);
360 while i < list_a.len() && j < list_b.len() {
361 let a = list_a[i];
362 let b = list_b[j];
363 if a < b {
364 i += 1;
365 } else if a > b {
366 j += 1;
367 } else {
368 let gid = (cluster_base + a as usize) as u32;
369 let dist = compute_dist(&self.vectors, gid as usize, query, dim);
370 let orig_id = self.original_ids[gid as usize];
371 heap_insert(heap, dist, orig_id, ef);
372 i += 1;
373 j += 1;
374 }
375 }
376 }
377 }
378
379 #[allow(clippy::too_many_arguments)]
381 pub fn batch_search_mqcb(
382 &self,
383 queries: &QueryStore,
384 nq: usize,
385 query_tags: &[Vec<usize>],
386 query_binary: &[Vec<u64>],
387 query_top_clusters: &[Vec<usize>],
388 binary: &BinaryStore,
389 k: usize,
390 ef: usize,
391 n_probe: usize,
392 binary_rerank: usize,
393 ) -> Vec<Vec<u32>> {
394 let dim = self.dim;
395
396 let mut cluster_queries: Vec<Vec<usize>> = vec![vec![]; self.n_clusters];
398 for (qi, top_clusters) in query_top_clusters.iter().enumerate().take(nq) {
399 let np = n_probe.min(top_clusters.len());
400 for &ci in &top_clusters[..np] {
401 cluster_queries[ci].push(qi);
402 }
403 }
404
405 struct HeapArray(Vec<UnsafeCell<BinaryHeap<(u32, u32)>>>);
408 unsafe impl Sync for HeapArray {}
409 impl HeapArray {
410 #[inline]
411 #[allow(clippy::mut_from_ref)]
412 unsafe fn get(&self, idx: usize) -> &mut BinaryHeap<(u32, u32)> {
413 &mut *self.0[idx].get()
414 }
415 }
416 let heaps = HeapArray(
417 (0..nq)
418 .map(|_| UnsafeCell::new(BinaryHeap::with_capacity(ef + 1)))
419 .collect(),
420 );
421
422 for (ci, qi_list) in cluster_queries.iter().enumerate() {
424 if qi_list.is_empty() {
425 continue;
426 }
427
428 qi_list.par_iter().for_each(|&qi| {
429 let query = match queries {
430 QueryStore::U8(data) => QueryVec::U8(&data[qi * dim..(qi + 1) * dim]),
431 QueryStore::F32(data) => QueryVec::F32(&data[qi * dim..(qi + 1) * dim]),
432 };
433 let tags = &query_tags[qi];
434 let heap = unsafe { heaps.get(qi) };
435
436 if tags.len() == 1 {
437 let matching = self.lookup_tag(ci, tags[0] as u32);
438 self.scan_cluster(
439 ci,
440 matching,
441 &query,
442 &query_binary[qi],
443 binary,
444 ef,
445 binary_rerank,
446 heap,
447 );
448 } else {
449 let list_a = self.lookup_tag(ci, tags[0] as u32);
450 let list_b = self.lookup_tag(ci, tags[1] as u32);
451 self.scan_cluster_intersect(
452 ci,
453 list_a,
454 list_b,
455 &query,
456 &query_binary[qi],
457 binary,
458 ef,
459 binary_rerank,
460 heap,
461 );
462 }
463 });
464 }
465
466 heaps
468 .0
469 .into_par_iter()
470 .map(|cell| {
471 let heap = cell.into_inner();
472 let mut results: Vec<(u32, u32)> = heap.into_vec();
473 results.sort_unstable_by_key(|&(d, _)| d);
474 results.iter().take(k).map(|&(_, id)| id).collect()
475 })
476 .collect()
477 }
478}
479
480#[inline]
482fn heap_insert(heap: &mut BinaryHeap<(u32, u32)>, dist: u32, id: u32, cap: usize) {
483 if heap.len() < cap {
484 heap.push((dist, id));
485 } else if let Some(mut top) = heap.peek_mut() {
486 if dist < top.0 {
487 *top = (dist, id);
488 }
489 }
490}
491
492pub fn sorted_intersect_u16(a: &[u16], b: &[u16]) -> Vec<u16> {
494 let mut result = Vec::new();
495 let (mut i, mut j) = (0, 0);
496 while i < a.len() && j < b.len() {
497 match a[i].cmp(&b[j]) {
498 std::cmp::Ordering::Less => i += 1,
499 std::cmp::Ordering::Greater => j += 1,
500 std::cmp::Ordering::Equal => {
501 result.push(a[i]);
502 i += 1;
503 j += 1;
504 }
505 }
506 }
507 result
508}
509
510pub fn kmeans(
512 base: &VecStore,
513 n: usize,
514 dim: usize,
515 c: usize,
516 iters: usize,
517) -> (Vec<u16>, VecStore) {
518 let mut rng = StdRng::seed_from_u64(42);
519 let mut centroid_ids: Vec<usize> = (0..n).collect();
520 centroid_ids.shuffle(&mut rng);
521 centroid_ids.truncate(c);
522
523 let mut centroids_f32 = vec![0.0f32; c * dim];
524 match base {
525 VecStore::U8(data) => {
526 for (ci, &vid) in centroid_ids.iter().enumerate() {
527 for d in 0..dim {
528 centroids_f32[ci * dim + d] = data[vid * dim + d] as f32;
529 }
530 }
531 }
532 VecStore::F32(data) => {
533 for (ci, &vid) in centroid_ids.iter().enumerate() {
534 centroids_f32[ci * dim..(ci + 1) * dim]
535 .copy_from_slice(&data[vid * dim..(vid + 1) * dim]);
536 }
537 }
538 }
539
540 let mut assignments = vec![0u16; n];
541
542 for iter in 0..iters {
543 let t0 = std::time::Instant::now();
544
545 let new_assignments: Vec<u16> = match base {
547 VecStore::U8(data) => {
548 let centroids_u8: Vec<u8> = centroids_f32
549 .iter()
550 .map(|&x| x.round().clamp(0.0, 255.0) as u8)
551 .collect();
552 (0..n)
553 .into_par_iter()
554 .map(|i| {
555 let v = &data[i * dim..(i + 1) * dim];
556 let mut best_c = 0u16;
557 let mut best_d = u32::MAX;
558 for ci in 0..c {
559 let cent = ¢roids_u8[ci * dim..(ci + 1) * dim];
560 let d = distance::l2_sq8(v, cent);
561 if d < best_d {
562 best_d = d;
563 best_c = ci as u16;
564 }
565 }
566 best_c
567 })
568 .collect()
569 }
570 VecStore::F32(data) => (0..n)
571 .into_par_iter()
572 .map(|i| {
573 let v = &data[i * dim..(i + 1) * dim];
574 let mut best_c = 0u16;
575 let mut best_d = f32::INFINITY;
576 for ci in 0..c {
577 let cent = ¢roids_f32[ci * dim..(ci + 1) * dim];
578 let d = distance::l2_squared(v, cent);
579 if d < best_d {
580 best_d = d;
581 best_c = ci as u16;
582 }
583 }
584 best_c
585 })
586 .collect(),
587 };
588 assignments = new_assignments;
589
590 let mut sums = vec![0.0f64; c * dim];
592 let mut counts = vec![0u32; c];
593 match base {
594 VecStore::U8(data) => {
595 for i in 0..n {
596 let ci = assignments[i] as usize;
597 counts[ci] += 1;
598 for d in 0..dim {
599 sums[ci * dim + d] += data[i * dim + d] as f64;
600 }
601 }
602 }
603 VecStore::F32(data) => {
604 for i in 0..n {
605 let ci = assignments[i] as usize;
606 counts[ci] += 1;
607 for d in 0..dim {
608 sums[ci * dim + d] += data[i * dim + d] as f64;
609 }
610 }
611 }
612 }
613 for ci in 0..c {
614 if counts[ci] > 0 {
615 let inv = 1.0 / counts[ci] as f64;
616 for d in 0..dim {
617 centroids_f32[ci * dim + d] = (sums[ci * dim + d] * inv) as f32;
618 }
619 }
620 }
621
622 let min_s = counts.iter().min().unwrap();
623 let max_s = counts.iter().max().unwrap();
624 let empty = counts.iter().filter(|&&c| c == 0).count();
625 eprintln!(
626 " iter {}/{}: min={min_s}, max={max_s}, empty={empty} ({:.1}s)",
627 iter + 1,
628 iters,
629 t0.elapsed().as_secs_f64()
630 );
631 }
632
633 let centroids = match base {
634 VecStore::U8(_) => VecStore::U8(
635 centroids_f32
636 .iter()
637 .map(|&x| x.round().clamp(0.0, 255.0) as u8)
638 .collect(),
639 ),
640 VecStore::F32(_) => VecStore::F32(centroids_f32),
641 };
642
643 (assignments, centroids)
644}