Skip to main content

citadel_vector/vendored/prism/
construct.rs

1use super::binary::BinaryStore;
2use super::distance::{self, Metric};
3use super::graph::{AdjBuilder, Graph};
4use super::partition::PartitionTree;
5use super::point::PointStore;
6use super::quantize::SQ8Store;
7
8use rand::prelude::*;
9use rayon::prelude::*;
10use std::collections::HashSet;
11
12/// Configuration for PRISM index construction.
13#[derive(Clone, Debug)]
14pub struct PrismConfig {
15    /// Local degree (edges within each leaf cell).
16    pub m_local: usize,
17    /// Greedy cross-partition degree.
18    pub m_greedy: usize,
19    /// Random cross-partition degree (must be even).
20    pub m_random: usize,
21    /// Covering strength for attribute-diverse selection.
22    pub t: usize,
23    /// Proximity-diversity tradeoff for cross-neighbor selection (0 = pure diversity).
24    pub alpha: f32,
25    /// Vamana pruning parameter (standard DiskANN: 1.2).
26    pub vamana_alpha: f32,
27    /// Beam width for candidate search during construction (paper: 10 * M_g).
28    pub beam_width: usize,
29    /// Distance metric.
30    pub metric: Metric,
31    /// Selectivity threshold for HIGH regime.
32    pub sigma_high: f32,
33    /// Selectivity threshold for LOW regime.
34    pub sigma_low: f32,
35    /// Bridge budget multiplier for MID regime.
36    pub beta: f32,
37    /// Search pruning tolerance for filtered queries.
38    pub epsilon: f32,
39    /// Binary pre-filter rerank factor. Top `binary_rerank * ef` Hamming candidates
40    /// are reranked with SQ8. 0 disables binary pre-filter.
41    pub binary_rerank: usize,
42}
43
44impl Default for PrismConfig {
45    fn default() -> Self {
46        Self {
47            m_local: 16,
48            m_greedy: 12,
49            m_random: 4,
50            t: 2,
51            alpha: 1.0,
52            vamana_alpha: 1.0,
53            beam_width: 120,
54            metric: Metric::L2,
55            sigma_high: 0.10,
56            sigma_low: 0.001,
57            beta: 3.0,
58            epsilon: 0.2,
59            binary_rerank: 4,
60        }
61    }
62}
63
64/// The complete PRISM index.
65pub struct PrismIndex {
66    pub store: PointStore,
67    pub tree: PartitionTree,
68    pub graph: Graph,
69    /// Local-only graph (intra-cell edges) for per-cell graph search.
70    pub local_graph: Graph,
71    pub medoids: Vec<u32>,
72    pub global_medoid: u32,
73    /// Reverse mapping: point_id -> cell index.
74    pub point_cell: Vec<u32>,
75    /// Maps internal ID -> original ID.
76    pub original_ids: Vec<u32>,
77    /// Scalar-quantized vectors for distance computation.
78    pub sq8: SQ8Store,
79    /// Binary codes for Hamming pre-filter.
80    pub binary: BinaryStore,
81    pub config: PrismConfig,
82}
83
84impl PrismIndex {
85    /// Build a PRISM index from a PointStore (Algorithm 2).
86    pub fn build(mut store: PointStore, config: PrismConfig) -> Self {
87        let n = store.len;
88        assert!(n > 0, "cannot build index from empty point store");
89        assert!(
90            config.m_random >= 4 && config.m_random % 2 == 0,
91            "m_random must be >= 4 and even (Friedman model requires d >= 4)"
92        );
93
94        // Cosine: normalize once at build so SQ8-L2 code distances are
95        // rank-equivalent to cosine (L2^2 = 2 - 2cos on unit vectors). The
96        // exact rerank is scale-invariant, so reported distances and segment
97        // rehydration from raw table rows are unaffected.
98        if config.metric == Metric::Cosine {
99            let dim = store.dim;
100            distance::normalize_rows(&mut store.vectors, dim);
101        }
102
103        let tree = PartitionTree::build(&store);
104        let (store, tree, original_ids) = reorder_by_cell(store, tree);
105        let sq8 = SQ8Store::build(&store);
106        let binary = if config.binary_rerank > 0 {
107            BinaryStore::build(&store)
108        } else {
109            BinaryStore::empty(store.dim)
110        };
111
112        let mut point_cell = vec![0u32; n];
113        for (ci, cell) in tree.cells.iter().enumerate() {
114            for &pid in &cell.point_ids {
115                point_cell[pid as usize] = ci as u32;
116            }
117        }
118
119        // Local Vamana graphs within each cell
120        let mut adj = AdjBuilder::new(n);
121        build_local_edges(&store, &tree, &sq8, &config, &mut adj);
122
123        let medoids = compute_medoids(&store, &tree, config.metric);
124
125        let local_graph = adj.snapshot();
126
127        // The global graph (cross edges + random overlay) is traversed only by
128        // REGIME_MID; when sigma_high <= sigma_low that regime is unreachable
129        // and the two most expensive construction phases would build dead edges.
130        if config.sigma_high > config.sigma_low {
131            // Greedy cross-partition edges (attribute-diverse selection)
132            build_greedy_cross_edges(
133                &store,
134                &tree,
135                &medoids,
136                &local_graph,
137                &sq8,
138                &point_cell,
139                &config,
140                &mut adj,
141            );
142
143            // Random regular overlay (Friedman permutation model)
144            build_random_overlay(n, config.m_random, &mut adj);
145        }
146
147        let graph = adj.build();
148
149        let global_medoid = compute_global_medoid(&store, config.metric);
150
151        Self {
152            store,
153            tree,
154            graph,
155            local_graph,
156            medoids,
157            global_medoid,
158            point_cell,
159            original_ids,
160            sq8,
161            binary,
162            config,
163        }
164    }
165}
166
167/// Reorder so points in the same cell are contiguous. Returns (store, tree, original_ids).
168fn reorder_by_cell(
169    store: PointStore,
170    mut tree: PartitionTree,
171) -> (PointStore, PartitionTree, Vec<u32>) {
172    let n = store.len;
173    let dim = store.dim;
174    let k = store.k();
175
176    let mut new_order: Vec<u32> = Vec::with_capacity(n);
177    for cell in &tree.cells {
178        new_order.extend_from_slice(&cell.point_ids);
179    }
180
181    let mut old_to_new = vec![0u32; n];
182    for (new_id, &old_id) in new_order.iter().enumerate() {
183        old_to_new[old_id as usize] = new_id as u32;
184    }
185
186    let mut new_vectors = vec![0.0f32; n * dim];
187    for (new_id, &old_id) in new_order.iter().enumerate() {
188        let src = &store.vectors[old_id as usize * dim..(old_id as usize + 1) * dim];
189        new_vectors[new_id * dim..(new_id + 1) * dim].copy_from_slice(src);
190    }
191
192    let mut new_attrs = Vec::with_capacity(k);
193    for j in 0..k {
194        let mut attr_col = vec![0u32; n];
195        for (new_id, &old_id) in new_order.iter().enumerate() {
196            attr_col[new_id] = store.attrs[j][old_id as usize];
197        }
198        new_attrs.push(attr_col);
199    }
200
201    for cell in &mut tree.cells {
202        for pid in &mut cell.point_ids {
203            *pid = old_to_new[*pid as usize];
204        }
205    }
206
207    let new_store = PointStore::from_parts(new_vectors, dim, new_attrs);
208    (new_store, tree, new_order)
209}
210
211/// Build local Vamana graphs within each cell. Small cells get complete graphs,
212/// larger cells use greedy Vamana construction with robust pruning.
213fn build_local_edges(
214    store: &PointStore,
215    tree: &PartitionTree,
216    sq8: &SQ8Store,
217    config: &PrismConfig,
218    adj: &mut AdjBuilder,
219) {
220    let cell_edges: Vec<Vec<(u32, u32)>> = tree
221        .cells
222        .par_iter()
223        .map(|cell| {
224            let pts = &cell.point_ids;
225            let mut edges = Vec::new();
226            if pts.len() <= 1 {
227                return edges;
228            }
229
230            if pts.len() <= config.m_local + 1 {
231                for i in 0..pts.len() {
232                    for j in (i + 1)..pts.len() {
233                        edges.push((pts[i], pts[j]));
234                        edges.push((pts[j], pts[i]));
235                    }
236                }
237            } else {
238                let mut rng = rand::thread_rng();
239                build_vamana_cell(store, sq8, pts, config, &mut edges, &mut rng);
240            }
241            edges
242        })
243        .collect();
244
245    for edges in cell_edges {
246        for (src, dst) in edges {
247            adj.add_edge(src, dst);
248        }
249    }
250}
251
252/// Vamana construction within a single cell: code-space beam search + f32 pruning, two passes.
253fn build_vamana_cell(
254    store: &PointStore,
255    sq8: &SQ8Store,
256    pts: &[u32],
257    config: &PrismConfig,
258    edges: &mut Vec<(u32, u32)>,
259    rng: &mut impl Rng,
260) {
261    let n = pts.len();
262    let r = config.m_local;
263    let beam = n.min(config.beam_width);
264    let alpha = config.vamana_alpha;
265
266    let actual_r = r.min(n - 1);
267    let mut graph: Vec<Vec<usize>> = (0..n)
268        .map(|i| {
269            let mut neighbors = Vec::with_capacity(actual_r);
270            while neighbors.len() < actual_r {
271                let j = rng.gen_range(0..n);
272                if j != i && !neighbors.contains(&j) {
273                    neighbors.push(j);
274                }
275            }
276            neighbors
277        })
278        .collect();
279
280    let dim = store.dim;
281    let mut centroid = vec![0.0f32; dim];
282    for &p in pts {
283        let v = store.vector(p);
284        for (c, &x) in centroid.iter_mut().zip(v.iter()) {
285            *c += x;
286        }
287    }
288    let inv_n = 1.0 / n as f32;
289    for c in &mut centroid {
290        *c *= inv_n;
291    }
292    let entry = (0..n)
293        .min_by(|&a, &b| {
294            let da = distance::distance(&centroid, store.vector(pts[a]), config.metric);
295            let db = distance::distance(&centroid, store.vector(pts[b]), config.metric);
296            da.partial_cmp(&db).unwrap()
297        })
298        .unwrap();
299
300    for _pass in 0..2 {
301        let mut order: Vec<usize> = (0..n).collect();
302        order.shuffle(rng);
303
304        for &i in &order {
305            let search_results =
306                vamana_search_code(store, sq8, config.metric, pts, &graph, entry, pts[i], beam);
307
308            let mut candidates = search_results;
309            for &nb in &graph[i] {
310                if !candidates.contains(&nb) {
311                    candidates.push(nb);
312                }
313            }
314
315            graph[i] = robust_prune(store, pts, i, &candidates, alpha, r, config.metric);
316
317            let new_neighbors: Vec<usize> = graph[i].clone();
318            for &j in &new_neighbors {
319                if !graph[j].contains(&i) {
320                    graph[j].push(i);
321                    if graph[j].len() > r {
322                        let cands: Vec<usize> = graph[j].clone();
323                        graph[j] = robust_prune(store, pts, j, &cands, alpha, r, config.metric);
324                    }
325                }
326            }
327        }
328    }
329
330    for (i, neighbors) in graph.iter().enumerate() {
331        for &j in neighbors {
332            edges.push((pts[i], pts[j]));
333        }
334    }
335}
336
337/// Heap-ordered candidate distance between two stored points. L2 and
338/// (build-normalized) Cosine rank by SQ8 codes; InnerProduct cannot be ranked
339/// in code-space L2, so it uses the exact f32 metric via a total-order key.
340#[inline]
341fn build_cand_dist(store: &PointStore, sq8: &SQ8Store, metric: Metric, a: u32, b: u32) -> u32 {
342    match metric {
343        Metric::L2 | Metric::Cosine => distance::l2_sq8(sq8.code(a), sq8.code(b)),
344        Metric::InnerProduct => distance::ord_key(distance::distance(
345            store.vector(a),
346            store.vector(b),
347            Metric::InnerProduct,
348        )),
349    }
350}
351
352/// Code-space beam search within a cell's local graph. Returns visited local indices.
353#[allow(clippy::too_many_arguments)]
354fn vamana_search_code(
355    store: &PointStore,
356    sq8: &SQ8Store,
357    metric: Metric,
358    pts: &[u32],
359    graph: &[Vec<usize>],
360    entry: usize,
361    query_id: u32,
362    beam: usize,
363) -> Vec<usize> {
364    use std::cmp::Reverse;
365    use std::collections::BinaryHeap;
366
367    let mut visited = vec![false; pts.len()];
368    let mut candidates: BinaryHeap<Reverse<(u32, usize)>> = BinaryHeap::new();
369    let mut results: BinaryHeap<(u32, usize)> = BinaryHeap::new();
370
371    let d = build_cand_dist(store, sq8, metric, query_id, pts[entry]);
372    visited[entry] = true;
373    candidates.push(Reverse((d, entry)));
374    results.push((d, entry));
375
376    while let Some(Reverse((d, c))) = candidates.pop() {
377        if results.len() >= beam {
378            if let Some(&(worst, _)) = results.peek() {
379                if d > worst {
380                    break;
381                }
382            }
383        }
384
385        for &w in &graph[c] {
386            if visited[w] {
387                continue;
388            }
389            visited[w] = true;
390            let wd = build_cand_dist(store, sq8, metric, query_id, pts[w]);
391            candidates.push(Reverse((wd, w)));
392            results.push((wd, w));
393            if results.len() > beam {
394                results.pop();
395            }
396        }
397    }
398
399    results.into_iter().map(|(_, idx)| idx).collect()
400}
401
402/// Robust prune: rejects c if alpha * dist(c, selected) <= dist(p, c).
403fn robust_prune(
404    store: &PointStore,
405    pts: &[u32],
406    p: usize,
407    candidates: &[usize],
408    alpha: f32,
409    r: usize,
410    metric: Metric,
411) -> Vec<usize> {
412    let p_vec = store.vector(pts[p]);
413    let mut sorted: Vec<(usize, f32)> = candidates
414        .iter()
415        .filter(|&&c| c != p)
416        .map(|&c| (c, distance::distance(p_vec, store.vector(pts[c]), metric)))
417        .collect();
418    sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
419    sorted.dedup_by_key(|x| x.0);
420
421    let mut selected: Vec<usize> = Vec::with_capacity(r);
422    for &(c, d_pc) in &sorted {
423        if selected.len() >= r {
424            break;
425        }
426        let dominated = selected.iter().any(|&s| {
427            let d_cs = distance::distance(store.vector(pts[c]), store.vector(pts[s]), metric);
428            alpha * d_cs <= d_pc
429        });
430        if !dominated {
431            selected.push(c);
432        }
433    }
434    selected
435}
436
437/// Greedy attribute-diverse cross-partition edges. SQ8 beam search for
438/// candidate discovery, f32 rerank, parallelized across points.
439#[allow(clippy::too_many_arguments)]
440fn build_greedy_cross_edges(
441    store: &PointStore,
442    tree: &PartitionTree,
443    medoids: &[u32],
444    local_graph: &Graph,
445    sq8: &SQ8Store,
446    point_cell: &[u32],
447    config: &PrismConfig,
448    adj: &mut AdjBuilder,
449) {
450    let n = store.len;
451    let k = store.k();
452    let t = config.t.min(k);
453    let beam = config.beam_width;
454    let subsets = t_subsets(k, t);
455    // SQ8-L2 candidate ranking is rank-faithful for L2 and for Cosine (vectors
456    // are build-normalized); InnerProduct falls back to exact full-cell scans.
457    let use_sq8 = config.metric != Metric::InnerProduct;
458
459    let point_edges: Vec<Vec<u32>> = (0..n as u32)
460        .into_par_iter()
461        .map(|p_id| {
462            let p_cell_idx = point_cell[p_id as usize];
463            let p_vec = store.vector(p_id);
464
465            let p_code = sq8.code(p_id);
466            let mut cell_dists: Vec<(usize, u32)> = tree
467                .cells
468                .iter()
469                .enumerate()
470                .filter(|&(ci, _)| ci as u32 != p_cell_idx)
471                .map(|(ci, _)| {
472                    let d = distance::l2_sq8(p_code, sq8.code(medoids[ci]));
473                    (ci, d)
474                })
475                .collect();
476            cell_dists.sort_unstable_by_key(|&(_, d)| d);
477
478            let mut all_cand_ids: Vec<u32> = Vec::with_capacity(beam);
479            for &(ci, _) in &cell_dists {
480                let cell_size = tree.cells[ci].point_ids.len();
481
482                if use_sq8 && cell_size > beam * 2 {
483                    let found = beam_search_sq8(sq8, local_graph, p_code, medoids[ci], beam);
484                    for (id, _) in found {
485                        all_cand_ids.push(id);
486                    }
487                } else if use_sq8 {
488                    let mut scored: Vec<(u32, u32)> = tree.cells[ci]
489                        .point_ids
490                        .iter()
491                        .map(|&q| (q, distance::l2_sq8(p_code, sq8.code(q))))
492                        .collect();
493                    scored.sort_unstable_by_key(|&(_, d)| d);
494                    for &(id, _) in scored.iter().take(beam) {
495                        all_cand_ids.push(id);
496                    }
497                } else {
498                    for &q_id in &tree.cells[ci].point_ids {
499                        all_cand_ids.push(q_id);
500                    }
501                }
502
503                if all_cand_ids.len() >= beam {
504                    break;
505                }
506            }
507
508            let mut candidates: Vec<(u32, f32)> = all_cand_ids
509                .iter()
510                .map(|&id| {
511                    (
512                        id,
513                        distance::distance(p_vec, store.vector(id), config.metric),
514                    )
515                })
516                .collect();
517            candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
518            candidates.truncate(beam);
519
520            select_cross_neighbors(store, &candidates, config, &subsets)
521        })
522        .collect();
523
524    for (p_id, neighbors) in point_edges.into_iter().enumerate() {
525        for q_id in neighbors {
526            adj.add_edge(p_id as u32, q_id);
527        }
528    }
529}
530
531/// SQ8 beam search through a cell's local graph. Returns (point_id, sq8_distance).
532fn beam_search_sq8(
533    sq8: &SQ8Store,
534    graph: &Graph,
535    query_code: &[u8],
536    entry: u32,
537    beam: usize,
538) -> Vec<(u32, u32)> {
539    use std::cmp::Reverse;
540    use std::collections::BinaryHeap;
541
542    let mut visited = HashSet::new();
543    let mut candidates: BinaryHeap<Reverse<(u32, u32)>> = BinaryHeap::new();
544    let mut results: BinaryHeap<(u32, u32)> = BinaryHeap::new();
545
546    let d = distance::l2_sq8(query_code, sq8.code(entry));
547    visited.insert(entry);
548    candidates.push(Reverse((d, entry)));
549    results.push((d, entry));
550
551    while let Some(Reverse((d, c))) = candidates.pop() {
552        if results.len() >= beam {
553            if let Some(&(worst, _)) = results.peek() {
554                if d > worst {
555                    break;
556                }
557            }
558        }
559
560        for &w in graph.neighbors(c) {
561            if !visited.insert(w) {
562                continue;
563            }
564            let wd = distance::l2_sq8(query_code, sq8.code(w));
565            candidates.push(Reverse((wd, w)));
566            results.push((wd, w));
567            if results.len() > beam {
568                results.pop();
569            }
570        }
571    }
572
573    results.into_iter().map(|(d, id)| (id, d)).collect()
574}
575
576/// Attribute-diverse neighbor selection. Candidates sorted by distance.
577pub(crate) fn select_cross_neighbors(
578    store: &PointStore,
579    candidates: &[(u32, f32)],
580    config: &PrismConfig,
581    subsets: &[Vec<usize>],
582) -> Vec<u32> {
583    let m_g = config.m_greedy;
584    let alpha = config.alpha;
585
586    if candidates.is_empty() || m_g == 0 {
587        return Vec::new();
588    }
589
590    let mut covered: HashSet<u64> = HashSet::new();
591    let mut selected = Vec::with_capacity(m_g);
592    let mut available: Vec<bool> = vec![true; candidates.len()];
593
594    for _ in 0..m_g {
595        let mut best_idx = None;
596        let mut best_score = f32::NEG_INFINITY;
597
598        for (idx, &(q_id, dist)) in candidates.iter().enumerate() {
599            if !available[idx] {
600                continue;
601            }
602
603            let new_tuples = count_new_tuples(store, q_id, &covered, subsets);
604
605            // score = gain / cost; fall back to proximity when fully covered
606            let score = if alpha == 0.0 || dist == 0.0 {
607                new_tuples as f32
608            } else {
609                (new_tuples as f32 + 0.001) / dist.powf(alpha)
610            };
611
612            if score > best_score {
613                best_score = score;
614                best_idx = Some(idx);
615            }
616        }
617
618        let Some(idx) = best_idx else { break };
619        selected.push(candidates[idx].0);
620        available[idx] = false;
621
622        add_tuples(store, candidates[idx].0, &mut covered, subsets);
623    }
624
625    selected
626}
627
628/// Encode (combo, values) as u64 key. Supports up to 8 dims, values < 256.
629#[inline]
630fn tuple_key(combo: &[usize], store: &PointStore, q: u32) -> u64 {
631    let mut key: u64 = 0;
632    for (i, &j) in combo.iter().enumerate() {
633        let val = store.attr(q, j) as u64;
634        key |= ((j as u64) << 8 | val) << (i * 16);
635    }
636    key
637}
638
639/// Count how many new t-tuples a candidate would contribute.
640fn count_new_tuples(
641    store: &PointStore,
642    q: u32,
643    covered: &HashSet<u64>,
644    subsets: &[Vec<usize>],
645) -> usize {
646    let mut count = 0;
647    for combo in subsets {
648        let key = tuple_key(combo, store, q);
649        if !covered.contains(&key) {
650            count += 1;
651        }
652    }
653    count
654}
655
656/// Add all t-tuples of a point to the covered set.
657pub(crate) fn add_tuples(
658    store: &PointStore,
659    q: u32,
660    covered: &mut HashSet<u64>,
661    subsets: &[Vec<usize>],
662) {
663    for combo in subsets {
664        let key = tuple_key(combo, store, q);
665        covered.insert(key);
666    }
667}
668
669/// Generate all t-element subsets of [0..k].
670pub(crate) fn t_subsets(k: usize, t: usize) -> Vec<Vec<usize>> {
671    let mut result = Vec::new();
672    let mut combo = Vec::with_capacity(t);
673    generate_subsets(k, t, 0, &mut combo, &mut result);
674    result
675}
676
677fn generate_subsets(
678    k: usize,
679    t: usize,
680    start: usize,
681    combo: &mut Vec<usize>,
682    result: &mut Vec<Vec<usize>>,
683) {
684    if combo.len() == t {
685        result.push(combo.clone());
686        return;
687    }
688    for i in start..k {
689        combo.push(i);
690        generate_subsets(k, t, i + 1, combo, result);
691        combo.pop();
692    }
693}
694
695/// Random cross-partition edges via Friedman permutation model.
696pub(crate) fn build_random_overlay(n: usize, m_random: usize, adj: &mut AdjBuilder) {
697    if m_random == 0 || n <= 1 {
698        return;
699    }
700    let mut rng = rand::thread_rng();
701    let half = m_random / 2;
702
703    for _ in 0..half {
704        let mut perm: Vec<u32> = (0..n as u32).collect();
705        perm.shuffle(&mut rng);
706        for (i, &j) in perm.iter().enumerate() {
707            if i as u32 != j {
708                adj.add_undirected(i as u32, j);
709            }
710        }
711    }
712}
713
714/// Per-cell medoid via centroid-nearest approximation.
715fn compute_medoids(store: &PointStore, tree: &PartitionTree, metric: Metric) -> Vec<u32> {
716    let dim = store.dim;
717    tree.cells
718        .iter()
719        .map(|cell| {
720            let pts = &cell.point_ids;
721            if pts.len() == 1 {
722                return pts[0];
723            }
724            let mut centroid = vec![0.0f32; dim];
725            for &p in pts {
726                let v = store.vector(p);
727                for (c, &x) in centroid.iter_mut().zip(v.iter()) {
728                    *c += x;
729                }
730            }
731            let inv_n = 1.0 / pts.len() as f32;
732            for c in &mut centroid {
733                *c *= inv_n;
734            }
735            *pts.iter()
736                .min_by(|&&a, &&b| {
737                    let da = distance::distance(&centroid, store.vector(a), metric);
738                    let db = distance::distance(&centroid, store.vector(b), metric);
739                    da.partial_cmp(&db).unwrap()
740                })
741                .unwrap()
742        })
743        .collect()
744}
745
746/// Compute global medoid: the point closest to the centroid of the entire dataset.
747fn compute_global_medoid(store: &PointStore, metric: Metric) -> u32 {
748    let n = store.len;
749    let dim = store.dim;
750    let mut centroid = vec![0.0f32; dim];
751    for i in 0..n as u32 {
752        let v = store.vector(i);
753        for (c, &x) in centroid.iter_mut().zip(v.iter()) {
754            *c += x;
755        }
756    }
757    let inv_n = 1.0 / n as f32;
758    for c in &mut centroid {
759        *c *= inv_n;
760    }
761    (0..n as u32)
762        .min_by(|&a, &b| {
763            let da = distance::distance(&centroid, store.vector(a), metric);
764            let db = distance::distance(&centroid, store.vector(b), metric);
765            da.partial_cmp(&db).unwrap()
766        })
767        .unwrap()
768}
769
770#[cfg(test)]
771mod tests {
772    use super::super::point::PointStore;
773    use super::*;
774
775    #[test]
776    fn test_build_small() {
777        let mut store = PointStore::new(2, 2);
778        // 4 points, 2 attributes with 2 values each = 4 cells.
779        store.push(&[0.0, 0.0], &[0, 0]);
780        store.push(&[1.0, 0.0], &[0, 1]);
781        store.push(&[0.0, 1.0], &[1, 0]);
782        store.push(&[1.0, 1.0], &[1, 1]);
783
784        let config = PrismConfig {
785            m_local: 2,
786            m_greedy: 2,
787            m_random: 4,
788            t: 1,
789            alpha: 0.0,
790            beam_width: 10,
791            ..Default::default()
792        };
793
794        let index = PrismIndex::build(store, config);
795        assert_eq!(index.tree.cells.len(), 4);
796        assert_eq!(index.medoids.len(), 4);
797        // Each point should have some neighbors
798        for i in 0..4u32 {
799            assert!(index.graph.degree(i) > 0);
800        }
801    }
802
803    #[test]
804    fn test_t_subsets() {
805        let subs = t_subsets(4, 2);
806        assert_eq!(subs.len(), 6); // C(4,2) = 6
807        let subs = t_subsets(3, 1);
808        assert_eq!(subs.len(), 3);
809    }
810}