1use super::binary::BinaryStore;
2use super::distance::{self, Metric};
3use super::graph::{AdjBuilder, Graph};
4use super::partition::PartitionTree;
5use super::point::PointStore;
6use super::quantize::SQ8Store;
7
8use rand::prelude::*;
9use rayon::prelude::*;
10use std::collections::HashSet;
11
12#[derive(Clone, Debug)]
14pub struct PrismConfig {
15 pub m_local: usize,
17 pub m_greedy: usize,
19 pub m_random: usize,
21 pub t: usize,
23 pub alpha: f32,
25 pub vamana_alpha: f32,
27 pub beam_width: usize,
29 pub metric: Metric,
31 pub sigma_high: f32,
33 pub sigma_low: f32,
35 pub beta: f32,
37 pub epsilon: f32,
39 pub binary_rerank: usize,
42}
43
44impl Default for PrismConfig {
45 fn default() -> Self {
46 Self {
47 m_local: 16,
48 m_greedy: 12,
49 m_random: 4,
50 t: 2,
51 alpha: 1.0,
52 vamana_alpha: 1.0,
53 beam_width: 120,
54 metric: Metric::L2,
55 sigma_high: 0.10,
56 sigma_low: 0.001,
57 beta: 3.0,
58 epsilon: 0.2,
59 binary_rerank: 4,
60 }
61 }
62}
63
64pub struct PrismIndex {
66 pub store: PointStore,
67 pub tree: PartitionTree,
68 pub graph: Graph,
69 pub local_graph: Graph,
71 pub medoids: Vec<u32>,
72 pub global_medoid: u32,
73 pub point_cell: Vec<u32>,
75 pub original_ids: Vec<u32>,
77 pub sq8: SQ8Store,
79 pub binary: BinaryStore,
81 pub config: PrismConfig,
82}
83
84impl PrismIndex {
85 pub fn build(store: PointStore, config: PrismConfig) -> Self {
87 let n = store.len;
88 assert!(n > 0, "cannot build index from empty point store");
89 assert!(
90 config.m_random >= 4 && config.m_random % 2 == 0,
91 "m_random must be >= 4 and even (Friedman model requires d >= 4)"
92 );
93
94 let tree = PartitionTree::build(&store);
95 let (store, tree, original_ids) = reorder_by_cell(store, tree);
96 let sq8 = SQ8Store::build(&store);
97 let binary = BinaryStore::build(&store);
98
99 let mut point_cell = vec![0u32; n];
100 for (ci, cell) in tree.cells.iter().enumerate() {
101 for &pid in &cell.point_ids {
102 point_cell[pid as usize] = ci as u32;
103 }
104 }
105
106 let mut adj = AdjBuilder::new(n);
108 let t0 = std::time::Instant::now();
109 build_local_edges(&store, &tree, &sq8, &config, &mut adj);
110 let local_edges = adj.total_edges();
111 eprintln!(
112 " Local edges: {:.1}s, {} edges ({:.1}/node)",
113 t0.elapsed().as_secs_f64(),
114 local_edges,
115 local_edges as f64 / n as f64
116 );
117
118 let t0 = std::time::Instant::now();
119 let medoids = compute_medoids(&store, &tree, config.metric);
120 eprintln!(" Medoids: {:.1}s", t0.elapsed().as_secs_f64());
121
122 let local_graph = adj.snapshot();
123
124 let t0 = std::time::Instant::now();
126 build_greedy_cross_edges(
127 &store,
128 &tree,
129 &medoids,
130 &local_graph,
131 &sq8,
132 &point_cell,
133 &config,
134 &mut adj,
135 );
136 let cross_edges = adj.total_edges() - local_edges;
137 eprintln!(
138 " Cross edges: {:.1}s, {} edges ({:.1}/node)",
139 t0.elapsed().as_secs_f64(),
140 cross_edges,
141 cross_edges as f64 / n as f64
142 );
143
144 let edges_before = adj.total_edges();
146 let t0 = std::time::Instant::now();
147 build_random_overlay(n, config.m_random, &mut adj);
148 let random_edges = adj.total_edges() - edges_before;
149 eprintln!(
150 " Random overlay: {:.1}s, {} edges ({:.1}/node)",
151 t0.elapsed().as_secs_f64(),
152 random_edges,
153 random_edges as f64 / n as f64
154 );
155
156 let graph = adj.build();
157
158 let global_medoid = compute_global_medoid(&store, config.metric);
159
160 Self {
161 store,
162 tree,
163 graph,
164 local_graph,
165 medoids,
166 global_medoid,
167 point_cell,
168 original_ids,
169 sq8,
170 binary,
171 config,
172 }
173 }
174}
175
176fn reorder_by_cell(
178 store: PointStore,
179 mut tree: PartitionTree,
180) -> (PointStore, PartitionTree, Vec<u32>) {
181 let n = store.len;
182 let dim = store.dim;
183 let k = store.k();
184
185 let mut new_order: Vec<u32> = Vec::with_capacity(n);
187 for cell in &tree.cells {
188 new_order.extend_from_slice(&cell.point_ids);
189 }
190
191 let mut old_to_new = vec![0u32; n];
193 for (new_id, &old_id) in new_order.iter().enumerate() {
194 old_to_new[old_id as usize] = new_id as u32;
195 }
196
197 let mut new_vectors = vec![0.0f32; n * dim];
199 for (new_id, &old_id) in new_order.iter().enumerate() {
200 let src = &store.vectors[old_id as usize * dim..(old_id as usize + 1) * dim];
201 new_vectors[new_id * dim..(new_id + 1) * dim].copy_from_slice(src);
202 }
203
204 let mut new_attrs = Vec::with_capacity(k);
206 for j in 0..k {
207 let mut attr_col = vec![0u32; n];
208 for (new_id, &old_id) in new_order.iter().enumerate() {
209 attr_col[new_id] = store.attrs[j][old_id as usize];
210 }
211 new_attrs.push(attr_col);
212 }
213
214 for cell in &mut tree.cells {
216 for pid in &mut cell.point_ids {
217 *pid = old_to_new[*pid as usize];
218 }
219 }
220
221 let new_store = PointStore::from_parts(new_vectors, dim, new_attrs);
222 (new_store, tree, new_order)
223}
224
225fn build_local_edges(
228 store: &PointStore,
229 tree: &PartitionTree,
230 sq8: &SQ8Store,
231 config: &PrismConfig,
232 adj: &mut AdjBuilder,
233) {
234 let cell_edges: Vec<Vec<(u32, u32)>> = tree
235 .cells
236 .par_iter()
237 .map(|cell| {
238 let pts = &cell.point_ids;
239 let mut edges = Vec::new();
240 if pts.len() <= 1 {
241 return edges;
242 }
243
244 if pts.len() <= config.m_local + 1 {
245 for i in 0..pts.len() {
246 for j in (i + 1)..pts.len() {
247 edges.push((pts[i], pts[j]));
248 edges.push((pts[j], pts[i]));
249 }
250 }
251 } else {
252 let mut rng = rand::thread_rng();
253 build_vamana_cell(store, sq8, pts, config, &mut edges, &mut rng);
254 }
255 edges
256 })
257 .collect();
258
259 for edges in cell_edges {
260 for (src, dst) in edges {
261 adj.add_edge(src, dst);
262 }
263 }
264}
265
266fn build_vamana_cell(
268 store: &PointStore,
269 sq8: &SQ8Store,
270 pts: &[u32],
271 config: &PrismConfig,
272 edges: &mut Vec<(u32, u32)>,
273 rng: &mut impl Rng,
274) {
275 let n = pts.len();
276 let r = config.m_local;
277 let beam = n.min(config.beam_width);
278 let alpha = config.vamana_alpha;
279
280 let actual_r = r.min(n - 1);
282 let mut graph: Vec<Vec<usize>> = (0..n)
283 .map(|i| {
284 let mut neighbors = Vec::with_capacity(actual_r);
285 while neighbors.len() < actual_r {
286 let j = rng.gen_range(0..n);
287 if j != i && !neighbors.contains(&j) {
288 neighbors.push(j);
289 }
290 }
291 neighbors
292 })
293 .collect();
294
295 let dim = store.dim;
297 let mut centroid = vec![0.0f32; dim];
298 for &p in pts {
299 let v = store.vector(p);
300 for (c, &x) in centroid.iter_mut().zip(v.iter()) {
301 *c += x;
302 }
303 }
304 let inv_n = 1.0 / n as f32;
305 for c in &mut centroid {
306 *c *= inv_n;
307 }
308 let entry = (0..n)
309 .min_by(|&a, &b| {
310 let da = distance::distance(¢roid, store.vector(pts[a]), config.metric);
311 let db = distance::distance(¢roid, store.vector(pts[b]), config.metric);
312 da.partial_cmp(&db).unwrap()
313 })
314 .unwrap();
315
316 for _pass in 0..2 {
317 let mut order: Vec<usize> = (0..n).collect();
318 order.shuffle(rng);
319
320 for &i in &order {
321 let search_results = vamana_search_sq8(sq8, pts, &graph, entry, pts[i], beam);
322
323 let mut candidates = search_results;
325 for &nb in &graph[i] {
326 if !candidates.contains(&nb) {
327 candidates.push(nb);
328 }
329 }
330
331 graph[i] = robust_prune(store, pts, i, &candidates, alpha, r, config.metric);
332
333 let new_neighbors: Vec<usize> = graph[i].clone();
335 for &j in &new_neighbors {
336 if !graph[j].contains(&i) {
337 graph[j].push(i);
338 if graph[j].len() > r {
339 let cands: Vec<usize> = graph[j].clone();
340 graph[j] = robust_prune(store, pts, j, &cands, alpha, r, config.metric);
341 }
342 }
343 }
344 }
345 }
346
347 for (i, neighbors) in graph.iter().enumerate() {
348 for &j in neighbors {
349 edges.push((pts[i], pts[j]));
350 }
351 }
352}
353
354fn vamana_search_sq8(
356 sq8: &SQ8Store,
357 pts: &[u32],
358 graph: &[Vec<usize>],
359 entry: usize,
360 query_id: u32,
361 beam: usize,
362) -> Vec<usize> {
363 use std::cmp::Reverse;
364 use std::collections::BinaryHeap;
365
366 let q_code = sq8.code(query_id);
367 let mut visited = vec![false; pts.len()];
368 let mut candidates: BinaryHeap<Reverse<(u32, usize)>> = BinaryHeap::new();
369 let mut results: BinaryHeap<(u32, usize)> = BinaryHeap::new();
370
371 let d = distance::l2_sq8(q_code, sq8.code(pts[entry]));
372 visited[entry] = true;
373 candidates.push(Reverse((d, entry)));
374 results.push((d, entry));
375
376 while let Some(Reverse((d, c))) = candidates.pop() {
377 if results.len() >= beam {
378 if let Some(&(worst, _)) = results.peek() {
379 if d > worst {
380 break;
381 }
382 }
383 }
384
385 for &w in &graph[c] {
386 if visited[w] {
387 continue;
388 }
389 visited[w] = true;
390 let wd = distance::l2_sq8(q_code, sq8.code(pts[w]));
391 candidates.push(Reverse((wd, w)));
392 results.push((wd, w));
393 if results.len() > beam {
394 results.pop();
395 }
396 }
397 }
398
399 results.into_iter().map(|(_, idx)| idx).collect()
400}
401
402fn robust_prune(
404 store: &PointStore,
405 pts: &[u32],
406 p: usize,
407 candidates: &[usize],
408 alpha: f32,
409 r: usize,
410 metric: Metric,
411) -> Vec<usize> {
412 let p_vec = store.vector(pts[p]);
413 let mut sorted: Vec<(usize, f32)> = candidates
414 .iter()
415 .filter(|&&c| c != p)
416 .map(|&c| (c, distance::distance(p_vec, store.vector(pts[c]), metric)))
417 .collect();
418 sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
419 sorted.dedup_by_key(|x| x.0);
420
421 let mut selected: Vec<usize> = Vec::with_capacity(r);
422 for &(c, d_pc) in &sorted {
423 if selected.len() >= r {
424 break;
425 }
426 let dominated = selected.iter().any(|&s| {
427 let d_cs = distance::distance(store.vector(pts[c]), store.vector(pts[s]), metric);
428 alpha * d_cs <= d_pc
429 });
430 if !dominated {
431 selected.push(c);
432 }
433 }
434 selected
435}
436
437#[allow(clippy::too_many_arguments)]
440fn build_greedy_cross_edges(
441 store: &PointStore,
442 tree: &PartitionTree,
443 medoids: &[u32],
444 local_graph: &Graph,
445 sq8: &SQ8Store,
446 point_cell: &[u32],
447 config: &PrismConfig,
448 adj: &mut AdjBuilder,
449) {
450 let n = store.len;
451 let k = store.k();
452 let t = config.t.min(k);
453 let beam = config.beam_width;
454 let subsets = t_subsets(k, t);
455 let use_sq8 = config.metric == Metric::L2;
456
457 let point_edges: Vec<Vec<u32>> = (0..n as u32)
458 .into_par_iter()
459 .map(|p_id| {
460 let p_cell_idx = point_cell[p_id as usize];
461 let p_vec = store.vector(p_id);
462
463 let p_code = sq8.code(p_id);
465 let mut cell_dists: Vec<(usize, u32)> = tree
466 .cells
467 .iter()
468 .enumerate()
469 .filter(|&(ci, _)| ci as u32 != p_cell_idx)
470 .map(|(ci, _)| {
471 let d = distance::l2_sq8(p_code, sq8.code(medoids[ci]));
472 (ci, d)
473 })
474 .collect();
475 cell_dists.sort_unstable_by_key(|&(_, d)| d);
476
477 let mut all_cand_ids: Vec<u32> = Vec::with_capacity(beam);
479 for &(ci, _) in &cell_dists {
480 let cell_size = tree.cells[ci].point_ids.len();
481
482 if use_sq8 && cell_size > beam * 2 {
483 let found = beam_search_sq8(sq8, local_graph, p_code, medoids[ci], beam);
484 for (id, _) in found {
485 all_cand_ids.push(id);
486 }
487 } else if use_sq8 {
488 let mut scored: Vec<(u32, u32)> = tree.cells[ci]
489 .point_ids
490 .iter()
491 .map(|&q| (q, distance::l2_sq8(p_code, sq8.code(q))))
492 .collect();
493 scored.sort_unstable_by_key(|&(_, d)| d);
494 for &(id, _) in scored.iter().take(beam) {
495 all_cand_ids.push(id);
496 }
497 } else {
498 for &q_id in &tree.cells[ci].point_ids {
499 all_cand_ids.push(q_id);
500 }
501 }
502
503 if all_cand_ids.len() >= beam {
504 break;
505 }
506 }
507
508 let mut candidates: Vec<(u32, f32)> = all_cand_ids
510 .iter()
511 .map(|&id| {
512 (
513 id,
514 distance::distance(p_vec, store.vector(id), config.metric),
515 )
516 })
517 .collect();
518 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
519 candidates.truncate(beam);
520
521 select_cross_neighbors(store, &candidates, config, &subsets)
522 })
523 .collect();
524
525 for (p_id, neighbors) in point_edges.into_iter().enumerate() {
526 for q_id in neighbors {
527 adj.add_edge(p_id as u32, q_id);
528 }
529 }
530}
531
532fn beam_search_sq8(
534 sq8: &SQ8Store,
535 graph: &Graph,
536 query_code: &[u8],
537 entry: u32,
538 beam: usize,
539) -> Vec<(u32, u32)> {
540 use std::cmp::Reverse;
541 use std::collections::BinaryHeap;
542
543 let mut visited = HashSet::new();
544 let mut candidates: BinaryHeap<Reverse<(u32, u32)>> = BinaryHeap::new();
545 let mut results: BinaryHeap<(u32, u32)> = BinaryHeap::new();
546
547 let d = distance::l2_sq8(query_code, sq8.code(entry));
548 visited.insert(entry);
549 candidates.push(Reverse((d, entry)));
550 results.push((d, entry));
551
552 while let Some(Reverse((d, c))) = candidates.pop() {
553 if results.len() >= beam {
554 if let Some(&(worst, _)) = results.peek() {
555 if d > worst {
556 break;
557 }
558 }
559 }
560
561 for &w in graph.neighbors(c) {
562 if !visited.insert(w) {
563 continue;
564 }
565 let wd = distance::l2_sq8(query_code, sq8.code(w));
566 candidates.push(Reverse((wd, w)));
567 results.push((wd, w));
568 if results.len() > beam {
569 results.pop();
570 }
571 }
572 }
573
574 results.into_iter().map(|(d, id)| (id, d)).collect()
575}
576
577pub(crate) fn select_cross_neighbors(
579 store: &PointStore,
580 candidates: &[(u32, f32)],
581 config: &PrismConfig,
582 subsets: &[Vec<usize>],
583) -> Vec<u32> {
584 let m_g = config.m_greedy;
585 let alpha = config.alpha;
586
587 if candidates.is_empty() || m_g == 0 {
588 return Vec::new();
589 }
590
591 let mut covered: HashSet<u64> = HashSet::new();
592 let mut selected = Vec::with_capacity(m_g);
593 let mut available: Vec<bool> = vec![true; candidates.len()];
594
595 for _ in 0..m_g {
596 let mut best_idx = None;
597 let mut best_score = f32::NEG_INFINITY;
598
599 for (idx, &(q_id, dist)) in candidates.iter().enumerate() {
600 if !available[idx] {
601 continue;
602 }
603
604 let new_tuples = count_new_tuples(store, q_id, &covered, subsets);
605
606 let score = if alpha == 0.0 || dist == 0.0 {
608 new_tuples as f32
609 } else {
610 (new_tuples as f32 + 0.001) / dist.powf(alpha)
611 };
612
613 if score > best_score {
614 best_score = score;
615 best_idx = Some(idx);
616 }
617 }
618
619 let Some(idx) = best_idx else { break };
620 selected.push(candidates[idx].0);
621 available[idx] = false;
622
623 add_tuples(store, candidates[idx].0, &mut covered, subsets);
624 }
625
626 selected
627}
628
629#[inline]
631fn tuple_key(combo: &[usize], store: &PointStore, q: u32) -> u64 {
632 let mut key: u64 = 0;
633 for (i, &j) in combo.iter().enumerate() {
634 let val = store.attr(q, j) as u64;
635 key |= ((j as u64) << 8 | val) << (i * 16);
636 }
637 key
638}
639
640fn count_new_tuples(
642 store: &PointStore,
643 q: u32,
644 covered: &HashSet<u64>,
645 subsets: &[Vec<usize>],
646) -> usize {
647 let mut count = 0;
648 for combo in subsets {
649 let key = tuple_key(combo, store, q);
650 if !covered.contains(&key) {
651 count += 1;
652 }
653 }
654 count
655}
656
657pub(crate) fn add_tuples(
659 store: &PointStore,
660 q: u32,
661 covered: &mut HashSet<u64>,
662 subsets: &[Vec<usize>],
663) {
664 for combo in subsets {
665 let key = tuple_key(combo, store, q);
666 covered.insert(key);
667 }
668}
669
670pub(crate) fn t_subsets(k: usize, t: usize) -> Vec<Vec<usize>> {
672 let mut result = Vec::new();
673 let mut combo = Vec::with_capacity(t);
674 generate_subsets(k, t, 0, &mut combo, &mut result);
675 result
676}
677
678fn generate_subsets(
679 k: usize,
680 t: usize,
681 start: usize,
682 combo: &mut Vec<usize>,
683 result: &mut Vec<Vec<usize>>,
684) {
685 if combo.len() == t {
686 result.push(combo.clone());
687 return;
688 }
689 for i in start..k {
690 combo.push(i);
691 generate_subsets(k, t, i + 1, combo, result);
692 combo.pop();
693 }
694}
695
696pub(crate) fn build_random_overlay(n: usize, m_random: usize, adj: &mut AdjBuilder) {
698 if m_random == 0 || n <= 1 {
699 return;
700 }
701 let mut rng = rand::thread_rng();
702 let half = m_random / 2;
703
704 for _ in 0..half {
705 let mut perm: Vec<u32> = (0..n as u32).collect();
707 perm.shuffle(&mut rng);
708 for (i, &j) in perm.iter().enumerate() {
709 if i as u32 != j {
710 adj.add_undirected(i as u32, j);
711 }
712 }
713 }
714}
715
716fn compute_medoids(store: &PointStore, tree: &PartitionTree, metric: Metric) -> Vec<u32> {
718 let dim = store.dim;
719 tree.cells
720 .iter()
721 .map(|cell| {
722 let pts = &cell.point_ids;
723 if pts.len() == 1 {
724 return pts[0];
725 }
726 let mut centroid = vec![0.0f32; dim];
728 for &p in pts {
729 let v = store.vector(p);
730 for (c, &x) in centroid.iter_mut().zip(v.iter()) {
731 *c += x;
732 }
733 }
734 let inv_n = 1.0 / pts.len() as f32;
735 for c in &mut centroid {
736 *c *= inv_n;
737 }
738 *pts.iter()
740 .min_by(|&&a, &&b| {
741 let da = distance::distance(¢roid, store.vector(a), metric);
742 let db = distance::distance(¢roid, store.vector(b), metric);
743 da.partial_cmp(&db).unwrap()
744 })
745 .unwrap()
746 })
747 .collect()
748}
749
750fn compute_global_medoid(store: &PointStore, metric: Metric) -> u32 {
752 let n = store.len;
753 let dim = store.dim;
754 let mut centroid = vec![0.0f32; dim];
755 for i in 0..n as u32 {
756 let v = store.vector(i);
757 for (c, &x) in centroid.iter_mut().zip(v.iter()) {
758 *c += x;
759 }
760 }
761 let inv_n = 1.0 / n as f32;
762 for c in &mut centroid {
763 *c *= inv_n;
764 }
765 (0..n as u32)
766 .min_by(|&a, &b| {
767 let da = distance::distance(¢roid, store.vector(a), metric);
768 let db = distance::distance(¢roid, store.vector(b), metric);
769 da.partial_cmp(&db).unwrap()
770 })
771 .unwrap()
772}
773
774#[cfg(test)]
775mod tests {
776 use super::super::point::PointStore;
777 use super::*;
778
779 #[test]
780 fn test_build_small() {
781 let mut store = PointStore::new(2, 2);
782 store.push(&[0.0, 0.0], &[0, 0]);
784 store.push(&[1.0, 0.0], &[0, 1]);
785 store.push(&[0.0, 1.0], &[1, 0]);
786 store.push(&[1.0, 1.0], &[1, 1]);
787
788 let config = PrismConfig {
789 m_local: 2,
790 m_greedy: 2,
791 m_random: 4,
792 t: 1,
793 alpha: 0.0,
794 beam_width: 10,
795 ..Default::default()
796 };
797
798 let index = PrismIndex::build(store, config);
799 assert_eq!(index.tree.cells.len(), 4);
800 assert_eq!(index.medoids.len(), 4);
801 for i in 0..4u32 {
803 assert!(index.graph.degree(i) > 0);
804 }
805 }
806
807 #[test]
808 fn test_t_subsets() {
809 let subs = t_subsets(4, 2);
810 assert_eq!(subs.len(), 6); let subs = t_subsets(3, 1);
812 assert_eq!(subs.len(), 3);
813 }
814}