1use crate::{
2 store::{AsyncStoreRead, Hashable, StoreRead},
3 tree::{HashOutput, Node, NodeKey, ProofGenerationError, ProofValidationError, Tree, TreeHead},
4};
5use futures::{FutureExt, future::join_all};
6
7impl<N, L> Tree<N, L>
8where
9 N: StoreRead<Key = NodeKey, Value = HashOutput>,
10{
11 pub fn get_consistency_proof(
13 &self,
14 first: &TreeHead,
15 second: &TreeHead,
16 ) -> Result<ConsistencyProof, ProofGenerationError> {
17 if first.tree_size > second.tree_size {
18 return Err(ProofGenerationError::InvalidTreeSize {
19 expected: first.tree_size,
20 received: second.tree_size,
21 });
22 }
23
24 let path = get_consistency_proof(first, second, |key| {
25 self.nodes
26 .get(&key)
27 .ok_or(ProofGenerationError::KeyNotFound(key))
28 });
29 let mut path = path
30 .into_iter()
31 .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
32
33 path.reverse();
34 Ok(ConsistencyProof { path })
35 }
36}
37
38impl<N, L> Tree<N, L>
39where
40 N: AsyncStoreRead<Key = NodeKey, Value = HashOutput>,
41{
42 pub async fn get_consistency_proof_async(
43 &self,
44 first: &TreeHead,
45 second: &TreeHead,
46 ) -> Result<ConsistencyProof, ProofGenerationError> {
47 if first.tree_size >= second.tree_size {
48 return Err(ProofGenerationError::InvalidTreeSize {
49 expected: first.tree_size,
50 received: second.tree_size,
51 });
52 }
53
54 let path = get_consistency_proof(first, second, |key| {
55 self.nodes
56 .get(key.clone())
57 .map(|result| result.ok_or(ProofGenerationError::KeyNotFound(key)))
58 });
59 let path = join_all(path).await;
60 let mut path = path
61 .into_iter()
62 .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
63
64 path.reverse();
65 Ok(ConsistencyProof { path })
66 }
67}
68
69fn get_consistency_proof<F, O>(first: &TreeHead, second: &TreeHead, get: F) -> Vec<O>
70where
71 F: Fn(NodeKey) -> O,
72{
73 let mut n = NodeKey::full_range(second.tree_size);
74 let m = first.tree_size;
75 let mut known = true;
76
77 let mut path = vec![];
78
79 while m != n.end {
80 let (left, right) = n.split();
81 if m <= right.start {
82 let elem = get(right);
83 path.push(elem);
84 n = left;
85 } else {
86 let elem = get(left);
87 path.push(elem);
88
89 known = false;
90 n = right;
91 }
92 }
93
94 if !known {
95 let elem = get(n);
96 path.push(elem);
97 }
98
99 path
100}
101
102#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
103pub struct ConsistencyProof {
104 pub(crate) path: Vec<HashOutput>,
105}
106
107impl ConsistencyProof {
108 pub fn validate(
110 &self,
111 first: &TreeHead,
112 second: &TreeHead,
113 ) -> Result<(), ProofValidationError> {
114 if first.tree_size > second.tree_size {
115 return Err(ProofValidationError::InvalidTreeSize {
116 expected: first.tree_size,
117 received: second.tree_size,
118 });
119 };
120 if first == second && self.path.is_empty() {
121 return Ok(());
122 }
123
124 let path: Vec<&HashOutput> = if first.tree_size.is_power_of_two() {
125 std::iter::once(&first.head)
126 .chain(self.path.iter())
127 .collect()
128 } else {
129 self.path.iter().collect()
130 };
131
132 let mut f_n = first.tree_size - 1;
133 let mut s_n = second.tree_size - 1;
134
135 while f_n & 1 == 1 {
136 f_n >>= 1;
137 s_n >>= 1;
138 }
139
140 let mut f_r = *path[0];
141 let mut s_r = *path[0];
142
143 for &c in &path[1..] {
144 if s_n == 0 {
145 return Err(ProofValidationError::PathTooShort);
146 }
147
148 if f_n & 1 == 1 || f_n == s_n {
149 f_r = Node {
150 left: *c,
151 right: f_r,
152 }
153 .hash();
154
155 s_r = Node {
156 left: *c,
157 right: s_r,
158 }
159 .hash();
160
161 while f_n & 1 == 0 && f_n != 0 {
162 f_n >>= 1;
163 s_n >>= 1;
164 }
165 } else {
166 s_r = Node {
167 left: s_r,
168 right: *c,
169 }
170 .hash();
171 }
172
173 f_n >>= 1;
174 s_n >>= 1;
175 }
176
177 if s_n != 0 {
178 return Err(ProofValidationError::PathTooLong);
179 }
180
181 if f_r != first.head || s_r != second.head {
182 return Err(ProofValidationError::HashMismatch);
183 }
184
185 Ok(())
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use rand::{Rng, SeedableRng, rngs::ChaCha8Rng};
192
193 use super::*;
194 use crate::store::MemoryStore;
195
196 #[test]
197 fn compute_inclusion_proofs() {
198 let tree = Tree::<MemoryStore<NodeKey, HashOutput>, MemoryStore<u64, String>>::new(
199 MemoryStore::default(),
200 MemoryStore::default(),
201 );
202
203 tree.insert_entry("A".to_string());
204 tree.insert_entry("B".to_string());
205 tree.insert_entry("C".to_string());
206
207 let tree_head1 = tree.recompute_tree_head();
209
210 tree.insert_entry("D".to_string());
211 let tree_head2 = tree.recompute_tree_head();
212
213 tree.insert_entry("E".to_string());
214 tree.insert_entry("F".to_string());
215 let tree_head3 = tree.recompute_tree_head();
216
217 tree.insert_entry("G".to_string());
218 let tree_head4 = tree.recompute_tree_head();
219
220 tree.insert_entry("H".to_string());
221
222 let proof1 = tree
223 .get_consistency_proof(&tree_head1, &tree_head4)
224 .unwrap();
225 assert_eq!(proof1.path.len(), 4);
226 proof1.validate(&tree_head1, &tree_head4).unwrap();
227
228 let proof2 = tree
229 .get_consistency_proof(&tree_head2, &tree_head4)
230 .unwrap();
231 assert_eq!(proof2.path.len(), 1);
232 assert_eq!(proof1.path[3], proof2.path[0]);
233 proof2.validate(&tree_head2, &tree_head4).unwrap();
234
235 let proof3 = tree
236 .get_consistency_proof(&tree_head3, &tree_head4)
237 .unwrap();
238 assert_eq!(proof3.path.len(), 3);
239 proof3.validate(&tree_head3, &tree_head4).unwrap();
240
241 let proof4 = tree
242 .get_consistency_proof(&tree_head4, &tree_head4)
243 .unwrap();
244 assert!(proof4.path.is_empty());
245 proof4.validate(&tree_head4, &tree_head4).unwrap();
246 }
247
248 #[test]
249 fn randomized_inclusion_proof() {
250 let first_size = 4973;
251 let second_size = 5009;
252 let mut rng = ChaCha8Rng::seed_from_u64(1337);
253
254 let tree = Tree::<MemoryStore<NodeKey, HashOutput>, MemoryStore<u64, HashOutput>>::new(
255 MemoryStore::default(),
256 MemoryStore::default(),
257 );
258
259 for _ in 0..first_size {
260 let mut entry = [0; 32];
261 rng.fill_bytes(&mut entry);
262 tree.insert_entry(entry);
263 }
264
265 let first_th = tree.recompute_tree_head();
266
267 for _ in first_size..second_size {
268 let mut entry = [0; 32];
269 rng.fill_bytes(&mut entry);
270 tree.insert_entry(entry);
271 }
272
273 let second_th = tree.recompute_tree_head();
274
275 let proof = tree.get_consistency_proof(&first_th, &second_th).unwrap();
276 proof.validate(&first_th, &second_th).unwrap();
277 }
278}