Skip to main content

luct_core/tree/
inclusion.rs

1use crate::{
2    store::{AsyncStoreRead, Hashable, StoreRead},
3    tree::{HashOutput, Node, NodeKey, ProofGenerationError, ProofValidationError, Tree, TreeHead},
4};
5use futures::{FutureExt, future::join_all};
6
7impl<N, L> Tree<N, L>
8where
9    N: StoreRead<Key = NodeKey, Value = HashOutput>,
10{
11    /// This follows RFC 9162 2.1.3.1
12    pub fn get_audit_proof(
13        &self,
14        head: &TreeHead,
15        index: u64,
16    ) -> Result<AuditProof, ProofGenerationError> {
17        if index >= head.tree_size {
18            return Err(ProofGenerationError::InvalidIndex {
19                tree_size: head.tree_size,
20                index,
21            });
22        }
23
24        let path = get_audit_proof(head, index, |key| {
25            self.nodes
26                .get(&key)
27                .ok_or(ProofGenerationError::KeyNotFound(key))
28        });
29        let mut path = path
30            .into_iter()
31            .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
32
33        path.reverse();
34        Ok(AuditProof { index, path })
35    }
36}
37
38impl<N, L> Tree<N, L>
39where
40    N: AsyncStoreRead<Key = NodeKey, Value = HashOutput>,
41{
42    pub async fn get_audit_proof_async(
43        &self,
44        head: &TreeHead,
45        index: u64,
46    ) -> Result<AuditProof, ProofGenerationError> {
47        if index >= head.tree_size {
48            return Err(ProofGenerationError::InvalidIndex {
49                tree_size: head.tree_size,
50                index,
51            });
52        }
53
54        let path = get_audit_proof(head, index, |key| {
55            self.nodes
56                .get(key.clone())
57                .map(|result| result.ok_or(ProofGenerationError::KeyNotFound(key)))
58        });
59        let path = join_all(path).await;
60        let mut path = path
61            .into_iter()
62            .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
63
64        path.reverse();
65        Ok(AuditProof { index, path })
66    }
67}
68
69fn get_audit_proof<F, O>(head: &TreeHead, index: u64, get: F) -> Vec<O>
70where
71    F: Fn(NodeKey) -> O,
72{
73    let mut n = NodeKey::full_range(head.tree_size);
74    let m = index;
75
76    let mut path = vec![];
77
78    while !n.is_leaf() {
79        let (left, right) = n.split();
80        if m < right.start {
81            let elem = get(right);
82            path.push(elem);
83
84            n = left;
85        } else {
86            let elem = get(left);
87            path.push(elem);
88
89            n = right;
90        }
91    }
92
93    path
94}
95
96#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
97pub struct AuditProof {
98    pub(crate) index: u64,
99    pub(crate) path: Vec<HashOutput>,
100}
101
102impl AuditProof {
103    pub fn index(&self) -> u64 {
104        self.index
105    }
106
107    pub fn validate(
108        &self,
109        head: &TreeHead,
110        leaf: &impl Hashable,
111    ) -> Result<(), ProofValidationError> {
112        if head.tree_size < self.index {
113            return Err(ProofValidationError::InvalidIndex {
114                tree_size: head.tree_size,
115                index: self.index,
116            });
117        }
118
119        let mut f_n = self.index;
120        let mut s_n = head.tree_size - 1;
121        let mut r = leaf.hash();
122
123        for p in &self.path {
124            if s_n == 0 {
125                return Err(ProofValidationError::PathTooShort);
126            }
127
128            if f_n & 1 == 1 || f_n == s_n {
129                r = Node { left: *p, right: r }.hash();
130
131                while f_n & 1 != 1 && f_n != 0 {
132                    f_n >>= 1;
133                    s_n >>= 1;
134                }
135            } else {
136                r = Node { left: r, right: *p }.hash();
137            }
138
139            f_n >>= 1;
140            s_n >>= 1;
141        }
142
143        if s_n != 0 {
144            return Err(ProofValidationError::PathTooLong);
145        }
146        if r != head.head {
147            return Err(ProofValidationError::HashMismatch);
148        }
149
150        Ok(())
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::store::MemoryStore;
158
159    #[test]
160    fn compute_audit_proofs() {
161        let tree = Tree::<MemoryStore<NodeKey, HashOutput>, MemoryStore<u64, String>>::new(
162            MemoryStore::default(),
163            MemoryStore::default(),
164        );
165
166        tree.insert_entry("A".to_string());
167        tree.insert_entry("B".to_string());
168        tree.insert_entry("C".to_string());
169        tree.insert_entry("D".to_string());
170        tree.insert_entry("E".to_string());
171        tree.insert_entry("F".to_string());
172        tree.insert_entry("G".to_string());
173
174        let head = tree.recompute_tree_head();
175
176        let proof1 = tree.get_audit_proof(&head, 0).unwrap();
177        assert_eq!(proof1.path.len(), 3);
178        proof1.validate(&head, &"A".to_string()).unwrap();
179
180        let proof2 = tree.get_audit_proof(&head, 3).unwrap();
181        assert_eq!(proof2.path.len(), 3);
182        proof2.validate(&head, &"D".to_string()).unwrap();
183
184        let proof3 = tree.get_audit_proof(&head, 4).unwrap();
185        assert_eq!(proof3.path.len(), 3);
186        proof3.validate(&head, &"E".to_string()).unwrap();
187
188        let proof4 = tree.get_audit_proof(&head, 6).unwrap();
189        assert_eq!(proof4.path.len(), 2);
190        proof4.validate(&head, &"G".to_string()).unwrap();
191    }
192}