Skip to main content

citadel_vector/vendored/prism/
partition.rs

1use super::point::PointStore;
2
3/// A single leaf cell in the Attribute Partition Tree.
4#[derive(Clone, Debug)]
5pub struct Cell {
6    /// Attribute values that define this cell: `values[j]` = value for attribute j.
7    pub values: Vec<u32>,
8    /// Point ids belonging to this cell.
9    pub point_ids: Vec<u32>,
10}
11
12/// Attribute Partition Tree (Algorithm 1 from the paper).
13///
14/// Recursive balanced partition on attributes. Each leaf is a cell of points
15/// sharing the same attribute combination.
16pub struct PartitionTree {
17    /// All leaf cells.
18    pub cells: Vec<Cell>,
19    /// Attribute split order (permutation of [0..k]).
20    pub split_order: Vec<usize>,
21    /// Number of attribute dimensions.
22    pub k: usize,
23}
24
25impl PartitionTree {
26    /// Build the partition tree from a PointStore.
27    /// Split order: most-distinct-values first (information gain heuristic).
28    pub fn build(store: &PointStore) -> Self {
29        let k = store.k();
30        let n = store.len;
31
32        // Determine split order: descending by cardinality
33        let mut order: Vec<usize> = (0..k).collect();
34        order.sort_by_key(|&b| std::cmp::Reverse(store.cardinality(b)));
35
36        // Group points by their full attribute combination
37        let mut groups: std::collections::HashMap<Vec<u32>, Vec<u32>> =
38            std::collections::HashMap::new();
39        for i in 0..n {
40            let key: Vec<u32> = (0..k).map(|j| store.attr(i as u32, j)).collect();
41            groups.entry(key).or_default().push(i as u32);
42        }
43
44        let mut cells: Vec<Cell> = groups
45            .into_iter()
46            .map(|(values, point_ids)| Cell { values, point_ids })
47            .collect();
48        // HashMap iteration order is random per process; sort so identical
49        // data always builds identical (and byte-identical persisted) indexes.
50        cells.sort_unstable_by(|a, b| a.values.cmp(&b.values));
51
52        Self {
53            cells,
54            split_order: order,
55            k,
56        }
57    }
58
59    /// Find all cells compatible with a filter.
60    /// A cell is compatible if for every constrained attribute j,
61    /// the cell's value on j is in the allowed set.
62    pub fn filter_cells(&self, constraints: &[(usize, Vec<u32>)]) -> Vec<usize> {
63        self.cells
64            .iter()
65            .enumerate()
66            .filter(|(_, cell)| {
67                constraints
68                    .iter()
69                    .all(|(j, allowed)| allowed.contains(&cell.values[*j]))
70            })
71            .map(|(i, _)| i)
72            .collect()
73    }
74
75    /// Total number of points across given cell indices.
76    pub fn count_points(&self, cell_indices: &[usize]) -> usize {
77        cell_indices
78            .iter()
79            .map(|&i| self.cells[i].point_ids.len())
80            .sum()
81    }
82
83    /// Get all point ids in the given cell indices.
84    pub fn collect_points(&self, cell_indices: &[usize]) -> Vec<u32> {
85        let mut pts = Vec::new();
86        for &i in cell_indices {
87            pts.extend_from_slice(&self.cells[i].point_ids);
88        }
89        pts
90    }
91
92    /// Find which cell a point belongs to. Returns cell index.
93    pub fn cell_of(&self, store: &PointStore, point_id: u32) -> Option<usize> {
94        let key: Vec<u32> = (0..self.k).map(|j| store.attr(point_id, j)).collect();
95        self.cells.iter().position(|c| c.values == key)
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::super::point::PointStore;
102    use super::*;
103
104    #[test]
105    fn test_partition_tree() {
106        // 6 points, 2 attributes: color(3 values), size(2 values)
107        let vectors = vec![0.0f32; 6 * 2];
108        let attrs = vec![
109            vec![0, 0, 1, 1, 2, 2], // color
110            vec![0, 1, 0, 1, 0, 1], // size
111        ];
112        let store = PointStore::from_parts(vectors, 2, attrs);
113        let tree = PartitionTree::build(&store);
114        assert_eq!(tree.cells.len(), 6); // 3*2 = 6 distinct combos
115
116        // Filter: color=0
117        let cells = tree.filter_cells(&[(0, vec![0])]);
118        let pts = tree.collect_points(&cells);
119        assert_eq!(pts.len(), 2);
120
121        // Filter: color=1 AND size=0
122        let cells = tree.filter_cells(&[(0, vec![1]), (1, vec![0])]);
123        let pts = tree.collect_points(&cells);
124        assert_eq!(pts.len(), 1);
125    }
126}