Skip to main content

citadel_vector/vendored/prism/
construct.rs

1use super::binary::BinaryStore;
2use super::distance::{self, Metric};
3use super::graph::{AdjBuilder, Graph};
4use super::partition::PartitionTree;
5use super::point::PointStore;
6use super::quantize::SQ8Store;
7
8use rand::prelude::*;
9use rayon::prelude::*;
10use std::collections::HashSet;
11
12/// Configuration for PRISM index construction.
13#[derive(Clone, Debug)]
14pub struct PrismConfig {
15    /// Local degree (edges within each leaf cell).
16    pub m_local: usize,
17    /// Greedy cross-partition degree.
18    pub m_greedy: usize,
19    /// Random cross-partition degree (must be even).
20    pub m_random: usize,
21    /// Covering strength for attribute-diverse selection.
22    pub t: usize,
23    /// Proximity-diversity tradeoff for cross-neighbor selection (0 = pure diversity).
24    pub alpha: f32,
25    /// Vamana pruning parameter (standard DiskANN: 1.2).
26    pub vamana_alpha: f32,
27    /// Beam width for candidate search during construction (paper: 10 * M_g).
28    pub beam_width: usize,
29    /// Distance metric.
30    pub metric: Metric,
31    /// Selectivity threshold for HIGH regime.
32    pub sigma_high: f32,
33    /// Selectivity threshold for LOW regime.
34    pub sigma_low: f32,
35    /// Bridge budget multiplier for MID regime.
36    pub beta: f32,
37    /// Search pruning tolerance for filtered queries.
38    pub epsilon: f32,
39    /// Binary pre-filter rerank factor. Top `binary_rerank * ef` Hamming candidates
40    /// are reranked with SQ8. 0 disables binary pre-filter.
41    pub binary_rerank: usize,
42}
43
44impl Default for PrismConfig {
45    fn default() -> Self {
46        Self {
47            m_local: 16,
48            m_greedy: 12,
49            m_random: 4,
50            t: 2,
51            alpha: 1.0,
52            vamana_alpha: 1.0,
53            beam_width: 120,
54            metric: Metric::L2,
55            sigma_high: 0.10,
56            sigma_low: 0.001,
57            beta: 3.0,
58            epsilon: 0.2,
59            binary_rerank: 4,
60        }
61    }
62}
63
64/// The complete PRISM index.
65pub struct PrismIndex {
66    pub store: PointStore,
67    pub tree: PartitionTree,
68    pub graph: Graph,
69    /// Local-only graph (intra-cell edges) for per-cell graph search.
70    pub local_graph: Graph,
71    pub medoids: Vec<u32>,
72    pub global_medoid: u32,
73    /// Reverse mapping: point_id → cell index.
74    pub point_cell: Vec<u32>,
75    /// Maps internal ID → original ID.
76    pub original_ids: Vec<u32>,
77    /// Scalar-quantized vectors for distance computation.
78    pub sq8: SQ8Store,
79    /// Binary codes for Hamming pre-filter.
80    pub binary: BinaryStore,
81    pub config: PrismConfig,
82}
83
84impl PrismIndex {
85    /// Build a PRISM index from a PointStore (Algorithm 2).
86    pub fn build(store: PointStore, config: PrismConfig) -> Self {
87        let n = store.len;
88        assert!(n > 0, "cannot build index from empty point store");
89        assert!(
90            config.m_random >= 4 && config.m_random % 2 == 0,
91            "m_random must be >= 4 and even (Friedman model requires d >= 4)"
92        );
93
94        let tree = PartitionTree::build(&store);
95        let (store, tree, original_ids) = reorder_by_cell(store, tree);
96        let sq8 = SQ8Store::build(&store);
97        let binary = BinaryStore::build(&store);
98
99        let mut point_cell = vec![0u32; n];
100        for (ci, cell) in tree.cells.iter().enumerate() {
101            for &pid in &cell.point_ids {
102                point_cell[pid as usize] = ci as u32;
103            }
104        }
105
106        // Local Vamana graphs within each cell
107        let mut adj = AdjBuilder::new(n);
108        let t0 = std::time::Instant::now();
109        build_local_edges(&store, &tree, &sq8, &config, &mut adj);
110        let local_edges = adj.total_edges();
111        eprintln!(
112            "  Local edges: {:.1}s, {} edges ({:.1}/node)",
113            t0.elapsed().as_secs_f64(),
114            local_edges,
115            local_edges as f64 / n as f64
116        );
117
118        let t0 = std::time::Instant::now();
119        let medoids = compute_medoids(&store, &tree, config.metric);
120        eprintln!("  Medoids: {:.1}s", t0.elapsed().as_secs_f64());
121
122        let local_graph = adj.snapshot();
123
124        // Greedy cross-partition edges (attribute-diverse selection)
125        let t0 = std::time::Instant::now();
126        build_greedy_cross_edges(
127            &store,
128            &tree,
129            &medoids,
130            &local_graph,
131            &sq8,
132            &point_cell,
133            &config,
134            &mut adj,
135        );
136        let cross_edges = adj.total_edges() - local_edges;
137        eprintln!(
138            "  Cross edges: {:.1}s, {} edges ({:.1}/node)",
139            t0.elapsed().as_secs_f64(),
140            cross_edges,
141            cross_edges as f64 / n as f64
142        );
143
144        // Random regular overlay (Friedman permutation model)
145        let edges_before = adj.total_edges();
146        let t0 = std::time::Instant::now();
147        build_random_overlay(n, config.m_random, &mut adj);
148        let random_edges = adj.total_edges() - edges_before;
149        eprintln!(
150            "  Random overlay: {:.1}s, {} edges ({:.1}/node)",
151            t0.elapsed().as_secs_f64(),
152            random_edges,
153            random_edges as f64 / n as f64
154        );
155
156        let graph = adj.build();
157
158        let global_medoid = compute_global_medoid(&store, config.metric);
159
160        Self {
161            store,
162            tree,
163            graph,
164            local_graph,
165            medoids,
166            global_medoid,
167            point_cell,
168            original_ids,
169            sq8,
170            binary,
171            config,
172        }
173    }
174}
175
176/// Reorder so points in the same cell are contiguous. Returns (store, tree, original_ids).
177fn reorder_by_cell(
178    store: PointStore,
179    mut tree: PartitionTree,
180) -> (PointStore, PartitionTree, Vec<u32>) {
181    let n = store.len;
182    let dim = store.dim;
183    let k = store.k();
184
185    // Build new ordering: cell 0's points, then cell 1's, etc.
186    let mut new_order: Vec<u32> = Vec::with_capacity(n);
187    for cell in &tree.cells {
188        new_order.extend_from_slice(&cell.point_ids);
189    }
190
191    // old_to_new[old_id] = new_id
192    let mut old_to_new = vec![0u32; n];
193    for (new_id, &old_id) in new_order.iter().enumerate() {
194        old_to_new[old_id as usize] = new_id as u32;
195    }
196
197    // Reorder vectors
198    let mut new_vectors = vec![0.0f32; n * dim];
199    for (new_id, &old_id) in new_order.iter().enumerate() {
200        let src = &store.vectors[old_id as usize * dim..(old_id as usize + 1) * dim];
201        new_vectors[new_id * dim..(new_id + 1) * dim].copy_from_slice(src);
202    }
203
204    // Reorder attributes
205    let mut new_attrs = Vec::with_capacity(k);
206    for j in 0..k {
207        let mut attr_col = vec![0u32; n];
208        for (new_id, &old_id) in new_order.iter().enumerate() {
209            attr_col[new_id] = store.attrs[j][old_id as usize];
210        }
211        new_attrs.push(attr_col);
212    }
213
214    // Update tree cell point IDs to new IDs
215    for cell in &mut tree.cells {
216        for pid in &mut cell.point_ids {
217            *pid = old_to_new[*pid as usize];
218        }
219    }
220
221    let new_store = PointStore::from_parts(new_vectors, dim, new_attrs);
222    (new_store, tree, new_order)
223}
224
225/// Build local Vamana graphs within each cell. Small cells get complete graphs,
226/// larger cells use greedy Vamana construction with robust pruning.
227fn build_local_edges(
228    store: &PointStore,
229    tree: &PartitionTree,
230    sq8: &SQ8Store,
231    config: &PrismConfig,
232    adj: &mut AdjBuilder,
233) {
234    let cell_edges: Vec<Vec<(u32, u32)>> = tree
235        .cells
236        .par_iter()
237        .map(|cell| {
238            let pts = &cell.point_ids;
239            let mut edges = Vec::new();
240            if pts.len() <= 1 {
241                return edges;
242            }
243
244            if pts.len() <= config.m_local + 1 {
245                for i in 0..pts.len() {
246                    for j in (i + 1)..pts.len() {
247                        edges.push((pts[i], pts[j]));
248                        edges.push((pts[j], pts[i]));
249                    }
250                }
251            } else {
252                let mut rng = rand::thread_rng();
253                build_vamana_cell(store, sq8, pts, config, &mut edges, &mut rng);
254            }
255            edges
256        })
257        .collect();
258
259    for edges in cell_edges {
260        for (src, dst) in edges {
261            adj.add_edge(src, dst);
262        }
263    }
264}
265
266/// Vamana construction within a single cell: SQ8 beam search + f32 pruning, two passes.
267fn build_vamana_cell(
268    store: &PointStore,
269    sq8: &SQ8Store,
270    pts: &[u32],
271    config: &PrismConfig,
272    edges: &mut Vec<(u32, u32)>,
273    rng: &mut impl Rng,
274) {
275    let n = pts.len();
276    let r = config.m_local;
277    let beam = n.min(config.beam_width);
278    let alpha = config.vamana_alpha;
279
280    // Random initial graph
281    let actual_r = r.min(n - 1);
282    let mut graph: Vec<Vec<usize>> = (0..n)
283        .map(|i| {
284            let mut neighbors = Vec::with_capacity(actual_r);
285            while neighbors.len() < actual_r {
286                let j = rng.gen_range(0..n);
287                if j != i && !neighbors.contains(&j) {
288                    neighbors.push(j);
289                }
290            }
291            neighbors
292        })
293        .collect();
294
295    // Medoid as entry point
296    let dim = store.dim;
297    let mut centroid = vec![0.0f32; dim];
298    for &p in pts {
299        let v = store.vector(p);
300        for (c, &x) in centroid.iter_mut().zip(v.iter()) {
301            *c += x;
302        }
303    }
304    let inv_n = 1.0 / n as f32;
305    for c in &mut centroid {
306        *c *= inv_n;
307    }
308    let entry = (0..n)
309        .min_by(|&a, &b| {
310            let da = distance::distance(&centroid, store.vector(pts[a]), config.metric);
311            let db = distance::distance(&centroid, store.vector(pts[b]), config.metric);
312            da.partial_cmp(&db).unwrap()
313        })
314        .unwrap();
315
316    for _pass in 0..2 {
317        let mut order: Vec<usize> = (0..n).collect();
318        order.shuffle(rng);
319
320        for &i in &order {
321            let search_results = vamana_search_sq8(sq8, pts, &graph, entry, pts[i], beam);
322
323            // Union search results with current neighbors
324            let mut candidates = search_results;
325            for &nb in &graph[i] {
326                if !candidates.contains(&nb) {
327                    candidates.push(nb);
328                }
329            }
330
331            graph[i] = robust_prune(store, pts, i, &candidates, alpha, r, config.metric);
332
333            // Reverse edges
334            let new_neighbors: Vec<usize> = graph[i].clone();
335            for &j in &new_neighbors {
336                if !graph[j].contains(&i) {
337                    graph[j].push(i);
338                    if graph[j].len() > r {
339                        let cands: Vec<usize> = graph[j].clone();
340                        graph[j] = robust_prune(store, pts, j, &cands, alpha, r, config.metric);
341                    }
342                }
343            }
344        }
345    }
346
347    for (i, neighbors) in graph.iter().enumerate() {
348        for &j in neighbors {
349            edges.push((pts[i], pts[j]));
350        }
351    }
352}
353
354/// SQ8 beam search within a cell's local graph. Returns visited local indices.
355fn vamana_search_sq8(
356    sq8: &SQ8Store,
357    pts: &[u32],
358    graph: &[Vec<usize>],
359    entry: usize,
360    query_id: u32,
361    beam: usize,
362) -> Vec<usize> {
363    use std::cmp::Reverse;
364    use std::collections::BinaryHeap;
365
366    let q_code = sq8.code(query_id);
367    let mut visited = vec![false; pts.len()];
368    let mut candidates: BinaryHeap<Reverse<(u32, usize)>> = BinaryHeap::new();
369    let mut results: BinaryHeap<(u32, usize)> = BinaryHeap::new();
370
371    let d = distance::l2_sq8(q_code, sq8.code(pts[entry]));
372    visited[entry] = true;
373    candidates.push(Reverse((d, entry)));
374    results.push((d, entry));
375
376    while let Some(Reverse((d, c))) = candidates.pop() {
377        if results.len() >= beam {
378            if let Some(&(worst, _)) = results.peek() {
379                if d > worst {
380                    break;
381                }
382            }
383        }
384
385        for &w in &graph[c] {
386            if visited[w] {
387                continue;
388            }
389            visited[w] = true;
390            let wd = distance::l2_sq8(q_code, sq8.code(pts[w]));
391            candidates.push(Reverse((wd, w)));
392            results.push((wd, w));
393            if results.len() > beam {
394                results.pop();
395            }
396        }
397    }
398
399    results.into_iter().map(|(_, idx)| idx).collect()
400}
401
402/// Robust prune: rejects c if dist(c, selected) < α·dist(p, c).
403fn robust_prune(
404    store: &PointStore,
405    pts: &[u32],
406    p: usize,
407    candidates: &[usize],
408    alpha: f32,
409    r: usize,
410    metric: Metric,
411) -> Vec<usize> {
412    let p_vec = store.vector(pts[p]);
413    let mut sorted: Vec<(usize, f32)> = candidates
414        .iter()
415        .filter(|&&c| c != p)
416        .map(|&c| (c, distance::distance(p_vec, store.vector(pts[c]), metric)))
417        .collect();
418    sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
419    sorted.dedup_by_key(|x| x.0);
420
421    let mut selected: Vec<usize> = Vec::with_capacity(r);
422    for &(c, d_pc) in &sorted {
423        if selected.len() >= r {
424            break;
425        }
426        let dominated = selected.iter().any(|&s| {
427            let d_cs = distance::distance(store.vector(pts[c]), store.vector(pts[s]), metric);
428            alpha * d_cs <= d_pc
429        });
430        if !dominated {
431            selected.push(c);
432        }
433    }
434    selected
435}
436
437/// Greedy attribute-diverse cross-partition edges. SQ8 beam search for
438/// candidate discovery, f32 rerank, parallelized across points.
439#[allow(clippy::too_many_arguments)]
440fn build_greedy_cross_edges(
441    store: &PointStore,
442    tree: &PartitionTree,
443    medoids: &[u32],
444    local_graph: &Graph,
445    sq8: &SQ8Store,
446    point_cell: &[u32],
447    config: &PrismConfig,
448    adj: &mut AdjBuilder,
449) {
450    let n = store.len;
451    let k = store.k();
452    let t = config.t.min(k);
453    let beam = config.beam_width;
454    let subsets = t_subsets(k, t);
455    let use_sq8 = config.metric == Metric::L2;
456
457    let point_edges: Vec<Vec<u32>> = (0..n as u32)
458        .into_par_iter()
459        .map(|p_id| {
460            let p_cell_idx = point_cell[p_id as usize];
461            let p_vec = store.vector(p_id);
462
463            // Rank cells by SQ8 medoid distance
464            let p_code = sq8.code(p_id);
465            let mut cell_dists: Vec<(usize, u32)> = tree
466                .cells
467                .iter()
468                .enumerate()
469                .filter(|&(ci, _)| ci as u32 != p_cell_idx)
470                .map(|(ci, _)| {
471                    let d = distance::l2_sq8(p_code, sq8.code(medoids[ci]));
472                    (ci, d)
473                })
474                .collect();
475            cell_dists.sort_unstable_by_key(|&(_, d)| d);
476
477            // Beam search closest cells for candidates
478            let mut all_cand_ids: Vec<u32> = Vec::with_capacity(beam);
479            for &(ci, _) in &cell_dists {
480                let cell_size = tree.cells[ci].point_ids.len();
481
482                if use_sq8 && cell_size > beam * 2 {
483                    let found = beam_search_sq8(sq8, local_graph, p_code, medoids[ci], beam);
484                    for (id, _) in found {
485                        all_cand_ids.push(id);
486                    }
487                } else if use_sq8 {
488                    let mut scored: Vec<(u32, u32)> = tree.cells[ci]
489                        .point_ids
490                        .iter()
491                        .map(|&q| (q, distance::l2_sq8(p_code, sq8.code(q))))
492                        .collect();
493                    scored.sort_unstable_by_key(|&(_, d)| d);
494                    for &(id, _) in scored.iter().take(beam) {
495                        all_cand_ids.push(id);
496                    }
497                } else {
498                    for &q_id in &tree.cells[ci].point_ids {
499                        all_cand_ids.push(q_id);
500                    }
501                }
502
503                if all_cand_ids.len() >= beam {
504                    break;
505                }
506            }
507
508            // F32 rerank
509            let mut candidates: Vec<(u32, f32)> = all_cand_ids
510                .iter()
511                .map(|&id| {
512                    (
513                        id,
514                        distance::distance(p_vec, store.vector(id), config.metric),
515                    )
516                })
517                .collect();
518            candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
519            candidates.truncate(beam);
520
521            select_cross_neighbors(store, &candidates, config, &subsets)
522        })
523        .collect();
524
525    for (p_id, neighbors) in point_edges.into_iter().enumerate() {
526        for q_id in neighbors {
527            adj.add_edge(p_id as u32, q_id);
528        }
529    }
530}
531
532/// SQ8 beam search through a cell's local graph. Returns (point_id, sq8_distance).
533fn beam_search_sq8(
534    sq8: &SQ8Store,
535    graph: &Graph,
536    query_code: &[u8],
537    entry: u32,
538    beam: usize,
539) -> Vec<(u32, u32)> {
540    use std::cmp::Reverse;
541    use std::collections::BinaryHeap;
542
543    let mut visited = HashSet::new();
544    let mut candidates: BinaryHeap<Reverse<(u32, u32)>> = BinaryHeap::new();
545    let mut results: BinaryHeap<(u32, u32)> = BinaryHeap::new();
546
547    let d = distance::l2_sq8(query_code, sq8.code(entry));
548    visited.insert(entry);
549    candidates.push(Reverse((d, entry)));
550    results.push((d, entry));
551
552    while let Some(Reverse((d, c))) = candidates.pop() {
553        if results.len() >= beam {
554            if let Some(&(worst, _)) = results.peek() {
555                if d > worst {
556                    break;
557                }
558            }
559        }
560
561        for &w in graph.neighbors(c) {
562            if !visited.insert(w) {
563                continue;
564            }
565            let wd = distance::l2_sq8(query_code, sq8.code(w));
566            candidates.push(Reverse((wd, w)));
567            results.push((wd, w));
568            if results.len() > beam {
569                results.pop();
570            }
571        }
572    }
573
574    results.into_iter().map(|(d, id)| (id, d)).collect()
575}
576
577/// Attribute-diverse neighbor selection. Candidates sorted by distance.
578pub(crate) fn select_cross_neighbors(
579    store: &PointStore,
580    candidates: &[(u32, f32)],
581    config: &PrismConfig,
582    subsets: &[Vec<usize>],
583) -> Vec<u32> {
584    let m_g = config.m_greedy;
585    let alpha = config.alpha;
586
587    if candidates.is_empty() || m_g == 0 {
588        return Vec::new();
589    }
590
591    let mut covered: HashSet<u64> = HashSet::new();
592    let mut selected = Vec::with_capacity(m_g);
593    let mut available: Vec<bool> = vec![true; candidates.len()];
594
595    for _ in 0..m_g {
596        let mut best_idx = None;
597        let mut best_score = f32::NEG_INFINITY;
598
599        for (idx, &(q_id, dist)) in candidates.iter().enumerate() {
600            if !available[idx] {
601                continue;
602            }
603
604            let new_tuples = count_new_tuples(store, q_id, &covered, subsets);
605
606            // score = gain / cost; fall back to proximity when fully covered
607            let score = if alpha == 0.0 || dist == 0.0 {
608                new_tuples as f32
609            } else {
610                (new_tuples as f32 + 0.001) / dist.powf(alpha)
611            };
612
613            if score > best_score {
614                best_score = score;
615                best_idx = Some(idx);
616            }
617        }
618
619        let Some(idx) = best_idx else { break };
620        selected.push(candidates[idx].0);
621        available[idx] = false;
622
623        add_tuples(store, candidates[idx].0, &mut covered, subsets);
624    }
625
626    selected
627}
628
629/// Encode (combo, values) as u64 key. Supports up to 8 dims, values < 256.
630#[inline]
631fn tuple_key(combo: &[usize], store: &PointStore, q: u32) -> u64 {
632    let mut key: u64 = 0;
633    for (i, &j) in combo.iter().enumerate() {
634        let val = store.attr(q, j) as u64;
635        key |= ((j as u64) << 8 | val) << (i * 16);
636    }
637    key
638}
639
640/// Count how many new t-tuples a candidate would contribute.
641fn count_new_tuples(
642    store: &PointStore,
643    q: u32,
644    covered: &HashSet<u64>,
645    subsets: &[Vec<usize>],
646) -> usize {
647    let mut count = 0;
648    for combo in subsets {
649        let key = tuple_key(combo, store, q);
650        if !covered.contains(&key) {
651            count += 1;
652        }
653    }
654    count
655}
656
657/// Add all t-tuples of a point to the covered set.
658pub(crate) fn add_tuples(
659    store: &PointStore,
660    q: u32,
661    covered: &mut HashSet<u64>,
662    subsets: &[Vec<usize>],
663) {
664    for combo in subsets {
665        let key = tuple_key(combo, store, q);
666        covered.insert(key);
667    }
668}
669
670/// Generate all t-element subsets of [0..k].
671pub(crate) fn t_subsets(k: usize, t: usize) -> Vec<Vec<usize>> {
672    let mut result = Vec::new();
673    let mut combo = Vec::with_capacity(t);
674    generate_subsets(k, t, 0, &mut combo, &mut result);
675    result
676}
677
678fn generate_subsets(
679    k: usize,
680    t: usize,
681    start: usize,
682    combo: &mut Vec<usize>,
683    result: &mut Vec<Vec<usize>>,
684) {
685    if combo.len() == t {
686        result.push(combo.clone());
687        return;
688    }
689    for i in start..k {
690        combo.push(i);
691        generate_subsets(k, t, i + 1, combo, result);
692        combo.pop();
693    }
694}
695
696/// Random cross-partition edges via Friedman permutation model.
697pub(crate) fn build_random_overlay(n: usize, m_random: usize, adj: &mut AdjBuilder) {
698    if m_random == 0 || n <= 1 {
699        return;
700    }
701    let mut rng = rand::thread_rng();
702    let half = m_random / 2;
703
704    for _ in 0..half {
705        // Random permutation
706        let mut perm: Vec<u32> = (0..n as u32).collect();
707        perm.shuffle(&mut rng);
708        for (i, &j) in perm.iter().enumerate() {
709            if i as u32 != j {
710                adj.add_undirected(i as u32, j);
711            }
712        }
713    }
714}
715
716/// Per-cell medoid via centroid-nearest approximation.
717fn compute_medoids(store: &PointStore, tree: &PartitionTree, metric: Metric) -> Vec<u32> {
718    let dim = store.dim;
719    tree.cells
720        .iter()
721        .map(|cell| {
722            let pts = &cell.point_ids;
723            if pts.len() == 1 {
724                return pts[0];
725            }
726            // Compute centroid
727            let mut centroid = vec![0.0f32; dim];
728            for &p in pts {
729                let v = store.vector(p);
730                for (c, &x) in centroid.iter_mut().zip(v.iter()) {
731                    *c += x;
732                }
733            }
734            let inv_n = 1.0 / pts.len() as f32;
735            for c in &mut centroid {
736                *c *= inv_n;
737            }
738            // Return point closest to centroid
739            *pts.iter()
740                .min_by(|&&a, &&b| {
741                    let da = distance::distance(&centroid, store.vector(a), metric);
742                    let db = distance::distance(&centroid, store.vector(b), metric);
743                    da.partial_cmp(&db).unwrap()
744                })
745                .unwrap()
746        })
747        .collect()
748}
749
750/// Compute global medoid: the point closest to the centroid of the entire dataset.
751fn compute_global_medoid(store: &PointStore, metric: Metric) -> u32 {
752    let n = store.len;
753    let dim = store.dim;
754    let mut centroid = vec![0.0f32; dim];
755    for i in 0..n as u32 {
756        let v = store.vector(i);
757        for (c, &x) in centroid.iter_mut().zip(v.iter()) {
758            *c += x;
759        }
760    }
761    let inv_n = 1.0 / n as f32;
762    for c in &mut centroid {
763        *c *= inv_n;
764    }
765    (0..n as u32)
766        .min_by(|&a, &b| {
767            let da = distance::distance(&centroid, store.vector(a), metric);
768            let db = distance::distance(&centroid, store.vector(b), metric);
769            da.partial_cmp(&db).unwrap()
770        })
771        .unwrap()
772}
773
774#[cfg(test)]
775mod tests {
776    use super::super::point::PointStore;
777    use super::*;
778
779    #[test]
780    fn test_build_small() {
781        let mut store = PointStore::new(2, 2);
782        // 4 points, 2 attributes with 2 values each → 4 cells
783        store.push(&[0.0, 0.0], &[0, 0]);
784        store.push(&[1.0, 0.0], &[0, 1]);
785        store.push(&[0.0, 1.0], &[1, 0]);
786        store.push(&[1.0, 1.0], &[1, 1]);
787
788        let config = PrismConfig {
789            m_local: 2,
790            m_greedy: 2,
791            m_random: 4,
792            t: 1,
793            alpha: 0.0,
794            beam_width: 10,
795            ..Default::default()
796        };
797
798        let index = PrismIndex::build(store, config);
799        assert_eq!(index.tree.cells.len(), 4);
800        assert_eq!(index.medoids.len(), 4);
801        // Each point should have some neighbors
802        for i in 0..4u32 {
803            assert!(index.graph.degree(i) > 0);
804        }
805    }
806
807    #[test]
808    fn test_t_subsets() {
809        let subs = t_subsets(4, 2);
810        assert_eq!(subs.len(), 6); // C(4,2) = 6
811        let subs = t_subsets(3, 1);
812        assert_eq!(subs.len(), 3);
813    }
814}