Skip to main content

citadel_vector/vendored/prism/
ivf.rs

1//! IVF2 (geometric clusters x tag posting lists) + MQCB batch search.
2//!
3//! Two-level inverted index: K-means clusters for geometric proximity,
4//! per-cluster tag posting lists for attribute filtering. Vectors stored
5//! once (no duplication). Intra-cluster tag-affinity sort for sequential
6//! memory access on posting list scans.
7
8use super::binary::BinaryStore;
9use super::distance;
10
11use rand::prelude::*;
12use rayon::prelude::*;
13use std::cell::UnsafeCell;
14use std::collections::BinaryHeap;
15
16/// CSR sparse matrix (same layout as scipy.sparse.csr_matrix).
17pub struct SpMat {
18    pub rows: usize,
19    pub cols: usize,
20    pub indptr: Vec<i64>,
21    pub indices: Vec<i32>,
22}
23
24/// Type-erased flat vector storage (u8 or f32).
25pub enum VecStore {
26    U8(Vec<u8>),
27    F32(Vec<f32>),
28}
29
30/// Borrowed query batch (flat, nq x dim).
31pub enum QueryStore<'a> {
32    U8(&'a [u8]),
33    F32(&'a [f32]),
34}
35
36/// Single query vector slice.
37enum QueryVec<'a> {
38    U8(&'a [u8]),
39    F32(&'a [f32]),
40}
41
42/// Distance suitable for heap ordering. For u8: raw u32 from l2_sq8.
43/// For f32: f32::to_bits() (monotonic for non-negative IEEE 754 floats).
44#[inline]
45fn compute_dist(store: &VecStore, gid: usize, query: &QueryVec, dim: usize) -> u32 {
46    match (store, query) {
47        (VecStore::U8(v), QueryVec::U8(q)) => distance::l2_sq8(q, &v[gid * dim..(gid + 1) * dim]),
48        (VecStore::F32(v), QueryVec::F32(q)) => {
49            distance::l2_squared(q, &v[gid * dim..(gid + 1) * dim]).to_bits()
50        }
51        _ => unreachable!("mismatched vector/query types"),
52    }
53}
54
55/// IVF2 index: geometric clusters x per-cluster tag posting lists.
56pub struct IvfIndex {
57    /// Reordered vectors (contiguous per cluster).
58    pub vectors: VecStore,
59    /// Mapping: reordered_id -> original_id.
60    pub original_ids: Vec<u32>,
61    /// Cluster boundaries: cluster c spans [cluster_starts[c]..cluster_starts[c+1]).
62    pub cluster_starts: Vec<u32>,
63    /// Per-cluster tag index offsets.
64    tag_offsets: Vec<u32>,
65    /// (tag_id, posting_start, posting_len) triples, sorted by tag_id within each cluster.
66    tag_index: Vec<(u32, u32, u32)>,
67    /// Flat array of local IDs for all (cluster, tag) posting lists.
68    posting_ids: Vec<u32>,
69    /// Per-tag list of clusters containing matching vectors.
70    pub tag_clusters: Vec<Vec<u16>>,
71    /// Vector dimensionality.
72    pub dim: usize,
73    /// Number of clusters.
74    pub n_clusters: usize,
75}
76
77impl IvfIndex {
78    /// Build the IVF2 index from clustered vectors and metadata.
79    ///
80    /// Reorders vectors by cluster, sorts within each cluster by most popular
81    /// tag (tag-affinity sort), and builds per-cluster tag posting lists.
82    pub fn build(
83        base: &VecStore,
84        base_meta: &SpMat,
85        assignments: &[u16],
86        n: usize,
87        dim: usize,
88        n_clusters: usize,
89    ) -> Self {
90        let mut cluster_sizes = vec![0u32; n_clusters];
91        for &a in assignments {
92            cluster_sizes[a as usize] += 1;
93        }
94        let mut cluster_starts = vec![0u32; n_clusters + 1];
95        for i in 0..n_clusters {
96            cluster_starts[i + 1] = cluster_starts[i] + cluster_sizes[i];
97        }
98
99        let mut position = cluster_starts[..n_clusters].to_vec();
100        let mut new_order = vec![0u32; n];
101        for (i, &ci_raw) in assignments.iter().enumerate().take(n) {
102            let ci = ci_raw as usize;
103            let new_id = position[ci] as usize;
104            new_order[new_id] = i as u32;
105            position[ci] += 1;
106        }
107
108        macro_rules! reorder_and_sort {
109            ($base_data:expr, $zero:expr, $T:ty) => {{
110                let mut vecs = vec![$zero; n * dim];
111                for (new_id, &old_id) in new_order.iter().enumerate() {
112                    let src = &$base_data[old_id as usize * dim..(old_id as usize + 1) * dim];
113                    vecs[new_id * dim..(new_id + 1) * dim].copy_from_slice(src);
114                }
115
116                // Tag-affinity sort within each cluster
117                let mut tag_freq = vec![0u32; base_meta.cols + 1];
118                for &tag in &base_meta.indices {
119                    tag_freq[tag as usize] += 1;
120                }
121                for ci in 0..n_clusters {
122                    let cs = cluster_starts[ci] as usize;
123                    let ce = cluster_starts[ci + 1] as usize;
124                    if ce - cs <= 1 {
125                        continue;
126                    }
127
128                    let mut sort_keys: Vec<(u32, usize)> = (0..ce - cs)
129                        .map(|local| {
130                            let old_id = new_order[cs + local] as usize;
131                            let ms = base_meta.indptr[old_id] as usize;
132                            let me = base_meta.indptr[old_id + 1] as usize;
133                            let tag = base_meta.indices[ms..me]
134                                .iter()
135                                .max_by_key(|&&t| tag_freq[t as usize])
136                                .map(|&t| t as u32)
137                                .unwrap_or(u32::MAX);
138                            (tag, local)
139                        })
140                        .collect();
141                    sort_keys.sort_unstable_by_key(|&(tag, _)| tag);
142
143                    let old_vecs: Vec<$T> = vecs[cs * dim..ce * dim].to_vec();
144                    let old_ids: Vec<u32> = new_order[cs..ce].to_vec();
145                    for (new_local, &(_, old_local)) in sort_keys.iter().enumerate() {
146                        vecs[(cs + new_local) * dim..(cs + new_local + 1) * dim]
147                            .copy_from_slice(&old_vecs[old_local * dim..(old_local + 1) * dim]);
148                        new_order[cs + new_local] = old_ids[old_local];
149                    }
150                }
151                vecs
152            }};
153        }
154
155        let vectors = match base {
156            VecStore::U8(data) => VecStore::U8(reorder_and_sort!(data, 0u8, u8)),
157            VecStore::F32(data) => VecStore::F32(reorder_and_sort!(data, 0.0f32, f32)),
158        };
159
160        // old_to_new must come AFTER the intra-cluster tag-affinity sort.
161        let mut old_to_new = vec![0u32; n];
162        for (new_id, &old_id) in new_order.iter().enumerate() {
163            old_to_new[old_id as usize] = new_id as u32;
164        }
165
166        let mut all_tag_entries: Vec<Vec<(u32, u32, u32)>> = Vec::with_capacity(n_clusters);
167        let mut all_posting_ids: Vec<u32> = Vec::new();
168
169        let mut cluster_maps: Vec<std::collections::HashMap<u32, Vec<u32>>> = (0..n_clusters)
170            .map(|_| std::collections::HashMap::new())
171            .collect();
172
173        for old_id in 0..n {
174            let new_id = old_to_new[old_id] as usize;
175            let ci = assignments[old_id] as usize;
176            let local_id = new_id - cluster_starts[ci] as usize;
177
178            let start = base_meta.indptr[old_id] as usize;
179            let end = base_meta.indptr[old_id + 1] as usize;
180            for &tag in &base_meta.indices[start..end] {
181                cluster_maps[ci]
182                    .entry(tag as u32)
183                    .or_default()
184                    .push(local_id as u32);
185            }
186        }
187
188        for cluster_map in cluster_maps.iter_mut().take(n_clusters) {
189            let mut entries: Vec<(u32, Vec<u32>)> = cluster_map.drain().collect();
190            entries.sort_unstable_by_key(|&(tag, _)| tag);
191
192            let mut cluster_entries = Vec::with_capacity(entries.len());
193            for (tag, mut ids) in entries {
194                ids.sort_unstable();
195                let posting_start = all_posting_ids.len() as u32;
196                let posting_len = ids.len() as u32;
197                all_posting_ids.extend_from_slice(&ids);
198                cluster_entries.push((tag, posting_start, posting_len));
199            }
200            all_tag_entries.push(cluster_entries);
201        }
202
203        let mut tag_offsets = Vec::with_capacity(n_clusters + 1);
204        let mut tag_index = Vec::new();
205        let mut offset = 0u32;
206        for entries in &all_tag_entries {
207            tag_offsets.push(offset);
208            tag_index.extend_from_slice(entries);
209            offset += entries.len() as u32;
210        }
211        tag_offsets.push(offset);
212
213        // Build per-tag cluster lists (for filtered cluster selection)
214        let max_tag = tag_index.iter().map(|&(t, _, _)| t).max().unwrap_or(0) as usize;
215        let mut tag_clusters: Vec<Vec<u16>> = vec![vec![]; max_tag + 1];
216        for ci in 0..n_clusters {
217            let start = tag_offsets[ci] as usize;
218            let end = tag_offsets[ci + 1] as usize;
219            for &(tag, _, _) in &tag_index[start..end] {
220                tag_clusters[tag as usize].push(ci as u16);
221            }
222        }
223
224        Self {
225            vectors,
226            original_ids: new_order,
227            cluster_starts,
228            tag_offsets,
229            tag_index,
230            posting_ids: all_posting_ids,
231            tag_clusters,
232            dim,
233            n_clusters,
234        }
235    }
236
237    /// Look up local IDs matching a tag within a cluster.
238    #[inline]
239    fn lookup_tag(&self, cluster: usize, tag: u32) -> &[u32] {
240        let start = self.tag_offsets[cluster] as usize;
241        let end = self.tag_offsets[cluster + 1] as usize;
242        let entries = &self.tag_index[start..end];
243        match entries.binary_search_by_key(&tag, |&(t, _, _)| t) {
244            Ok(idx) => {
245                let (_, ps, pl) = entries[idx];
246                &self.posting_ids[ps as usize..(ps + pl) as usize]
247            }
248            Err(_) => &[],
249        }
250    }
251
252    /// Scan local ids in a cluster against the query. The Hamming pre-filter
253    /// applies when the candidate count exceeds the rerank budget.
254    #[allow(clippy::too_many_arguments)]
255    fn scan_cluster(
256        &self,
257        ci: usize,
258        lids: impl ExactSizeIterator<Item = u32>,
259        query: &QueryVec,
260        q_binary: &[u64],
261        binary: &BinaryStore,
262        ef: usize,
263        binary_rerank: usize,
264        heap: &mut BinaryHeap<(u32, u32)>,
265    ) {
266        let dim = self.dim;
267        let cluster_base = self.cluster_starts[ci] as usize;
268        let rerank_budget = binary_rerank * ef;
269
270        if binary_rerank > 0 && lids.len() > rerank_budget {
271            let mut candidates: Vec<(u32, u32)> = lids
272                .map(|lid| {
273                    let gid = (cluster_base + lid as usize) as u32;
274                    (distance::hamming(q_binary, binary.code(gid)), lid)
275                })
276                .collect();
277            let budget = rerank_budget.min(candidates.len());
278            candidates.select_nth_unstable_by_key(budget - 1, |&(d, _)| d);
279            candidates.truncate(budget);
280            for &(_, lid) in &candidates {
281                let gid = (cluster_base + lid as usize) as u32;
282                let dist = compute_dist(&self.vectors, gid as usize, query, dim);
283                let orig_id = self.original_ids[gid as usize];
284                heap_insert(heap, dist, orig_id, ef);
285            }
286        } else {
287            for lid in lids {
288                let gid = (cluster_base + lid as usize) as u32;
289                let dist = compute_dist(&self.vectors, gid as usize, query, dim);
290                let orig_id = self.original_ids[gid as usize];
291                heap_insert(heap, dist, orig_id, ef);
292            }
293        }
294    }
295
296    /// MQCB: processes queries grouped by cluster for L3 cache reuse.
297    #[allow(clippy::too_many_arguments)]
298    pub fn batch_search_mqcb(
299        &self,
300        queries: &QueryStore,
301        nq: usize,
302        query_tags: &[Vec<usize>],
303        query_binary: &[Vec<u64>],
304        query_top_clusters: &[Vec<usize>],
305        binary: &BinaryStore,
306        k: usize,
307        ef: usize,
308        n_probe: usize,
309        binary_rerank: usize,
310    ) -> Vec<Vec<u32>> {
311        let dim = self.dim;
312
313        // Invert: cluster -> list of query indices.
314        let mut cluster_queries: Vec<Vec<usize>> = vec![vec![]; self.n_clusters];
315        for (qi, top_clusters) in query_top_clusters.iter().enumerate().take(nq) {
316            let np = n_probe.min(top_clusters.len());
317            for &ci in &top_clusters[..np] {
318                cluster_queries[ci].push(qi);
319            }
320        }
321
322        // Per-query heaps. Safety: each qi appears at most once per cluster
323        // and clusters are processed sequentially, so no races.
324        struct HeapArray(Vec<UnsafeCell<BinaryHeap<(u32, u32)>>>);
325        unsafe impl Sync for HeapArray {}
326        impl HeapArray {
327            #[inline]
328            #[allow(clippy::mut_from_ref)]
329            unsafe fn get(&self, idx: usize) -> &mut BinaryHeap<(u32, u32)> {
330                &mut *self.0[idx].get()
331            }
332        }
333        let heaps = HeapArray(
334            (0..nq)
335                .map(|_| UnsafeCell::new(BinaryHeap::with_capacity(ef + 1)))
336                .collect(),
337        );
338
339        // Sequential cluster iteration for prefetcher-friendly memory access
340        for (ci, qi_list) in cluster_queries.iter().enumerate() {
341            if qi_list.is_empty() {
342                continue;
343            }
344
345            qi_list.par_iter().for_each(|&qi| {
346                let query = match queries {
347                    QueryStore::U8(data) => QueryVec::U8(&data[qi * dim..(qi + 1) * dim]),
348                    QueryStore::F32(data) => QueryVec::F32(&data[qi * dim..(qi + 1) * dim]),
349                };
350                let tags = &query_tags[qi];
351                let heap = unsafe { heaps.get(qi) };
352
353                match tags.len() {
354                    // Unfiltered: every point in the cluster is a candidate.
355                    0 => {
356                        let len = self.cluster_starts[ci + 1] - self.cluster_starts[ci];
357                        self.scan_cluster(
358                            ci,
359                            0..len,
360                            &query,
361                            &query_binary[qi],
362                            binary,
363                            ef,
364                            binary_rerank,
365                            heap,
366                        );
367                    }
368                    1 => {
369                        let matching = self.lookup_tag(ci, tags[0] as u32);
370                        self.scan_cluster(
371                            ci,
372                            matching.iter().copied(),
373                            &query,
374                            &query_binary[qi],
375                            binary,
376                            ef,
377                            binary_rerank,
378                            heap,
379                        );
380                    }
381                    // Conjunctive filter: a candidate must match EVERY tag.
382                    _ => {
383                        let lists: Vec<&[u32]> = tags
384                            .iter()
385                            .map(|&t| self.lookup_tag(ci, t as u32))
386                            .collect();
387                        let matching = intersect_postings(lists);
388                        self.scan_cluster(
389                            ci,
390                            matching.iter().copied(),
391                            &query,
392                            &query_binary[qi],
393                            binary,
394                            ef,
395                            binary_rerank,
396                            heap,
397                        );
398                    }
399                }
400            });
401        }
402
403        heaps
404            .0
405            .into_par_iter()
406            .map(|cell| {
407                let heap = cell.into_inner();
408                let mut results: Vec<(u32, u32)> = heap.into_vec();
409                results.sort_unstable_by_key(|&(d, _)| d);
410                results.iter().take(k).map(|&(_, id)| id).collect()
411            })
412            .collect()
413    }
414}
415
416/// Bounded max-heap insert via PeekMut (single sift-down).
417#[inline]
418fn heap_insert(heap: &mut BinaryHeap<(u32, u32)>, dist: u32, id: u32, cap: usize) {
419    if heap.len() < cap {
420        heap.push((dist, id));
421    } else if let Some(mut top) = heap.peek_mut() {
422        if dist < top.0 {
423            *top = (dist, id);
424        }
425    }
426}
427
428/// Sorted k-way intersection of posting lists, smallest list first so the
429/// accumulator only shrinks.
430fn intersect_postings(mut lists: Vec<&[u32]>) -> Vec<u32> {
431    lists.sort_unstable_by_key(|l| l.len());
432    let mut acc: Vec<u32> = lists[0].to_vec();
433    for list in &lists[1..] {
434        if acc.is_empty() {
435            break;
436        }
437        let mut out = Vec::with_capacity(acc.len().min(list.len()));
438        let (mut i, mut j) = (0, 0);
439        while i < acc.len() && j < list.len() {
440            match acc[i].cmp(&list[j]) {
441                std::cmp::Ordering::Less => i += 1,
442                std::cmp::Ordering::Greater => j += 1,
443                std::cmp::Ordering::Equal => {
444                    out.push(acc[i]);
445                    i += 1;
446                    j += 1;
447                }
448            }
449        }
450        acc = out;
451    }
452    acc
453}
454
455/// Sorted intersection of two sorted u16 slices.
456pub fn sorted_intersect_u16(a: &[u16], b: &[u16]) -> Vec<u16> {
457    let mut result = Vec::new();
458    let (mut i, mut j) = (0, 0);
459    while i < a.len() && j < b.len() {
460        match a[i].cmp(&b[j]) {
461            std::cmp::Ordering::Less => i += 1,
462            std::cmp::Ordering::Greater => j += 1,
463            std::cmp::Ordering::Equal => {
464                result.push(a[i]);
465                i += 1;
466                j += 1;
467            }
468        }
469    }
470    result
471}
472
473/// K-means clustering. Returns (assignments, centroids as VecStore matching input type).
474pub fn kmeans(
475    base: &VecStore,
476    n: usize,
477    dim: usize,
478    c: usize,
479    iters: usize,
480) -> (Vec<u16>, VecStore) {
481    let mut rng = StdRng::seed_from_u64(42);
482    let mut centroid_ids: Vec<usize> = (0..n).collect();
483    centroid_ids.shuffle(&mut rng);
484    centroid_ids.truncate(c);
485
486    let mut centroids_f32 = vec![0.0f32; c * dim];
487    match base {
488        VecStore::U8(data) => {
489            for (ci, &vid) in centroid_ids.iter().enumerate() {
490                for d in 0..dim {
491                    centroids_f32[ci * dim + d] = data[vid * dim + d] as f32;
492                }
493            }
494        }
495        VecStore::F32(data) => {
496            for (ci, &vid) in centroid_ids.iter().enumerate() {
497                centroids_f32[ci * dim..(ci + 1) * dim]
498                    .copy_from_slice(&data[vid * dim..(vid + 1) * dim]);
499            }
500        }
501    }
502
503    let mut assignments = vec![0u16; n];
504
505    for _ in 0..iters {
506        let new_assignments: Vec<u16> = match base {
507            VecStore::U8(data) => {
508                let centroids_u8: Vec<u8> = centroids_f32
509                    .iter()
510                    .map(|&x| x.round().clamp(0.0, 255.0) as u8)
511                    .collect();
512                (0..n)
513                    .into_par_iter()
514                    .map(|i| {
515                        let v = &data[i * dim..(i + 1) * dim];
516                        let mut best_c = 0u16;
517                        let mut best_d = u32::MAX;
518                        for ci in 0..c {
519                            let cent = &centroids_u8[ci * dim..(ci + 1) * dim];
520                            let d = distance::l2_sq8(v, cent);
521                            if d < best_d {
522                                best_d = d;
523                                best_c = ci as u16;
524                            }
525                        }
526                        best_c
527                    })
528                    .collect()
529            }
530            VecStore::F32(data) => (0..n)
531                .into_par_iter()
532                .map(|i| {
533                    let v = &data[i * dim..(i + 1) * dim];
534                    let mut best_c = 0u16;
535                    let mut best_d = f32::INFINITY;
536                    for ci in 0..c {
537                        let cent = &centroids_f32[ci * dim..(ci + 1) * dim];
538                        let d = distance::l2_squared(v, cent);
539                        if d < best_d {
540                            best_d = d;
541                            best_c = ci as u16;
542                        }
543                    }
544                    best_c
545                })
546                .collect(),
547        };
548        assignments = new_assignments;
549
550        // Centroid update accumulates in f64 to avoid f32 cancellation.
551        let mut sums = vec![0.0f64; c * dim];
552        let mut counts = vec![0u32; c];
553        match base {
554            VecStore::U8(data) => {
555                for i in 0..n {
556                    let ci = assignments[i] as usize;
557                    counts[ci] += 1;
558                    for d in 0..dim {
559                        sums[ci * dim + d] += data[i * dim + d] as f64;
560                    }
561                }
562            }
563            VecStore::F32(data) => {
564                for i in 0..n {
565                    let ci = assignments[i] as usize;
566                    counts[ci] += 1;
567                    for d in 0..dim {
568                        sums[ci * dim + d] += data[i * dim + d] as f64;
569                    }
570                }
571            }
572        }
573        for ci in 0..c {
574            if counts[ci] > 0 {
575                let inv = 1.0 / counts[ci] as f64;
576                for d in 0..dim {
577                    centroids_f32[ci * dim + d] = (sums[ci * dim + d] * inv) as f32;
578                }
579            }
580        }
581
582        // FAISS-style repair: reseed each empty cluster from a random point of
583        // the most populated cluster (a frozen empty centroid rarely wins an
584        // assignment again, so the effective cluster count would only shrink).
585        for ci in 0..c {
586            if counts[ci] > 0 {
587                continue;
588            }
589            let donor = (0..c).max_by_key(|&d| counts[d]).unwrap();
590            if counts[donor] <= 1 {
591                break;
592            }
593            let members: Vec<usize> = (0..n)
594                .filter(|&i| assignments[i] as usize == donor)
595                .collect();
596            let p = members[rng.gen_range(0..members.len())];
597            match base {
598                VecStore::U8(data) => {
599                    for d in 0..dim {
600                        centroids_f32[ci * dim + d] = data[p * dim + d] as f32;
601                    }
602                }
603                VecStore::F32(data) => {
604                    centroids_f32[ci * dim..(ci + 1) * dim]
605                        .copy_from_slice(&data[p * dim..(p + 1) * dim]);
606                }
607            }
608            assignments[p] = ci as u16;
609            counts[donor] -= 1;
610            counts[ci] = 1;
611        }
612    }
613
614    let centroids = match base {
615        VecStore::U8(_) => VecStore::U8(
616            centroids_f32
617                .iter()
618                .map(|&x| x.round().clamp(0.0, 255.0) as u8)
619                .collect(),
620        ),
621        VecStore::F32(_) => VecStore::F32(centroids_f32),
622    };
623
624    (assignments, centroids)
625}
626
627#[cfg(test)]
628mod tests {
629    use super::*;
630    use crate::prism::point::PointStore;
631
632    /// 6 points, 4 tags, 2 hand-assigned clusters. Tag sets per point:
633    /// 0:{0,1,2} 1:{0,1} 2:{0,2} 3:{1,2} 4:{0,1,2} 5:{3}.
634    fn fixture() -> (IvfIndex, BinaryStore) {
635        let points: Vec<Vec<f32>> = vec![
636            vec![0.0, 0.0],
637            vec![0.1, 0.0],
638            vec![0.2, 0.0],
639            vec![0.3, 0.0],
640            vec![5.0, 5.0],
641            vec![5.1, 5.0],
642        ];
643        let tag_sets: Vec<Vec<i32>> = vec![
644            vec![0, 1, 2],
645            vec![0, 1],
646            vec![0, 2],
647            vec![1, 2],
648            vec![0, 1, 2],
649            vec![3],
650        ];
651        let flat: Vec<f32> = points.iter().flatten().copied().collect();
652        let mut indptr = vec![0i64];
653        let mut indices = Vec::new();
654        for tags in &tag_sets {
655            indices.extend_from_slice(tags);
656            indptr.push(indices.len() as i64);
657        }
658        let meta = SpMat {
659            rows: points.len(),
660            cols: 4,
661            indptr,
662            indices,
663        };
664        let assignments: Vec<u16> = vec![0, 0, 0, 0, 1, 1];
665        let base = VecStore::F32(flat.clone());
666        let index = IvfIndex::build(&base, &meta, &assignments, points.len(), 2, 2);
667        let store = PointStore::from_parts(flat, 2, vec![vec![0; points.len()]]);
668        let binary = BinaryStore::build(&store);
669        (index, binary)
670    }
671
672    fn run_query(
673        index: &IvfIndex,
674        binary: &BinaryStore,
675        query: &[f32],
676        tags: Vec<usize>,
677        k: usize,
678    ) -> Vec<u32> {
679        let qb = binary.encode_query(query);
680        let mut results = index.batch_search_mqcb(
681            &QueryStore::F32(query),
682            1,
683            &[tags],
684            &[qb],
685            &[vec![0, 1]],
686            binary,
687            k,
688            10,
689            2,
690            0,
691        );
692        results.pop().unwrap()
693    }
694
695    #[test]
696    fn batch_zero_tags_scans_whole_clusters() {
697        let (index, binary) = fixture();
698        let mut ids = run_query(&index, &binary, &[5.05, 5.0], Vec::new(), 2);
699        ids.sort_unstable();
700        assert_eq!(ids, vec![4, 5]);
701    }
702
703    #[test]
704    fn batch_three_tags_enforces_full_conjunction() {
705        let (index, binary) = fixture();
706        let ids = run_query(&index, &binary, &[0.05, 0.0], vec![0, 1, 2], 4);
707        let mut sorted = ids.clone();
708        sorted.sort_unstable();
709        // Only points 0 and 4 carry all three tags; point 1 matches just {0,1}
710        // and must not leak through a first-two-tags-only intersection.
711        assert_eq!(sorted, vec![0, 4]);
712    }
713
714    #[test]
715    fn kmeans_reseeds_empty_clusters() {
716        // 60 identical points + 4 outliers, 8 clusters: without repair most
717        // centroids never win an assignment and stay empty forever.
718        let n = 64;
719        let mut flat = vec![0.0f32; n * 2];
720        for (i, off) in [(60, 50.0f32), (61, -50.0), (62, 100.0), (63, -100.0)] {
721            flat[i * 2] = off;
722            flat[i * 2 + 1] = off;
723        }
724        let (assignments, _) = kmeans(&VecStore::F32(flat), n, 2, 8, 3);
725        let mut seen = [false; 8];
726        for &a in &assignments {
727            seen[a as usize] = true;
728        }
729        assert!(
730            seen.iter().all(|&s| s),
731            "every cluster must keep at least one member, got {assignments:?}"
732        );
733    }
734}