Skip to main content

luct_core/tree/
consistency.rs

1use crate::{
2    store::{AsyncStoreRead, Hashable, StoreRead},
3    tree::{HashOutput, Node, NodeKey, ProofGenerationError, ProofValidationError, Tree, TreeHead},
4};
5use futures::{FutureExt, future::join_all};
6
7impl<N, L> Tree<N, L>
8where
9    N: StoreRead<Key = NodeKey, Value = HashOutput>,
10{
11    /// This follows RFC 9162 2.1.4.1
12    pub fn get_consistency_proof(
13        &self,
14        first: &TreeHead,
15        second: &TreeHead,
16    ) -> Result<ConsistencyProof, ProofGenerationError> {
17        if first.tree_size > second.tree_size {
18            return Err(ProofGenerationError::InvalidTreeSize {
19                expected: first.tree_size,
20                received: second.tree_size,
21            });
22        }
23
24        let path = get_consistency_proof(first, second, |key| {
25            self.nodes
26                .get(&key)
27                .ok_or(ProofGenerationError::KeyNotFound(key))
28        });
29        let mut path = path
30            .into_iter()
31            .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
32
33        path.reverse();
34        Ok(ConsistencyProof { path })
35    }
36}
37
38impl<N, L> Tree<N, L>
39where
40    N: AsyncStoreRead<Key = NodeKey, Value = HashOutput>,
41{
42    pub async fn get_consistency_proof_async(
43        &self,
44        first: &TreeHead,
45        second: &TreeHead,
46    ) -> Result<ConsistencyProof, ProofGenerationError> {
47        if first.tree_size >= second.tree_size {
48            return Err(ProofGenerationError::InvalidTreeSize {
49                expected: first.tree_size,
50                received: second.tree_size,
51            });
52        }
53
54        let path = get_consistency_proof(first, second, |key| {
55            self.nodes
56                .get(key.clone())
57                .map(|result| result.ok_or(ProofGenerationError::KeyNotFound(key)))
58        });
59        let path = join_all(path).await;
60        let mut path = path
61            .into_iter()
62            .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
63
64        path.reverse();
65        Ok(ConsistencyProof { path })
66    }
67}
68
69fn get_consistency_proof<F, O>(first: &TreeHead, second: &TreeHead, get: F) -> Vec<O>
70where
71    F: Fn(NodeKey) -> O,
72{
73    let mut n = NodeKey::full_range(second.tree_size);
74    let m = first.tree_size;
75    let mut known = true;
76
77    let mut path = vec![];
78
79    while m != n.end {
80        let (left, right) = n.split();
81        if m <= right.start {
82            let elem = get(right);
83            path.push(elem);
84            n = left;
85        } else {
86            let elem = get(left);
87            path.push(elem);
88
89            known = false;
90            n = right;
91        }
92    }
93
94    if !known {
95        let elem = get(n);
96        path.push(elem);
97    }
98
99    path
100}
101
102#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
103pub struct ConsistencyProof {
104    pub(crate) path: Vec<HashOutput>,
105}
106
107impl ConsistencyProof {
108    /// This follows RFC 9162 2.1.4.2
109    pub fn validate(
110        &self,
111        first: &TreeHead,
112        second: &TreeHead,
113    ) -> Result<(), ProofValidationError> {
114        if first.tree_size > second.tree_size {
115            return Err(ProofValidationError::InvalidTreeSize {
116                expected: first.tree_size,
117                received: second.tree_size,
118            });
119        };
120        if first == second && self.path.is_empty() {
121            return Ok(());
122        }
123
124        let path: Vec<&HashOutput> = if first.tree_size.is_power_of_two() {
125            std::iter::once(&first.head)
126                .chain(self.path.iter())
127                .collect()
128        } else {
129            self.path.iter().collect()
130        };
131
132        let mut f_n = first.tree_size - 1;
133        let mut s_n = second.tree_size - 1;
134
135        while f_n & 1 == 1 {
136            f_n >>= 1;
137            s_n >>= 1;
138        }
139
140        let mut f_r = *path[0];
141        let mut s_r = *path[0];
142
143        for &c in &path[1..] {
144            if s_n == 0 {
145                return Err(ProofValidationError::PathTooShort);
146            }
147
148            if f_n & 1 == 1 || f_n == s_n {
149                f_r = Node {
150                    left: *c,
151                    right: f_r,
152                }
153                .hash();
154
155                s_r = Node {
156                    left: *c,
157                    right: s_r,
158                }
159                .hash();
160
161                while f_n & 1 == 0 && f_n != 0 {
162                    f_n >>= 1;
163                    s_n >>= 1;
164                }
165            } else {
166                s_r = Node {
167                    left: s_r,
168                    right: *c,
169                }
170                .hash();
171            }
172
173            f_n >>= 1;
174            s_n >>= 1;
175        }
176
177        if s_n != 0 {
178            return Err(ProofValidationError::PathTooLong);
179        }
180
181        if f_r != first.head || s_r != second.head {
182            return Err(ProofValidationError::HashMismatch);
183        }
184
185        Ok(())
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use rand::{Rng, SeedableRng, rngs::ChaCha8Rng};
192
193    use super::*;
194    use crate::store::MemoryStore;
195
196    #[test]
197    fn compute_inclusion_proofs() {
198        let tree = Tree::<MemoryStore<NodeKey, HashOutput>, MemoryStore<u64, String>>::new(
199            MemoryStore::default(),
200            MemoryStore::default(),
201        );
202
203        tree.insert_entry("A".to_string());
204        tree.insert_entry("B".to_string());
205        tree.insert_entry("C".to_string());
206
207        // Generate tree head
208        let tree_head1 = tree.recompute_tree_head();
209
210        tree.insert_entry("D".to_string());
211        let tree_head2 = tree.recompute_tree_head();
212
213        tree.insert_entry("E".to_string());
214        tree.insert_entry("F".to_string());
215        let tree_head3 = tree.recompute_tree_head();
216
217        tree.insert_entry("G".to_string());
218        let tree_head4 = tree.recompute_tree_head();
219
220        tree.insert_entry("H".to_string());
221
222        let proof1 = tree
223            .get_consistency_proof(&tree_head1, &tree_head4)
224            .unwrap();
225        assert_eq!(proof1.path.len(), 4);
226        proof1.validate(&tree_head1, &tree_head4).unwrap();
227
228        let proof2 = tree
229            .get_consistency_proof(&tree_head2, &tree_head4)
230            .unwrap();
231        assert_eq!(proof2.path.len(), 1);
232        assert_eq!(proof1.path[3], proof2.path[0]);
233        proof2.validate(&tree_head2, &tree_head4).unwrap();
234
235        let proof3 = tree
236            .get_consistency_proof(&tree_head3, &tree_head4)
237            .unwrap();
238        assert_eq!(proof3.path.len(), 3);
239        proof3.validate(&tree_head3, &tree_head4).unwrap();
240
241        let proof4 = tree
242            .get_consistency_proof(&tree_head4, &tree_head4)
243            .unwrap();
244        assert!(proof4.path.is_empty());
245        proof4.validate(&tree_head4, &tree_head4).unwrap();
246    }
247
248    #[test]
249    fn randomized_inclusion_proof() {
250        let first_size = 4973;
251        let second_size = 5009;
252        let mut rng = ChaCha8Rng::seed_from_u64(1337);
253
254        let tree = Tree::<MemoryStore<NodeKey, HashOutput>, MemoryStore<u64, HashOutput>>::new(
255            MemoryStore::default(),
256            MemoryStore::default(),
257        );
258
259        for _ in 0..first_size {
260            let mut entry = [0; 32];
261            rng.fill_bytes(&mut entry);
262            tree.insert_entry(entry);
263        }
264
265        let first_th = tree.recompute_tree_head();
266
267        for _ in first_size..second_size {
268            let mut entry = [0; 32];
269            rng.fill_bytes(&mut entry);
270            tree.insert_entry(entry);
271        }
272
273        let second_th = tree.recompute_tree_head();
274
275        let proof = tree.get_consistency_proof(&first_th, &second_th).unwrap();
276        proof.validate(&first_th, &second_th).unwrap();
277    }
278}