brine_tree/
subtree.rs

1use super::{
2    tree::{MerkleTree, get_merkle_proof},
3    hash::{Hash, Leaf},
4    utils,
5};
6use bytemuck::{Pod, Zeroable};
7
8/// Generates a Merkle proof using a cached layer and a closure to fetch only the required leaves
9/// by index.
10pub fn get_cached_merkle_proof<const TREE_HEIGHT: usize>(
11    tree: &MerkleTree<TREE_HEIGHT>,
12    leaf_index: usize,
13    cached_layer_number: usize,
14    cached_layer_nodes: &[Hash],
15    fetch_leaf: impl Fn(usize) -> Option<Leaf>,
16) -> Vec<Hash> {
17
18    assert!(cached_layer_number <= TREE_HEIGHT, "cached_layer_number exceeds tree height");
19    assert!(leaf_index < (1usize << TREE_HEIGHT), "leaf_index out of capacity");
20    assert!(!cached_layer_nodes.is_empty(), "cached_layer_nodes must not be empty");
21
22    let subtree = Subtree::new(leaf_index, cached_layer_number, TREE_HEIGHT);
23
24    let zero = tree.zero_values[0].as_leaf();
25    let start = subtree.leaf_start;
26    let end = subtree.leaf_start + subtree.leaf_count;
27
28    // Should be way smaller than the total leaf count unless the root layer is chosen.
29    let lower_leaves: Vec<Leaf> = (start..end)
30        .map(|i| fetch_leaf(i).unwrap_or(zero))
31        .collect();
32
33    let target_relative = leaf_index - subtree.leaf_start;
34
35    subtree.get_merkle_proof(
36        tree,
37        target_relative,
38        &lower_leaves,
39        cached_layer_nodes,
40    )
41}
42
43#[repr(C)]
44#[derive(Clone, Copy, PartialEq, Debug, Pod, Zeroable)]
45pub struct Subtree {
46    pub subtree_node_index: usize,
47    pub pos_in_layer: usize,
48    pub leaf_start: usize,
49    pub leaf_count: usize,
50    pub lower_height: usize,
51    pub upper_height: usize,
52}
53
54impl Subtree {
55    pub fn new(
56        leaf_index: usize,
57        layer_number: usize,
58        height: usize,
59    ) -> Self {
60        compute_subtree_metadata(leaf_index, layer_number, height)
61    }
62
63    pub fn get_merkle_proof<const TREE_HEIGHT: usize>(
64        &self,
65        tree: &MerkleTree<TREE_HEIGHT>,
66        target_leaf_relative: usize,
67        lower_leaves: &[Leaf],
68        cached_layer: &[Hash],
69    ) -> Vec<Hash> {
70        get_proof_with_metadata(
71            self,
72            tree,
73            target_leaf_relative,
74            lower_leaves,
75            cached_layer,
76        )
77    }
78}
79
80/// Creates metadata for generating a split Merkle proof.
81fn compute_subtree_metadata(
82    leaf_index: usize,
83    layer_number: usize,
84    height: usize,
85) -> Subtree {
86    assert!(layer_number <= height, "layer_number > height");
87    assert!(leaf_index < (1usize << height), "leaf_index out of capacity");
88
89    // Ancestor at the split layer
90    let subtree_node_index =
91        utils::find_ancestor(layer_number, leaf_index, height);
92
93    let pos_in_layer = subtree_node_index - 
94        utils::first_index_in_layer(layer_number, height);
95
96    // Range of leaves covered by that ancestor
97    let (leaf_start, leaf_count) =
98        utils::descendant_range(subtree_node_index, 0, height);
99
100    Subtree {
101        subtree_node_index,
102        pos_in_layer,
103        leaf_start,
104        leaf_count,
105        lower_height: layer_number,
106        upper_height: height - layer_number,
107    }
108}
109
110/// Generates a split Merkle proof for a specific layer into the Merkle tree using split metadata.
111fn get_proof_with_metadata<const TREE_HEIGHT: usize>(
112    meta: &Subtree,
113    tree: &MerkleTree<TREE_HEIGHT>,
114    target_leaf_relative: usize,
115    lower_leaves: &[Leaf],
116    cached_layer: &[Hash],
117) -> Vec<Hash> {
118    assert_eq!(
119        lower_leaves.len(),
120        meta.leaf_count,
121        "lower_leaves length mismatch"
122    );
123    assert!(
124        meta.pos_in_layer < cached_layer.len(),
125        "cached layer missing target node"
126    );
127
128    let lower_proof = get_merkle_proof(
129        lower_leaves,
130        &tree.zero_values,
131        target_leaf_relative,
132        meta.lower_height,
133    );
134
135    let upper_leaves_as_leaf: Vec<Leaf> =
136        cached_layer.iter().map(|h| h.as_leaf()).collect();
137
138    let upper_proof = get_merkle_proof(
139        &upper_leaves_as_leaf,
140        &tree.zero_values[meta.lower_height..],
141        meta.pos_in_layer,
142        meta.upper_height,
143    );
144
145    let mut proof = Vec::with_capacity(TREE_HEIGHT);
146    proof.extend(lower_proof);
147    proof.extend(upper_proof);
148
149    debug_assert_eq!(proof.len(), TREE_HEIGHT);
150
151    proof
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::tree::{verify, MerkleTree};
158
159    #[test]
160    fn split_properties() {
161        const TREE_HEIGHT: usize = 3;    // tree height (root at layer 3)
162
163        let layer_number = 2;  // split layer
164        let leaf_index   = 5;  // any leaf < 8
165
166        let meta = compute_subtree_metadata(leaf_index, layer_number, TREE_HEIGHT);
167
168        assert_eq!(meta.lower_height + meta.upper_height, TREE_HEIGHT);
169        assert_eq!(meta.leaf_count, 1usize << layer_number);
170        assert!(
171            leaf_index >= meta.leaf_start &&
172            leaf_index <  meta.leaf_start + meta.leaf_count
173        );
174        assert_eq!(
175            meta.subtree_node_index,
176            utils::first_index_in_layer(layer_number, TREE_HEIGHT) + meta.pos_in_layer
177        );
178    }
179
180    #[test]
181    fn split_baseline() {
182        const TREE_HEIGHT: usize = 5;
183        const FILLED: usize      = 4; // fill 4 leaves (out of 32)
184
185        let layer_number = 2;
186        let leaf_index   = 4usize;
187
188        let seeds: &[&[u8]] = &[b"test"];
189        let mut tree = MerkleTree::<TREE_HEIGHT>::new(seeds);
190
191        for i in 0..=FILLED {
192            tree.try_add(&[format!("val_{i}").as_bytes()]).unwrap();
193        }
194
195        let mut leaves_by_index = std::collections::HashMap::new();
196        for i in 0..=FILLED {
197            leaves_by_index.insert(i, Leaf::new(&[format!("val_{i}").as_bytes()]));
198        }
199        let fetch = |idx: usize| leaves_by_index.get(&idx).copied();
200        let zero  = tree.zero_values[0].as_leaf();
201
202        let leaves: Vec<Leaf> = (0..=FILLED).map(|i| fetch(i).unwrap()).collect();
203        let cached_layer = tree.get_layer_nodes(&leaves, layer_number);
204
205        println!("cached_layer len: {:?}", cached_layer.len());
206
207        let meta = compute_subtree_metadata(leaf_index, layer_number, TREE_HEIGHT);
208        let lower: Vec<Leaf> = (meta.leaf_start .. meta.leaf_start + meta.leaf_count)
209            .map(|g| fetch(g).unwrap_or(zero))
210            .collect();
211
212        let proof_split = get_proof_with_metadata(
213            &meta,
214            &tree,
215            leaf_index - meta.leaf_start, // correct relative index
216            &lower,
217            &cached_layer,
218        );
219
220        let baseline = get_merkle_proof(&leaves, &tree.zero_values, leaf_index, TREE_HEIGHT);
221
222        assert_eq!(proof_split, baseline);
223        assert!(verify(tree.get_root(), &proof_split, leaves_by_index[&leaf_index]));
224    }
225
226    #[test]
227    fn split_large_tree() {
228        const TREE_HEIGHT: usize = 18; // 262 144-leaf capacity
229        const FILLED: usize = 1 << 12; // fill 4096 leaves
230
231        let layer_number = 10;
232        let seeds: &[&[u8]] = &[b"large_tree"];
233        let mut tree = MerkleTree::<TREE_HEIGHT>::new(seeds);
234        let mut leaves_by_index = std::collections::HashMap::new();
235
236        for i in 0..FILLED {
237            let bytes = (i as u64).to_le_bytes();
238            let leaf  = Leaf::new(&[&bytes]);
239            leaves_by_index.insert(i, leaf);
240            tree.try_add_leaf(leaf).unwrap();
241        }
242
243        let fetch = |idx: usize| leaves_by_index.get(&idx).copied();
244        let zero  = tree.zero_values[0].as_leaf();
245
246        let leaves: Vec<Leaf> = (0..FILLED).map(|i| fetch(i).unwrap()).collect();
247        let cached_layer = tree.get_layer_nodes(&leaves, layer_number);
248
249        // only 4 nodes in layer 10 are actually non-zero; a fully filled tree would have 256
250        assert_eq!(cached_layer.len(), 4); 
251
252        let leaf_index = 1234usize;
253        let proof_split = get_cached_merkle_proof(
254            &tree,
255            leaf_index,
256            layer_number,
257            &cached_layer,
258            |i| fetch(i).or_else(|| Some(zero)),
259        );
260
261        let baseline = get_merkle_proof(&leaves, &tree.zero_values, leaf_index, TREE_HEIGHT);
262
263        assert_eq!(proof_split, baseline);
264        assert!(verify(tree.get_root(), &proof_split, fetch(leaf_index).unwrap()));
265    }
266}