1use super::construct::PrismIndex;
2use super::distance;
3use super::filter::Filter;
4
5use rayon::prelude::*;
6use std::cmp::Reverse;
7use std::collections::BinaryHeap;
8
9#[derive(Clone, Debug)]
11pub struct SearchResult {
12 pub id: u32,
13 pub dist: f32,
14}
15
16struct Bitset {
18 bits: Vec<u64>,
19}
20
21impl Bitset {
22 fn new(n: usize) -> Self {
23 Self {
24 bits: vec![0u64; n.div_ceil(64)],
25 }
26 }
27
28 #[inline]
30 fn insert(&mut self, i: u32) -> bool {
31 let word = i as usize >> 6;
32 let bit = 1u64 << (i & 63);
33 if self.bits[word] & bit != 0 {
34 false
35 } else {
36 self.bits[word] |= bit;
37 true
38 }
39 }
40
41 #[inline]
43 fn contains(&self, i: u32) -> bool {
44 let word = i as usize >> 6;
45 let bit = 1u64 << (i & 63);
46 self.bits[word] & bit != 0
47 }
48}
49
50#[cfg(target_arch = "x86_64")]
52#[target_feature(enable = "sse")]
53#[inline]
54unsafe fn prefetch_t0(ptr: *const u8) {
55 std::arch::x86_64::_mm_prefetch(ptr as *const i8, std::arch::x86_64::_MM_HINT_T0);
56}
57
58#[inline(always)]
60fn prefetch_read(ptr: *const u8) {
61 #[cfg(target_arch = "x86_64")]
62 unsafe {
63 prefetch_t0(ptr);
64 }
65 #[cfg(not(target_arch = "x86_64"))]
66 let _ = ptr;
67}
68
69#[inline(always)]
71fn prefetch_range(ptr: *const u8, len: usize) {
72 let mut offset = 0;
73 while offset < len {
74 prefetch_read(unsafe { ptr.add(offset) });
75 offset += 64;
76 }
77}
78
79#[derive(Clone, Copy, PartialEq)]
81struct OrdF32(f32);
82
83impl Eq for OrdF32 {}
84impl PartialOrd for OrdF32 {
85 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
86 Some(self.cmp(other))
87 }
88}
89impl Ord for OrdF32 {
90 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
91 self.0
92 .partial_cmp(&other.0)
93 .unwrap_or(std::cmp::Ordering::Equal)
94 }
95}
96
97#[inline]
99fn heap_insert_sq8(heap: &mut BinaryHeap<(u32, u32)>, dist: u32, id: u32, cap: usize) {
100 if heap.len() < cap {
101 heap.push((dist, id));
102 } else if let Some(&(worst, _)) = heap.peek() {
103 if dist < worst {
104 heap.pop();
105 heap.push((dist, id));
106 }
107 }
108}
109
110impl PrismIndex {
111 pub fn search(&self, query: &[f32], filter: &Filter, k: usize, ef: usize) -> Vec<SearchResult> {
113 assert_eq!(query.len(), self.store.dim);
114
115 let cell_indices = self.tree.filter_cells(filter.constraints());
116 let n_f = self.tree.count_points(&cell_indices);
117 let sigma = n_f as f32 / self.store.len as f32;
118 if sigma >= self.config.sigma_high {
119 self.regime_high_filtered(query, &cell_indices, k, ef)
120 } else if sigma > self.config.sigma_low {
121 self.regime_mid(query, &cell_indices, k, ef)
122 } else {
123 self.regime_low(query, filter, &cell_indices, k)
124 }
125 }
126
127 fn regime_high_filtered(
130 &self,
131 query: &[f32],
132 cell_indices: &[usize],
133 k: usize,
134 ef: usize,
135 ) -> Vec<SearchResult> {
136 if cell_indices.is_empty() {
137 return Vec::new();
138 }
139
140 let q_code = self.sq8.quantize_query(query);
141 let q_binary = self.binary.encode_query(query);
142 let mut merged: BinaryHeap<(u32, u32)> = BinaryHeap::new();
143
144 if cell_indices.len() == self.tree.cells.len() {
145 let n = self.store.len as u32;
147 let rerank_budget = self.config.binary_rerank * ef;
148 if self.config.binary_rerank > 0 && (n as usize) > rerank_budget {
149 let mut binary_heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
150 for p in 0..n {
151 let hd = distance::hamming(&q_binary, self.binary.code(p));
152 heap_insert_sq8(&mut binary_heap, hd, p, rerank_budget);
153 }
154 for (_, p) in binary_heap {
155 let dist = distance::l2_sq8(&q_code, self.sq8.code(p));
156 heap_insert_sq8(&mut merged, dist, p, ef);
157 }
158 } else {
159 for p in 0..n {
160 let dist = distance::l2_sq8(&q_code, self.sq8.code(p));
161 heap_insert_sq8(&mut merged, dist, p, ef);
162 }
163 }
164 } else {
165 let mut ranked: Vec<(usize, u32)> = cell_indices
167 .iter()
168 .map(|&ci| {
169 let d = distance::l2_sq8(&q_code, self.sq8.code(self.medoids[ci]));
170 (ci, d)
171 })
172 .collect();
173 ranked.sort_unstable_by_key(|&(_, d)| d);
174
175 let scan_threshold = (ef * self.config.m_local).max(2000);
176
177 for &(ci, _) in &ranked {
178 let cands = self.search_cell(&q_code, &q_binary, ci, ef, scan_threshold);
179 for (sq8_dist, id) in cands {
180 heap_insert_sq8(&mut merged, sq8_dist, id, ef);
181 }
182 }
183 }
184
185 let mut results: Vec<SearchResult> = merged
187 .into_iter()
188 .map(|(_, id)| SearchResult {
189 id,
190 dist: distance::distance(query, self.store.vector(id), self.config.metric),
191 })
192 .collect();
193 results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
194 results.truncate(k);
195 results
196 }
197
198 fn regime_mid(
201 &self,
202 query: &[f32],
203 compatible_cells: &[usize],
204 k: usize,
205 ef: usize,
206 ) -> Vec<SearchResult> {
207 if compatible_cells.is_empty() {
208 return Vec::new();
209 }
210
211 let q_code = self.sq8.quantize_query(query);
212
213 let n_cells = self.tree.cells.len();
215 let mut cell_match = vec![false; n_cells];
216 for &ci in compatible_cells {
217 cell_match[ci] = true;
218 }
219
220 let (_, entry) = compatible_cells
222 .iter()
223 .map(|&ci| {
224 let d = distance::l2_sq8(&q_code, self.sq8.code(self.medoids[ci]));
225 (d, self.medoids[ci])
226 })
227 .min_by_key(|&(d, _)| d)
228 .unwrap();
229
230 let entry_dist = distance::l2_sq8(&q_code, self.sq8.code(entry));
231
232 let mut visited = Bitset::new(self.store.len);
233 visited.insert(entry);
234
235 let mut candidates: BinaryHeap<Reverse<(u32, u32)>> = BinaryHeap::new();
236 let mut results: BinaryHeap<(u32, u32)> = BinaryHeap::new();
237
238 candidates.push(Reverse((entry_dist, entry)));
239 results.push((entry_dist, entry));
240
241 let bridge_budget = (self.config.beta * ef as f32) as usize;
242 let mut bridges_used = 0usize;
243 let epsilon_factor = ((1.0 + self.config.epsilon) * (1.0 + self.config.epsilon)) as f64;
244
245 let n_f: usize = compatible_cells
247 .iter()
248 .map(|&ci| self.tree.cells[ci].point_ids.len())
249 .sum();
250 let sigma = n_f as f32 / self.store.len as f32;
251 let tau = sigma / (1.0 + sigma);
252
253 while let Some(Reverse((d, c))) = candidates.pop() {
254 if results.len() >= ef {
256 if let Some(&(worst, _)) = results.peek() {
257 if (d as f64) > (worst as f64) * epsilon_factor {
258 break;
259 }
260 }
261 }
262
263 if bridges_used >= bridge_budget {
264 break;
265 }
266
267 let neighbors = self.graph.neighbors(c);
269 let sq8_dim = self.store.dim;
270
271 let mut unvisited_buf: Vec<u32> = Vec::with_capacity(neighbors.len());
272 for &w in neighbors {
273 if visited.insert(w) {
274 unvisited_buf.push(w);
275 prefetch_range(self.sq8.code(w).as_ptr(), sq8_dim);
276 }
277 }
278
279 for &w in &unvisited_buf {
280 let wd = distance::l2_sq8(&q_code, self.sq8.code(w));
281 let w_cell = self.point_cell[w as usize];
282
283 if cell_match[w_cell as usize] {
284 heap_insert_sq8(&mut results, wd, w, ef);
286 candidates.push(Reverse((wd, w)));
287 } else {
288 let w_neighbors = self.graph.neighbors(w);
290 if !w_neighbors.is_empty() {
291 let matching_unvisited = w_neighbors
292 .iter()
293 .filter(|&&u| {
294 cell_match[self.point_cell[u as usize] as usize]
295 && !visited.contains(u)
296 })
297 .count();
298 let fraction = matching_unvisited as f32 / w_neighbors.len() as f32;
299
300 let r = results.peek().map_or(1.0f32, |&(worst, _)| worst as f32);
302 let bridge_score = fraction / (1.0 + wd as f32 / r.max(1.0));
303
304 if bridge_score > tau {
305 candidates.push(Reverse((wd, w)));
306 bridges_used += 1;
307 }
308 }
309 }
310 }
311 }
312
313 let mut final_results: Vec<SearchResult> = results
315 .into_iter()
316 .map(|(_, id)| SearchResult {
317 id,
318 dist: distance::distance(query, self.store.vector(id), self.config.metric),
319 })
320 .collect();
321 final_results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
322 final_results.truncate(k);
323 final_results
324 }
325
326 fn greedy_search_cell_sq8(&self, q_code: &[u8], cell_idx: usize, ef: usize) -> Vec<(u32, u32)> {
328 let pts = &self.tree.cells[cell_idx].point_ids;
329 let base = pts[0];
330 let sq8_dim = self.store.dim;
331
332 let entry = self.medoids[cell_idx];
333 let entry_dist = distance::l2_sq8(q_code, self.sq8.code(entry));
334
335 let mut visited = Bitset::new(pts.len());
336 visited.insert(entry - base);
337
338 let mut candidates: BinaryHeap<Reverse<(u32, u32)>> = BinaryHeap::new();
339 let mut results: BinaryHeap<(u32, u32)> = BinaryHeap::new();
340 let mut unvisited: Vec<u32> = Vec::with_capacity(32);
341
342 candidates.push(Reverse((entry_dist, entry)));
343 results.push((entry_dist, entry));
344
345 while let Some(Reverse((d, c))) = candidates.pop() {
346 if results.len() >= ef {
347 if let Some(&(worst, _)) = results.peek() {
348 if d > worst {
349 break;
350 }
351 }
352 }
353
354 unvisited.clear();
355 for &w in self.local_graph.neighbors(c) {
356 if visited.insert(w - base) {
357 unvisited.push(w);
358 prefetch_range(self.sq8.code(w).as_ptr(), sq8_dim);
359 }
360 }
361
362 for &w in &unvisited {
363 let wd = distance::l2_sq8(q_code, self.sq8.code(w));
364 if results.len() < ef {
365 candidates.push(Reverse((wd, w)));
366 results.push((wd, w));
367 } else if let Some(&(worst, _)) = results.peek() {
368 if wd < worst {
369 results.pop();
370 results.push((wd, w));
371 candidates.push(Reverse((wd, w)));
372 }
373 }
374 }
375 }
376
377 results
378 .into_vec()
379 .into_iter()
380 .map(|(d, id)| (id, d))
381 .collect()
382 }
383
384 fn regime_low(
386 &self,
387 query: &[f32],
388 filter: &Filter,
389 cell_indices: &[usize],
390 k: usize,
391 ) -> Vec<SearchResult> {
392 let mut heap: BinaryHeap<(OrdF32, u32)> = BinaryHeap::new();
393 for &ci in cell_indices {
394 for &p in &self.tree.cells[ci].point_ids {
395 if filter.matches(&self.store, p) {
396 let dist = distance::distance(query, self.store.vector(p), self.config.metric);
397 if heap.len() < k {
398 heap.push((OrdF32(dist), p));
399 } else if let Some(&(OrdF32(worst), _)) = heap.peek() {
400 if dist < worst {
401 heap.pop();
402 heap.push((OrdF32(dist), p));
403 }
404 }
405 }
406 }
407 }
408 let mut results: Vec<SearchResult> = heap
409 .into_iter()
410 .map(|(OrdF32(d), id)| SearchResult { id, dist: d })
411 .collect();
412 results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
413 results
414 }
415
416 pub fn batch_search(
420 &self,
421 queries: &[f32],
422 filters: &[Filter],
423 nq: usize,
424 k: usize,
425 ef: usize,
426 ) -> Vec<Vec<SearchResult>> {
427 let dim = self.store.dim;
428 let n_cells = self.tree.cells.len();
429 let scan_threshold = (ef * self.config.m_local).max(2000);
430
431 let query_info: Vec<(Vec<u8>, Vec<u64>, Vec<usize>)> = (0..nq)
433 .into_par_iter()
434 .map(|qi| {
435 let q = &queries[qi * dim..(qi + 1) * dim];
436 let q_code = self.sq8.quantize_query(q);
437 let q_binary = self.binary.encode_query(q);
438 let cells = self.tree.filter_cells(filters[qi].constraints());
439 (q_code, q_binary, cells)
440 })
441 .collect();
442
443 let mut high_regime: Vec<usize> = Vec::with_capacity(nq);
445 let mut mid_regime: Vec<usize> = Vec::new();
446 let mut low_regime: Vec<usize> = Vec::new();
447 let mut unfiltered: Vec<usize> = Vec::new();
448 for (qi, info) in query_info.iter().enumerate() {
449 let cells = &info.2;
450 if cells.len() >= n_cells {
451 unfiltered.push(qi);
452 } else {
453 let n_f: usize = cells
454 .iter()
455 .map(|&ci| self.tree.cells[ci].point_ids.len())
456 .sum();
457 let sigma = n_f as f32 / self.store.len as f32;
458 if sigma >= self.config.sigma_high {
459 high_regime.push(qi);
460 } else if sigma > self.config.sigma_low {
461 mid_regime.push(qi);
462 } else {
463 low_regime.push(qi);
464 }
465 }
466 }
467
468 let mut cell_queries: Vec<Vec<usize>> = vec![Vec::new(); n_cells];
470 for &qi in &high_regime {
471 for &ci in &query_info[qi].2 {
472 cell_queries[ci].push(qi);
473 }
474 }
475
476 #[allow(clippy::type_complexity)]
478 let cell_results: Vec<Vec<(usize, Vec<(u32, u32)>)>> = cell_queries
479 .into_par_iter()
480 .enumerate()
481 .filter(|(_, qs)| !qs.is_empty())
482 .map(|(ci, qs)| {
483 qs.iter()
484 .map(|&qi| {
485 let q_code = &query_info[qi].0;
486 let q_binary = &query_info[qi].1;
487 let cands = self.search_cell(q_code, q_binary, ci, ef, scan_threshold);
488 (qi, cands)
489 })
490 .collect()
491 })
492 .collect();
493
494 let mut query_heaps: Vec<BinaryHeap<(u32, u32)>> =
496 (0..nq).map(|_| BinaryHeap::new()).collect();
497 for cell_batch in cell_results {
498 for (qi, cands) in cell_batch {
499 for (sq8_dist, id) in cands {
500 heap_insert_sq8(&mut query_heaps[qi], sq8_dist, id, ef);
501 }
502 }
503 }
504
505 let unfilt_heaps: Vec<(usize, BinaryHeap<(u32, u32)>)> = unfiltered
507 .par_iter()
508 .map(|&qi| {
509 let q_code = &query_info[qi].0;
510 let q_binary = &query_info[qi].1;
511 let n = self.store.len as u32;
512 let rerank_budget = self.config.binary_rerank * ef;
513 let mut heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
514 if self.config.binary_rerank > 0 && (n as usize) > rerank_budget {
515 let mut binary_heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
516 for p in 0..n {
517 let hd = distance::hamming(q_binary, self.binary.code(p));
518 heap_insert_sq8(&mut binary_heap, hd, p, rerank_budget);
519 }
520 for (_, p) in binary_heap {
521 let dist = distance::l2_sq8(q_code, self.sq8.code(p));
522 heap_insert_sq8(&mut heap, dist, p, ef);
523 }
524 } else {
525 for p in 0..n {
526 let dist = distance::l2_sq8(q_code, self.sq8.code(p));
527 heap_insert_sq8(&mut heap, dist, p, ef);
528 }
529 }
530 (qi, heap)
531 })
532 .collect();
533 for (qi, heap) in unfilt_heaps {
534 query_heaps[qi] = heap;
535 }
536
537 let mut all_results: Vec<Vec<SearchResult>> = query_heaps
539 .into_par_iter()
540 .enumerate()
541 .map(|(qi, heap)| {
542 if heap.is_empty() {
543 return Vec::new();
544 }
545 let q = &queries[qi * dim..(qi + 1) * dim];
546 let mut results: Vec<SearchResult> = heap
547 .into_iter()
548 .map(|(_, id)| SearchResult {
549 id,
550 dist: distance::distance(q, self.store.vector(id), self.config.metric),
551 })
552 .collect();
553 results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
554 results.truncate(k);
555 results
556 })
557 .collect();
558
559 if !mid_regime.is_empty() {
561 let mid_results: Vec<(usize, Vec<SearchResult>)> = mid_regime
562 .par_iter()
563 .map(|&qi| {
564 let q = &queries[qi * dim..(qi + 1) * dim];
565 let cells = &query_info[qi].2;
566 let results = self.regime_mid(q, cells, k, ef);
567 (qi, results)
568 })
569 .collect();
570 for (qi, results) in mid_results {
571 all_results[qi] = results;
572 }
573 }
574
575 if !low_regime.is_empty() {
577 let low_results: Vec<(usize, Vec<SearchResult>)> = low_regime
578 .par_iter()
579 .map(|&qi| {
580 let q = &queries[qi * dim..(qi + 1) * dim];
581 let results = self.search(q, &filters[qi], k, ef);
582 (qi, results)
583 })
584 .collect();
585 for (qi, results) in low_results {
586 all_results[qi] = results;
587 }
588 }
589
590 all_results
591 }
592
593 fn search_cell(
596 &self,
597 q_code: &[u8],
598 q_binary: &[u64],
599 cell_idx: usize,
600 ef: usize,
601 scan_threshold: usize,
602 ) -> Vec<(u32, u32)> {
603 let pts = &self.tree.cells[cell_idx].point_ids;
604 let mut heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
605
606 if pts.len() <= scan_threshold {
607 let base = pts[0];
608 let rerank_budget = self.config.binary_rerank * ef;
609
610 if self.config.binary_rerank > 0 && pts.len() > rerank_budget {
611 let mut binary_heap: BinaryHeap<(u32, u32)> = BinaryHeap::new();
613 for i in 0..pts.len() {
614 let p = base + i as u32;
615 let hd = distance::hamming(q_binary, self.binary.code(p));
616 heap_insert_sq8(&mut binary_heap, hd, p, rerank_budget);
617 }
618 for (_, p) in binary_heap {
619 let dist = distance::l2_sq8(q_code, self.sq8.code(p));
620 heap_insert_sq8(&mut heap, dist, p, ef);
621 }
622 } else {
623 for i in 0..pts.len() {
625 let p = base + i as u32;
626 let dist = distance::l2_sq8(q_code, self.sq8.code(p));
627 heap_insert_sq8(&mut heap, dist, p, ef);
628 }
629 }
630 } else {
631 let ef_cell = ef.max((pts.len() / 200).min(ef * 5));
633 let local = self.greedy_search_cell_sq8(q_code, cell_idx, ef_cell);
634 for (id, dist) in local {
635 heap_insert_sq8(&mut heap, dist, id, ef);
636 }
637 }
638
639 heap.into_vec()
640 }
641}
642
643#[cfg(test)]
644mod tests {
645 use super::super::construct::{PrismConfig, PrismIndex};
646 use super::super::filter::Filter;
647 use super::super::point::PointStore;
648
649 fn build_test_index() -> PrismIndex {
650 let mut store = PointStore::new(2, 1);
651 for i in 0..10 {
652 let x = (i as f32) * 0.1;
653 let attr = if i < 5 { 0 } else { 1 };
654 store.push(&[x, x], &[attr]);
655 }
656 let config = PrismConfig {
657 m_local: 4,
658 m_greedy: 2,
659 m_random: 4,
660 t: 1,
661 alpha: 0.0,
662 beam_width: 10,
663 ..Default::default()
664 };
665 PrismIndex::build(store, config)
666 }
667
668 #[test]
669 fn test_search_no_filter() {
670 let index = build_test_index();
671 let results = index.search(&[0.25, 0.25], &Filter::none(), 3, 10);
672 assert_eq!(results.len(), 3);
673 for r in &results {
674 assert!(r.dist >= 0.0);
675 }
676 }
677
678 #[test]
679 fn test_search_with_filter() {
680 let index = build_test_index();
681 let filter = Filter::eq(0, 1);
682 let results = index.search(&[0.5, 0.5], &filter, 3, 10);
683 assert!(!results.is_empty());
684 for r in &results {
685 assert!(filter.matches(&index.store, r.id));
686 }
687 }
688
689 #[test]
690 fn test_graph_search_mid_selectivity() {
691 let dim = 16;
692 let n = 2000;
693 let n_vals = 20;
694 let mut store = PointStore::new(dim, 1);
695 for i in 0..n {
696 let vec: Vec<f32> = (0..dim).map(|d| ((i * dim + d) as f32).sin()).collect();
697 store.push(&vec, &[(i % n_vals) as u32]);
698 }
699 let config = PrismConfig {
700 m_local: 4,
701 m_greedy: 2,
702 m_random: 4,
703 t: 1,
704 beam_width: 10,
705 ..Default::default()
706 };
707 let index = PrismIndex::build(store, config);
708
709 let query: Vec<f32> = (0..dim).map(|d| (d as f32 * 0.3).sin()).collect();
710 let filter = Filter::eq(0, 0);
711 let k = 5;
712 let ef = 10;
713
714 let results = index.search(&query, &filter, k, ef);
715 assert!(!results.is_empty());
716 assert!(results.len() <= k);
717 for r in &results {
718 assert!(filter.matches(&index.store, r.id));
719 }
720 for w in results.windows(2) {
721 assert!(w[0].dist <= w[1].dist);
722 }
723 }
724
725 #[test]
726 fn test_search_empty_filter() {
727 let index = build_test_index();
728 let filter = Filter::eq(0, 99);
729 let results = index.search(&[0.0, 0.0], &filter, 3, 10);
730 assert!(results.is_empty());
731 }
732
733 #[test]
734 fn test_regime_mid_bridge_routing() {
735 let dim = 16;
739 let n = 2000;
740 let n_vals = 20;
741 let mut store = PointStore::new(dim, 1);
742 for i in 0..n {
743 let vec: Vec<f32> = (0..dim).map(|d| ((i * dim + d) as f32).sin()).collect();
744 store.push(&vec, &[(i % n_vals) as u32]);
745 }
746 let config = PrismConfig {
747 m_local: 4,
748 m_greedy: 4,
749 m_random: 4,
750 t: 1,
751 beam_width: 20,
752 sigma_high: 0.10,
753 sigma_low: 0.001,
754 beta: 3.0,
755 epsilon: 0.2,
756 ..Default::default()
757 };
758 let index = PrismIndex::build(store, config);
759
760 let query: Vec<f32> = (0..dim).map(|d| (d as f32 * 0.3).sin()).collect();
762 let filter = Filter::eq(0, 0);
763 let k = 5;
764 let ef = 50;
765
766 let results = index.search(&query, &filter, k, ef);
767 assert!(!results.is_empty());
768 assert!(results.len() <= k);
769 for r in &results {
770 assert!(filter.matches(&index.store, r.id));
771 }
772 for w in results.windows(2) {
773 assert!(w[0].dist <= w[1].dist);
774 }
775 }
776
777 #[test]
778 fn test_batch_search_mixed_regimes() {
779 let dim = 8;
781 let n = 1000;
782 let n_vals = 10;
783 let mut store = PointStore::new(dim, 1);
784 for i in 0..n {
785 let vec: Vec<f32> = (0..dim).map(|d| ((i * dim + d) as f32).sin()).collect();
786 store.push(&vec, &[(i % n_vals) as u32]);
787 }
788 let config = PrismConfig {
789 m_local: 4,
790 m_greedy: 4,
791 m_random: 4,
792 t: 1,
793 beam_width: 20,
794 sigma_high: 0.10,
795 sigma_low: 0.001,
796 ..Default::default()
797 };
798 let index = PrismIndex::build(store, config);
799
800 let k = 3;
801 let ef = 20;
802 let nq = 3;
803
804 let queries: Vec<f32> = (0..nq)
808 .flat_map(|qi| (0..dim).map(move |d| ((qi * dim + d) as f32 * 0.5).sin()))
809 .collect();
810 let filters = vec![Filter::none(), Filter::eq(0, 0), Filter::eq(0, 5)];
811
812 let results = index.batch_search(&queries, &filters, nq, k, ef);
813 assert_eq!(results.len(), nq);
814 for (qi, res) in results.iter().enumerate() {
815 assert!(!res.is_empty(), "query {} returned no results", qi);
816 assert!(res.len() <= k);
817 for r in res {
818 assert!(filters[qi].matches(&index.store, r.id));
819 }
820 }
821 }
822
823 #[test]
824 fn test_binary_prefilter_recall() {
825 let dim = 64;
828 let n = 2000;
829 let n_vals = 10;
830 let mut store = PointStore::new(dim, 1);
831 for i in 0..n {
832 let vec: Vec<f32> = (0..dim)
833 .map(|d| ((i * dim + d) as f32 * 0.01).sin())
834 .collect();
835 store.push(&vec, &[(i % n_vals) as u32]);
836 }
837
838 let config_binary = PrismConfig {
840 m_local: 4,
841 m_greedy: 2,
842 m_random: 4,
843 t: 1,
844 beam_width: 10,
845 binary_rerank: 4,
846 ..Default::default()
847 };
848 let index_binary = PrismIndex::build(store, config_binary);
849
850 let query: Vec<f32> = (0..dim).map(|d| (d as f32 * 0.3).sin()).collect();
851 let filter = Filter::eq(0, 0);
852 let k = 10;
853 let ef = 50;
854
855 let results_binary = index_binary.search(&query, &filter, k, ef);
856 assert!(!results_binary.is_empty());
857 assert!(results_binary.len() <= k);
858 for r in &results_binary {
859 assert!(filter.matches(&index_binary.store, r.id));
860 }
861 for w in results_binary.windows(2) {
862 assert!(w[0].dist <= w[1].dist);
863 }
864 }
865
866 #[test]
867 fn test_binary_prefilter_batch() {
868 let dim = 32;
870 let n = 500;
871 let n_vals = 5;
872 let mut store = PointStore::new(dim, 1);
873 for i in 0..n {
874 let vec: Vec<f32> = (0..dim)
875 .map(|d| ((i * dim + d) as f32 * 0.02).sin())
876 .collect();
877 store.push(&vec, &[(i % n_vals) as u32]);
878 }
879
880 let config = PrismConfig {
881 m_local: 4,
882 m_greedy: 2,
883 m_random: 4,
884 t: 1,
885 beam_width: 10,
886 binary_rerank: 4,
887 ..Default::default()
888 };
889 let index = PrismIndex::build(store, config);
890
891 let nq = 5;
892 let k = 5;
893 let ef = 20;
894 let queries: Vec<f32> = (0..nq)
895 .flat_map(|qi| (0..dim).map(move |d| ((qi * dim + d) as f32 * 0.1).sin()))
896 .collect();
897 let filters: Vec<Filter> = (0..nq)
898 .map(|qi| Filter::eq(0, (qi % n_vals) as u32))
899 .collect();
900
901 let results = index.batch_search(&queries, &filters, nq, k, ef);
902 assert_eq!(results.len(), nq);
903 for (qi, res) in results.iter().enumerate() {
904 assert!(!res.is_empty(), "query {} returned no results", qi);
905 assert!(res.len() <= k);
906 for r in res {
907 assert!(filters[qi].matches(&index.store, r.id));
908 }
909 }
910 }
911}