1use super::binary::BinaryStore;
2use super::distance::{self, Metric};
3use super::graph::{AdjBuilder, Graph};
4use super::partition::PartitionTree;
5use super::point::PointStore;
6use super::quantize::SQ8Store;
7
8use rand::prelude::*;
9use rayon::prelude::*;
10use std::collections::HashSet;
11
12#[derive(Clone, Debug)]
14pub struct PrismConfig {
15 pub m_local: usize,
17 pub m_greedy: usize,
19 pub m_random: usize,
21 pub t: usize,
23 pub alpha: f32,
25 pub vamana_alpha: f32,
27 pub beam_width: usize,
29 pub metric: Metric,
31 pub sigma_high: f32,
33 pub sigma_low: f32,
35 pub beta: f32,
37 pub epsilon: f32,
39 pub binary_rerank: usize,
42}
43
44impl Default for PrismConfig {
45 fn default() -> Self {
46 Self {
47 m_local: 16,
48 m_greedy: 12,
49 m_random: 4,
50 t: 2,
51 alpha: 1.0,
52 vamana_alpha: 1.0,
53 beam_width: 120,
54 metric: Metric::L2,
55 sigma_high: 0.10,
56 sigma_low: 0.001,
57 beta: 3.0,
58 epsilon: 0.2,
59 binary_rerank: 4,
60 }
61 }
62}
63
64pub struct PrismIndex {
66 pub store: PointStore,
67 pub tree: PartitionTree,
68 pub graph: Graph,
69 pub local_graph: Graph,
71 pub medoids: Vec<u32>,
72 pub global_medoid: u32,
73 pub point_cell: Vec<u32>,
75 pub original_ids: Vec<u32>,
77 pub sq8: SQ8Store,
79 pub binary: BinaryStore,
81 pub config: PrismConfig,
82}
83
84impl PrismIndex {
85 pub fn build(mut store: PointStore, config: PrismConfig) -> Self {
87 let n = store.len;
88 assert!(n > 0, "cannot build index from empty point store");
89 assert!(
90 config.m_random >= 4 && config.m_random % 2 == 0,
91 "m_random must be >= 4 and even (Friedman model requires d >= 4)"
92 );
93
94 if config.metric == Metric::Cosine {
99 let dim = store.dim;
100 distance::normalize_rows(&mut store.vectors, dim);
101 }
102
103 let tree = PartitionTree::build(&store);
104 let (store, tree, original_ids) = reorder_by_cell(store, tree);
105 let sq8 = SQ8Store::build(&store);
106 let binary = if config.binary_rerank > 0 {
107 BinaryStore::build(&store)
108 } else {
109 BinaryStore::empty(store.dim)
110 };
111
112 let mut point_cell = vec![0u32; n];
113 for (ci, cell) in tree.cells.iter().enumerate() {
114 for &pid in &cell.point_ids {
115 point_cell[pid as usize] = ci as u32;
116 }
117 }
118
119 let mut adj = AdjBuilder::new(n);
121 build_local_edges(&store, &tree, &sq8, &config, &mut adj);
122
123 let medoids = compute_medoids(&store, &tree, config.metric);
124
125 let local_graph = adj.snapshot();
126
127 if config.sigma_high > config.sigma_low {
131 build_greedy_cross_edges(
133 &store,
134 &tree,
135 &medoids,
136 &local_graph,
137 &sq8,
138 &point_cell,
139 &config,
140 &mut adj,
141 );
142
143 build_random_overlay(n, config.m_random, &mut adj);
145 }
146
147 let graph = adj.build();
148
149 let global_medoid = compute_global_medoid(&store, config.metric);
150
151 Self {
152 store,
153 tree,
154 graph,
155 local_graph,
156 medoids,
157 global_medoid,
158 point_cell,
159 original_ids,
160 sq8,
161 binary,
162 config,
163 }
164 }
165}
166
167fn reorder_by_cell(
169 store: PointStore,
170 mut tree: PartitionTree,
171) -> (PointStore, PartitionTree, Vec<u32>) {
172 let n = store.len;
173 let dim = store.dim;
174 let k = store.k();
175
176 let mut new_order: Vec<u32> = Vec::with_capacity(n);
177 for cell in &tree.cells {
178 new_order.extend_from_slice(&cell.point_ids);
179 }
180
181 let mut old_to_new = vec![0u32; n];
182 for (new_id, &old_id) in new_order.iter().enumerate() {
183 old_to_new[old_id as usize] = new_id as u32;
184 }
185
186 let mut new_vectors = vec![0.0f32; n * dim];
187 for (new_id, &old_id) in new_order.iter().enumerate() {
188 let src = &store.vectors[old_id as usize * dim..(old_id as usize + 1) * dim];
189 new_vectors[new_id * dim..(new_id + 1) * dim].copy_from_slice(src);
190 }
191
192 let mut new_attrs = Vec::with_capacity(k);
193 for j in 0..k {
194 let mut attr_col = vec![0u32; n];
195 for (new_id, &old_id) in new_order.iter().enumerate() {
196 attr_col[new_id] = store.attrs[j][old_id as usize];
197 }
198 new_attrs.push(attr_col);
199 }
200
201 for cell in &mut tree.cells {
202 for pid in &mut cell.point_ids {
203 *pid = old_to_new[*pid as usize];
204 }
205 }
206
207 let new_store = PointStore::from_parts(new_vectors, dim, new_attrs);
208 (new_store, tree, new_order)
209}
210
211fn build_local_edges(
214 store: &PointStore,
215 tree: &PartitionTree,
216 sq8: &SQ8Store,
217 config: &PrismConfig,
218 adj: &mut AdjBuilder,
219) {
220 let cell_edges: Vec<Vec<(u32, u32)>> = tree
221 .cells
222 .par_iter()
223 .map(|cell| {
224 let pts = &cell.point_ids;
225 let mut edges = Vec::new();
226 if pts.len() <= 1 {
227 return edges;
228 }
229
230 if pts.len() <= config.m_local + 1 {
231 for i in 0..pts.len() {
232 for j in (i + 1)..pts.len() {
233 edges.push((pts[i], pts[j]));
234 edges.push((pts[j], pts[i]));
235 }
236 }
237 } else {
238 let mut rng = rand::thread_rng();
239 build_vamana_cell(store, sq8, pts, config, &mut edges, &mut rng);
240 }
241 edges
242 })
243 .collect();
244
245 for edges in cell_edges {
246 for (src, dst) in edges {
247 adj.add_edge(src, dst);
248 }
249 }
250}
251
252fn build_vamana_cell(
254 store: &PointStore,
255 sq8: &SQ8Store,
256 pts: &[u32],
257 config: &PrismConfig,
258 edges: &mut Vec<(u32, u32)>,
259 rng: &mut impl Rng,
260) {
261 let n = pts.len();
262 let r = config.m_local;
263 let beam = n.min(config.beam_width);
264 let alpha = config.vamana_alpha;
265
266 let actual_r = r.min(n - 1);
267 let mut graph: Vec<Vec<usize>> = (0..n)
268 .map(|i| {
269 let mut neighbors = Vec::with_capacity(actual_r);
270 while neighbors.len() < actual_r {
271 let j = rng.gen_range(0..n);
272 if j != i && !neighbors.contains(&j) {
273 neighbors.push(j);
274 }
275 }
276 neighbors
277 })
278 .collect();
279
280 let dim = store.dim;
281 let mut centroid = vec![0.0f32; dim];
282 for &p in pts {
283 let v = store.vector(p);
284 for (c, &x) in centroid.iter_mut().zip(v.iter()) {
285 *c += x;
286 }
287 }
288 let inv_n = 1.0 / n as f32;
289 for c in &mut centroid {
290 *c *= inv_n;
291 }
292 let entry = (0..n)
293 .min_by(|&a, &b| {
294 let da = distance::distance(¢roid, store.vector(pts[a]), config.metric);
295 let db = distance::distance(¢roid, store.vector(pts[b]), config.metric);
296 da.partial_cmp(&db).unwrap()
297 })
298 .unwrap();
299
300 for _pass in 0..2 {
301 let mut order: Vec<usize> = (0..n).collect();
302 order.shuffle(rng);
303
304 for &i in &order {
305 let search_results =
306 vamana_search_code(store, sq8, config.metric, pts, &graph, entry, pts[i], beam);
307
308 let mut candidates = search_results;
309 for &nb in &graph[i] {
310 if !candidates.contains(&nb) {
311 candidates.push(nb);
312 }
313 }
314
315 graph[i] = robust_prune(store, pts, i, &candidates, alpha, r, config.metric);
316
317 let new_neighbors: Vec<usize> = graph[i].clone();
318 for &j in &new_neighbors {
319 if !graph[j].contains(&i) {
320 graph[j].push(i);
321 if graph[j].len() > r {
322 let cands: Vec<usize> = graph[j].clone();
323 graph[j] = robust_prune(store, pts, j, &cands, alpha, r, config.metric);
324 }
325 }
326 }
327 }
328 }
329
330 for (i, neighbors) in graph.iter().enumerate() {
331 for &j in neighbors {
332 edges.push((pts[i], pts[j]));
333 }
334 }
335}
336
337#[inline]
341fn build_cand_dist(store: &PointStore, sq8: &SQ8Store, metric: Metric, a: u32, b: u32) -> u32 {
342 match metric {
343 Metric::L2 | Metric::Cosine => distance::l2_sq8(sq8.code(a), sq8.code(b)),
344 Metric::InnerProduct => distance::ord_key(distance::distance(
345 store.vector(a),
346 store.vector(b),
347 Metric::InnerProduct,
348 )),
349 }
350}
351
352#[allow(clippy::too_many_arguments)]
354fn vamana_search_code(
355 store: &PointStore,
356 sq8: &SQ8Store,
357 metric: Metric,
358 pts: &[u32],
359 graph: &[Vec<usize>],
360 entry: usize,
361 query_id: u32,
362 beam: usize,
363) -> Vec<usize> {
364 use std::cmp::Reverse;
365 use std::collections::BinaryHeap;
366
367 let mut visited = vec![false; pts.len()];
368 let mut candidates: BinaryHeap<Reverse<(u32, usize)>> = BinaryHeap::new();
369 let mut results: BinaryHeap<(u32, usize)> = BinaryHeap::new();
370
371 let d = build_cand_dist(store, sq8, metric, query_id, pts[entry]);
372 visited[entry] = true;
373 candidates.push(Reverse((d, entry)));
374 results.push((d, entry));
375
376 while let Some(Reverse((d, c))) = candidates.pop() {
377 if results.len() >= beam {
378 if let Some(&(worst, _)) = results.peek() {
379 if d > worst {
380 break;
381 }
382 }
383 }
384
385 for &w in &graph[c] {
386 if visited[w] {
387 continue;
388 }
389 visited[w] = true;
390 let wd = build_cand_dist(store, sq8, metric, query_id, pts[w]);
391 candidates.push(Reverse((wd, w)));
392 results.push((wd, w));
393 if results.len() > beam {
394 results.pop();
395 }
396 }
397 }
398
399 results.into_iter().map(|(_, idx)| idx).collect()
400}
401
402fn robust_prune(
404 store: &PointStore,
405 pts: &[u32],
406 p: usize,
407 candidates: &[usize],
408 alpha: f32,
409 r: usize,
410 metric: Metric,
411) -> Vec<usize> {
412 let p_vec = store.vector(pts[p]);
413 let mut sorted: Vec<(usize, f32)> = candidates
414 .iter()
415 .filter(|&&c| c != p)
416 .map(|&c| (c, distance::distance(p_vec, store.vector(pts[c]), metric)))
417 .collect();
418 sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
419 sorted.dedup_by_key(|x| x.0);
420
421 let mut selected: Vec<usize> = Vec::with_capacity(r);
422 for &(c, d_pc) in &sorted {
423 if selected.len() >= r {
424 break;
425 }
426 let dominated = selected.iter().any(|&s| {
427 let d_cs = distance::distance(store.vector(pts[c]), store.vector(pts[s]), metric);
428 alpha * d_cs <= d_pc
429 });
430 if !dominated {
431 selected.push(c);
432 }
433 }
434 selected
435}
436
437#[allow(clippy::too_many_arguments)]
440fn build_greedy_cross_edges(
441 store: &PointStore,
442 tree: &PartitionTree,
443 medoids: &[u32],
444 local_graph: &Graph,
445 sq8: &SQ8Store,
446 point_cell: &[u32],
447 config: &PrismConfig,
448 adj: &mut AdjBuilder,
449) {
450 let n = store.len;
451 let k = store.k();
452 let t = config.t.min(k);
453 let beam = config.beam_width;
454 let subsets = t_subsets(k, t);
455 let use_sq8 = config.metric != Metric::InnerProduct;
458
459 let point_edges: Vec<Vec<u32>> = (0..n as u32)
460 .into_par_iter()
461 .map(|p_id| {
462 let p_cell_idx = point_cell[p_id as usize];
463 let p_vec = store.vector(p_id);
464
465 let p_code = sq8.code(p_id);
466 let mut cell_dists: Vec<(usize, u32)> = tree
467 .cells
468 .iter()
469 .enumerate()
470 .filter(|&(ci, _)| ci as u32 != p_cell_idx)
471 .map(|(ci, _)| {
472 let d = distance::l2_sq8(p_code, sq8.code(medoids[ci]));
473 (ci, d)
474 })
475 .collect();
476 cell_dists.sort_unstable_by_key(|&(_, d)| d);
477
478 let mut all_cand_ids: Vec<u32> = Vec::with_capacity(beam);
479 for &(ci, _) in &cell_dists {
480 let cell_size = tree.cells[ci].point_ids.len();
481
482 if use_sq8 && cell_size > beam * 2 {
483 let found = beam_search_sq8(sq8, local_graph, p_code, medoids[ci], beam);
484 for (id, _) in found {
485 all_cand_ids.push(id);
486 }
487 } else if use_sq8 {
488 let mut scored: Vec<(u32, u32)> = tree.cells[ci]
489 .point_ids
490 .iter()
491 .map(|&q| (q, distance::l2_sq8(p_code, sq8.code(q))))
492 .collect();
493 scored.sort_unstable_by_key(|&(_, d)| d);
494 for &(id, _) in scored.iter().take(beam) {
495 all_cand_ids.push(id);
496 }
497 } else {
498 for &q_id in &tree.cells[ci].point_ids {
499 all_cand_ids.push(q_id);
500 }
501 }
502
503 if all_cand_ids.len() >= beam {
504 break;
505 }
506 }
507
508 let mut candidates: Vec<(u32, f32)> = all_cand_ids
509 .iter()
510 .map(|&id| {
511 (
512 id,
513 distance::distance(p_vec, store.vector(id), config.metric),
514 )
515 })
516 .collect();
517 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
518 candidates.truncate(beam);
519
520 select_cross_neighbors(store, &candidates, config, &subsets)
521 })
522 .collect();
523
524 for (p_id, neighbors) in point_edges.into_iter().enumerate() {
525 for q_id in neighbors {
526 adj.add_edge(p_id as u32, q_id);
527 }
528 }
529}
530
531fn beam_search_sq8(
533 sq8: &SQ8Store,
534 graph: &Graph,
535 query_code: &[u8],
536 entry: u32,
537 beam: usize,
538) -> Vec<(u32, u32)> {
539 use std::cmp::Reverse;
540 use std::collections::BinaryHeap;
541
542 let mut visited = HashSet::new();
543 let mut candidates: BinaryHeap<Reverse<(u32, u32)>> = BinaryHeap::new();
544 let mut results: BinaryHeap<(u32, u32)> = BinaryHeap::new();
545
546 let d = distance::l2_sq8(query_code, sq8.code(entry));
547 visited.insert(entry);
548 candidates.push(Reverse((d, entry)));
549 results.push((d, entry));
550
551 while let Some(Reverse((d, c))) = candidates.pop() {
552 if results.len() >= beam {
553 if let Some(&(worst, _)) = results.peek() {
554 if d > worst {
555 break;
556 }
557 }
558 }
559
560 for &w in graph.neighbors(c) {
561 if !visited.insert(w) {
562 continue;
563 }
564 let wd = distance::l2_sq8(query_code, sq8.code(w));
565 candidates.push(Reverse((wd, w)));
566 results.push((wd, w));
567 if results.len() > beam {
568 results.pop();
569 }
570 }
571 }
572
573 results.into_iter().map(|(d, id)| (id, d)).collect()
574}
575
576pub(crate) fn select_cross_neighbors(
578 store: &PointStore,
579 candidates: &[(u32, f32)],
580 config: &PrismConfig,
581 subsets: &[Vec<usize>],
582) -> Vec<u32> {
583 let m_g = config.m_greedy;
584 let alpha = config.alpha;
585
586 if candidates.is_empty() || m_g == 0 {
587 return Vec::new();
588 }
589
590 let mut covered: HashSet<u64> = HashSet::new();
591 let mut selected = Vec::with_capacity(m_g);
592 let mut available: Vec<bool> = vec![true; candidates.len()];
593
594 for _ in 0..m_g {
595 let mut best_idx = None;
596 let mut best_score = f32::NEG_INFINITY;
597
598 for (idx, &(q_id, dist)) in candidates.iter().enumerate() {
599 if !available[idx] {
600 continue;
601 }
602
603 let new_tuples = count_new_tuples(store, q_id, &covered, subsets);
604
605 let score = if alpha == 0.0 || dist == 0.0 {
607 new_tuples as f32
608 } else {
609 (new_tuples as f32 + 0.001) / dist.powf(alpha)
610 };
611
612 if score > best_score {
613 best_score = score;
614 best_idx = Some(idx);
615 }
616 }
617
618 let Some(idx) = best_idx else { break };
619 selected.push(candidates[idx].0);
620 available[idx] = false;
621
622 add_tuples(store, candidates[idx].0, &mut covered, subsets);
623 }
624
625 selected
626}
627
628#[inline]
630fn tuple_key(combo: &[usize], store: &PointStore, q: u32) -> u64 {
631 let mut key: u64 = 0;
632 for (i, &j) in combo.iter().enumerate() {
633 let val = store.attr(q, j) as u64;
634 key |= ((j as u64) << 8 | val) << (i * 16);
635 }
636 key
637}
638
639fn count_new_tuples(
641 store: &PointStore,
642 q: u32,
643 covered: &HashSet<u64>,
644 subsets: &[Vec<usize>],
645) -> usize {
646 let mut count = 0;
647 for combo in subsets {
648 let key = tuple_key(combo, store, q);
649 if !covered.contains(&key) {
650 count += 1;
651 }
652 }
653 count
654}
655
656pub(crate) fn add_tuples(
658 store: &PointStore,
659 q: u32,
660 covered: &mut HashSet<u64>,
661 subsets: &[Vec<usize>],
662) {
663 for combo in subsets {
664 let key = tuple_key(combo, store, q);
665 covered.insert(key);
666 }
667}
668
669pub(crate) fn t_subsets(k: usize, t: usize) -> Vec<Vec<usize>> {
671 let mut result = Vec::new();
672 let mut combo = Vec::with_capacity(t);
673 generate_subsets(k, t, 0, &mut combo, &mut result);
674 result
675}
676
677fn generate_subsets(
678 k: usize,
679 t: usize,
680 start: usize,
681 combo: &mut Vec<usize>,
682 result: &mut Vec<Vec<usize>>,
683) {
684 if combo.len() == t {
685 result.push(combo.clone());
686 return;
687 }
688 for i in start..k {
689 combo.push(i);
690 generate_subsets(k, t, i + 1, combo, result);
691 combo.pop();
692 }
693}
694
695pub(crate) fn build_random_overlay(n: usize, m_random: usize, adj: &mut AdjBuilder) {
697 if m_random == 0 || n <= 1 {
698 return;
699 }
700 let mut rng = rand::thread_rng();
701 let half = m_random / 2;
702
703 for _ in 0..half {
704 let mut perm: Vec<u32> = (0..n as u32).collect();
705 perm.shuffle(&mut rng);
706 for (i, &j) in perm.iter().enumerate() {
707 if i as u32 != j {
708 adj.add_undirected(i as u32, j);
709 }
710 }
711 }
712}
713
714fn compute_medoids(store: &PointStore, tree: &PartitionTree, metric: Metric) -> Vec<u32> {
716 let dim = store.dim;
717 tree.cells
718 .iter()
719 .map(|cell| {
720 let pts = &cell.point_ids;
721 if pts.len() == 1 {
722 return pts[0];
723 }
724 let mut centroid = vec![0.0f32; dim];
725 for &p in pts {
726 let v = store.vector(p);
727 for (c, &x) in centroid.iter_mut().zip(v.iter()) {
728 *c += x;
729 }
730 }
731 let inv_n = 1.0 / pts.len() as f32;
732 for c in &mut centroid {
733 *c *= inv_n;
734 }
735 *pts.iter()
736 .min_by(|&&a, &&b| {
737 let da = distance::distance(¢roid, store.vector(a), metric);
738 let db = distance::distance(¢roid, store.vector(b), metric);
739 da.partial_cmp(&db).unwrap()
740 })
741 .unwrap()
742 })
743 .collect()
744}
745
746fn compute_global_medoid(store: &PointStore, metric: Metric) -> u32 {
748 let n = store.len;
749 let dim = store.dim;
750 let mut centroid = vec![0.0f32; dim];
751 for i in 0..n as u32 {
752 let v = store.vector(i);
753 for (c, &x) in centroid.iter_mut().zip(v.iter()) {
754 *c += x;
755 }
756 }
757 let inv_n = 1.0 / n as f32;
758 for c in &mut centroid {
759 *c *= inv_n;
760 }
761 (0..n as u32)
762 .min_by(|&a, &b| {
763 let da = distance::distance(¢roid, store.vector(a), metric);
764 let db = distance::distance(¢roid, store.vector(b), metric);
765 da.partial_cmp(&db).unwrap()
766 })
767 .unwrap()
768}
769
770#[cfg(test)]
771mod tests {
772 use super::super::point::PointStore;
773 use super::*;
774
775 #[test]
776 fn test_build_small() {
777 let mut store = PointStore::new(2, 2);
778 store.push(&[0.0, 0.0], &[0, 0]);
780 store.push(&[1.0, 0.0], &[0, 1]);
781 store.push(&[0.0, 1.0], &[1, 0]);
782 store.push(&[1.0, 1.0], &[1, 1]);
783
784 let config = PrismConfig {
785 m_local: 2,
786 m_greedy: 2,
787 m_random: 4,
788 t: 1,
789 alpha: 0.0,
790 beam_width: 10,
791 ..Default::default()
792 };
793
794 let index = PrismIndex::build(store, config);
795 assert_eq!(index.tree.cells.len(), 4);
796 assert_eq!(index.medoids.len(), 4);
797 for i in 0..4u32 {
799 assert!(index.graph.degree(i) > 0);
800 }
801 }
802
803 #[test]
804 fn test_t_subsets() {
805 let subs = t_subsets(4, 2);
806 assert_eq!(subs.len(), 6); let subs = t_subsets(3, 1);
808 assert_eq!(subs.len(), 3);
809 }
810}