citadel_vector/vendored/prism/
partition.rs1use super::point::PointStore;
2
3#[derive(Clone, Debug)]
5pub struct Cell {
6 pub values: Vec<u32>,
8 pub point_ids: Vec<u32>,
10}
11
12pub struct PartitionTree {
17 pub cells: Vec<Cell>,
19 pub split_order: Vec<usize>,
21 pub k: usize,
23}
24
25impl PartitionTree {
26 pub fn build(store: &PointStore) -> Self {
29 let k = store.k();
30 let n = store.len;
31
32 let mut order: Vec<usize> = (0..k).collect();
34 order.sort_by_key(|&b| std::cmp::Reverse(store.cardinality(b)));
35
36 let mut groups: std::collections::HashMap<Vec<u32>, Vec<u32>> =
38 std::collections::HashMap::new();
39 for i in 0..n {
40 let key: Vec<u32> = (0..k).map(|j| store.attr(i as u32, j)).collect();
41 groups.entry(key).or_default().push(i as u32);
42 }
43
44 let mut cells: Vec<Cell> = groups
45 .into_iter()
46 .map(|(values, point_ids)| Cell { values, point_ids })
47 .collect();
48 cells.sort_unstable_by(|a, b| a.values.cmp(&b.values));
51
52 Self {
53 cells,
54 split_order: order,
55 k,
56 }
57 }
58
59 pub fn filter_cells(&self, constraints: &[(usize, Vec<u32>)]) -> Vec<usize> {
63 self.cells
64 .iter()
65 .enumerate()
66 .filter(|(_, cell)| {
67 constraints
68 .iter()
69 .all(|(j, allowed)| allowed.contains(&cell.values[*j]))
70 })
71 .map(|(i, _)| i)
72 .collect()
73 }
74
75 pub fn count_points(&self, cell_indices: &[usize]) -> usize {
77 cell_indices
78 .iter()
79 .map(|&i| self.cells[i].point_ids.len())
80 .sum()
81 }
82
83 pub fn collect_points(&self, cell_indices: &[usize]) -> Vec<u32> {
85 let mut pts = Vec::new();
86 for &i in cell_indices {
87 pts.extend_from_slice(&self.cells[i].point_ids);
88 }
89 pts
90 }
91
92 pub fn cell_of(&self, store: &PointStore, point_id: u32) -> Option<usize> {
94 let key: Vec<u32> = (0..self.k).map(|j| store.attr(point_id, j)).collect();
95 self.cells.iter().position(|c| c.values == key)
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::super::point::PointStore;
102 use super::*;
103
104 #[test]
105 fn test_partition_tree() {
106 let vectors = vec![0.0f32; 6 * 2];
108 let attrs = vec![
109 vec![0, 0, 1, 1, 2, 2], vec![0, 1, 0, 1, 0, 1], ];
112 let store = PointStore::from_parts(vectors, 2, attrs);
113 let tree = PartitionTree::build(&store);
114 assert_eq!(tree.cells.len(), 6); let cells = tree.filter_cells(&[(0, vec![0])]);
118 let pts = tree.collect_points(&cells);
119 assert_eq!(pts.len(), 2);
120
121 let cells = tree.filter_cells(&[(0, vec![1]), (1, vec![0])]);
123 let pts = tree.collect_points(&cells);
124 assert_eq!(pts.len(), 1);
125 }
126}