1use super::construct::PrismIndex;
2use super::distance;
3use super::filter::Filter;
4
5use rayon::prelude::*;
6use std::cmp::Reverse;
7use std::collections::BinaryHeap;
8
9#[derive(Clone, Debug)]
11pub struct SearchResult {
12 pub id: u32,
13 pub dist: f32,
14}
15
16struct Bitset {
18 bits: Vec<u64>,
19}
20
21impl Bitset {
22 fn new(n: usize) -> Self {
23 Self {
24 bits: vec![0u64; n.div_ceil(64)],
25 }
26 }
27
28 #[inline]
30 fn insert(&mut self, i: u32) -> bool {
31 let word = i as usize >> 6;
32 let bit = 1u64 << (i & 63);
33 if self.bits[word] & bit != 0 {
34 false
35 } else {
36 self.bits[word] |= bit;
37 true
38 }
39 }
40
41 #[inline]
43 fn contains(&self, i: u32) -> bool {
44 let word = i as usize >> 6;
45 let bit = 1u64 << (i & 63);
46 self.bits[word] & bit != 0
47 }
48}
49
50#[cfg(target_arch = "x86_64")]
52#[target_feature(enable = "sse")]
53#[inline]
54unsafe fn prefetch_t0(ptr: *const u8) {
55 std::arch::x86_64::_mm_prefetch(ptr as *const i8, std::arch::x86_64::_MM_HINT_T0);
56}
57
58#[inline(always)]
60fn prefetch_read(ptr: *const u8) {
61 #[cfg(target_arch = "x86_64")]
62 unsafe {
63 prefetch_t0(ptr);
64 }
65 #[cfg(not(target_arch = "x86_64"))]
66 let _ = ptr;
67}
68
69#[inline(always)]
71fn prefetch_range(ptr: *const u8, len: usize) {
72 let mut offset = 0;
73 while offset < len {
74 prefetch_read(unsafe { ptr.add(offset) });
75 offset += 64;
76 }
77}
78
79#[derive(Clone, Copy, PartialEq)]
81struct OrdF32(f32);
82
83impl Eq for OrdF32 {}
84impl PartialOrd for OrdF32 {
85 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
86 Some(self.cmp(other))
87 }
88}
89impl Ord for OrdF32 {
90 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
91 self.0
92 .partial_cmp(&other.0)
93 .unwrap_or(std::cmp::Ordering::Equal)
94 }
95}
96
97#[inline]
99fn heap_insert_sq8(heap: &mut BinaryHeap<(u32, u32)>, dist: u32, id: u32, cap: usize) {
100 if heap.len() < cap {
101 heap.push((dist, id));
102 } else if let Some(&(worst, _)) = heap.peek() {
103 if dist < worst {
104 heap.pop();
105 heap.push((dist, id));
106 }
107 }
108}
109
110impl PrismIndex {
111 pub fn search(&self, query: &[f32], filter: &Filter, k: usize, ef: usize) -> Vec<SearchResult> {
113 assert_eq!(query.len(), self.store.dim);
114
115 let normalized;
119 let query = if self.config.metric == distance::Metric::Cosine {
120 normalized = distance::normalized(query);
121 normalized.as_slice()
122 } else {
123 query
124 };
125
126 let cell_indices = self.tree.filter_cells(filter.constraints());
127 let n_f = self.tree.count_points(&cell_indices);
128 let sigma = n_f as f32 / self.store.len as f32;
129 if sigma >= self.config.sigma_high {
130 self.regime_high_filtered(query, &cell_indices, k, ef)
131 } else if sigma > self.config.sigma_low {
132 self.regime_mid(query, &cell_indices, k, ef)
133 } else {
134 self.regime_low(query, filter, &cell_indices, k)
135 }
136 }
137
138 #[inline]
143 fn cand_dist(&self, query: &[f32], q_code: &[u8], p: u32) -> u32 {
144 match self.config.metric {
145 distance::Metric::L2 | distance::Metric::Cosine => {
146 distance::l2_sq8(q_code, self.sq8.code(p))
147 }
148 distance::Metric::InnerProduct => distance::ord_key(distance::distance(
149 query,
150 self.store.vector(p),
151 distance::Metric::InnerProduct,
152 )),
153 }
154 }
155
156 fn regime_high_filtered(
159 &self,
160 query: &[f32],
161 cell_indices: &[usize],
162 k: usize,
163 ef: usize,
164 ) -> Vec<SearchResult> {
165 if cell_indices.is_empty() {
166 return Vec::new();
167 }
168
169 let q_code = self.sq8.quantize_query(query);
170 let q_binary = if self.config.binary_rerank > 0 {
171 self.binary.encode_query(query)
172 } else {
173 Vec::new()
174 };
175 let mut merged: BinaryHeap<(u32, u32)> = BinaryHeap::new();
176
177 if cell_indices.len() == self.tree.cells.len() {
178 let n = self.store.len as u32;
180 let rerank_budget = self.config.binary_rerank * ef;
181 if self.config.binary_rerank > 0 && (n as usize) > rerank_budget {
182 let mut binary_heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
183 for p in 0..n {
184 let hd = distance::hamming(&q_binary, self.binary.code(p));
185 heap_insert_sq8(&mut binary_heap, hd, p, rerank_budget);
186 }
187 for (_, p) in binary_heap {
188 let dist = self.cand_dist(query, &q_code, p);
189 heap_insert_sq8(&mut merged, dist, p, ef);
190 }
191 } else {
192 for p in 0..n {
193 let dist = self.cand_dist(query, &q_code, p);
194 heap_insert_sq8(&mut merged, dist, p, ef);
195 }
196 }
197 } else {
198 let mut ranked: Vec<(usize, u32)> = cell_indices
200 .iter()
201 .map(|&ci| {
202 let d = self.cand_dist(query, &q_code, self.medoids[ci]);
203 (ci, d)
204 })
205 .collect();
206 ranked.sort_unstable_by_key(|&(_, d)| d);
207
208 let scan_threshold = (ef * self.config.m_local).max(2000);
209
210 for &(ci, _) in &ranked {
211 let cands = self.search_cell(query, &q_code, &q_binary, ci, ef, scan_threshold);
212 for (cand_dist, id) in cands {
213 heap_insert_sq8(&mut merged, cand_dist, id, ef);
214 }
215 }
216 }
217
218 let mut results: Vec<SearchResult> = merged
219 .into_iter()
220 .map(|(_, id)| SearchResult {
221 id,
222 dist: distance::distance(query, self.store.vector(id), self.config.metric),
223 })
224 .collect();
225 results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
226 results.truncate(k);
227 results
228 }
229
230 fn regime_mid(
233 &self,
234 query: &[f32],
235 compatible_cells: &[usize],
236 k: usize,
237 ef: usize,
238 ) -> Vec<SearchResult> {
239 if compatible_cells.is_empty() {
240 return Vec::new();
241 }
242
243 let q_code = self.sq8.quantize_query(query);
244
245 let n_cells = self.tree.cells.len();
246 let mut cell_match = vec![false; n_cells];
247 for &ci in compatible_cells {
248 cell_match[ci] = true;
249 }
250
251 let (_, entry) = compatible_cells
252 .iter()
253 .map(|&ci| {
254 let d = self.cand_dist(query, &q_code, self.medoids[ci]);
255 (d, self.medoids[ci])
256 })
257 .min_by_key(|&(d, _)| d)
258 .unwrap();
259
260 let entry_dist = self.cand_dist(query, &q_code, entry);
261
262 let mut visited = Bitset::new(self.store.len);
263 visited.insert(entry);
264
265 let mut candidates: BinaryHeap<Reverse<(u32, u32)>> = BinaryHeap::new();
266 let mut results: BinaryHeap<(u32, u32)> = BinaryHeap::new();
267
268 candidates.push(Reverse((entry_dist, entry)));
269 results.push((entry_dist, entry));
270
271 let bridge_budget = (self.config.beta * ef as f32) as usize;
272 let mut bridges_used = 0usize;
273 let epsilon_factor = ((1.0 + self.config.epsilon) * (1.0 + self.config.epsilon)) as f64;
274
275 let n_f: usize = compatible_cells
277 .iter()
278 .map(|&ci| self.tree.cells[ci].point_ids.len())
279 .sum();
280 let sigma = n_f as f32 / self.store.len as f32;
281 let tau = sigma / (1.0 + sigma);
282
283 while let Some(Reverse((d, c))) = candidates.pop() {
284 if results.len() >= ef {
285 if let Some(&(worst, _)) = results.peek() {
286 if (d as f64) > (worst as f64) * epsilon_factor {
287 break;
288 }
289 }
290 }
291
292 if bridges_used >= bridge_budget {
293 break;
294 }
295
296 let neighbors = self.graph.neighbors(c);
297 let sq8_dim = self.store.dim;
298
299 let mut unvisited_buf: Vec<u32> = Vec::with_capacity(neighbors.len());
300 for &w in neighbors {
301 if visited.insert(w) {
302 unvisited_buf.push(w);
303 prefetch_range(self.sq8.code(w).as_ptr(), sq8_dim);
304 }
305 }
306
307 for &w in &unvisited_buf {
308 let wd = self.cand_dist(query, &q_code, w);
309 let w_cell = self.point_cell[w as usize];
310
311 if cell_match[w_cell as usize] {
312 heap_insert_sq8(&mut results, wd, w, ef);
313 candidates.push(Reverse((wd, w)));
314 } else {
315 let w_neighbors = self.graph.neighbors(w);
316 if !w_neighbors.is_empty() {
317 let matching_unvisited = w_neighbors
318 .iter()
319 .filter(|&&u| {
320 cell_match[self.point_cell[u as usize] as usize]
321 && !visited.contains(u)
322 })
323 .count();
324 let fraction = matching_unvisited as f32 / w_neighbors.len() as f32;
325
326 let r = results.peek().map_or(1.0f32, |&(worst, _)| worst as f32);
328 let bridge_score = fraction / (1.0 + wd as f32 / r.max(1.0));
329
330 if bridge_score > tau {
331 candidates.push(Reverse((wd, w)));
332 bridges_used += 1;
333 }
334 }
335 }
336 }
337 }
338
339 let mut final_results: Vec<SearchResult> = results
340 .into_iter()
341 .map(|(_, id)| SearchResult {
342 id,
343 dist: distance::distance(query, self.store.vector(id), self.config.metric),
344 })
345 .collect();
346 final_results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
347 final_results.truncate(k);
348 final_results
349 }
350
351 fn greedy_search_cell_sq8(
353 &self,
354 query: &[f32],
355 q_code: &[u8],
356 cell_idx: usize,
357 ef: usize,
358 ) -> Vec<(u32, u32)> {
359 let pts = &self.tree.cells[cell_idx].point_ids;
360 let base = pts[0];
361 let sq8_dim = self.store.dim;
362
363 let entry = self.medoids[cell_idx];
364 let entry_dist = self.cand_dist(query, q_code, entry);
365
366 let mut visited = Bitset::new(pts.len());
367 visited.insert(entry - base);
368
369 let mut candidates: BinaryHeap<Reverse<(u32, u32)>> = BinaryHeap::new();
370 let mut results: BinaryHeap<(u32, u32)> = BinaryHeap::new();
371 let mut unvisited: Vec<u32> = Vec::with_capacity(32);
372
373 candidates.push(Reverse((entry_dist, entry)));
374 results.push((entry_dist, entry));
375
376 while let Some(Reverse((d, c))) = candidates.pop() {
377 if results.len() >= ef {
378 if let Some(&(worst, _)) = results.peek() {
379 if d > worst {
380 break;
381 }
382 }
383 }
384
385 unvisited.clear();
386 for &w in self.local_graph.neighbors(c) {
387 if visited.insert(w - base) {
388 unvisited.push(w);
389 prefetch_range(self.sq8.code(w).as_ptr(), sq8_dim);
390 }
391 }
392
393 for &w in &unvisited {
394 let wd = self.cand_dist(query, q_code, w);
395 if results.len() < ef {
396 candidates.push(Reverse((wd, w)));
397 results.push((wd, w));
398 } else if let Some(&(worst, _)) = results.peek() {
399 if wd < worst {
400 results.pop();
401 results.push((wd, w));
402 candidates.push(Reverse((wd, w)));
403 }
404 }
405 }
406 }
407
408 results
409 .into_vec()
410 .into_iter()
411 .map(|(d, id)| (id, d))
412 .collect()
413 }
414
415 fn regime_low(
417 &self,
418 query: &[f32],
419 filter: &Filter,
420 cell_indices: &[usize],
421 k: usize,
422 ) -> Vec<SearchResult> {
423 let mut heap: BinaryHeap<(OrdF32, u32)> = BinaryHeap::new();
424 for &ci in cell_indices {
425 for &p in &self.tree.cells[ci].point_ids {
426 if filter.matches(&self.store, p) {
427 let dist = distance::distance(query, self.store.vector(p), self.config.metric);
428 if heap.len() < k {
429 heap.push((OrdF32(dist), p));
430 } else if let Some(&(OrdF32(worst), _)) = heap.peek() {
431 if dist < worst {
432 heap.pop();
433 heap.push((OrdF32(dist), p));
434 }
435 }
436 }
437 }
438 }
439 let mut results: Vec<SearchResult> = heap
440 .into_iter()
441 .map(|(OrdF32(d), id)| SearchResult { id, dist: d })
442 .collect();
443 results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
444 results
445 }
446
447 pub fn batch_search(
451 &self,
452 queries: &[f32],
453 filters: &[Filter],
454 nq: usize,
455 k: usize,
456 ef: usize,
457 ) -> Vec<Vec<SearchResult>> {
458 let dim = self.store.dim;
459 let n_cells = self.tree.cells.len();
460 let scan_threshold = (ef * self.config.m_local).max(2000);
461
462 let normalized;
464 let queries = if self.config.metric == distance::Metric::Cosine {
465 let mut buf = queries.to_vec();
466 distance::normalize_rows(&mut buf, dim);
467 normalized = buf;
468 normalized.as_slice()
469 } else {
470 queries
471 };
472
473 let query_info: Vec<(Vec<u8>, Vec<u64>, Vec<usize>)> = (0..nq)
474 .into_par_iter()
475 .map(|qi| {
476 let q = &queries[qi * dim..(qi + 1) * dim];
477 let q_code = self.sq8.quantize_query(q);
478 let q_binary = if self.config.binary_rerank > 0 {
479 self.binary.encode_query(q)
480 } else {
481 Vec::new()
482 };
483 let cells = self.tree.filter_cells(filters[qi].constraints());
484 (q_code, q_binary, cells)
485 })
486 .collect();
487
488 let mut high_regime: Vec<usize> = Vec::with_capacity(nq);
489 let mut mid_regime: Vec<usize> = Vec::new();
490 let mut low_regime: Vec<usize> = Vec::new();
491 let mut unfiltered: Vec<usize> = Vec::new();
492 for (qi, info) in query_info.iter().enumerate() {
493 let cells = &info.2;
494 if cells.len() >= n_cells {
495 unfiltered.push(qi);
496 } else {
497 let n_f: usize = cells
498 .iter()
499 .map(|&ci| self.tree.cells[ci].point_ids.len())
500 .sum();
501 let sigma = n_f as f32 / self.store.len as f32;
502 if sigma >= self.config.sigma_high {
503 high_regime.push(qi);
504 } else if sigma > self.config.sigma_low {
505 mid_regime.push(qi);
506 } else {
507 low_regime.push(qi);
508 }
509 }
510 }
511
512 let mut cell_queries: Vec<Vec<usize>> = vec![Vec::new(); n_cells];
513 for &qi in &high_regime {
514 for &ci in &query_info[qi].2 {
515 cell_queries[ci].push(qi);
516 }
517 }
518
519 #[allow(clippy::type_complexity)]
522 let cell_results: Vec<Vec<(usize, Vec<(u32, u32)>)>> = cell_queries
523 .into_par_iter()
524 .enumerate()
525 .filter(|(_, qs)| !qs.is_empty())
526 .map(|(ci, qs)| {
527 qs.iter()
528 .map(|&qi| {
529 let q = &queries[qi * dim..(qi + 1) * dim];
530 let q_code = &query_info[qi].0;
531 let q_binary = &query_info[qi].1;
532 let cands = self.search_cell(q, q_code, q_binary, ci, ef, scan_threshold);
533 (qi, cands)
534 })
535 .collect()
536 })
537 .collect();
538
539 let mut query_heaps: Vec<BinaryHeap<(u32, u32)>> =
540 (0..nq).map(|_| BinaryHeap::new()).collect();
541 for cell_batch in cell_results {
542 for (qi, cands) in cell_batch {
543 for (sq8_dist, id) in cands {
544 heap_insert_sq8(&mut query_heaps[qi], sq8_dist, id, ef);
545 }
546 }
547 }
548
549 let unfilt_heaps: Vec<(usize, BinaryHeap<(u32, u32)>)> = unfiltered
550 .par_iter()
551 .map(|&qi| {
552 let q = &queries[qi * dim..(qi + 1) * dim];
553 let q_code = &query_info[qi].0;
554 let q_binary = &query_info[qi].1;
555 let n = self.store.len as u32;
556 let rerank_budget = self.config.binary_rerank * ef;
557 let mut heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
558 if self.config.binary_rerank > 0 && (n as usize) > rerank_budget {
559 let mut binary_heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
560 for p in 0..n {
561 let hd = distance::hamming(q_binary, self.binary.code(p));
562 heap_insert_sq8(&mut binary_heap, hd, p, rerank_budget);
563 }
564 for (_, p) in binary_heap {
565 let dist = self.cand_dist(q, q_code, p);
566 heap_insert_sq8(&mut heap, dist, p, ef);
567 }
568 } else {
569 for p in 0..n {
570 let dist = self.cand_dist(q, q_code, p);
571 heap_insert_sq8(&mut heap, dist, p, ef);
572 }
573 }
574 (qi, heap)
575 })
576 .collect();
577 for (qi, heap) in unfilt_heaps {
578 query_heaps[qi] = heap;
579 }
580
581 let mut all_results: Vec<Vec<SearchResult>> = query_heaps
582 .into_par_iter()
583 .enumerate()
584 .map(|(qi, heap)| {
585 if heap.is_empty() {
586 return Vec::new();
587 }
588 let q = &queries[qi * dim..(qi + 1) * dim];
589 let mut results: Vec<SearchResult> = heap
590 .into_iter()
591 .map(|(_, id)| SearchResult {
592 id,
593 dist: distance::distance(q, self.store.vector(id), self.config.metric),
594 })
595 .collect();
596 results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
597 results.truncate(k);
598 results
599 })
600 .collect();
601
602 if !mid_regime.is_empty() {
603 let mid_results: Vec<(usize, Vec<SearchResult>)> = mid_regime
604 .par_iter()
605 .map(|&qi| {
606 let q = &queries[qi * dim..(qi + 1) * dim];
607 let cells = &query_info[qi].2;
608 let results = self.regime_mid(q, cells, k, ef);
609 (qi, results)
610 })
611 .collect();
612 for (qi, results) in mid_results {
613 all_results[qi] = results;
614 }
615 }
616
617 if !low_regime.is_empty() {
618 let low_results: Vec<(usize, Vec<SearchResult>)> = low_regime
619 .par_iter()
620 .map(|&qi| {
621 let q = &queries[qi * dim..(qi + 1) * dim];
622 let results = self.search(q, &filters[qi], k, ef);
623 (qi, results)
624 })
625 .collect();
626 for (qi, results) in low_results {
627 all_results[qi] = results;
628 }
629 }
630
631 all_results
632 }
633
634 fn search_cell(
638 &self,
639 query: &[f32],
640 q_code: &[u8],
641 q_binary: &[u64],
642 cell_idx: usize,
643 ef: usize,
644 scan_threshold: usize,
645 ) -> Vec<(u32, u32)> {
646 let pts = &self.tree.cells[cell_idx].point_ids;
647 let mut heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
648
649 if pts.len() <= scan_threshold {
650 let base = pts[0];
651 let rerank_budget = self.config.binary_rerank * ef;
652
653 if self.config.binary_rerank > 0 && pts.len() > rerank_budget {
654 let mut binary_heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
655 for i in 0..pts.len() {
656 let p = base + i as u32;
657 let hd = distance::hamming(q_binary, self.binary.code(p));
658 heap_insert_sq8(&mut binary_heap, hd, p, rerank_budget);
659 }
660 for (_, p) in binary_heap {
661 let dist = self.cand_dist(query, q_code, p);
662 heap_insert_sq8(&mut heap, dist, p, ef);
663 }
664 } else {
665 for i in 0..pts.len() {
666 let p = base + i as u32;
667 let dist = self.cand_dist(query, q_code, p);
668 heap_insert_sq8(&mut heap, dist, p, ef);
669 }
670 }
671 } else {
672 let ef_cell = ef.max((pts.len() / 200).min(ef * 5));
674 let local = self.greedy_search_cell_sq8(query, q_code, cell_idx, ef_cell);
675 for (id, dist) in local {
676 heap_insert_sq8(&mut heap, dist, id, ef);
677 }
678 }
679
680 heap.into_vec()
681 }
682}
683
684#[cfg(test)]
685mod tests {
686 use super::super::construct::{PrismConfig, PrismIndex};
687 use super::super::distance;
688 use super::super::filter::Filter;
689 use super::super::point::PointStore;
690
691 fn build_test_index() -> PrismIndex {
692 let mut store = PointStore::new(2, 1);
693 for i in 0..10 {
694 let x = (i as f32) * 0.1;
695 let attr = if i < 5 { 0 } else { 1 };
696 store.push(&[x, x], &[attr]);
697 }
698 let config = PrismConfig {
699 m_local: 4,
700 m_greedy: 2,
701 m_random: 4,
702 t: 1,
703 alpha: 0.0,
704 beam_width: 10,
705 ..Default::default()
706 };
707 PrismIndex::build(store, config)
708 }
709
710 #[test]
711 fn test_search_no_filter() {
712 let index = build_test_index();
713 let results = index.search(&[0.25, 0.25], &Filter::none(), 3, 10);
714 assert_eq!(results.len(), 3);
715 for r in &results {
716 assert!(r.dist >= 0.0);
717 }
718 }
719
720 #[test]
721 fn test_search_with_filter() {
722 let index = build_test_index();
723 let filter = Filter::eq(0, 1);
724 let results = index.search(&[0.5, 0.5], &filter, 3, 10);
725 assert!(!results.is_empty());
726 for r in &results {
727 assert!(filter.matches(&index.store, r.id));
728 }
729 }
730
731 #[test]
732 fn test_graph_search_mid_selectivity() {
733 let dim = 16;
734 let n = 2000;
735 let n_vals = 20;
736 let mut store = PointStore::new(dim, 1);
737 for i in 0..n {
738 let vec: Vec<f32> = (0..dim).map(|d| ((i * dim + d) as f32).sin()).collect();
739 store.push(&vec, &[(i % n_vals) as u32]);
740 }
741 let config = PrismConfig {
742 m_local: 4,
743 m_greedy: 2,
744 m_random: 4,
745 t: 1,
746 beam_width: 10,
747 ..Default::default()
748 };
749 let index = PrismIndex::build(store, config);
750
751 let query: Vec<f32> = (0..dim).map(|d| (d as f32 * 0.3).sin()).collect();
752 let filter = Filter::eq(0, 0);
753 let k = 5;
754 let ef = 10;
755
756 let results = index.search(&query, &filter, k, ef);
757 assert!(!results.is_empty());
758 assert!(results.len() <= k);
759 for r in &results {
760 assert!(filter.matches(&index.store, r.id));
761 }
762 for w in results.windows(2) {
763 assert!(w[0].dist <= w[1].dist);
764 }
765 }
766
767 #[test]
768 fn test_search_empty_filter() {
769 let index = build_test_index();
770 let filter = Filter::eq(0, 99);
771 let results = index.search(&[0.0, 0.0], &filter, 3, 10);
772 assert!(results.is_empty());
773 }
774
775 #[test]
776 fn test_regime_mid_bridge_routing() {
777 let dim = 16;
780 let n = 2000;
781 let n_vals = 20;
782 let mut store = PointStore::new(dim, 1);
783 for i in 0..n {
784 let vec: Vec<f32> = (0..dim).map(|d| ((i * dim + d) as f32).sin()).collect();
785 store.push(&vec, &[(i % n_vals) as u32]);
786 }
787 let config = PrismConfig {
788 m_local: 4,
789 m_greedy: 4,
790 m_random: 4,
791 t: 1,
792 beam_width: 20,
793 sigma_high: 0.10,
794 sigma_low: 0.001,
795 beta: 3.0,
796 epsilon: 0.2,
797 ..Default::default()
798 };
799 let index = PrismIndex::build(store, config);
800
801 let query: Vec<f32> = (0..dim).map(|d| (d as f32 * 0.3).sin()).collect();
803 let filter = Filter::eq(0, 0);
804 let k = 5;
805 let ef = 50;
806
807 let results = index.search(&query, &filter, k, ef);
808 assert!(!results.is_empty());
809 assert!(results.len() <= k);
810 for r in &results {
811 assert!(filter.matches(&index.store, r.id));
812 }
813 for w in results.windows(2) {
814 assert!(w[0].dist <= w[1].dist);
815 }
816 }
817
818 #[test]
819 fn test_batch_search_mixed_regimes() {
820 let dim = 8;
821 let n = 1000;
822 let n_vals = 10;
823 let mut store = PointStore::new(dim, 1);
824 for i in 0..n {
825 let vec: Vec<f32> = (0..dim).map(|d| ((i * dim + d) as f32).sin()).collect();
826 store.push(&vec, &[(i % n_vals) as u32]);
827 }
828 let config = PrismConfig {
829 m_local: 4,
830 m_greedy: 4,
831 m_random: 4,
832 t: 1,
833 beam_width: 20,
834 sigma_high: 0.10,
835 sigma_low: 0.001,
836 ..Default::default()
837 };
838 let index = PrismIndex::build(store, config);
839
840 let k = 3;
841 let ef = 20;
842 let nq = 3;
843
844 let queries: Vec<f32> = (0..nq)
847 .flat_map(|qi| (0..dim).map(move |d| ((qi * dim + d) as f32 * 0.5).sin()))
848 .collect();
849 let filters = vec![Filter::none(), Filter::eq(0, 0), Filter::eq(0, 5)];
850
851 let results = index.batch_search(&queries, &filters, nq, k, ef);
852 assert_eq!(results.len(), nq);
853 for (qi, res) in results.iter().enumerate() {
854 assert!(!res.is_empty(), "query {} returned no results", qi);
855 assert!(res.len() <= k);
856 for r in res {
857 assert!(filters[qi].matches(&index.store, r.id));
858 }
859 }
860 }
861
862 #[test]
863 fn inner_product_candidates_survive_l2_blind_spot() {
864 let mut store = PointStore::new(2, 1);
868 for i in 0..59 {
869 let j = (i as f32) * 0.001;
870 store.push(&[0.5 + j, j], &[0]);
871 }
872 store.push(&[20.0, 0.0], &[0]);
873 let config = PrismConfig {
874 m_local: 4,
875 m_greedy: 2,
876 m_random: 4,
877 t: 1,
878 beam_width: 10,
879 metric: distance::Metric::InnerProduct,
880 binary_rerank: 0,
881 ..Default::default()
882 };
883 let index = PrismIndex::build(store, config);
884
885 let results = index.search(&[1.0, 0.0], &Filter::none(), 1, 8);
886 assert_eq!(results[0].id, 59, "true IP winner must reach the rerank");
887 assert!((results[0].dist - (-20.0)).abs() < 1e-3);
888 }
889
890 #[test]
891 fn cosine_candidates_survive_unnormalized_inputs() {
892 let mut store = PointStore::new(2, 1);
896 for i in 0..59 {
897 let j = (i as f32) * 0.001;
898 store.push(&[j, 1.0 + j], &[0]);
899 }
900 store.push(&[50.0, 1.0], &[0]);
901 let config = PrismConfig {
902 m_local: 4,
903 m_greedy: 2,
904 m_random: 4,
905 t: 1,
906 beam_width: 10,
907 metric: distance::Metric::Cosine,
908 binary_rerank: 0,
909 ..Default::default()
910 };
911 let index = PrismIndex::build(store, config);
912
913 let results = index.search(&[3.0, 0.0], &Filter::none(), 1, 8);
914 assert_eq!(results[0].id, 59, "best-angle point must reach the rerank");
915 assert!(
916 results[0].dist < 0.01,
917 "dist {} is not ~1-cos",
918 results[0].dist
919 );
920 }
921
922 #[test]
923 fn test_binary_prefilter_recall() {
924 let dim = 64;
927 let n = 2000;
928 let n_vals = 10;
929 let mut store = PointStore::new(dim, 1);
930 for i in 0..n {
931 let vec: Vec<f32> = (0..dim)
932 .map(|d| ((i * dim + d) as f32 * 0.01).sin())
933 .collect();
934 store.push(&vec, &[(i % n_vals) as u32]);
935 }
936
937 let config_binary = PrismConfig {
938 m_local: 4,
939 m_greedy: 2,
940 m_random: 4,
941 t: 1,
942 beam_width: 10,
943 binary_rerank: 4,
944 ..Default::default()
945 };
946 let index_binary = PrismIndex::build(store, config_binary);
947
948 let query: Vec<f32> = (0..dim).map(|d| (d as f32 * 0.3).sin()).collect();
949 let filter = Filter::eq(0, 0);
950 let k = 10;
951 let ef = 50;
952
953 let results_binary = index_binary.search(&query, &filter, k, ef);
954 assert!(!results_binary.is_empty());
955 assert!(results_binary.len() <= k);
956 for r in &results_binary {
957 assert!(filter.matches(&index_binary.store, r.id));
958 }
959 for w in results_binary.windows(2) {
960 assert!(w[0].dist <= w[1].dist);
961 }
962 }
963
964 #[test]
965 fn test_binary_prefilter_batch() {
966 let dim = 32;
967 let n = 500;
968 let n_vals = 5;
969 let mut store = PointStore::new(dim, 1);
970 for i in 0..n {
971 let vec: Vec<f32> = (0..dim)
972 .map(|d| ((i * dim + d) as f32 * 0.02).sin())
973 .collect();
974 store.push(&vec, &[(i % n_vals) as u32]);
975 }
976
977 let config = PrismConfig {
978 m_local: 4,
979 m_greedy: 2,
980 m_random: 4,
981 t: 1,
982 beam_width: 10,
983 binary_rerank: 4,
984 ..Default::default()
985 };
986 let index = PrismIndex::build(store, config);
987
988 let nq = 5;
989 let k = 5;
990 let ef = 20;
991 let queries: Vec<f32> = (0..nq)
992 .flat_map(|qi| (0..dim).map(move |d| ((qi * dim + d) as f32 * 0.1).sin()))
993 .collect();
994 let filters: Vec<Filter> = (0..nq)
995 .map(|qi| Filter::eq(0, (qi % n_vals) as u32))
996 .collect();
997
998 let results = index.batch_search(&queries, &filters, nq, k, ef);
999 assert_eq!(results.len(), nq);
1000 for (qi, res) in results.iter().enumerate() {
1001 assert!(!res.is_empty(), "query {} returned no results", qi);
1002 assert!(res.len() <= k);
1003 for r in res {
1004 assert!(filters[qi].matches(&index.store, r.id));
1005 }
1006 }
1007 }
1008}