citadel_vector/vendored/prism/
partition.rs1use super::point::PointStore;
2
3#[derive(Clone, Debug)]
5pub struct Cell {
6 pub values: Vec<u32>,
8 pub point_ids: Vec<u32>,
10}
11
12pub struct PartitionTree {
17 pub cells: Vec<Cell>,
19 pub split_order: Vec<usize>,
21 pub k: usize,
23}
24
25impl PartitionTree {
26 pub fn build(store: &PointStore) -> Self {
29 let k = store.k();
30 let n = store.len;
31
32 let mut order: Vec<usize> = (0..k).collect();
34 order.sort_by_key(|&b| std::cmp::Reverse(store.cardinality(b)));
35
36 let mut groups: std::collections::HashMap<Vec<u32>, Vec<u32>> =
38 std::collections::HashMap::new();
39 for i in 0..n {
40 let key: Vec<u32> = (0..k).map(|j| store.attr(i as u32, j)).collect();
41 groups.entry(key).or_default().push(i as u32);
42 }
43
44 let cells: Vec<Cell> = groups
45 .into_iter()
46 .map(|(values, point_ids)| Cell { values, point_ids })
47 .collect();
48
49 Self {
50 cells,
51 split_order: order,
52 k,
53 }
54 }
55
56 pub fn filter_cells(&self, constraints: &[(usize, Vec<u32>)]) -> Vec<usize> {
60 self.cells
61 .iter()
62 .enumerate()
63 .filter(|(_, cell)| {
64 constraints
65 .iter()
66 .all(|(j, allowed)| allowed.contains(&cell.values[*j]))
67 })
68 .map(|(i, _)| i)
69 .collect()
70 }
71
72 pub fn count_points(&self, cell_indices: &[usize]) -> usize {
74 cell_indices
75 .iter()
76 .map(|&i| self.cells[i].point_ids.len())
77 .sum()
78 }
79
80 pub fn collect_points(&self, cell_indices: &[usize]) -> Vec<u32> {
82 let mut pts = Vec::new();
83 for &i in cell_indices {
84 pts.extend_from_slice(&self.cells[i].point_ids);
85 }
86 pts
87 }
88
89 pub fn cell_of(&self, store: &PointStore, point_id: u32) -> Option<usize> {
91 let key: Vec<u32> = (0..self.k).map(|j| store.attr(point_id, j)).collect();
92 self.cells.iter().position(|c| c.values == key)
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::super::point::PointStore;
99 use super::*;
100
101 #[test]
102 fn test_partition_tree() {
103 let vectors = vec![0.0f32; 6 * 2];
105 let attrs = vec![
106 vec![0, 0, 1, 1, 2, 2], vec![0, 1, 0, 1, 0, 1], ];
109 let store = PointStore::from_parts(vectors, 2, attrs);
110 let tree = PartitionTree::build(&store);
111 assert_eq!(tree.cells.len(), 6); let cells = tree.filter_cells(&[(0, vec![0])]);
115 let pts = tree.collect_points(&cells);
116 assert_eq!(pts.len(), 2);
117
118 let cells = tree.filter_cells(&[(0, vec![1]), (1, vec![0])]);
120 let pts = tree.collect_points(&cells);
121 assert_eq!(pts.len(), 1);
122 }
123}