luct_core/tree/
inclusion.rs1use crate::{
2 store::{AsyncStoreRead, Hashable, StoreRead},
3 tree::{HashOutput, Node, NodeKey, ProofGenerationError, ProofValidationError, Tree, TreeHead},
4};
5use futures::{FutureExt, future::join_all};
6
7impl<N, L, V> Tree<N, L, V>
8where
9 N: StoreRead<NodeKey, HashOutput>,
10 V: Hashable,
11{
12 pub fn get_audit_proof(
14 &self,
15 head: &TreeHead,
16 index: u64,
17 ) -> Result<AuditProof, ProofGenerationError> {
18 if index >= head.tree_size {
19 return Err(ProofGenerationError::InvalidIndex {
20 tree_size: head.tree_size,
21 index,
22 });
23 }
24
25 let path = get_audit_proof(head, index, |key| {
26 self.nodes
27 .get(&key)
28 .ok_or(ProofGenerationError::KeyNotFound(key))
29 });
30 let mut path = path
31 .into_iter()
32 .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
33
34 path.reverse();
35 Ok(AuditProof { index, path })
36 }
37}
38
39impl<N, L, V> Tree<N, L, V>
40where
41 N: AsyncStoreRead<NodeKey, HashOutput>,
42 V: Hashable,
43{
44 pub async fn get_audit_proof_async(
45 &self,
46 head: &TreeHead,
47 index: u64,
48 ) -> Result<AuditProof, ProofGenerationError> {
49 if index >= head.tree_size {
50 return Err(ProofGenerationError::InvalidIndex {
51 tree_size: head.tree_size,
52 index,
53 });
54 }
55
56 let path = get_audit_proof(head, index, |key| {
57 self.nodes
58 .get(key.clone())
59 .map(|result| result.ok_or(ProofGenerationError::KeyNotFound(key)))
60 });
61 let path = join_all(path).await;
62 let mut path = path
63 .into_iter()
64 .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
65
66 path.reverse();
67 Ok(AuditProof { index, path })
68 }
69}
70
71fn get_audit_proof<F, O>(head: &TreeHead, index: u64, get: F) -> Vec<O>
72where
73 F: Fn(NodeKey) -> O,
74{
75 let mut n = NodeKey::full_range(head.tree_size);
76 let m = index;
77
78 let mut path = vec![];
79
80 while !n.is_leaf() {
81 let (left, right) = n.split();
82 if m < right.start {
83 let elem = get(right);
84 path.push(elem);
85
86 n = left;
87 } else {
88 let elem = get(left);
89 path.push(elem);
90
91 n = right;
92 }
93 }
94
95 path
96}
97
98#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
99pub struct AuditProof {
100 pub(crate) index: u64,
101 pub(crate) path: Vec<HashOutput>,
102}
103
104impl AuditProof {
105 pub fn index(&self) -> u64 {
106 self.index
107 }
108
109 pub fn validate(
110 &self,
111 head: &TreeHead,
112 leaf: &impl Hashable,
113 ) -> Result<(), ProofValidationError> {
114 if head.tree_size < self.index {
115 return Err(ProofValidationError::InvalidIndex {
116 tree_size: head.tree_size,
117 index: self.index,
118 });
119 }
120
121 let mut f_n = self.index;
122 let mut s_n = head.tree_size - 1;
123 let mut r = leaf.hash();
124
125 for p in &self.path {
126 if s_n == 0 {
127 return Err(ProofValidationError::PathTooShort);
128 }
129
130 if f_n & 1 == 1 || f_n == s_n {
131 r = Node { left: *p, right: r }.hash();
132
133 while f_n & 1 != 1 && f_n != 0 {
134 f_n >>= 1;
135 s_n >>= 1;
136 }
137 } else {
138 r = Node { left: r, right: *p }.hash();
139 }
140
141 f_n >>= 1;
142 s_n >>= 1;
143 }
144
145 if s_n != 0 {
146 return Err(ProofValidationError::PathTooLong);
147 }
148 if r != head.head {
149 return Err(ProofValidationError::HashMismatch);
150 }
151
152 Ok(())
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::store::MemoryStore;
160
161 #[test]
162 fn compute_audit_proofs() {
163 let tree = Tree::<_, _, String>::new(MemoryStore::default(), MemoryStore::default());
164
165 tree.insert_entry("A".to_string());
166 tree.insert_entry("B".to_string());
167 tree.insert_entry("C".to_string());
168 tree.insert_entry("D".to_string());
169 tree.insert_entry("E".to_string());
170 tree.insert_entry("F".to_string());
171 tree.insert_entry("G".to_string());
172
173 let head = tree.recompute_tree_head();
174
175 let proof1 = tree.get_audit_proof(&head, 0).unwrap();
176 assert_eq!(proof1.path.len(), 3);
177 proof1.validate(&head, &"A".to_string()).unwrap();
178
179 let proof2 = tree.get_audit_proof(&head, 3).unwrap();
180 assert_eq!(proof2.path.len(), 3);
181 proof2.validate(&head, &"D".to_string()).unwrap();
182
183 let proof3 = tree.get_audit_proof(&head, 4).unwrap();
184 assert_eq!(proof3.path.len(), 3);
185 proof3.validate(&head, &"E".to_string()).unwrap();
186
187 let proof4 = tree.get_audit_proof(&head, 6).unwrap();
188 assert_eq!(proof4.path.len(), 2);
189 proof4.validate(&head, &"G".to_string()).unwrap();
190 }
191}