Skip to main content

luct_core/tree/
consistency.rs

1use crate::{
2    store::{AsyncStoreRead, Hashable, StoreRead},
3    tree::{HashOutput, Node, NodeKey, ProofGenerationError, ProofValidationError, Tree, TreeHead},
4};
5use futures::{FutureExt, future::join_all};
6
7impl<N, L, V> Tree<N, L, V>
8where
9    N: StoreRead<NodeKey, HashOutput>,
10    V: Hashable,
11{
12    /// This follows RFC 9162 2.1.4.1
13    pub fn get_consistency_proof(
14        &self,
15        first: &TreeHead,
16        second: &TreeHead,
17    ) -> Result<ConsistencyProof, ProofGenerationError> {
18        if first.tree_size > second.tree_size {
19            return Err(ProofGenerationError::InvalidTreeSize {
20                expected: first.tree_size,
21                received: second.tree_size,
22            });
23        }
24
25        let path = get_consistency_proof(first, second, |key| {
26            self.nodes
27                .get(&key)
28                .ok_or(ProofGenerationError::KeyNotFound(key))
29        });
30        let mut path = path
31            .into_iter()
32            .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
33
34        path.reverse();
35        Ok(ConsistencyProof { path })
36    }
37}
38
39impl<N, L, V> Tree<N, L, V>
40where
41    N: AsyncStoreRead<NodeKey, HashOutput>,
42    V: Hashable,
43{
44    pub async fn get_consistency_proof_async(
45        &self,
46        first: &TreeHead,
47        second: &TreeHead,
48    ) -> Result<ConsistencyProof, ProofGenerationError> {
49        if first.tree_size >= second.tree_size {
50            return Err(ProofGenerationError::InvalidTreeSize {
51                expected: first.tree_size,
52                received: second.tree_size,
53            });
54        }
55
56        let path = get_consistency_proof(first, second, |key| {
57            self.nodes
58                .get(key.clone())
59                .map(|result| result.ok_or(ProofGenerationError::KeyNotFound(key)))
60        });
61        let path = join_all(path).await;
62        let mut path = path
63            .into_iter()
64            .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
65
66        path.reverse();
67        Ok(ConsistencyProof { path })
68    }
69}
70
71fn get_consistency_proof<F, O>(first: &TreeHead, second: &TreeHead, get: F) -> Vec<O>
72where
73    F: Fn(NodeKey) -> O,
74{
75    let mut n = NodeKey::full_range(second.tree_size);
76    let m = first.tree_size;
77    let mut known = true;
78
79    let mut path = vec![];
80
81    while m != n.end {
82        let (left, right) = n.split();
83        if m <= right.start {
84            let elem = get(right);
85            path.push(elem);
86            n = left;
87        } else {
88            let elem = get(left);
89            path.push(elem);
90
91            known = false;
92            n = right;
93        }
94    }
95
96    if !known {
97        let elem = get(n);
98        path.push(elem);
99    }
100
101    path
102}
103
104#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
105pub struct ConsistencyProof {
106    pub(crate) path: Vec<HashOutput>,
107}
108
109impl ConsistencyProof {
110    /// This follows RFC 9162 2.1.4.2
111    pub fn validate(
112        &self,
113        first: &TreeHead,
114        second: &TreeHead,
115    ) -> Result<(), ProofValidationError> {
116        if first.tree_size > second.tree_size {
117            return Err(ProofValidationError::InvalidTreeSize {
118                expected: first.tree_size,
119                received: second.tree_size,
120            });
121        };
122        if first == second && self.path.is_empty() {
123            return Ok(());
124        }
125
126        let path: Vec<&HashOutput> = if first.tree_size.is_power_of_two() {
127            std::iter::once(&first.head)
128                .chain(self.path.iter())
129                .collect()
130        } else {
131            self.path.iter().collect()
132        };
133
134        let mut f_n = first.tree_size - 1;
135        let mut s_n = second.tree_size - 1;
136
137        while f_n & 1 == 1 {
138            f_n >>= 1;
139            s_n >>= 1;
140        }
141
142        let mut f_r = *path[0];
143        let mut s_r = *path[0];
144
145        for &c in &path[1..] {
146            if s_n == 0 {
147                return Err(ProofValidationError::PathTooShort);
148            }
149
150            if f_n & 1 == 1 || f_n == s_n {
151                f_r = Node {
152                    left: *c,
153                    right: f_r,
154                }
155                .hash();
156
157                s_r = Node {
158                    left: *c,
159                    right: s_r,
160                }
161                .hash();
162
163                while f_n & 1 == 0 && f_n != 0 {
164                    f_n >>= 1;
165                    s_n >>= 1;
166                }
167            } else {
168                s_r = Node {
169                    left: s_r,
170                    right: *c,
171                }
172                .hash();
173            }
174
175            f_n >>= 1;
176            s_n >>= 1;
177        }
178
179        if s_n != 0 {
180            return Err(ProofValidationError::PathTooLong);
181        }
182
183        if f_r != first.head || s_r != second.head {
184            return Err(ProofValidationError::HashMismatch);
185        }
186
187        Ok(())
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use rand::{Rng, SeedableRng, rngs::ChaCha8Rng};
194
195    use super::*;
196    use crate::store::MemoryStore;
197
198    #[test]
199    fn compute_inclusion_proofs() {
200        let tree = Tree::<_, _, String>::new(MemoryStore::default(), MemoryStore::default());
201
202        tree.insert_entry("A".to_string());
203        tree.insert_entry("B".to_string());
204        tree.insert_entry("C".to_string());
205
206        // Generate tree head
207        let tree_head1 = tree.recompute_tree_head();
208
209        tree.insert_entry("D".to_string());
210        let tree_head2 = tree.recompute_tree_head();
211
212        tree.insert_entry("E".to_string());
213        tree.insert_entry("F".to_string());
214        let tree_head3 = tree.recompute_tree_head();
215
216        tree.insert_entry("G".to_string());
217        let tree_head4 = tree.recompute_tree_head();
218
219        tree.insert_entry("H".to_string());
220
221        let proof1 = tree
222            .get_consistency_proof(&tree_head1, &tree_head4)
223            .unwrap();
224        assert_eq!(proof1.path.len(), 4);
225        proof1.validate(&tree_head1, &tree_head4).unwrap();
226
227        let proof2 = tree
228            .get_consistency_proof(&tree_head2, &tree_head4)
229            .unwrap();
230        assert_eq!(proof2.path.len(), 1);
231        assert_eq!(proof1.path[3], proof2.path[0]);
232        proof2.validate(&tree_head2, &tree_head4).unwrap();
233
234        let proof3 = tree
235            .get_consistency_proof(&tree_head3, &tree_head4)
236            .unwrap();
237        assert_eq!(proof3.path.len(), 3);
238        proof3.validate(&tree_head3, &tree_head4).unwrap();
239
240        let proof4 = tree
241            .get_consistency_proof(&tree_head4, &tree_head4)
242            .unwrap();
243        assert!(proof4.path.is_empty());
244        proof4.validate(&tree_head4, &tree_head4).unwrap();
245    }
246
247    #[test]
248    fn randomized_inclusion_proof() {
249        let first_size = 4973;
250        let second_size = 5009;
251        let mut rng = ChaCha8Rng::seed_from_u64(1337);
252
253        let tree = Tree::<_, _, HashOutput>::new(MemoryStore::default(), MemoryStore::default());
254
255        for _ in 0..first_size {
256            let mut entry = [0; 32];
257            rng.fill_bytes(&mut entry);
258            tree.insert_entry(entry);
259        }
260
261        let first_th = tree.recompute_tree_head();
262
263        for _ in first_size..second_size {
264            let mut entry = [0; 32];
265            rng.fill_bytes(&mut entry);
266            tree.insert_entry(entry);
267        }
268
269        let second_th = tree.recompute_tree_head();
270
271        let proof = tree.get_consistency_proof(&first_th, &second_th).unwrap();
272        proof.validate(&first_th, &second_th).unwrap();
273    }
274}