Skip to main content

citadel_vector/vendored/prism/
search.rs

1use super::construct::PrismIndex;
2use super::distance;
3use super::filter::Filter;
4
5use rayon::prelude::*;
6use std::cmp::Reverse;
7use std::collections::BinaryHeap;
8
9/// A search result: (point_id, distance).
10#[derive(Clone, Debug)]
11pub struct SearchResult {
12    pub id: u32,
13    pub dist: f32,
14}
15
16/// Bitset for O(1) visited tracking, sized to the cell for L1 cache locality.
17struct Bitset {
18    bits: Vec<u64>,
19}
20
21impl Bitset {
22    fn new(n: usize) -> Self {
23        Self {
24            bits: vec![0u64; n.div_ceil(64)],
25        }
26    }
27
28    /// Returns true if the bit was newly set (not previously visited).
29    #[inline]
30    fn insert(&mut self, i: u32) -> bool {
31        let word = i as usize >> 6;
32        let bit = 1u64 << (i & 63);
33        if self.bits[word] & bit != 0 {
34            false
35        } else {
36            self.bits[word] |= bit;
37            true
38        }
39    }
40
41    /// Check if a bit is set without modifying the bitset.
42    #[inline]
43    fn contains(&self, i: u32) -> bool {
44        let word = i as usize >> 6;
45        let bit = 1u64 << (i & 63);
46        self.bits[word] & bit != 0
47    }
48}
49
50/// Prefetch into L1.
51#[cfg(target_arch = "x86_64")]
52#[target_feature(enable = "sse")]
53#[inline]
54unsafe fn prefetch_t0(ptr: *const u8) {
55    std::arch::x86_64::_mm_prefetch(ptr as *const i8, std::arch::x86_64::_MM_HINT_T0);
56}
57
58/// Software prefetch hint.
59#[inline(always)]
60fn prefetch_read(ptr: *const u8) {
61    #[cfg(target_arch = "x86_64")]
62    unsafe {
63        prefetch_t0(ptr);
64    }
65    #[cfg(not(target_arch = "x86_64"))]
66    let _ = ptr;
67}
68
69/// Prefetch `len` bytes starting at `ptr`.
70#[inline(always)]
71fn prefetch_range(ptr: *const u8, len: usize) {
72    let mut offset = 0;
73    while offset < len {
74        prefetch_read(unsafe { ptr.add(offset) });
75        offset += 64;
76    }
77}
78
79/// Ordered f32 wrapper for use in BinaryHeap.
80#[derive(Clone, Copy, PartialEq)]
81struct OrdF32(f32);
82
83impl Eq for OrdF32 {}
84impl PartialOrd for OrdF32 {
85    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
86        Some(self.cmp(other))
87    }
88}
89impl Ord for OrdF32 {
90    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
91        self.0
92            .partial_cmp(&other.0)
93            .unwrap_or(std::cmp::Ordering::Equal)
94    }
95}
96
97/// Insert into a bounded max-heap of (u32_dist, point_id), keeping only the `cap` smallest.
98#[inline]
99fn heap_insert_sq8(heap: &mut BinaryHeap<(u32, u32)>, dist: u32, id: u32, cap: usize) {
100    if heap.len() < cap {
101        heap.push((dist, id));
102    } else if let Some(&(worst, _)) = heap.peek() {
103        if dist < worst {
104            heap.pop();
105            heap.push((dist, id));
106        }
107    }
108}
109
110impl PrismIndex {
111    /// Filtered k-NN search with automatic regime selection.
112    pub fn search(&self, query: &[f32], filter: &Filter, k: usize, ef: usize) -> Vec<SearchResult> {
113        assert_eq!(query.len(), self.store.dim);
114
115        // Match the build-time Cosine normalization so code-space distances
116        // stay rank-faithful. Reported distances are unchanged (cosine is
117        // scale-invariant in both arguments).
118        let normalized;
119        let query = if self.config.metric == distance::Metric::Cosine {
120            normalized = distance::normalized(query);
121            normalized.as_slice()
122        } else {
123            query
124        };
125
126        let cell_indices = self.tree.filter_cells(filter.constraints());
127        let n_f = self.tree.count_points(&cell_indices);
128        let sigma = n_f as f32 / self.store.len as f32;
129        if sigma >= self.config.sigma_high {
130            self.regime_high_filtered(query, &cell_indices, k, ef)
131        } else if sigma > self.config.sigma_low {
132            self.regime_mid(query, &cell_indices, k, ef)
133        } else {
134            self.regime_low(query, filter, &cell_indices, k)
135        }
136    }
137
138    /// Heap-ordered candidate distance from the query to point `p`. L2 and
139    /// (build-normalized) Cosine rank by SQ8 codes; InnerProduct cannot be
140    /// ranked in code-space L2, so it ranks by the exact f32 metric mapped to
141    /// a total-order key (mirrors the construct-side `use_sq8` gate).
142    #[inline]
143    fn cand_dist(&self, query: &[f32], q_code: &[u8], p: u32) -> u32 {
144        match self.config.metric {
145            distance::Metric::L2 | distance::Metric::Cosine => {
146                distance::l2_sq8(q_code, self.sq8.code(p))
147            }
148            distance::Metric::InnerProduct => distance::ord_key(distance::distance(
149                query,
150                self.store.vector(p),
151                distance::Metric::InnerProduct,
152            )),
153        }
154    }
155
156    /// Per-cell SQ8 search: brute-force scan for small cells,
157    /// Vamana graph beam search for large cells.
158    fn regime_high_filtered(
159        &self,
160        query: &[f32],
161        cell_indices: &[usize],
162        k: usize,
163        ef: usize,
164    ) -> Vec<SearchResult> {
165        if cell_indices.is_empty() {
166            return Vec::new();
167        }
168
169        let q_code = self.sq8.quantize_query(query);
170        let q_binary = if self.config.binary_rerank > 0 {
171            self.binary.encode_query(query)
172        } else {
173            Vec::new()
174        };
175        let mut merged: BinaryHeap<(u32, u32)> = BinaryHeap::new();
176
177        if cell_indices.len() == self.tree.cells.len() {
178            // All cells match: binary pre-filter -> code-space rerank over entire index.
179            let n = self.store.len as u32;
180            let rerank_budget = self.config.binary_rerank * ef;
181            if self.config.binary_rerank > 0 && (n as usize) > rerank_budget {
182                let mut binary_heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
183                for p in 0..n {
184                    let hd = distance::hamming(&q_binary, self.binary.code(p));
185                    heap_insert_sq8(&mut binary_heap, hd, p, rerank_budget);
186                }
187                for (_, p) in binary_heap {
188                    let dist = self.cand_dist(query, &q_code, p);
189                    heap_insert_sq8(&mut merged, dist, p, ef);
190                }
191            } else {
192                for p in 0..n {
193                    let dist = self.cand_dist(query, &q_code, p);
194                    heap_insert_sq8(&mut merged, dist, p, ef);
195                }
196            }
197        } else {
198            // Visit cells nearest-medoid-first so the ef heap tightens early.
199            let mut ranked: Vec<(usize, u32)> = cell_indices
200                .iter()
201                .map(|&ci| {
202                    let d = self.cand_dist(query, &q_code, self.medoids[ci]);
203                    (ci, d)
204                })
205                .collect();
206            ranked.sort_unstable_by_key(|&(_, d)| d);
207
208            let scan_threshold = (ef * self.config.m_local).max(2000);
209
210            for &(ci, _) in &ranked {
211                let cands = self.search_cell(query, &q_code, &q_binary, ci, ef, scan_threshold);
212                for (cand_dist, id) in cands {
213                    heap_insert_sq8(&mut merged, cand_dist, id, ef);
214                }
215            }
216        }
217
218        let mut results: Vec<SearchResult> = merged
219            .into_iter()
220            .map(|(_, id)| SearchResult {
221                id,
222                dist: distance::distance(query, self.store.vector(id), self.config.metric),
223            })
224            .collect();
225        results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
226        results.truncate(k);
227        results
228    }
229
230    /// Bridge routing for medium selectivity. Traverses full graph, using
231    /// non-matching nodes as bridges when bridge score > tau. SQ8 traversal, f32 rerank.
232    fn regime_mid(
233        &self,
234        query: &[f32],
235        compatible_cells: &[usize],
236        k: usize,
237        ef: usize,
238    ) -> Vec<SearchResult> {
239        if compatible_cells.is_empty() {
240            return Vec::new();
241        }
242
243        let q_code = self.sq8.quantize_query(query);
244
245        let n_cells = self.tree.cells.len();
246        let mut cell_match = vec![false; n_cells];
247        for &ci in compatible_cells {
248            cell_match[ci] = true;
249        }
250
251        let (_, entry) = compatible_cells
252            .iter()
253            .map(|&ci| {
254                let d = self.cand_dist(query, &q_code, self.medoids[ci]);
255                (d, self.medoids[ci])
256            })
257            .min_by_key(|&(d, _)| d)
258            .unwrap();
259
260        let entry_dist = self.cand_dist(query, &q_code, entry);
261
262        let mut visited = Bitset::new(self.store.len);
263        visited.insert(entry);
264
265        let mut candidates: BinaryHeap<Reverse<(u32, u32)>> = BinaryHeap::new();
266        let mut results: BinaryHeap<(u32, u32)> = BinaryHeap::new();
267
268        candidates.push(Reverse((entry_dist, entry)));
269        results.push((entry_dist, entry));
270
271        let bridge_budget = (self.config.beta * ef as f32) as usize;
272        let mut bridges_used = 0usize;
273        let epsilon_factor = ((1.0 + self.config.epsilon) * (1.0 + self.config.epsilon)) as f64;
274
275        // Bridge threshold tau = sigma / (1 + sigma), sigma = selectivity.
276        let n_f: usize = compatible_cells
277            .iter()
278            .map(|&ci| self.tree.cells[ci].point_ids.len())
279            .sum();
280        let sigma = n_f as f32 / self.store.len as f32;
281        let tau = sigma / (1.0 + sigma);
282
283        while let Some(Reverse((d, c))) = candidates.pop() {
284            if results.len() >= ef {
285                if let Some(&(worst, _)) = results.peek() {
286                    if (d as f64) > (worst as f64) * epsilon_factor {
287                        break;
288                    }
289                }
290            }
291
292            if bridges_used >= bridge_budget {
293                break;
294            }
295
296            let neighbors = self.graph.neighbors(c);
297            let sq8_dim = self.store.dim;
298
299            let mut unvisited_buf: Vec<u32> = Vec::with_capacity(neighbors.len());
300            for &w in neighbors {
301                if visited.insert(w) {
302                    unvisited_buf.push(w);
303                    prefetch_range(self.sq8.code(w).as_ptr(), sq8_dim);
304                }
305            }
306
307            for &w in &unvisited_buf {
308                let wd = self.cand_dist(query, &q_code, w);
309                let w_cell = self.point_cell[w as usize];
310
311                if cell_match[w_cell as usize] {
312                    heap_insert_sq8(&mut results, wd, w, ef);
313                    candidates.push(Reverse((wd, w)));
314                } else {
315                    let w_neighbors = self.graph.neighbors(w);
316                    if !w_neighbors.is_empty() {
317                        let matching_unvisited = w_neighbors
318                            .iter()
319                            .filter(|&&u| {
320                                cell_match[self.point_cell[u as usize] as usize]
321                                    && !visited.contains(u)
322                            })
323                            .count();
324                        let fraction = matching_unvisited as f32 / w_neighbors.len() as f32;
325
326                        // Bridge score: matching fraction x proximity.
327                        let r = results.peek().map_or(1.0f32, |&(worst, _)| worst as f32);
328                        let bridge_score = fraction / (1.0 + wd as f32 / r.max(1.0));
329
330                        if bridge_score > tau {
331                            candidates.push(Reverse((wd, w)));
332                            bridges_used += 1;
333                        }
334                    }
335                }
336            }
337        }
338
339        let mut final_results: Vec<SearchResult> = results
340            .into_iter()
341            .map(|(_, id)| SearchResult {
342                id,
343                dist: distance::distance(query, self.store.vector(id), self.config.metric),
344            })
345            .collect();
346        final_results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
347        final_results.truncate(k);
348        final_results
349    }
350
351    /// Code-space beam search within a cell's local graph. Returns (id, cand_dist) pairs.
352    fn greedy_search_cell_sq8(
353        &self,
354        query: &[f32],
355        q_code: &[u8],
356        cell_idx: usize,
357        ef: usize,
358    ) -> Vec<(u32, u32)> {
359        let pts = &self.tree.cells[cell_idx].point_ids;
360        let base = pts[0];
361        let sq8_dim = self.store.dim;
362
363        let entry = self.medoids[cell_idx];
364        let entry_dist = self.cand_dist(query, q_code, entry);
365
366        let mut visited = Bitset::new(pts.len());
367        visited.insert(entry - base);
368
369        let mut candidates: BinaryHeap<Reverse<(u32, u32)>> = BinaryHeap::new();
370        let mut results: BinaryHeap<(u32, u32)> = BinaryHeap::new();
371        let mut unvisited: Vec<u32> = Vec::with_capacity(32);
372
373        candidates.push(Reverse((entry_dist, entry)));
374        results.push((entry_dist, entry));
375
376        while let Some(Reverse((d, c))) = candidates.pop() {
377            if results.len() >= ef {
378                if let Some(&(worst, _)) = results.peek() {
379                    if d > worst {
380                        break;
381                    }
382                }
383            }
384
385            unvisited.clear();
386            for &w in self.local_graph.neighbors(c) {
387                if visited.insert(w - base) {
388                    unvisited.push(w);
389                    prefetch_range(self.sq8.code(w).as_ptr(), sq8_dim);
390                }
391            }
392
393            for &w in &unvisited {
394                let wd = self.cand_dist(query, q_code, w);
395                if results.len() < ef {
396                    candidates.push(Reverse((wd, w)));
397                    results.push((wd, w));
398                } else if let Some(&(worst, _)) = results.peek() {
399                    if wd < worst {
400                        results.pop();
401                        results.push((wd, w));
402                        candidates.push(Reverse((wd, w)));
403                    }
404                }
405            }
406        }
407
408        results
409            .into_vec()
410            .into_iter()
411            .map(|(d, id)| (id, d))
412            .collect()
413    }
414
415    /// REGIME_LOW: brute-force within compatible cells for very selective filters.
416    fn regime_low(
417        &self,
418        query: &[f32],
419        filter: &Filter,
420        cell_indices: &[usize],
421        k: usize,
422    ) -> Vec<SearchResult> {
423        let mut heap: BinaryHeap<(OrdF32, u32)> = BinaryHeap::new();
424        for &ci in cell_indices {
425            for &p in &self.tree.cells[ci].point_ids {
426                if filter.matches(&self.store, p) {
427                    let dist = distance::distance(query, self.store.vector(p), self.config.metric);
428                    if heap.len() < k {
429                        heap.push((OrdF32(dist), p));
430                    } else if let Some(&(OrdF32(worst), _)) = heap.peek() {
431                        if dist < worst {
432                            heap.pop();
433                            heap.push((OrdF32(dist), p));
434                        }
435                    }
436                }
437            }
438        }
439        let mut results: Vec<SearchResult> = heap
440            .into_iter()
441            .map(|(OrdF32(d), id)| SearchResult { id, dist: d })
442            .collect();
443        results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
444        results
445    }
446
447    /// MQCB (Multi-Query Cell Batching): groups queries by target cell so
448    /// cell data stays warm in L3 across queries. Cells processed in parallel,
449    /// queries within each cell sequentially.
450    pub fn batch_search(
451        &self,
452        queries: &[f32],
453        filters: &[Filter],
454        nq: usize,
455        k: usize,
456        ef: usize,
457    ) -> Vec<Vec<SearchResult>> {
458        let dim = self.store.dim;
459        let n_cells = self.tree.cells.len();
460        let scan_threshold = (ef * self.config.m_local).max(2000);
461
462        // Match the build-time Cosine normalization (see `search`).
463        let normalized;
464        let queries = if self.config.metric == distance::Metric::Cosine {
465            let mut buf = queries.to_vec();
466            distance::normalize_rows(&mut buf, dim);
467            normalized = buf;
468            normalized.as_slice()
469        } else {
470            queries
471        };
472
473        let query_info: Vec<(Vec<u8>, Vec<u64>, Vec<usize>)> = (0..nq)
474            .into_par_iter()
475            .map(|qi| {
476                let q = &queries[qi * dim..(qi + 1) * dim];
477                let q_code = self.sq8.quantize_query(q);
478                let q_binary = if self.config.binary_rerank > 0 {
479                    self.binary.encode_query(q)
480                } else {
481                    Vec::new()
482                };
483                let cells = self.tree.filter_cells(filters[qi].constraints());
484                (q_code, q_binary, cells)
485            })
486            .collect();
487
488        let mut high_regime: Vec<usize> = Vec::with_capacity(nq);
489        let mut mid_regime: Vec<usize> = Vec::new();
490        let mut low_regime: Vec<usize> = Vec::new();
491        let mut unfiltered: Vec<usize> = Vec::new();
492        for (qi, info) in query_info.iter().enumerate() {
493            let cells = &info.2;
494            if cells.len() >= n_cells {
495                unfiltered.push(qi);
496            } else {
497                let n_f: usize = cells
498                    .iter()
499                    .map(|&ci| self.tree.cells[ci].point_ids.len())
500                    .sum();
501                let sigma = n_f as f32 / self.store.len as f32;
502                if sigma >= self.config.sigma_high {
503                    high_regime.push(qi);
504                } else if sigma > self.config.sigma_low {
505                    mid_regime.push(qi);
506                } else {
507                    low_regime.push(qi);
508                }
509            }
510        }
511
512        let mut cell_queries: Vec<Vec<usize>> = vec![Vec::new(); n_cells];
513        for &qi in &high_regime {
514            for &ci in &query_info[qi].2 {
515                cell_queries[ci].push(qi);
516            }
517        }
518
519        // Cells in parallel, queries within each cell sequentially: cell data
520        // stays warm in cache across queries.
521        #[allow(clippy::type_complexity)]
522        let cell_results: Vec<Vec<(usize, Vec<(u32, u32)>)>> = cell_queries
523            .into_par_iter()
524            .enumerate()
525            .filter(|(_, qs)| !qs.is_empty())
526            .map(|(ci, qs)| {
527                qs.iter()
528                    .map(|&qi| {
529                        let q = &queries[qi * dim..(qi + 1) * dim];
530                        let q_code = &query_info[qi].0;
531                        let q_binary = &query_info[qi].1;
532                        let cands = self.search_cell(q, q_code, q_binary, ci, ef, scan_threshold);
533                        (qi, cands)
534                    })
535                    .collect()
536            })
537            .collect();
538
539        let mut query_heaps: Vec<BinaryHeap<(u32, u32)>> =
540            (0..nq).map(|_| BinaryHeap::new()).collect();
541        for cell_batch in cell_results {
542            for (qi, cands) in cell_batch {
543                for (sq8_dist, id) in cands {
544                    heap_insert_sq8(&mut query_heaps[qi], sq8_dist, id, ef);
545                }
546            }
547        }
548
549        let unfilt_heaps: Vec<(usize, BinaryHeap<(u32, u32)>)> = unfiltered
550            .par_iter()
551            .map(|&qi| {
552                let q = &queries[qi * dim..(qi + 1) * dim];
553                let q_code = &query_info[qi].0;
554                let q_binary = &query_info[qi].1;
555                let n = self.store.len as u32;
556                let rerank_budget = self.config.binary_rerank * ef;
557                let mut heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
558                if self.config.binary_rerank > 0 && (n as usize) > rerank_budget {
559                    let mut binary_heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
560                    for p in 0..n {
561                        let hd = distance::hamming(q_binary, self.binary.code(p));
562                        heap_insert_sq8(&mut binary_heap, hd, p, rerank_budget);
563                    }
564                    for (_, p) in binary_heap {
565                        let dist = self.cand_dist(q, q_code, p);
566                        heap_insert_sq8(&mut heap, dist, p, ef);
567                    }
568                } else {
569                    for p in 0..n {
570                        let dist = self.cand_dist(q, q_code, p);
571                        heap_insert_sq8(&mut heap, dist, p, ef);
572                    }
573                }
574                (qi, heap)
575            })
576            .collect();
577        for (qi, heap) in unfilt_heaps {
578            query_heaps[qi] = heap;
579        }
580
581        let mut all_results: Vec<Vec<SearchResult>> = query_heaps
582            .into_par_iter()
583            .enumerate()
584            .map(|(qi, heap)| {
585                if heap.is_empty() {
586                    return Vec::new();
587                }
588                let q = &queries[qi * dim..(qi + 1) * dim];
589                let mut results: Vec<SearchResult> = heap
590                    .into_iter()
591                    .map(|(_, id)| SearchResult {
592                        id,
593                        dist: distance::distance(q, self.store.vector(id), self.config.metric),
594                    })
595                    .collect();
596                results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
597                results.truncate(k);
598                results
599            })
600            .collect();
601
602        if !mid_regime.is_empty() {
603            let mid_results: Vec<(usize, Vec<SearchResult>)> = mid_regime
604                .par_iter()
605                .map(|&qi| {
606                    let q = &queries[qi * dim..(qi + 1) * dim];
607                    let cells = &query_info[qi].2;
608                    let results = self.regime_mid(q, cells, k, ef);
609                    (qi, results)
610                })
611                .collect();
612            for (qi, results) in mid_results {
613                all_results[qi] = results;
614            }
615        }
616
617        if !low_regime.is_empty() {
618            let low_results: Vec<(usize, Vec<SearchResult>)> = low_regime
619                .par_iter()
620                .map(|&qi| {
621                    let q = &queries[qi * dim..(qi + 1) * dim];
622                    let results = self.search(q, &filters[qi], k, ef);
623                    (qi, results)
624                })
625                .collect();
626            for (qi, results) in low_results {
627                all_results[qi] = results;
628            }
629        }
630
631        all_results
632    }
633
634    /// Search a single cell. Small cells: code-space scan (with optional binary
635    /// pre-filter). Large cells: graph search with adaptive ef. Returns
636    /// (cand_dist, point_id) pairs.
637    fn search_cell(
638        &self,
639        query: &[f32],
640        q_code: &[u8],
641        q_binary: &[u64],
642        cell_idx: usize,
643        ef: usize,
644        scan_threshold: usize,
645    ) -> Vec<(u32, u32)> {
646        let pts = &self.tree.cells[cell_idx].point_ids;
647        let mut heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
648
649        if pts.len() <= scan_threshold {
650            let base = pts[0];
651            let rerank_budget = self.config.binary_rerank * ef;
652
653            if self.config.binary_rerank > 0 && pts.len() > rerank_budget {
654                let mut binary_heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
655                for i in 0..pts.len() {
656                    let p = base + i as u32;
657                    let hd = distance::hamming(q_binary, self.binary.code(p));
658                    heap_insert_sq8(&mut binary_heap, hd, p, rerank_budget);
659                }
660                for (_, p) in binary_heap {
661                    let dist = self.cand_dist(query, q_code, p);
662                    heap_insert_sq8(&mut heap, dist, p, ef);
663                }
664            } else {
665                for i in 0..pts.len() {
666                    let p = base + i as u32;
667                    let dist = self.cand_dist(query, q_code, p);
668                    heap_insert_sq8(&mut heap, dist, p, ef);
669                }
670            }
671        } else {
672            // Scale the graph-search budget with cell size, capped at 5x ef.
673            let ef_cell = ef.max((pts.len() / 200).min(ef * 5));
674            let local = self.greedy_search_cell_sq8(query, q_code, cell_idx, ef_cell);
675            for (id, dist) in local {
676                heap_insert_sq8(&mut heap, dist, id, ef);
677            }
678        }
679
680        heap.into_vec()
681    }
682}
683
684#[cfg(test)]
685mod tests {
686    use super::super::construct::{PrismConfig, PrismIndex};
687    use super::super::distance;
688    use super::super::filter::Filter;
689    use super::super::point::PointStore;
690
691    fn build_test_index() -> PrismIndex {
692        let mut store = PointStore::new(2, 1);
693        for i in 0..10 {
694            let x = (i as f32) * 0.1;
695            let attr = if i < 5 { 0 } else { 1 };
696            store.push(&[x, x], &[attr]);
697        }
698        let config = PrismConfig {
699            m_local: 4,
700            m_greedy: 2,
701            m_random: 4,
702            t: 1,
703            alpha: 0.0,
704            beam_width: 10,
705            ..Default::default()
706        };
707        PrismIndex::build(store, config)
708    }
709
710    #[test]
711    fn test_search_no_filter() {
712        let index = build_test_index();
713        let results = index.search(&[0.25, 0.25], &Filter::none(), 3, 10);
714        assert_eq!(results.len(), 3);
715        for r in &results {
716            assert!(r.dist >= 0.0);
717        }
718    }
719
720    #[test]
721    fn test_search_with_filter() {
722        let index = build_test_index();
723        let filter = Filter::eq(0, 1);
724        let results = index.search(&[0.5, 0.5], &filter, 3, 10);
725        assert!(!results.is_empty());
726        for r in &results {
727            assert!(filter.matches(&index.store, r.id));
728        }
729    }
730
731    #[test]
732    fn test_graph_search_mid_selectivity() {
733        let dim = 16;
734        let n = 2000;
735        let n_vals = 20;
736        let mut store = PointStore::new(dim, 1);
737        for i in 0..n {
738            let vec: Vec<f32> = (0..dim).map(|d| ((i * dim + d) as f32).sin()).collect();
739            store.push(&vec, &[(i % n_vals) as u32]);
740        }
741        let config = PrismConfig {
742            m_local: 4,
743            m_greedy: 2,
744            m_random: 4,
745            t: 1,
746            beam_width: 10,
747            ..Default::default()
748        };
749        let index = PrismIndex::build(store, config);
750
751        let query: Vec<f32> = (0..dim).map(|d| (d as f32 * 0.3).sin()).collect();
752        let filter = Filter::eq(0, 0);
753        let k = 5;
754        let ef = 10;
755
756        let results = index.search(&query, &filter, k, ef);
757        assert!(!results.is_empty());
758        assert!(results.len() <= k);
759        for r in &results {
760            assert!(filter.matches(&index.store, r.id));
761        }
762        for w in results.windows(2) {
763            assert!(w[0].dist <= w[1].dist);
764        }
765    }
766
767    #[test]
768    fn test_search_empty_filter() {
769        let index = build_test_index();
770        let filter = Filter::eq(0, 99);
771        let results = index.search(&[0.0, 0.0], &filter, 3, 10);
772        assert!(results.is_empty());
773    }
774
775    #[test]
776    fn test_regime_mid_bridge_routing() {
777        // 20 attribute values x 100 points/value = 2000 points; sigma_high=0.10,
778        // each value is 5% selectivity, so single-value filters route to MID.
779        let dim = 16;
780        let n = 2000;
781        let n_vals = 20;
782        let mut store = PointStore::new(dim, 1);
783        for i in 0..n {
784            let vec: Vec<f32> = (0..dim).map(|d| ((i * dim + d) as f32).sin()).collect();
785            store.push(&vec, &[(i % n_vals) as u32]);
786        }
787        let config = PrismConfig {
788            m_local: 4,
789            m_greedy: 4,
790            m_random: 4,
791            t: 1,
792            beam_width: 20,
793            sigma_high: 0.10,
794            sigma_low: 0.001,
795            beta: 3.0,
796            epsilon: 0.2,
797            ..Default::default()
798        };
799        let index = PrismIndex::build(store, config);
800
801        // Value 0 matches 100 of 2000 points = 5% selectivity = MID regime.
802        let query: Vec<f32> = (0..dim).map(|d| (d as f32 * 0.3).sin()).collect();
803        let filter = Filter::eq(0, 0);
804        let k = 5;
805        let ef = 50;
806
807        let results = index.search(&query, &filter, k, ef);
808        assert!(!results.is_empty());
809        assert!(results.len() <= k);
810        for r in &results {
811            assert!(filter.matches(&index.store, r.id));
812        }
813        for w in results.windows(2) {
814            assert!(w[0].dist <= w[1].dist);
815        }
816    }
817
818    #[test]
819    fn test_batch_search_mixed_regimes() {
820        let dim = 8;
821        let n = 1000;
822        let n_vals = 10;
823        let mut store = PointStore::new(dim, 1);
824        for i in 0..n {
825            let vec: Vec<f32> = (0..dim).map(|d| ((i * dim + d) as f32).sin()).collect();
826            store.push(&vec, &[(i % n_vals) as u32]);
827        }
828        let config = PrismConfig {
829            m_local: 4,
830            m_greedy: 4,
831            m_random: 4,
832            t: 1,
833            beam_width: 20,
834            sigma_high: 0.10,
835            sigma_low: 0.001,
836            ..Default::default()
837        };
838        let index = PrismIndex::build(store, config);
839
840        let k = 3;
841        let ef = 20;
842        let nq = 3;
843
844        // Query 0: unfiltered (sigma=1.0, HIGH); queries 1 and 2: single-value
845        // filters (10% selectivity, MID).
846        let queries: Vec<f32> = (0..nq)
847            .flat_map(|qi| (0..dim).map(move |d| ((qi * dim + d) as f32 * 0.5).sin()))
848            .collect();
849        let filters = vec![Filter::none(), Filter::eq(0, 0), Filter::eq(0, 5)];
850
851        let results = index.batch_search(&queries, &filters, nq, k, ef);
852        assert_eq!(results.len(), nq);
853        for (qi, res) in results.iter().enumerate() {
854            assert!(!res.is_empty(), "query {} returned no results", qi);
855            assert!(res.len() <= k);
856            for r in res {
857                assert!(filters[qi].matches(&index.store, r.id));
858            }
859        }
860    }
861
862    #[test]
863    fn inner_product_candidates_survive_l2_blind_spot() {
864        // 59 decoys hug the query in L2 with tiny dot products; one high-norm
865        // point is the true IP winner but the L2-farthest point in the set. An
866        // SQ8-L2 candidate heap (ef < n) would evict it before the rerank.
867        let mut store = PointStore::new(2, 1);
868        for i in 0..59 {
869            let j = (i as f32) * 0.001;
870            store.push(&[0.5 + j, j], &[0]);
871        }
872        store.push(&[20.0, 0.0], &[0]);
873        let config = PrismConfig {
874            m_local: 4,
875            m_greedy: 2,
876            m_random: 4,
877            t: 1,
878            beam_width: 10,
879            metric: distance::Metric::InnerProduct,
880            binary_rerank: 0,
881            ..Default::default()
882        };
883        let index = PrismIndex::build(store, config);
884
885        let results = index.search(&[1.0, 0.0], &Filter::none(), 1, 8);
886        assert_eq!(results[0].id, 59, "true IP winner must reach the rerank");
887        assert!((results[0].dist - (-20.0)).abs() < 1e-3);
888    }
889
890    #[test]
891    fn cosine_candidates_survive_unnormalized_inputs() {
892        // Unnormalized data: the best-angle point has a huge norm (L2-farthest
893        // from the raw query) and would be evicted from a raw SQ8-L2 candidate
894        // heap; build-time normalization keeps code distances angle-faithful.
895        let mut store = PointStore::new(2, 1);
896        for i in 0..59 {
897            let j = (i as f32) * 0.001;
898            store.push(&[j, 1.0 + j], &[0]);
899        }
900        store.push(&[50.0, 1.0], &[0]);
901        let config = PrismConfig {
902            m_local: 4,
903            m_greedy: 2,
904            m_random: 4,
905            t: 1,
906            beam_width: 10,
907            metric: distance::Metric::Cosine,
908            binary_rerank: 0,
909            ..Default::default()
910        };
911        let index = PrismIndex::build(store, config);
912
913        let results = index.search(&[3.0, 0.0], &Filter::none(), 1, 8);
914        assert_eq!(results[0].id, 59, "best-angle point must reach the rerank");
915        assert!(
916            results[0].dist < 0.01,
917            "dist {} is not ~1-cos",
918            results[0].dist
919        );
920    }
921
922    #[test]
923    fn test_binary_prefilter_recall() {
924        // The binary pre-filter is an approximation; results stay valid and
925        // ordered, just not necessarily identical to the pure SQ8 path.
926        let dim = 64;
927        let n = 2000;
928        let n_vals = 10;
929        let mut store = PointStore::new(dim, 1);
930        for i in 0..n {
931            let vec: Vec<f32> = (0..dim)
932                .map(|d| ((i * dim + d) as f32 * 0.01).sin())
933                .collect();
934            store.push(&vec, &[(i % n_vals) as u32]);
935        }
936
937        let config_binary = PrismConfig {
938            m_local: 4,
939            m_greedy: 2,
940            m_random: 4,
941            t: 1,
942            beam_width: 10,
943            binary_rerank: 4,
944            ..Default::default()
945        };
946        let index_binary = PrismIndex::build(store, config_binary);
947
948        let query: Vec<f32> = (0..dim).map(|d| (d as f32 * 0.3).sin()).collect();
949        let filter = Filter::eq(0, 0);
950        let k = 10;
951        let ef = 50;
952
953        let results_binary = index_binary.search(&query, &filter, k, ef);
954        assert!(!results_binary.is_empty());
955        assert!(results_binary.len() <= k);
956        for r in &results_binary {
957            assert!(filter.matches(&index_binary.store, r.id));
958        }
959        for w in results_binary.windows(2) {
960            assert!(w[0].dist <= w[1].dist);
961        }
962    }
963
964    #[test]
965    fn test_binary_prefilter_batch() {
966        let dim = 32;
967        let n = 500;
968        let n_vals = 5;
969        let mut store = PointStore::new(dim, 1);
970        for i in 0..n {
971            let vec: Vec<f32> = (0..dim)
972                .map(|d| ((i * dim + d) as f32 * 0.02).sin())
973                .collect();
974            store.push(&vec, &[(i % n_vals) as u32]);
975        }
976
977        let config = PrismConfig {
978            m_local: 4,
979            m_greedy: 2,
980            m_random: 4,
981            t: 1,
982            beam_width: 10,
983            binary_rerank: 4,
984            ..Default::default()
985        };
986        let index = PrismIndex::build(store, config);
987
988        let nq = 5;
989        let k = 5;
990        let ef = 20;
991        let queries: Vec<f32> = (0..nq)
992            .flat_map(|qi| (0..dim).map(move |d| ((qi * dim + d) as f32 * 0.1).sin()))
993            .collect();
994        let filters: Vec<Filter> = (0..nq)
995            .map(|qi| Filter::eq(0, (qi % n_vals) as u32))
996            .collect();
997
998        let results = index.batch_search(&queries, &filters, nq, k, ef);
999        assert_eq!(results.len(), nq);
1000        for (qi, res) in results.iter().enumerate() {
1001            assert!(!res.is_empty(), "query {} returned no results", qi);
1002            assert!(res.len() <= k);
1003            for r in res {
1004                assert!(filters[qi].matches(&index.store, r.id));
1005            }
1006        }
1007    }
1008}