1use crate::{
2 store::{AsyncStoreRead, Hashable, StoreRead},
3 tree::{HashOutput, Node, NodeKey, ProofGenerationError, ProofValidationError, Tree, TreeHead},
4};
5use futures::{FutureExt, future::join_all};
6
7impl<N, L, V> Tree<N, L, V>
8where
9 N: StoreRead<NodeKey, HashOutput>,
10 V: Hashable,
11{
12 pub fn get_consistency_proof(
14 &self,
15 first: &TreeHead,
16 second: &TreeHead,
17 ) -> Result<ConsistencyProof, ProofGenerationError> {
18 if first.tree_size > second.tree_size {
19 return Err(ProofGenerationError::InvalidTreeSize {
20 expected: first.tree_size,
21 received: second.tree_size,
22 });
23 }
24
25 let path = get_consistency_proof(first, second, |key| {
26 self.nodes
27 .get(&key)
28 .ok_or(ProofGenerationError::KeyNotFound(key))
29 });
30 let mut path = path
31 .into_iter()
32 .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
33
34 path.reverse();
35 Ok(ConsistencyProof { path })
36 }
37}
38
39impl<N, L, V> Tree<N, L, V>
40where
41 N: AsyncStoreRead<NodeKey, HashOutput>,
42 V: Hashable,
43{
44 pub async fn get_consistency_proof_async(
45 &self,
46 first: &TreeHead,
47 second: &TreeHead,
48 ) -> Result<ConsistencyProof, ProofGenerationError> {
49 if first.tree_size >= second.tree_size {
50 return Err(ProofGenerationError::InvalidTreeSize {
51 expected: first.tree_size,
52 received: second.tree_size,
53 });
54 }
55
56 let path = get_consistency_proof(first, second, |key| {
57 self.nodes
58 .get(key.clone())
59 .map(|result| result.ok_or(ProofGenerationError::KeyNotFound(key)))
60 });
61 let path = join_all(path).await;
62 let mut path = path
63 .into_iter()
64 .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
65
66 path.reverse();
67 Ok(ConsistencyProof { path })
68 }
69}
70
71fn get_consistency_proof<F, O>(first: &TreeHead, second: &TreeHead, get: F) -> Vec<O>
72where
73 F: Fn(NodeKey) -> O,
74{
75 let mut n = NodeKey::full_range(second.tree_size);
76 let m = first.tree_size;
77 let mut known = true;
78
79 let mut path = vec![];
80
81 while m != n.end {
82 let (left, right) = n.split();
83 if m <= right.start {
84 let elem = get(right);
85 path.push(elem);
86 n = left;
87 } else {
88 let elem = get(left);
89 path.push(elem);
90
91 known = false;
92 n = right;
93 }
94 }
95
96 if !known {
97 let elem = get(n);
98 path.push(elem);
99 }
100
101 path
102}
103
104#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
105pub struct ConsistencyProof {
106 pub(crate) path: Vec<HashOutput>,
107}
108
109impl ConsistencyProof {
110 pub fn validate(
112 &self,
113 first: &TreeHead,
114 second: &TreeHead,
115 ) -> Result<(), ProofValidationError> {
116 if first.tree_size > second.tree_size {
117 return Err(ProofValidationError::InvalidTreeSize {
118 expected: first.tree_size,
119 received: second.tree_size,
120 });
121 };
122 if first == second && self.path.is_empty() {
123 return Ok(());
124 }
125
126 let path: Vec<&HashOutput> = if first.tree_size.is_power_of_two() {
127 std::iter::once(&first.head)
128 .chain(self.path.iter())
129 .collect()
130 } else {
131 self.path.iter().collect()
132 };
133
134 let mut f_n = first.tree_size - 1;
135 let mut s_n = second.tree_size - 1;
136
137 while f_n & 1 == 1 {
138 f_n >>= 1;
139 s_n >>= 1;
140 }
141
142 let mut f_r = *path[0];
143 let mut s_r = *path[0];
144
145 for &c in &path[1..] {
146 if s_n == 0 {
147 return Err(ProofValidationError::PathTooShort);
148 }
149
150 if f_n & 1 == 1 || f_n == s_n {
151 f_r = Node {
152 left: *c,
153 right: f_r,
154 }
155 .hash();
156
157 s_r = Node {
158 left: *c,
159 right: s_r,
160 }
161 .hash();
162
163 while f_n & 1 == 0 && f_n != 0 {
164 f_n >>= 1;
165 s_n >>= 1;
166 }
167 } else {
168 s_r = Node {
169 left: s_r,
170 right: *c,
171 }
172 .hash();
173 }
174
175 f_n >>= 1;
176 s_n >>= 1;
177 }
178
179 if s_n != 0 {
180 return Err(ProofValidationError::PathTooLong);
181 }
182
183 if f_r != first.head || s_r != second.head {
184 return Err(ProofValidationError::HashMismatch);
185 }
186
187 Ok(())
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use rand::{Rng, SeedableRng, rngs::ChaCha8Rng};
194
195 use super::*;
196 use crate::store::MemoryStore;
197
198 #[test]
199 fn compute_inclusion_proofs() {
200 let tree = Tree::<_, _, String>::new(MemoryStore::default(), MemoryStore::default());
201
202 tree.insert_entry("A".to_string());
203 tree.insert_entry("B".to_string());
204 tree.insert_entry("C".to_string());
205
206 let tree_head1 = tree.recompute_tree_head();
208
209 tree.insert_entry("D".to_string());
210 let tree_head2 = tree.recompute_tree_head();
211
212 tree.insert_entry("E".to_string());
213 tree.insert_entry("F".to_string());
214 let tree_head3 = tree.recompute_tree_head();
215
216 tree.insert_entry("G".to_string());
217 let tree_head4 = tree.recompute_tree_head();
218
219 tree.insert_entry("H".to_string());
220
221 let proof1 = tree
222 .get_consistency_proof(&tree_head1, &tree_head4)
223 .unwrap();
224 assert_eq!(proof1.path.len(), 4);
225 proof1.validate(&tree_head1, &tree_head4).unwrap();
226
227 let proof2 = tree
228 .get_consistency_proof(&tree_head2, &tree_head4)
229 .unwrap();
230 assert_eq!(proof2.path.len(), 1);
231 assert_eq!(proof1.path[3], proof2.path[0]);
232 proof2.validate(&tree_head2, &tree_head4).unwrap();
233
234 let proof3 = tree
235 .get_consistency_proof(&tree_head3, &tree_head4)
236 .unwrap();
237 assert_eq!(proof3.path.len(), 3);
238 proof3.validate(&tree_head3, &tree_head4).unwrap();
239
240 let proof4 = tree
241 .get_consistency_proof(&tree_head4, &tree_head4)
242 .unwrap();
243 assert!(proof4.path.is_empty());
244 proof4.validate(&tree_head4, &tree_head4).unwrap();
245 }
246
247 #[test]
248 fn randomized_inclusion_proof() {
249 let first_size = 4973;
250 let second_size = 5009;
251 let mut rng = ChaCha8Rng::seed_from_u64(1337);
252
253 let tree = Tree::<_, _, HashOutput>::new(MemoryStore::default(), MemoryStore::default());
254
255 for _ in 0..first_size {
256 let mut entry = [0; 32];
257 rng.fill_bytes(&mut entry);
258 tree.insert_entry(entry);
259 }
260
261 let first_th = tree.recompute_tree_head();
262
263 for _ in first_size..second_size {
264 let mut entry = [0; 32];
265 rng.fill_bytes(&mut entry);
266 tree.insert_entry(entry);
267 }
268
269 let second_th = tree.recompute_tree_head();
270
271 let proof = tree.get_consistency_proof(&first_th, &second_th).unwrap();
272 proof.validate(&first_th, &second_th).unwrap();
273 }
274}