Skip to main content

luct_core/tree/
inclusion.rs

1use crate::{
2    store::{AsyncStoreRead, Hashable, StoreRead},
3    tree::{HashOutput, Node, NodeKey, ProofGenerationError, ProofValidationError, Tree, TreeHead},
4};
5use futures::{FutureExt, future::join_all};
6
7impl<N, L, V> Tree<N, L, V>
8where
9    N: StoreRead<NodeKey, HashOutput>,
10    V: Hashable,
11{
12    /// This follows RFC 9162 2.1.3.1
13    pub fn get_audit_proof(
14        &self,
15        head: &TreeHead,
16        index: u64,
17    ) -> Result<AuditProof, ProofGenerationError> {
18        if index >= head.tree_size {
19            return Err(ProofGenerationError::InvalidIndex {
20                tree_size: head.tree_size,
21                index,
22            });
23        }
24
25        let path = get_audit_proof(head, index, |key| {
26            self.nodes
27                .get(&key)
28                .ok_or(ProofGenerationError::KeyNotFound(key))
29        });
30        let mut path = path
31            .into_iter()
32            .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
33
34        path.reverse();
35        Ok(AuditProof { index, path })
36    }
37}
38
39impl<N, L, V> Tree<N, L, V>
40where
41    N: AsyncStoreRead<NodeKey, HashOutput>,
42    V: Hashable,
43{
44    pub async fn get_audit_proof_async(
45        &self,
46        head: &TreeHead,
47        index: u64,
48    ) -> Result<AuditProof, ProofGenerationError> {
49        if index >= head.tree_size {
50            return Err(ProofGenerationError::InvalidIndex {
51                tree_size: head.tree_size,
52                index,
53            });
54        }
55
56        let path = get_audit_proof(head, index, |key| {
57            self.nodes
58                .get(key.clone())
59                .map(|result| result.ok_or(ProofGenerationError::KeyNotFound(key)))
60        });
61        let path = join_all(path).await;
62        let mut path = path
63            .into_iter()
64            .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
65
66        path.reverse();
67        Ok(AuditProof { index, path })
68    }
69}
70
71fn get_audit_proof<F, O>(head: &TreeHead, index: u64, get: F) -> Vec<O>
72where
73    F: Fn(NodeKey) -> O,
74{
75    let mut n = NodeKey::full_range(head.tree_size);
76    let m = index;
77
78    let mut path = vec![];
79
80    while !n.is_leaf() {
81        let (left, right) = n.split();
82        if m < right.start {
83            let elem = get(right);
84            path.push(elem);
85
86            n = left;
87        } else {
88            let elem = get(left);
89            path.push(elem);
90
91            n = right;
92        }
93    }
94
95    path
96}
97
98#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
99pub struct AuditProof {
100    pub(crate) index: u64,
101    pub(crate) path: Vec<HashOutput>,
102}
103
104impl AuditProof {
105    pub fn index(&self) -> u64 {
106        self.index
107    }
108
109    pub fn validate(
110        &self,
111        head: &TreeHead,
112        leaf: &impl Hashable,
113    ) -> Result<(), ProofValidationError> {
114        if head.tree_size < self.index {
115            return Err(ProofValidationError::InvalidIndex {
116                tree_size: head.tree_size,
117                index: self.index,
118            });
119        }
120
121        let mut f_n = self.index;
122        let mut s_n = head.tree_size - 1;
123        let mut r = leaf.hash();
124
125        for p in &self.path {
126            if s_n == 0 {
127                return Err(ProofValidationError::PathTooShort);
128            }
129
130            if f_n & 1 == 1 || f_n == s_n {
131                r = Node { left: *p, right: r }.hash();
132
133                while f_n & 1 != 1 && f_n != 0 {
134                    f_n >>= 1;
135                    s_n >>= 1;
136                }
137            } else {
138                r = Node { left: r, right: *p }.hash();
139            }
140
141            f_n >>= 1;
142            s_n >>= 1;
143        }
144
145        if s_n != 0 {
146            return Err(ProofValidationError::PathTooLong);
147        }
148        if r != head.head {
149            return Err(ProofValidationError::HashMismatch);
150        }
151
152        Ok(())
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::store::MemoryStore;
160
161    #[test]
162    fn compute_audit_proofs() {
163        let tree = Tree::<_, _, String>::new(MemoryStore::default(), MemoryStore::default());
164
165        tree.insert_entry("A".to_string());
166        tree.insert_entry("B".to_string());
167        tree.insert_entry("C".to_string());
168        tree.insert_entry("D".to_string());
169        tree.insert_entry("E".to_string());
170        tree.insert_entry("F".to_string());
171        tree.insert_entry("G".to_string());
172
173        let head = tree.recompute_tree_head();
174
175        let proof1 = tree.get_audit_proof(&head, 0).unwrap();
176        assert_eq!(proof1.path.len(), 3);
177        proof1.validate(&head, &"A".to_string()).unwrap();
178
179        let proof2 = tree.get_audit_proof(&head, 3).unwrap();
180        assert_eq!(proof2.path.len(), 3);
181        proof2.validate(&head, &"D".to_string()).unwrap();
182
183        let proof3 = tree.get_audit_proof(&head, 4).unwrap();
184        assert_eq!(proof3.path.len(), 3);
185        proof3.validate(&head, &"E".to_string()).unwrap();
186
187        let proof4 = tree.get_audit_proof(&head, 6).unwrap();
188        assert_eq!(proof4.path.len(), 2);
189        proof4.validate(&head, &"G".to_string()).unwrap();
190    }
191}