Skip to main content

citadel_vector/vendored/prism/
partition.rs

1use super::point::PointStore;
2
3/// A single leaf cell in the Attribute Partition Tree.
4#[derive(Clone, Debug)]
5pub struct Cell {
6    /// Attribute values that define this cell: `values[j]` = value for attribute j.
7    pub values: Vec<u32>,
8    /// Point ids belonging to this cell.
9    pub point_ids: Vec<u32>,
10}
11
12/// Attribute Partition Tree (Algorithm 1 from the paper).
13///
14/// Recursive balanced partition on attributes. Each leaf is a cell of points
15/// sharing the same attribute combination.
16pub struct PartitionTree {
17    /// All leaf cells.
18    pub cells: Vec<Cell>,
19    /// Attribute split order (permutation of [0..k]).
20    pub split_order: Vec<usize>,
21    /// Number of attribute dimensions.
22    pub k: usize,
23}
24
25impl PartitionTree {
26    /// Build the partition tree from a PointStore.
27    /// Split order: most-distinct-values first (information gain heuristic).
28    pub fn build(store: &PointStore) -> Self {
29        let k = store.k();
30        let n = store.len;
31
32        // Determine split order: descending by cardinality
33        let mut order: Vec<usize> = (0..k).collect();
34        order.sort_by_key(|&b| std::cmp::Reverse(store.cardinality(b)));
35
36        // Group points by their full attribute combination
37        let mut groups: std::collections::HashMap<Vec<u32>, Vec<u32>> =
38            std::collections::HashMap::new();
39        for i in 0..n {
40            let key: Vec<u32> = (0..k).map(|j| store.attr(i as u32, j)).collect();
41            groups.entry(key).or_default().push(i as u32);
42        }
43
44        let cells: Vec<Cell> = groups
45            .into_iter()
46            .map(|(values, point_ids)| Cell { values, point_ids })
47            .collect();
48
49        Self {
50            cells,
51            split_order: order,
52            k,
53        }
54    }
55
56    /// Find all cells compatible with a filter.
57    /// A cell is compatible if for every constrained attribute j,
58    /// the cell's value on j is in the allowed set.
59    pub fn filter_cells(&self, constraints: &[(usize, Vec<u32>)]) -> Vec<usize> {
60        self.cells
61            .iter()
62            .enumerate()
63            .filter(|(_, cell)| {
64                constraints
65                    .iter()
66                    .all(|(j, allowed)| allowed.contains(&cell.values[*j]))
67            })
68            .map(|(i, _)| i)
69            .collect()
70    }
71
72    /// Total number of points across given cell indices.
73    pub fn count_points(&self, cell_indices: &[usize]) -> usize {
74        cell_indices
75            .iter()
76            .map(|&i| self.cells[i].point_ids.len())
77            .sum()
78    }
79
80    /// Get all point ids in the given cell indices.
81    pub fn collect_points(&self, cell_indices: &[usize]) -> Vec<u32> {
82        let mut pts = Vec::new();
83        for &i in cell_indices {
84            pts.extend_from_slice(&self.cells[i].point_ids);
85        }
86        pts
87    }
88
89    /// Find which cell a point belongs to. Returns cell index.
90    pub fn cell_of(&self, store: &PointStore, point_id: u32) -> Option<usize> {
91        let key: Vec<u32> = (0..self.k).map(|j| store.attr(point_id, j)).collect();
92        self.cells.iter().position(|c| c.values == key)
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::super::point::PointStore;
99    use super::*;
100
101    #[test]
102    fn test_partition_tree() {
103        // 6 points, 2 attributes: color(3 values), size(2 values)
104        let vectors = vec![0.0f32; 6 * 2];
105        let attrs = vec![
106            vec![0, 0, 1, 1, 2, 2], // color
107            vec![0, 1, 0, 1, 0, 1], // size
108        ];
109        let store = PointStore::from_parts(vectors, 2, attrs);
110        let tree = PartitionTree::build(&store);
111        assert_eq!(tree.cells.len(), 6); // 3*2 = 6 distinct combos
112
113        // Filter: color=0
114        let cells = tree.filter_cells(&[(0, vec![0])]);
115        let pts = tree.collect_points(&cells);
116        assert_eq!(pts.len(), 2);
117
118        // Filter: color=1 AND size=0
119        let cells = tree.filter_cells(&[(0, vec![1]), (1, vec![0])]);
120        let pts = tree.collect_points(&cells);
121        assert_eq!(pts.len(), 1);
122    }
123}