brine_tree/
utils.rs

1use super::error::{BrineTreeError, ProgramResult};
2
3#[inline]
4/// Check a condition and return a custom error if false.
5pub fn check_condition(condition: bool, err: BrineTreeError) -> ProgramResult {
6    if condition {
7        Ok(())
8    } else {
9        Err(err)
10    }
11}
12
13/// Return the first global index at a given layer for a perfect binary Merkle tree of `height`.
14/// Layers are numbered with leaves at 0 and root at `height`.
15pub fn first_index_in_layer(layer: usize, height: usize) -> usize {
16    if layer == 0 {
17        0
18    } else {
19        (1usize << (height + 1)) - (1usize << (height + 1 - layer))
20    }
21}
22
23/// Return the ancestor index of `node_index` at absolute `target_layer`
24/// for a perfect binary Merkle tree of `height`.
25/// Layers are numbered with leaves at 0 and root at `height`.
26///
27/// Example (height = 3):
28///   - find_ancestor(2, 3, 3)  == 12
29///   - find_ancestor(2, 10, 3) == 13
30pub fn find_ancestor(target_layer: usize, node_index: usize, height: usize) -> usize {
31    assert!(target_layer <= height, "target_layer exceeds tree height");
32
33    // Determine the layer of `node_index`.
34    // Find the largest L such that first_index_in_layer(L) <= node_index.
35    let mut src_layer = 0usize;
36    while src_layer < height && node_index >= first_index_in_layer(src_layer + 1, height) {
37        src_layer += 1;
38    }
39
40    assert!(
41        target_layer >= src_layer,
42        "target_layer must be >= the node's current layer (ancestor lookup)"
43    );
44
45    // Position within its source layer, then shift right by the number of layers we go up.
46    let pos_in_src = node_index - first_index_in_layer(src_layer, height);
47    let up = target_layer - src_layer;
48
49    first_index_in_layer(target_layer, height) + (pos_in_src >> up)
50}
51
52/// Return the (start, count) range of descendant indices of `node_index`
53/// located at `target_layer` (leaves = 0, root = `height`) in a perfect
54/// binary Merkle tree indexed by layer left→right as in the prompt.
55///
56/// Same-layer query returns (node_index, 1).
57pub fn descendant_range(node_index: usize, target_layer: usize, height: usize) -> (usize, usize) {
58    assert!(target_layer <= height, "target_layer exceeds tree height");
59
60    // Max valid index for a perfect tree with given height.
61    let last_index = (1usize << (height + 1)) - 2;
62    assert!(
63        node_index <= last_index,
64        "node_index out of range for given height"
65    );
66
67    // Determine the layer of `node_index`.
68    let mut src_layer = 0usize;
69    while src_layer < height && node_index >= first_index_in_layer(src_layer + 1, height) {
70        src_layer += 1;
71    }
72
73    assert!(
74        target_layer <= src_layer,
75        "target_layer must be <= the node's current layer (descendant lookup)"
76    );
77
78    // Position within source layer; expand down by `down` layers.
79    let pos_in_src = node_index - first_index_in_layer(src_layer, height);
80    let down = src_layer - target_layer;
81    let count = 1usize << down;
82    let start = first_index_in_layer(target_layer, height) + (pos_in_src << down);
83
84    (start, count)
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn group_to_parents_layer1_height_3() {
93        // Leaves 0..7 -> layer 1 parents 8..11
94        assert_eq!(find_ancestor(1, 0, 3), 8);
95        assert_eq!(find_ancestor(1, 1, 3), 8);
96
97        assert_eq!(find_ancestor(1, 2, 3), 9);
98        assert_eq!(find_ancestor(1, 3, 3), 9);
99
100        assert_eq!(find_ancestor(1, 4, 3), 10);
101        assert_eq!(find_ancestor(1, 5, 3), 10);
102
103        assert_eq!(find_ancestor(1, 6, 3), 11);
104        assert_eq!(find_ancestor(1, 7, 3), 11);
105    }
106
107    #[test]
108    fn group_to_parents_layer2_height_3() {
109        // Layer 1 nodes 8..11 -> layer 2 parents 12..13
110        assert_eq!(find_ancestor(2, 8, 3), 12);
111        assert_eq!(find_ancestor(2, 9, 3), 12);
112        assert_eq!(find_ancestor(2, 10, 3), 13);
113        assert_eq!(find_ancestor(2, 11, 3), 13);
114    }
115
116    #[test]
117    fn all_to_root_height_3() {
118        // Root index for height=3 is 14
119        for idx in 0..=14 {
120            assert_eq!(find_ancestor(3, idx, 3), 14);
121        }
122    }
123
124    #[test]
125    fn same_layer_identity_height_3() {
126        // Asking for the same layer should return the same index.
127        assert_eq!(find_ancestor(0, 5, 3), 5);
128        assert_eq!(find_ancestor(1, 8, 3), 8);
129        assert_eq!(find_ancestor(2, 13, 3), 13);
130        assert_eq!(find_ancestor(3, 14, 3), 14);
131    }
132
133    fn expand((start, count): (usize, usize)) -> Vec<usize> {
134        (start..start + count).collect()
135    }
136
137    #[test]
138    fn left_and_right_subtrees_height_3() {
139        // Node 12 (layer 2, left) -> leaves 0..4 and layer1 nodes 8..10
140        assert_eq!(descendant_range(12, 0, 3), (0, 4));
141        assert_eq!(expand(descendant_range(12, 0, 3)), (0..4).collect::<Vec<_>>());
142        assert_eq!(descendant_range(12, 1, 3), (8, 2));
143        assert_eq!(descendant_range(12, 2, 3), (12, 1)); // same-layer
144
145        // Node 13 (layer 2, right) -> leaves 4..8 and layer1 nodes 10..12
146        assert_eq!(descendant_range(13, 0, 3), (4, 4));
147        assert_eq!(expand(descendant_range(13, 0, 3)), (4..8).collect::<Vec<_>>());
148        assert_eq!(descendant_range(13, 1, 3), (10, 2));
149        assert_eq!(descendant_range(13, 2, 3), (13, 1)); // same-layer
150    }
151
152    #[test]
153    fn mid_level_node_to_leaves_height_3() {
154        // Node 9 (layer 1, 'j') -> leaves [2,3]
155        assert_eq!(descendant_range(9, 0, 3), (2, 2));
156        assert_eq!(expand(descendant_range(9, 0, 3)), vec![2, 3]);
157        assert_eq!(descendant_range(9, 1, 3), (9, 1)); // same-layer
158    }
159
160    #[test]
161    #[should_panic(expected = "target_layer exceeds tree height")]
162    fn panics_when_target_layer_too_high() {
163        let _ = find_ancestor(4, 0, 3);
164    }
165
166}