1use super::{
2 tree::{MerkleTree, get_merkle_proof},
3 hash::{Hash, Leaf},
4 utils,
5};
6use bytemuck::{Pod, Zeroable};
7
8pub fn get_cached_merkle_proof<const TREE_HEIGHT: usize>(
11 tree: &MerkleTree<TREE_HEIGHT>,
12 leaf_index: usize,
13 cached_layer_number: usize,
14 cached_layer_nodes: &[Hash],
15 fetch_leaf: impl Fn(usize) -> Option<Leaf>,
16) -> Vec<Hash> {
17
18 assert!(cached_layer_number <= TREE_HEIGHT, "cached_layer_number exceeds tree height");
19 assert!(leaf_index < (1usize << TREE_HEIGHT), "leaf_index out of capacity");
20 assert!(!cached_layer_nodes.is_empty(), "cached_layer_nodes must not be empty");
21
22 let subtree = Subtree::new(leaf_index, cached_layer_number, TREE_HEIGHT);
23
24 let zero = tree.zero_values[0].as_leaf();
25 let start = subtree.leaf_start;
26 let end = subtree.leaf_start + subtree.leaf_count;
27
28 let lower_leaves: Vec<Leaf> = (start..end)
30 .map(|i| fetch_leaf(i).unwrap_or(zero))
31 .collect();
32
33 let target_relative = leaf_index - subtree.leaf_start;
34
35 subtree.get_merkle_proof(
36 tree,
37 target_relative,
38 &lower_leaves,
39 cached_layer_nodes,
40 )
41}
42
43#[repr(C)]
44#[derive(Clone, Copy, PartialEq, Debug, Pod, Zeroable)]
45pub struct Subtree {
46 pub subtree_node_index: usize,
47 pub pos_in_layer: usize,
48 pub leaf_start: usize,
49 pub leaf_count: usize,
50 pub lower_height: usize,
51 pub upper_height: usize,
52}
53
54impl Subtree {
55 pub fn new(
56 leaf_index: usize,
57 layer_number: usize,
58 height: usize,
59 ) -> Self {
60 compute_subtree_metadata(leaf_index, layer_number, height)
61 }
62
63 pub fn get_merkle_proof<const TREE_HEIGHT: usize>(
64 &self,
65 tree: &MerkleTree<TREE_HEIGHT>,
66 target_leaf_relative: usize,
67 lower_leaves: &[Leaf],
68 cached_layer: &[Hash],
69 ) -> Vec<Hash> {
70 get_proof_with_metadata(
71 self,
72 tree,
73 target_leaf_relative,
74 lower_leaves,
75 cached_layer,
76 )
77 }
78}
79
80fn compute_subtree_metadata(
82 leaf_index: usize,
83 layer_number: usize,
84 height: usize,
85) -> Subtree {
86 assert!(layer_number <= height, "layer_number > height");
87 assert!(leaf_index < (1usize << height), "leaf_index out of capacity");
88
89 let subtree_node_index =
91 utils::find_ancestor(layer_number, leaf_index, height);
92
93 let pos_in_layer = subtree_node_index -
94 utils::first_index_in_layer(layer_number, height);
95
96 let (leaf_start, leaf_count) =
98 utils::descendant_range(subtree_node_index, 0, height);
99
100 Subtree {
101 subtree_node_index,
102 pos_in_layer,
103 leaf_start,
104 leaf_count,
105 lower_height: layer_number,
106 upper_height: height - layer_number,
107 }
108}
109
110fn get_proof_with_metadata<const TREE_HEIGHT: usize>(
112 meta: &Subtree,
113 tree: &MerkleTree<TREE_HEIGHT>,
114 target_leaf_relative: usize,
115 lower_leaves: &[Leaf],
116 cached_layer: &[Hash],
117) -> Vec<Hash> {
118 assert_eq!(
119 lower_leaves.len(),
120 meta.leaf_count,
121 "lower_leaves length mismatch"
122 );
123 assert!(
124 meta.pos_in_layer < cached_layer.len(),
125 "cached layer missing target node"
126 );
127
128 let lower_proof = get_merkle_proof(
129 lower_leaves,
130 &tree.zero_values,
131 target_leaf_relative,
132 meta.lower_height,
133 );
134
135 let upper_leaves_as_leaf: Vec<Leaf> =
136 cached_layer.iter().map(|h| h.as_leaf()).collect();
137
138 let upper_proof = get_merkle_proof(
139 &upper_leaves_as_leaf,
140 &tree.zero_values[meta.lower_height..],
141 meta.pos_in_layer,
142 meta.upper_height,
143 );
144
145 let mut proof = Vec::with_capacity(TREE_HEIGHT);
146 proof.extend(lower_proof);
147 proof.extend(upper_proof);
148
149 debug_assert_eq!(proof.len(), TREE_HEIGHT);
150
151 proof
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::tree::{verify, MerkleTree};
158
159 #[test]
160 fn split_properties() {
161 const TREE_HEIGHT: usize = 3; let layer_number = 2; let leaf_index = 5; let meta = compute_subtree_metadata(leaf_index, layer_number, TREE_HEIGHT);
167
168 assert_eq!(meta.lower_height + meta.upper_height, TREE_HEIGHT);
169 assert_eq!(meta.leaf_count, 1usize << layer_number);
170 assert!(
171 leaf_index >= meta.leaf_start &&
172 leaf_index < meta.leaf_start + meta.leaf_count
173 );
174 assert_eq!(
175 meta.subtree_node_index,
176 utils::first_index_in_layer(layer_number, TREE_HEIGHT) + meta.pos_in_layer
177 );
178 }
179
180 #[test]
181 fn split_baseline() {
182 const TREE_HEIGHT: usize = 5;
183 const FILLED: usize = 4; let layer_number = 2;
186 let leaf_index = 4usize;
187
188 let seeds: &[&[u8]] = &[b"test"];
189 let mut tree = MerkleTree::<TREE_HEIGHT>::new(seeds);
190
191 for i in 0..=FILLED {
192 tree.try_add(&[format!("val_{i}").as_bytes()]).unwrap();
193 }
194
195 let mut leaves_by_index = std::collections::HashMap::new();
196 for i in 0..=FILLED {
197 leaves_by_index.insert(i, Leaf::new(&[format!("val_{i}").as_bytes()]));
198 }
199 let fetch = |idx: usize| leaves_by_index.get(&idx).copied();
200 let zero = tree.zero_values[0].as_leaf();
201
202 let leaves: Vec<Leaf> = (0..=FILLED).map(|i| fetch(i).unwrap()).collect();
203 let cached_layer = tree.get_layer_nodes(&leaves, layer_number);
204
205 println!("cached_layer len: {:?}", cached_layer.len());
206
207 let meta = compute_subtree_metadata(leaf_index, layer_number, TREE_HEIGHT);
208 let lower: Vec<Leaf> = (meta.leaf_start .. meta.leaf_start + meta.leaf_count)
209 .map(|g| fetch(g).unwrap_or(zero))
210 .collect();
211
212 let proof_split = get_proof_with_metadata(
213 &meta,
214 &tree,
215 leaf_index - meta.leaf_start, &lower,
217 &cached_layer,
218 );
219
220 let baseline = get_merkle_proof(&leaves, &tree.zero_values, leaf_index, TREE_HEIGHT);
221
222 assert_eq!(proof_split, baseline);
223 assert!(verify(tree.get_root(), &proof_split, leaves_by_index[&leaf_index]));
224 }
225
226 #[test]
227 fn split_large_tree() {
228 const TREE_HEIGHT: usize = 18; const FILLED: usize = 1 << 12; let layer_number = 10;
232 let seeds: &[&[u8]] = &[b"large_tree"];
233 let mut tree = MerkleTree::<TREE_HEIGHT>::new(seeds);
234 let mut leaves_by_index = std::collections::HashMap::new();
235
236 for i in 0..FILLED {
237 let bytes = (i as u64).to_le_bytes();
238 let leaf = Leaf::new(&[&bytes]);
239 leaves_by_index.insert(i, leaf);
240 tree.try_add_leaf(leaf).unwrap();
241 }
242
243 let fetch = |idx: usize| leaves_by_index.get(&idx).copied();
244 let zero = tree.zero_values[0].as_leaf();
245
246 let leaves: Vec<Leaf> = (0..FILLED).map(|i| fetch(i).unwrap()).collect();
247 let cached_layer = tree.get_layer_nodes(&leaves, layer_number);
248
249 assert_eq!(cached_layer.len(), 4);
251
252 let leaf_index = 1234usize;
253 let proof_split = get_cached_merkle_proof(
254 &tree,
255 leaf_index,
256 layer_number,
257 &cached_layer,
258 |i| fetch(i).or_else(|| Some(zero)),
259 );
260
261 let baseline = get_merkle_proof(&leaves, &tree.zero_values, leaf_index, TREE_HEIGHT);
262
263 assert_eq!(proof_split, baseline);
264 assert!(verify(tree.get_root(), &proof_split, fetch(leaf_index).unwrap()));
265 }
266}