Skip to main content

citadel_vector/vendored/prism/
search.rs

1use super::construct::PrismIndex;
2use super::distance;
3use super::filter::Filter;
4
5use rayon::prelude::*;
6use std::cmp::Reverse;
7use std::collections::BinaryHeap;
8
9/// A search result: (point_id, distance).
10#[derive(Clone, Debug)]
11pub struct SearchResult {
12    pub id: u32,
13    pub dist: f32,
14}
15
16/// Bitset for O(1) visited tracking, sized to the cell for L1 cache locality.
17struct Bitset {
18    bits: Vec<u64>,
19}
20
21impl Bitset {
22    fn new(n: usize) -> Self {
23        Self {
24            bits: vec![0u64; n.div_ceil(64)],
25        }
26    }
27
28    /// Returns true if the bit was newly set (not previously visited).
29    #[inline]
30    fn insert(&mut self, i: u32) -> bool {
31        let word = i as usize >> 6;
32        let bit = 1u64 << (i & 63);
33        if self.bits[word] & bit != 0 {
34            false
35        } else {
36            self.bits[word] |= bit;
37            true
38        }
39    }
40
41    /// Check if a bit is set without modifying the bitset.
42    #[inline]
43    fn contains(&self, i: u32) -> bool {
44        let word = i as usize >> 6;
45        let bit = 1u64 << (i & 63);
46        self.bits[word] & bit != 0
47    }
48}
49
50/// Prefetch into L1.
51#[cfg(target_arch = "x86_64")]
52#[target_feature(enable = "sse")]
53#[inline]
54unsafe fn prefetch_t0(ptr: *const u8) {
55    std::arch::x86_64::_mm_prefetch(ptr as *const i8, std::arch::x86_64::_MM_HINT_T0);
56}
57
58/// Software prefetch hint.
59#[inline(always)]
60fn prefetch_read(ptr: *const u8) {
61    #[cfg(target_arch = "x86_64")]
62    unsafe {
63        prefetch_t0(ptr);
64    }
65    #[cfg(not(target_arch = "x86_64"))]
66    let _ = ptr;
67}
68
69/// Prefetch `len` bytes starting at `ptr`.
70#[inline(always)]
71fn prefetch_range(ptr: *const u8, len: usize) {
72    let mut offset = 0;
73    while offset < len {
74        prefetch_read(unsafe { ptr.add(offset) });
75        offset += 64;
76    }
77}
78
79/// Ordered f32 wrapper for use in BinaryHeap.
80#[derive(Clone, Copy, PartialEq)]
81struct OrdF32(f32);
82
83impl Eq for OrdF32 {}
84impl PartialOrd for OrdF32 {
85    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
86        Some(self.cmp(other))
87    }
88}
89impl Ord for OrdF32 {
90    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
91        self.0
92            .partial_cmp(&other.0)
93            .unwrap_or(std::cmp::Ordering::Equal)
94    }
95}
96
97/// Insert into a bounded max-heap of (u32_dist, point_id), keeping only the `cap` smallest.
98#[inline]
99fn heap_insert_sq8(heap: &mut BinaryHeap<(u32, u32)>, dist: u32, id: u32, cap: usize) {
100    if heap.len() < cap {
101        heap.push((dist, id));
102    } else if let Some(&(worst, _)) = heap.peek() {
103        if dist < worst {
104            heap.pop();
105            heap.push((dist, id));
106        }
107    }
108}
109
110impl PrismIndex {
111    /// Filtered k-NN search with automatic regime selection.
112    pub fn search(&self, query: &[f32], filter: &Filter, k: usize, ef: usize) -> Vec<SearchResult> {
113        assert_eq!(query.len(), self.store.dim);
114
115        let cell_indices = self.tree.filter_cells(filter.constraints());
116        let n_f = self.tree.count_points(&cell_indices);
117        let sigma = n_f as f32 / self.store.len as f32;
118        if sigma >= self.config.sigma_high {
119            self.regime_high_filtered(query, &cell_indices, k, ef)
120        } else if sigma > self.config.sigma_low {
121            self.regime_mid(query, &cell_indices, k, ef)
122        } else {
123            self.regime_low(query, filter, &cell_indices, k)
124        }
125    }
126
127    /// Per-cell SQ8 search: brute-force scan for small cells,
128    /// Vamana graph beam search for large cells.
129    fn regime_high_filtered(
130        &self,
131        query: &[f32],
132        cell_indices: &[usize],
133        k: usize,
134        ef: usize,
135    ) -> Vec<SearchResult> {
136        if cell_indices.is_empty() {
137            return Vec::new();
138        }
139
140        let q_code = self.sq8.quantize_query(query);
141        let q_binary = self.binary.encode_query(query);
142        let mut merged: BinaryHeap<(u32, u32)> = BinaryHeap::new();
143
144        if cell_indices.len() == self.tree.cells.len() {
145            // All cells match: binary pre-filter → SQ8 rerank over entire index.
146            let n = self.store.len as u32;
147            let rerank_budget = self.config.binary_rerank * ef;
148            if self.config.binary_rerank > 0 && (n as usize) > rerank_budget {
149                let mut binary_heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
150                for p in 0..n {
151                    let hd = distance::hamming(&q_binary, self.binary.code(p));
152                    heap_insert_sq8(&mut binary_heap, hd, p, rerank_budget);
153                }
154                for (_, p) in binary_heap {
155                    let dist = distance::l2_sq8(&q_code, self.sq8.code(p));
156                    heap_insert_sq8(&mut merged, dist, p, ef);
157                }
158            } else {
159                for p in 0..n {
160                    let dist = distance::l2_sq8(&q_code, self.sq8.code(p));
161                    heap_insert_sq8(&mut merged, dist, p, ef);
162                }
163            }
164        } else {
165            // Rank cells by SQ8 medoid distance
166            let mut ranked: Vec<(usize, u32)> = cell_indices
167                .iter()
168                .map(|&ci| {
169                    let d = distance::l2_sq8(&q_code, self.sq8.code(self.medoids[ci]));
170                    (ci, d)
171                })
172                .collect();
173            ranked.sort_unstable_by_key(|&(_, d)| d);
174
175            let scan_threshold = (ef * self.config.m_local).max(2000);
176
177            for &(ci, _) in &ranked {
178                let cands = self.search_cell(&q_code, &q_binary, ci, ef, scan_threshold);
179                for (sq8_dist, id) in cands {
180                    heap_insert_sq8(&mut merged, sq8_dist, id, ef);
181                }
182            }
183        }
184
185        // F32 rerank → top-k
186        let mut results: Vec<SearchResult> = merged
187            .into_iter()
188            .map(|(_, id)| SearchResult {
189                id,
190                dist: distance::distance(query, self.store.vector(id), self.config.metric),
191            })
192            .collect();
193        results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
194        results.truncate(k);
195        results
196    }
197
198    /// Bridge routing for medium selectivity. Traverses full graph, using
199    /// non-matching nodes as bridges when bridge score > τ. SQ8 traversal, f32 rerank.
200    fn regime_mid(
201        &self,
202        query: &[f32],
203        compatible_cells: &[usize],
204        k: usize,
205        ef: usize,
206    ) -> Vec<SearchResult> {
207        if compatible_cells.is_empty() {
208            return Vec::new();
209        }
210
211        let q_code = self.sq8.quantize_query(query);
212
213        // Cell compatibility lookup
214        let n_cells = self.tree.cells.len();
215        let mut cell_match = vec![false; n_cells];
216        for &ci in compatible_cells {
217            cell_match[ci] = true;
218        }
219
220        // Entry: closest compatible medoid
221        let (_, entry) = compatible_cells
222            .iter()
223            .map(|&ci| {
224                let d = distance::l2_sq8(&q_code, self.sq8.code(self.medoids[ci]));
225                (d, self.medoids[ci])
226            })
227            .min_by_key(|&(d, _)| d)
228            .unwrap();
229
230        let entry_dist = distance::l2_sq8(&q_code, self.sq8.code(entry));
231
232        let mut visited = Bitset::new(self.store.len);
233        visited.insert(entry);
234
235        let mut candidates: BinaryHeap<Reverse<(u32, u32)>> = BinaryHeap::new();
236        let mut results: BinaryHeap<(u32, u32)> = BinaryHeap::new();
237
238        candidates.push(Reverse((entry_dist, entry)));
239        results.push((entry_dist, entry));
240
241        let bridge_budget = (self.config.beta * ef as f32) as usize;
242        let mut bridges_used = 0usize;
243        let epsilon_factor = ((1.0 + self.config.epsilon) * (1.0 + self.config.epsilon)) as f64;
244
245        // Bridge threshold τ = σ/(1+σ) where σ = selectivity
246        let n_f: usize = compatible_cells
247            .iter()
248            .map(|&ci| self.tree.cells[ci].point_ids.len())
249            .sum();
250        let sigma = n_f as f32 / self.store.len as f32;
251        let tau = sigma / (1.0 + sigma);
252
253        while let Some(Reverse((d, c))) = candidates.pop() {
254            // Early termination
255            if results.len() >= ef {
256                if let Some(&(worst, _)) = results.peek() {
257                    if (d as f64) > (worst as f64) * epsilon_factor {
258                        break;
259                    }
260                }
261            }
262
263            if bridges_used >= bridge_budget {
264                break;
265            }
266
267            // Explore neighbors
268            let neighbors = self.graph.neighbors(c);
269            let sq8_dim = self.store.dim;
270
271            let mut unvisited_buf: Vec<u32> = Vec::with_capacity(neighbors.len());
272            for &w in neighbors {
273                if visited.insert(w) {
274                    unvisited_buf.push(w);
275                    prefetch_range(self.sq8.code(w).as_ptr(), sq8_dim);
276                }
277            }
278
279            for &w in &unvisited_buf {
280                let wd = distance::l2_sq8(&q_code, self.sq8.code(w));
281                let w_cell = self.point_cell[w as usize];
282
283                if cell_match[w_cell as usize] {
284                    // Matching node
285                    heap_insert_sq8(&mut results, wd, w, ef);
286                    candidates.push(Reverse((wd, w)));
287                } else {
288                    // Bridge routing
289                    let w_neighbors = self.graph.neighbors(w);
290                    if !w_neighbors.is_empty() {
291                        let matching_unvisited = w_neighbors
292                            .iter()
293                            .filter(|&&u| {
294                                cell_match[self.point_cell[u as usize] as usize]
295                                    && !visited.contains(u)
296                            })
297                            .count();
298                        let fraction = matching_unvisited as f32 / w_neighbors.len() as f32;
299
300                        // Bridge score: matching fraction × proximity
301                        let r = results.peek().map_or(1.0f32, |&(worst, _)| worst as f32);
302                        let bridge_score = fraction / (1.0 + wd as f32 / r.max(1.0));
303
304                        if bridge_score > tau {
305                            candidates.push(Reverse((wd, w)));
306                            bridges_used += 1;
307                        }
308                    }
309                }
310            }
311        }
312
313        // F32 rerank → top-k
314        let mut final_results: Vec<SearchResult> = results
315            .into_iter()
316            .map(|(_, id)| SearchResult {
317                id,
318                dist: distance::distance(query, self.store.vector(id), self.config.metric),
319            })
320            .collect();
321        final_results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
322        final_results.truncate(k);
323        final_results
324    }
325
326    /// SQ8 beam search within a cell's local graph. Returns (id, sq8_dist) pairs.
327    fn greedy_search_cell_sq8(&self, q_code: &[u8], cell_idx: usize, ef: usize) -> Vec<(u32, u32)> {
328        let pts = &self.tree.cells[cell_idx].point_ids;
329        let base = pts[0];
330        let sq8_dim = self.store.dim;
331
332        let entry = self.medoids[cell_idx];
333        let entry_dist = distance::l2_sq8(q_code, self.sq8.code(entry));
334
335        let mut visited = Bitset::new(pts.len());
336        visited.insert(entry - base);
337
338        let mut candidates: BinaryHeap<Reverse<(u32, u32)>> = BinaryHeap::new();
339        let mut results: BinaryHeap<(u32, u32)> = BinaryHeap::new();
340        let mut unvisited: Vec<u32> = Vec::with_capacity(32);
341
342        candidates.push(Reverse((entry_dist, entry)));
343        results.push((entry_dist, entry));
344
345        while let Some(Reverse((d, c))) = candidates.pop() {
346            if results.len() >= ef {
347                if let Some(&(worst, _)) = results.peek() {
348                    if d > worst {
349                        break;
350                    }
351                }
352            }
353
354            unvisited.clear();
355            for &w in self.local_graph.neighbors(c) {
356                if visited.insert(w - base) {
357                    unvisited.push(w);
358                    prefetch_range(self.sq8.code(w).as_ptr(), sq8_dim);
359                }
360            }
361
362            for &w in &unvisited {
363                let wd = distance::l2_sq8(q_code, self.sq8.code(w));
364                if results.len() < ef {
365                    candidates.push(Reverse((wd, w)));
366                    results.push((wd, w));
367                } else if let Some(&(worst, _)) = results.peek() {
368                    if wd < worst {
369                        results.pop();
370                        results.push((wd, w));
371                        candidates.push(Reverse((wd, w)));
372                    }
373                }
374            }
375        }
376
377        results
378            .into_vec()
379            .into_iter()
380            .map(|(d, id)| (id, d))
381            .collect()
382    }
383
384    /// REGIME_LOW — brute-force within compatible cells for very selective filters.
385    fn regime_low(
386        &self,
387        query: &[f32],
388        filter: &Filter,
389        cell_indices: &[usize],
390        k: usize,
391    ) -> Vec<SearchResult> {
392        let mut heap: BinaryHeap<(OrdF32, u32)> = BinaryHeap::new();
393        for &ci in cell_indices {
394            for &p in &self.tree.cells[ci].point_ids {
395                if filter.matches(&self.store, p) {
396                    let dist = distance::distance(query, self.store.vector(p), self.config.metric);
397                    if heap.len() < k {
398                        heap.push((OrdF32(dist), p));
399                    } else if let Some(&(OrdF32(worst), _)) = heap.peek() {
400                        if dist < worst {
401                            heap.pop();
402                            heap.push((OrdF32(dist), p));
403                        }
404                    }
405                }
406            }
407        }
408        let mut results: Vec<SearchResult> = heap
409            .into_iter()
410            .map(|(OrdF32(d), id)| SearchResult { id, dist: d })
411            .collect();
412        results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
413        results
414    }
415
416    /// MQCB (Multi-Query Cell Batching) — groups queries by target cell so
417    /// cell data stays warm in L3 across queries. Cells processed in parallel,
418    /// queries within each cell sequentially.
419    pub fn batch_search(
420        &self,
421        queries: &[f32],
422        filters: &[Filter],
423        nq: usize,
424        k: usize,
425        ef: usize,
426    ) -> Vec<Vec<SearchResult>> {
427        let dim = self.store.dim;
428        let n_cells = self.tree.cells.len();
429        let scan_threshold = (ef * self.config.m_local).max(2000);
430
431        // Precompute query codes and matching cells
432        let query_info: Vec<(Vec<u8>, Vec<u64>, Vec<usize>)> = (0..nq)
433            .into_par_iter()
434            .map(|qi| {
435                let q = &queries[qi * dim..(qi + 1) * dim];
436                let q_code = self.sq8.quantize_query(q);
437                let q_binary = self.binary.encode_query(q);
438                let cells = self.tree.filter_cells(filters[qi].constraints());
439                (q_code, q_binary, cells)
440            })
441            .collect();
442
443        // Classify queries by regime
444        let mut high_regime: Vec<usize> = Vec::with_capacity(nq);
445        let mut mid_regime: Vec<usize> = Vec::new();
446        let mut low_regime: Vec<usize> = Vec::new();
447        let mut unfiltered: Vec<usize> = Vec::new();
448        for (qi, info) in query_info.iter().enumerate() {
449            let cells = &info.2;
450            if cells.len() >= n_cells {
451                unfiltered.push(qi);
452            } else {
453                let n_f: usize = cells
454                    .iter()
455                    .map(|&ci| self.tree.cells[ci].point_ids.len())
456                    .sum();
457                let sigma = n_f as f32 / self.store.len as f32;
458                if sigma >= self.config.sigma_high {
459                    high_regime.push(qi);
460                } else if sigma > self.config.sigma_low {
461                    mid_regime.push(qi);
462                } else {
463                    low_regime.push(qi);
464                }
465            }
466        }
467
468        // Group HIGH-regime queries by cell for MQCB
469        let mut cell_queries: Vec<Vec<usize>> = vec![Vec::new(); n_cells];
470        for &qi in &high_regime {
471            for &ci in &query_info[qi].2 {
472                cell_queries[ci].push(qi);
473            }
474        }
475
476        // MQCB: cells in parallel, queries within each cell sequentially (cache warmth)
477        #[allow(clippy::type_complexity)]
478        let cell_results: Vec<Vec<(usize, Vec<(u32, u32)>)>> = cell_queries
479            .into_par_iter()
480            .enumerate()
481            .filter(|(_, qs)| !qs.is_empty())
482            .map(|(ci, qs)| {
483                qs.iter()
484                    .map(|&qi| {
485                        let q_code = &query_info[qi].0;
486                        let q_binary = &query_info[qi].1;
487                        let cands = self.search_cell(q_code, q_binary, ci, ef, scan_threshold);
488                        (qi, cands)
489                    })
490                    .collect()
491            })
492            .collect();
493
494        // Merge cell results into per-query SQ8 heaps
495        let mut query_heaps: Vec<BinaryHeap<(u32, u32)>> =
496            (0..nq).map(|_| BinaryHeap::new()).collect();
497        for cell_batch in cell_results {
498            for (qi, cands) in cell_batch {
499                for (sq8_dist, id) in cands {
500                    heap_insert_sq8(&mut query_heaps[qi], sq8_dist, id, ef);
501                }
502            }
503        }
504
505        // Unfiltered queries: binary pre-filter → SQ8 rerank
506        let unfilt_heaps: Vec<(usize, BinaryHeap<(u32, u32)>)> = unfiltered
507            .par_iter()
508            .map(|&qi| {
509                let q_code = &query_info[qi].0;
510                let q_binary = &query_info[qi].1;
511                let n = self.store.len as u32;
512                let rerank_budget = self.config.binary_rerank * ef;
513                let mut heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
514                if self.config.binary_rerank > 0 && (n as usize) > rerank_budget {
515                    let mut binary_heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
516                    for p in 0..n {
517                        let hd = distance::hamming(q_binary, self.binary.code(p));
518                        heap_insert_sq8(&mut binary_heap, hd, p, rerank_budget);
519                    }
520                    for (_, p) in binary_heap {
521                        let dist = distance::l2_sq8(q_code, self.sq8.code(p));
522                        heap_insert_sq8(&mut heap, dist, p, ef);
523                    }
524                } else {
525                    for p in 0..n {
526                        let dist = distance::l2_sq8(q_code, self.sq8.code(p));
527                        heap_insert_sq8(&mut heap, dist, p, ef);
528                    }
529                }
530                (qi, heap)
531            })
532            .collect();
533        for (qi, heap) in unfilt_heaps {
534            query_heaps[qi] = heap;
535        }
536
537        // F32 rerank
538        let mut all_results: Vec<Vec<SearchResult>> = query_heaps
539            .into_par_iter()
540            .enumerate()
541            .map(|(qi, heap)| {
542                if heap.is_empty() {
543                    return Vec::new();
544                }
545                let q = &queries[qi * dim..(qi + 1) * dim];
546                let mut results: Vec<SearchResult> = heap
547                    .into_iter()
548                    .map(|(_, id)| SearchResult {
549                        id,
550                        dist: distance::distance(q, self.store.vector(id), self.config.metric),
551                    })
552                    .collect();
553                results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
554                results.truncate(k);
555                results
556            })
557            .collect();
558
559        // MID-regime queries (bridge routing)
560        if !mid_regime.is_empty() {
561            let mid_results: Vec<(usize, Vec<SearchResult>)> = mid_regime
562                .par_iter()
563                .map(|&qi| {
564                    let q = &queries[qi * dim..(qi + 1) * dim];
565                    let cells = &query_info[qi].2;
566                    let results = self.regime_mid(q, cells, k, ef);
567                    (qi, results)
568                })
569                .collect();
570            for (qi, results) in mid_results {
571                all_results[qi] = results;
572            }
573        }
574
575        // LOW-regime queries (brute-force)
576        if !low_regime.is_empty() {
577            let low_results: Vec<(usize, Vec<SearchResult>)> = low_regime
578                .par_iter()
579                .map(|&qi| {
580                    let q = &queries[qi * dim..(qi + 1) * dim];
581                    let results = self.search(q, &filters[qi], k, ef);
582                    (qi, results)
583                })
584                .collect();
585            for (qi, results) in low_results {
586                all_results[qi] = results;
587            }
588        }
589
590        all_results
591    }
592
593    /// Search a single cell. Small cells: SQ8 scan (with optional binary pre-filter).
594    /// Large cells: graph search with adaptive ef. Returns (sq8_dist, point_id) pairs.
595    fn search_cell(
596        &self,
597        q_code: &[u8],
598        q_binary: &[u64],
599        cell_idx: usize,
600        ef: usize,
601        scan_threshold: usize,
602    ) -> Vec<(u32, u32)> {
603        let pts = &self.tree.cells[cell_idx].point_ids;
604        let mut heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
605
606        if pts.len() <= scan_threshold {
607            let base = pts[0];
608            let rerank_budget = self.config.binary_rerank * ef;
609
610            if self.config.binary_rerank > 0 && pts.len() > rerank_budget {
611                // Binary Hamming pre-filter → SQ8 rerank
612                let mut binary_heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
613                for i in 0..pts.len() {
614                    let p = base + i as u32;
615                    let hd = distance::hamming(q_binary, self.binary.code(p));
616                    heap_insert_sq8(&mut binary_heap, hd, p, rerank_budget);
617                }
618                for (_, p) in binary_heap {
619                    let dist = distance::l2_sq8(q_code, self.sq8.code(p));
620                    heap_insert_sq8(&mut heap, dist, p, ef);
621                }
622            } else {
623                // Pure SQ8 scan (small cell or binary pre-filter disabled)
624                for i in 0..pts.len() {
625                    let p = base + i as u32;
626                    let dist = distance::l2_sq8(q_code, self.sq8.code(p));
627                    heap_insert_sq8(&mut heap, dist, p, ef);
628                }
629            }
630        } else {
631            // Adaptive ef: scale graph search budget for large cells
632            let ef_cell = ef.max((pts.len() / 200).min(ef * 5));
633            let local = self.greedy_search_cell_sq8(q_code, cell_idx, ef_cell);
634            for (id, dist) in local {
635                heap_insert_sq8(&mut heap, dist, id, ef);
636            }
637        }
638
639        heap.into_vec()
640    }
641}
642
643#[cfg(test)]
644mod tests {
645    use super::super::construct::{PrismConfig, PrismIndex};
646    use super::super::filter::Filter;
647    use super::super::point::PointStore;
648
649    fn build_test_index() -> PrismIndex {
650        let mut store = PointStore::new(2, 1);
651        for i in 0..10 {
652            let x = (i as f32) * 0.1;
653            let attr = if i < 5 { 0 } else { 1 };
654            store.push(&[x, x], &[attr]);
655        }
656        let config = PrismConfig {
657            m_local: 4,
658            m_greedy: 2,
659            m_random: 4,
660            t: 1,
661            alpha: 0.0,
662            beam_width: 10,
663            ..Default::default()
664        };
665        PrismIndex::build(store, config)
666    }
667
668    #[test]
669    fn test_search_no_filter() {
670        let index = build_test_index();
671        let results = index.search(&[0.25, 0.25], &Filter::none(), 3, 10);
672        assert_eq!(results.len(), 3);
673        for r in &results {
674            assert!(r.dist >= 0.0);
675        }
676    }
677
678    #[test]
679    fn test_search_with_filter() {
680        let index = build_test_index();
681        let filter = Filter::eq(0, 1);
682        let results = index.search(&[0.5, 0.5], &filter, 3, 10);
683        assert!(!results.is_empty());
684        for r in &results {
685            assert!(filter.matches(&index.store, r.id));
686        }
687    }
688
689    #[test]
690    fn test_graph_search_mid_selectivity() {
691        let dim = 16;
692        let n = 2000;
693        let n_vals = 20;
694        let mut store = PointStore::new(dim, 1);
695        for i in 0..n {
696            let vec: Vec<f32> = (0..dim).map(|d| ((i * dim + d) as f32).sin()).collect();
697            store.push(&vec, &[(i % n_vals) as u32]);
698        }
699        let config = PrismConfig {
700            m_local: 4,
701            m_greedy: 2,
702            m_random: 4,
703            t: 1,
704            beam_width: 10,
705            ..Default::default()
706        };
707        let index = PrismIndex::build(store, config);
708
709        let query: Vec<f32> = (0..dim).map(|d| (d as f32 * 0.3).sin()).collect();
710        let filter = Filter::eq(0, 0);
711        let k = 5;
712        let ef = 10;
713
714        let results = index.search(&query, &filter, k, ef);
715        assert!(!results.is_empty());
716        assert!(results.len() <= k);
717        for r in &results {
718            assert!(filter.matches(&index.store, r.id));
719        }
720        for w in results.windows(2) {
721            assert!(w[0].dist <= w[1].dist);
722        }
723    }
724
725    #[test]
726    fn test_search_empty_filter() {
727        let index = build_test_index();
728        let filter = Filter::eq(0, 99);
729        let results = index.search(&[0.0, 0.0], &filter, 3, 10);
730        assert!(results.is_empty());
731    }
732
733    #[test]
734    fn test_regime_mid_bridge_routing() {
735        // Build an index where mid-selectivity queries hit REGIME_MID.
736        // 20 attribute values × 100 points/value = 2000 points.
737        // sigma_high=0.10, each value is 5% selectivity → routes through MID.
738        let dim = 16;
739        let n = 2000;
740        let n_vals = 20;
741        let mut store = PointStore::new(dim, 1);
742        for i in 0..n {
743            let vec: Vec<f32> = (0..dim).map(|d| ((i * dim + d) as f32).sin()).collect();
744            store.push(&vec, &[(i % n_vals) as u32]);
745        }
746        let config = PrismConfig {
747            m_local: 4,
748            m_greedy: 4,
749            m_random: 4,
750            t: 1,
751            beam_width: 20,
752            sigma_high: 0.10,
753            sigma_low: 0.001,
754            beta: 3.0,
755            epsilon: 0.2,
756            ..Default::default()
757        };
758        let index = PrismIndex::build(store, config);
759
760        // Filter for value 0 → 100 out of 2000 points = 5% selectivity → MID regime
761        let query: Vec<f32> = (0..dim).map(|d| (d as f32 * 0.3).sin()).collect();
762        let filter = Filter::eq(0, 0);
763        let k = 5;
764        let ef = 50;
765
766        let results = index.search(&query, &filter, k, ef);
767        assert!(!results.is_empty());
768        assert!(results.len() <= k);
769        for r in &results {
770            assert!(filter.matches(&index.store, r.id));
771        }
772        for w in results.windows(2) {
773            assert!(w[0].dist <= w[1].dist);
774        }
775    }
776
777    #[test]
778    fn test_batch_search_mixed_regimes() {
779        // Test batch_search with queries spanning multiple regimes.
780        let dim = 8;
781        let n = 1000;
782        let n_vals = 10;
783        let mut store = PointStore::new(dim, 1);
784        for i in 0..n {
785            let vec: Vec<f32> = (0..dim).map(|d| ((i * dim + d) as f32).sin()).collect();
786            store.push(&vec, &[(i % n_vals) as u32]);
787        }
788        let config = PrismConfig {
789            m_local: 4,
790            m_greedy: 4,
791            m_random: 4,
792            t: 1,
793            beam_width: 20,
794            sigma_high: 0.10,
795            sigma_low: 0.001,
796            ..Default::default()
797        };
798        let index = PrismIndex::build(store, config);
799
800        let k = 3;
801        let ef = 20;
802        let nq = 3;
803
804        // Query 0: unfiltered (sigma=1.0 → HIGH)
805        // Query 1: filter for value 0 (10% → MID regime)
806        // Query 2: filter for value 5 (10% → MID regime)
807        let queries: Vec<f32> = (0..nq)
808            .flat_map(|qi| (0..dim).map(move |d| ((qi * dim + d) as f32 * 0.5).sin()))
809            .collect();
810        let filters = vec![Filter::none(), Filter::eq(0, 0), Filter::eq(0, 5)];
811
812        let results = index.batch_search(&queries, &filters, nq, k, ef);
813        assert_eq!(results.len(), nq);
814        for (qi, res) in results.iter().enumerate() {
815            assert!(!res.is_empty(), "query {} returned no results", qi);
816            assert!(res.len() <= k);
817            for r in res {
818                assert!(filters[qi].matches(&index.store, r.id));
819            }
820        }
821    }
822
823    #[test]
824    fn test_binary_prefilter_recall() {
825        // Compare binary pre-filter (binary_rerank=4) vs pure SQ8 (binary_rerank=0).
826        // Results should be comparable — binary pre-filter is an approximation.
827        let dim = 64;
828        let n = 2000;
829        let n_vals = 10;
830        let mut store = PointStore::new(dim, 1);
831        for i in 0..n {
832            let vec: Vec<f32> = (0..dim)
833                .map(|d| ((i * dim + d) as f32 * 0.01).sin())
834                .collect();
835            store.push(&vec, &[(i % n_vals) as u32]);
836        }
837
838        // Build with binary_rerank=4 (pre-filter enabled)
839        let config_binary = PrismConfig {
840            m_local: 4,
841            m_greedy: 2,
842            m_random: 4,
843            t: 1,
844            beam_width: 10,
845            binary_rerank: 4,
846            ..Default::default()
847        };
848        let index_binary = PrismIndex::build(store, config_binary);
849
850        let query: Vec<f32> = (0..dim).map(|d| (d as f32 * 0.3).sin()).collect();
851        let filter = Filter::eq(0, 0);
852        let k = 10;
853        let ef = 50;
854
855        let results_binary = index_binary.search(&query, &filter, k, ef);
856        assert!(!results_binary.is_empty());
857        assert!(results_binary.len() <= k);
858        for r in &results_binary {
859            assert!(filter.matches(&index_binary.store, r.id));
860        }
861        for w in results_binary.windows(2) {
862            assert!(w[0].dist <= w[1].dist);
863        }
864    }
865
866    #[test]
867    fn test_binary_prefilter_batch() {
868        // Verify batch_search works with binary pre-filter enabled.
869        let dim = 32;
870        let n = 500;
871        let n_vals = 5;
872        let mut store = PointStore::new(dim, 1);
873        for i in 0..n {
874            let vec: Vec<f32> = (0..dim)
875                .map(|d| ((i * dim + d) as f32 * 0.02).sin())
876                .collect();
877            store.push(&vec, &[(i % n_vals) as u32]);
878        }
879
880        let config = PrismConfig {
881            m_local: 4,
882            m_greedy: 2,
883            m_random: 4,
884            t: 1,
885            beam_width: 10,
886            binary_rerank: 4,
887            ..Default::default()
888        };
889        let index = PrismIndex::build(store, config);
890
891        let nq = 5;
892        let k = 5;
893        let ef = 20;
894        let queries: Vec<f32> = (0..nq)
895            .flat_map(|qi| (0..dim).map(move |d| ((qi * dim + d) as f32 * 0.1).sin()))
896            .collect();
897        let filters: Vec<Filter> = (0..nq)
898            .map(|qi| Filter::eq(0, (qi % n_vals) as u32))
899            .collect();
900
901        let results = index.batch_search(&queries, &filters, nq, k, ef);
902        assert_eq!(results.len(), nq);
903        for (qi, res) in results.iter().enumerate() {
904            assert!(!res.is_empty(), "query {} returned no results", qi);
905            assert!(res.len() <= k);
906            for r in res {
907                assert!(filters[qi].matches(&index.store, r.id));
908            }
909        }
910    }
911}