Skip to main content

citadel_vector/vendored/prism/
ivf.rs

1//! IVF² (geometric clusters × tag posting lists) + MQCB batch search.
2//!
3//! Two-level inverted index: K-means clusters for geometric proximity,
4//! per-cluster tag posting lists for attribute filtering. Vectors stored
5//! once (no duplication). Intra-cluster tag-affinity sort for sequential
6//! memory access on posting list scans.
7
8use super::binary::BinaryStore;
9use super::distance;
10
11use rand::prelude::*;
12use rayon::prelude::*;
13use std::cell::UnsafeCell;
14use std::collections::BinaryHeap;
15
16/// CSR sparse matrix (same layout as scipy.sparse.csr_matrix).
17pub struct SpMat {
18    pub rows: usize,
19    pub cols: usize,
20    pub indptr: Vec<i64>,
21    pub indices: Vec<i32>,
22}
23
24/// Type-erased flat vector storage (u8 or f32).
25pub enum VecStore {
26    U8(Vec<u8>),
27    F32(Vec<f32>),
28}
29
30/// Borrowed query batch (flat, nq × dim).
31pub enum QueryStore<'a> {
32    U8(&'a [u8]),
33    F32(&'a [f32]),
34}
35
36/// Single query vector slice.
37enum QueryVec<'a> {
38    U8(&'a [u8]),
39    F32(&'a [f32]),
40}
41
42/// Distance suitable for heap ordering. For u8: raw u32 from l2_sq8.
43/// For f32: f32::to_bits() (monotonic for non-negative IEEE 754 floats).
44#[inline]
45fn compute_dist(store: &VecStore, gid: usize, query: &QueryVec, dim: usize) -> u32 {
46    match (store, query) {
47        (VecStore::U8(v), QueryVec::U8(q)) => distance::l2_sq8(q, &v[gid * dim..(gid + 1) * dim]),
48        (VecStore::F32(v), QueryVec::F32(q)) => {
49            distance::l2_squared(q, &v[gid * dim..(gid + 1) * dim]).to_bits()
50        }
51        _ => unreachable!("mismatched vector/query types"),
52    }
53}
54
55/// IVF² index: geometric clusters × per-cluster tag posting lists.
56pub struct IvfIndex {
57    /// Reordered vectors (contiguous per cluster).
58    pub vectors: VecStore,
59    /// Mapping: reordered_id → original_id.
60    pub original_ids: Vec<u32>,
61    /// Cluster boundaries: cluster c spans [cluster_starts[c]..cluster_starts[c+1]).
62    pub cluster_starts: Vec<u32>,
63    /// Per-cluster tag index offsets.
64    tag_offsets: Vec<u32>,
65    /// (tag_id, posting_start, posting_len) triples, sorted by tag_id within each cluster.
66    tag_index: Vec<(u32, u32, u32)>,
67    /// Flat array of local IDs for all (cluster, tag) posting lists.
68    posting_ids: Vec<u32>,
69    /// Per-tag list of clusters containing matching vectors.
70    pub tag_clusters: Vec<Vec<u16>>,
71    /// Vector dimensionality.
72    pub dim: usize,
73    /// Number of clusters.
74    pub n_clusters: usize,
75}
76
77impl IvfIndex {
78    /// Build IVF² index from clustered vectors and metadata.
79    ///
80    /// Reorders vectors by cluster, sorts within each cluster by most popular
81    /// tag (tag-affinity sort), and builds per-cluster tag posting lists.
82    pub fn build(
83        base: &VecStore,
84        base_meta: &SpMat,
85        assignments: &[u16],
86        n: usize,
87        dim: usize,
88        n_clusters: usize,
89    ) -> Self {
90        // Compute cluster sizes and start offsets
91        let mut cluster_sizes = vec![0u32; n_clusters];
92        for &a in assignments {
93            cluster_sizes[a as usize] += 1;
94        }
95        let mut cluster_starts = vec![0u32; n_clusters + 1];
96        for i in 0..n_clusters {
97            cluster_starts[i + 1] = cluster_starts[i] + cluster_sizes[i];
98        }
99
100        // Build reordering: new_order[new_id] = old_id
101        let mut position = cluster_starts[..n_clusters].to_vec();
102        let mut new_order = vec![0u32; n];
103        for (i, &ci_raw) in assignments.iter().enumerate().take(n) {
104            let ci = ci_raw as usize;
105            let new_id = position[ci] as usize;
106            new_order[new_id] = i as u32;
107            position[ci] += 1;
108        }
109
110        // Reorder vectors by cluster (first pass)
111        macro_rules! reorder_and_sort {
112            ($base_data:expr, $zero:expr, $T:ty) => {{
113                let mut vecs = vec![$zero; n * dim];
114                for (new_id, &old_id) in new_order.iter().enumerate() {
115                    let src = &$base_data[old_id as usize * dim..(old_id as usize + 1) * dim];
116                    vecs[new_id * dim..(new_id + 1) * dim].copy_from_slice(src);
117                }
118
119                // Tag-affinity sort within each cluster
120                let mut tag_freq = vec![0u32; base_meta.cols + 1];
121                for &tag in &base_meta.indices {
122                    tag_freq[tag as usize] += 1;
123                }
124                for ci in 0..n_clusters {
125                    let cs = cluster_starts[ci] as usize;
126                    let ce = cluster_starts[ci + 1] as usize;
127                    if ce - cs <= 1 {
128                        continue;
129                    }
130
131                    let mut sort_keys: Vec<(u32, usize)> = (0..ce - cs)
132                        .map(|local| {
133                            let old_id = new_order[cs + local] as usize;
134                            let ms = base_meta.indptr[old_id] as usize;
135                            let me = base_meta.indptr[old_id + 1] as usize;
136                            let tag = base_meta.indices[ms..me]
137                                .iter()
138                                .max_by_key(|&&t| tag_freq[t as usize])
139                                .map(|&t| t as u32)
140                                .unwrap_or(u32::MAX);
141                            (tag, local)
142                        })
143                        .collect();
144                    sort_keys.sort_unstable_by_key(|&(tag, _)| tag);
145
146                    let old_vecs: Vec<$T> = vecs[cs * dim..ce * dim].to_vec();
147                    let old_ids: Vec<u32> = new_order[cs..ce].to_vec();
148                    for (new_local, &(_, old_local)) in sort_keys.iter().enumerate() {
149                        vecs[(cs + new_local) * dim..(cs + new_local + 1) * dim]
150                            .copy_from_slice(&old_vecs[old_local * dim..(old_local + 1) * dim]);
151                        new_order[cs + new_local] = old_ids[old_local];
152                    }
153                }
154                vecs
155            }};
156        }
157
158        let vectors = match base {
159            VecStore::U8(data) => VecStore::U8(reorder_and_sort!(data, 0u8, u8)),
160            VecStore::F32(data) => VecStore::F32(reorder_and_sort!(data, 0.0f32, f32)),
161        };
162
163        // Build old_to_new mapping (after intra-cluster sort)
164        let mut old_to_new = vec![0u32; n];
165        for (new_id, &old_id) in new_order.iter().enumerate() {
166            old_to_new[old_id as usize] = new_id as u32;
167        }
168
169        // Build per-cluster tag index using HashMap, then flatten
170        let mut all_tag_entries: Vec<Vec<(u32, u32, u32)>> = Vec::with_capacity(n_clusters);
171        let mut all_posting_ids: Vec<u32> = Vec::new();
172
173        let mut cluster_maps: Vec<std::collections::HashMap<u32, Vec<u32>>> = (0..n_clusters)
174            .map(|_| std::collections::HashMap::new())
175            .collect();
176
177        for old_id in 0..n {
178            let new_id = old_to_new[old_id] as usize;
179            let ci = assignments[old_id] as usize;
180            let local_id = new_id - cluster_starts[ci] as usize;
181
182            let start = base_meta.indptr[old_id] as usize;
183            let end = base_meta.indptr[old_id + 1] as usize;
184            for &tag in &base_meta.indices[start..end] {
185                cluster_maps[ci]
186                    .entry(tag as u32)
187                    .or_default()
188                    .push(local_id as u32);
189            }
190        }
191
192        // Flatten to sorted arrays
193        for cluster_map in cluster_maps.iter_mut().take(n_clusters) {
194            let mut entries: Vec<(u32, Vec<u32>)> = cluster_map.drain().collect();
195            entries.sort_unstable_by_key(|&(tag, _)| tag);
196
197            let mut cluster_entries = Vec::with_capacity(entries.len());
198            for (tag, mut ids) in entries {
199                ids.sort_unstable();
200                let posting_start = all_posting_ids.len() as u32;
201                let posting_len = ids.len() as u32;
202                all_posting_ids.extend_from_slice(&ids);
203                cluster_entries.push((tag, posting_start, posting_len));
204            }
205            all_tag_entries.push(cluster_entries);
206        }
207
208        // Build flat tag_offsets + tag_index
209        let mut tag_offsets = Vec::with_capacity(n_clusters + 1);
210        let mut tag_index = Vec::new();
211        let mut offset = 0u32;
212        for entries in &all_tag_entries {
213            tag_offsets.push(offset);
214            tag_index.extend_from_slice(entries);
215            offset += entries.len() as u32;
216        }
217        tag_offsets.push(offset);
218
219        let total_posting = all_posting_ids.len();
220        let total_entries = tag_index.len();
221        eprintln!(
222            "  IVF: {n_clusters} clusters, {total_entries} tag entries, {total_posting} posting IDs"
223        );
224
225        // Build per-tag cluster lists (for filtered cluster selection)
226        let max_tag = tag_index.iter().map(|&(t, _, _)| t).max().unwrap_or(0) as usize;
227        let mut tag_clusters: Vec<Vec<u16>> = vec![vec![]; max_tag + 1];
228        for ci in 0..n_clusters {
229            let start = tag_offsets[ci] as usize;
230            let end = tag_offsets[ci + 1] as usize;
231            for &(tag, _, _) in &tag_index[start..end] {
232                tag_clusters[tag as usize].push(ci as u16);
233            }
234        }
235
236        Self {
237            vectors,
238            original_ids: new_order,
239            cluster_starts,
240            tag_offsets,
241            tag_index,
242            posting_ids: all_posting_ids,
243            tag_clusters,
244            dim,
245            n_clusters,
246        }
247    }
248
249    /// Look up local IDs matching a tag within a cluster.
250    #[inline]
251    fn lookup_tag(&self, cluster: usize, tag: u32) -> &[u32] {
252        let start = self.tag_offsets[cluster] as usize;
253        let end = self.tag_offsets[cluster + 1] as usize;
254        let entries = &self.tag_index[start..end];
255        match entries.binary_search_by_key(&tag, |&(t, _, _)| t) {
256            Ok(idx) => {
257                let (_, ps, pl) = entries[idx];
258                &self.posting_ids[ps as usize..(ps + pl) as usize]
259            }
260            Err(_) => &[],
261        }
262    }
263
264    /// Scan matching vectors in a cluster against the query.
265    #[allow(clippy::too_many_arguments)]
266    fn scan_cluster(
267        &self,
268        ci: usize,
269        matching: &[u32],
270        query: &QueryVec,
271        q_binary: &[u64],
272        binary: &BinaryStore,
273        ef: usize,
274        binary_rerank: usize,
275        heap: &mut BinaryHeap<(u32, u32)>,
276    ) {
277        if matching.is_empty() {
278            return;
279        }
280        let dim = self.dim;
281        let cluster_base = self.cluster_starts[ci] as usize;
282        let rerank_budget = binary_rerank * ef;
283
284        if binary_rerank > 0 && matching.len() > rerank_budget {
285            let mut candidates: Vec<(u32, u32)> = matching
286                .iter()
287                .map(|&lid| {
288                    let gid = (cluster_base + lid as usize) as u32;
289                    (distance::hamming(q_binary, binary.code(gid)), lid)
290                })
291                .collect();
292            let budget = rerank_budget.min(candidates.len());
293            candidates.select_nth_unstable_by_key(budget - 1, |&(d, _)| d);
294            candidates.truncate(budget);
295            for &(_, lid) in &candidates {
296                let gid = (cluster_base + lid as usize) as u32;
297                let dist = compute_dist(&self.vectors, gid as usize, query, dim);
298                let orig_id = self.original_ids[gid as usize];
299                heap_insert(heap, dist, orig_id, ef);
300            }
301        } else {
302            for &lid in matching {
303                let gid = (cluster_base + lid as usize) as u32;
304                let dist = compute_dist(&self.vectors, gid as usize, query, dim);
305                let orig_id = self.original_ids[gid as usize];
306                heap_insert(heap, dist, orig_id, ef);
307            }
308        }
309    }
310
311    /// Intersect two sorted tag lists and scan matches.
312    #[allow(clippy::too_many_arguments)]
313    fn scan_cluster_intersect(
314        &self,
315        ci: usize,
316        list_a: &[u32],
317        list_b: &[u32],
318        query: &QueryVec,
319        q_binary: &[u64],
320        binary: &BinaryStore,
321        ef: usize,
322        binary_rerank: usize,
323        heap: &mut BinaryHeap<(u32, u32)>,
324    ) {
325        let dim = self.dim;
326        let cluster_base = self.cluster_starts[ci] as usize;
327        let rerank_budget = binary_rerank * ef;
328
329        let est = list_a.len().min(list_b.len());
330
331        if binary_rerank > 0 && est > rerank_budget {
332            let mut candidates: Vec<(u32, u32)> = Vec::new();
333            let (mut i, mut j) = (0, 0);
334            while i < list_a.len() && j < list_b.len() {
335                let a = list_a[i];
336                let b = list_b[j];
337                if a < b {
338                    i += 1;
339                } else if a > b {
340                    j += 1;
341                } else {
342                    let gid = (cluster_base + a as usize) as u32;
343                    let hd = distance::hamming(q_binary, binary.code(gid));
344                    candidates.push((hd, gid));
345                    i += 1;
346                    j += 1;
347                }
348            }
349            if candidates.len() > rerank_budget {
350                candidates.select_nth_unstable_by_key(rerank_budget - 1, |&(d, _)| d);
351                candidates.truncate(rerank_budget);
352            }
353            for &(_, gid) in &candidates {
354                let dist = compute_dist(&self.vectors, gid as usize, query, dim);
355                let orig_id = self.original_ids[gid as usize];
356                heap_insert(heap, dist, orig_id, ef);
357            }
358        } else {
359            let (mut i, mut j) = (0, 0);
360            while i < list_a.len() && j < list_b.len() {
361                let a = list_a[i];
362                let b = list_b[j];
363                if a < b {
364                    i += 1;
365                } else if a > b {
366                    j += 1;
367                } else {
368                    let gid = (cluster_base + a as usize) as u32;
369                    let dist = compute_dist(&self.vectors, gid as usize, query, dim);
370                    let orig_id = self.original_ids[gid as usize];
371                    heap_insert(heap, dist, orig_id, ef);
372                    i += 1;
373                    j += 1;
374                }
375            }
376        }
377    }
378
379    /// MQCB: processes queries grouped by cluster for L3 cache reuse.
380    #[allow(clippy::too_many_arguments)]
381    pub fn batch_search_mqcb(
382        &self,
383        queries: &QueryStore,
384        nq: usize,
385        query_tags: &[Vec<usize>],
386        query_binary: &[Vec<u64>],
387        query_top_clusters: &[Vec<usize>],
388        binary: &BinaryStore,
389        k: usize,
390        ef: usize,
391        n_probe: usize,
392        binary_rerank: usize,
393    ) -> Vec<Vec<u32>> {
394        let dim = self.dim;
395
396        // Invert: cluster → list of query indices
397        let mut cluster_queries: Vec<Vec<usize>> = vec![vec![]; self.n_clusters];
398        for (qi, top_clusters) in query_top_clusters.iter().enumerate().take(nq) {
399            let np = n_probe.min(top_clusters.len());
400            for &ci in &top_clusters[..np] {
401                cluster_queries[ci].push(qi);
402            }
403        }
404
405        // Per-query heaps. Safety: each qi appears at most once per cluster,
406        // clusters processed sequentially → no races.
407        struct HeapArray(Vec<UnsafeCell<BinaryHeap<(u32, u32)>>>);
408        unsafe impl Sync for HeapArray {}
409        impl HeapArray {
410            #[inline]
411            #[allow(clippy::mut_from_ref)]
412            unsafe fn get(&self, idx: usize) -> &mut BinaryHeap<(u32, u32)> {
413                &mut *self.0[idx].get()
414            }
415        }
416        let heaps = HeapArray(
417            (0..nq)
418                .map(|_| UnsafeCell::new(BinaryHeap::with_capacity(ef + 1)))
419                .collect(),
420        );
421
422        // Sequential cluster iteration for prefetcher-friendly memory access
423        for (ci, qi_list) in cluster_queries.iter().enumerate() {
424            if qi_list.is_empty() {
425                continue;
426            }
427
428            qi_list.par_iter().for_each(|&qi| {
429                let query = match queries {
430                    QueryStore::U8(data) => QueryVec::U8(&data[qi * dim..(qi + 1) * dim]),
431                    QueryStore::F32(data) => QueryVec::F32(&data[qi * dim..(qi + 1) * dim]),
432                };
433                let tags = &query_tags[qi];
434                let heap = unsafe { heaps.get(qi) };
435
436                if tags.len() == 1 {
437                    let matching = self.lookup_tag(ci, tags[0] as u32);
438                    self.scan_cluster(
439                        ci,
440                        matching,
441                        &query,
442                        &query_binary[qi],
443                        binary,
444                        ef,
445                        binary_rerank,
446                        heap,
447                    );
448                } else {
449                    let list_a = self.lookup_tag(ci, tags[0] as u32);
450                    let list_b = self.lookup_tag(ci, tags[1] as u32);
451                    self.scan_cluster_intersect(
452                        ci,
453                        list_a,
454                        list_b,
455                        &query,
456                        &query_binary[qi],
457                        binary,
458                        ef,
459                        binary_rerank,
460                        heap,
461                    );
462                }
463            });
464        }
465
466        // Extract top-k results
467        heaps
468            .0
469            .into_par_iter()
470            .map(|cell| {
471                let heap = cell.into_inner();
472                let mut results: Vec<(u32, u32)> = heap.into_vec();
473                results.sort_unstable_by_key(|&(d, _)| d);
474                results.iter().take(k).map(|&(_, id)| id).collect()
475            })
476            .collect()
477    }
478}
479
480/// Bounded max-heap insert via PeekMut (single sift-down).
481#[inline]
482fn heap_insert(heap: &mut BinaryHeap<(u32, u32)>, dist: u32, id: u32, cap: usize) {
483    if heap.len() < cap {
484        heap.push((dist, id));
485    } else if let Some(mut top) = heap.peek_mut() {
486        if dist < top.0 {
487            *top = (dist, id);
488        }
489    }
490}
491
492/// Sorted intersection of two sorted u16 slices.
493pub fn sorted_intersect_u16(a: &[u16], b: &[u16]) -> Vec<u16> {
494    let mut result = Vec::new();
495    let (mut i, mut j) = (0, 0);
496    while i < a.len() && j < b.len() {
497        match a[i].cmp(&b[j]) {
498            std::cmp::Ordering::Less => i += 1,
499            std::cmp::Ordering::Greater => j += 1,
500            std::cmp::Ordering::Equal => {
501                result.push(a[i]);
502                i += 1;
503                j += 1;
504            }
505        }
506    }
507    result
508}
509
510/// K-means clustering. Returns (assignments, centroids as VecStore matching input type).
511pub fn kmeans(
512    base: &VecStore,
513    n: usize,
514    dim: usize,
515    c: usize,
516    iters: usize,
517) -> (Vec<u16>, VecStore) {
518    let mut rng = StdRng::seed_from_u64(42);
519    let mut centroid_ids: Vec<usize> = (0..n).collect();
520    centroid_ids.shuffle(&mut rng);
521    centroid_ids.truncate(c);
522
523    let mut centroids_f32 = vec![0.0f32; c * dim];
524    match base {
525        VecStore::U8(data) => {
526            for (ci, &vid) in centroid_ids.iter().enumerate() {
527                for d in 0..dim {
528                    centroids_f32[ci * dim + d] = data[vid * dim + d] as f32;
529                }
530            }
531        }
532        VecStore::F32(data) => {
533            for (ci, &vid) in centroid_ids.iter().enumerate() {
534                centroids_f32[ci * dim..(ci + 1) * dim]
535                    .copy_from_slice(&data[vid * dim..(vid + 1) * dim]);
536            }
537        }
538    }
539
540    let mut assignments = vec![0u16; n];
541
542    for iter in 0..iters {
543        let t0 = std::time::Instant::now();
544
545        // Assignment step
546        let new_assignments: Vec<u16> = match base {
547            VecStore::U8(data) => {
548                let centroids_u8: Vec<u8> = centroids_f32
549                    .iter()
550                    .map(|&x| x.round().clamp(0.0, 255.0) as u8)
551                    .collect();
552                (0..n)
553                    .into_par_iter()
554                    .map(|i| {
555                        let v = &data[i * dim..(i + 1) * dim];
556                        let mut best_c = 0u16;
557                        let mut best_d = u32::MAX;
558                        for ci in 0..c {
559                            let cent = &centroids_u8[ci * dim..(ci + 1) * dim];
560                            let d = distance::l2_sq8(v, cent);
561                            if d < best_d {
562                                best_d = d;
563                                best_c = ci as u16;
564                            }
565                        }
566                        best_c
567                    })
568                    .collect()
569            }
570            VecStore::F32(data) => (0..n)
571                .into_par_iter()
572                .map(|i| {
573                    let v = &data[i * dim..(i + 1) * dim];
574                    let mut best_c = 0u16;
575                    let mut best_d = f32::INFINITY;
576                    for ci in 0..c {
577                        let cent = &centroids_f32[ci * dim..(ci + 1) * dim];
578                        let d = distance::l2_squared(v, cent);
579                        if d < best_d {
580                            best_d = d;
581                            best_c = ci as u16;
582                        }
583                    }
584                    best_c
585                })
586                .collect(),
587        };
588        assignments = new_assignments;
589
590        // Update step: accumulate in f64
591        let mut sums = vec![0.0f64; c * dim];
592        let mut counts = vec![0u32; c];
593        match base {
594            VecStore::U8(data) => {
595                for i in 0..n {
596                    let ci = assignments[i] as usize;
597                    counts[ci] += 1;
598                    for d in 0..dim {
599                        sums[ci * dim + d] += data[i * dim + d] as f64;
600                    }
601                }
602            }
603            VecStore::F32(data) => {
604                for i in 0..n {
605                    let ci = assignments[i] as usize;
606                    counts[ci] += 1;
607                    for d in 0..dim {
608                        sums[ci * dim + d] += data[i * dim + d] as f64;
609                    }
610                }
611            }
612        }
613        for ci in 0..c {
614            if counts[ci] > 0 {
615                let inv = 1.0 / counts[ci] as f64;
616                for d in 0..dim {
617                    centroids_f32[ci * dim + d] = (sums[ci * dim + d] * inv) as f32;
618                }
619            }
620        }
621
622        let min_s = counts.iter().min().unwrap();
623        let max_s = counts.iter().max().unwrap();
624        let empty = counts.iter().filter(|&&c| c == 0).count();
625        eprintln!(
626            "  iter {}/{}: min={min_s}, max={max_s}, empty={empty} ({:.1}s)",
627            iter + 1,
628            iters,
629            t0.elapsed().as_secs_f64()
630        );
631    }
632
633    let centroids = match base {
634        VecStore::U8(_) => VecStore::U8(
635            centroids_f32
636                .iter()
637                .map(|&x| x.round().clamp(0.0, 255.0) as u8)
638                .collect(),
639        ),
640        VecStore::F32(_) => VecStore::F32(centroids_f32),
641    };
642
643    (assignments, centroids)
644}