luct_core/tree/
inclusion.rs1use crate::{
2 store::{AsyncStoreRead, Hashable, StoreRead},
3 tree::{HashOutput, Node, NodeKey, ProofGenerationError, ProofValidationError, Tree, TreeHead},
4};
5use futures::{FutureExt, future::join_all};
6
7impl<N, L> Tree<N, L>
8where
9 N: StoreRead<Key = NodeKey, Value = HashOutput>,
10{
11 pub fn get_audit_proof(
13 &self,
14 head: &TreeHead,
15 index: u64,
16 ) -> Result<AuditProof, ProofGenerationError> {
17 if index >= head.tree_size {
18 return Err(ProofGenerationError::InvalidIndex {
19 tree_size: head.tree_size,
20 index,
21 });
22 }
23
24 let path = get_audit_proof(head, index, |key| {
25 self.nodes
26 .get(&key)
27 .ok_or(ProofGenerationError::KeyNotFound(key))
28 });
29 let mut path = path
30 .into_iter()
31 .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
32
33 path.reverse();
34 Ok(AuditProof { index, path })
35 }
36}
37
38impl<N, L> Tree<N, L>
39where
40 N: AsyncStoreRead<Key = NodeKey, Value = HashOutput>,
41{
42 pub async fn get_audit_proof_async(
43 &self,
44 head: &TreeHead,
45 index: u64,
46 ) -> Result<AuditProof, ProofGenerationError> {
47 if index >= head.tree_size {
48 return Err(ProofGenerationError::InvalidIndex {
49 tree_size: head.tree_size,
50 index,
51 });
52 }
53
54 let path = get_audit_proof(head, index, |key| {
55 self.nodes
56 .get(key.clone())
57 .map(|result| result.ok_or(ProofGenerationError::KeyNotFound(key)))
58 });
59 let path = join_all(path).await;
60 let mut path = path
61 .into_iter()
62 .collect::<Result<Vec<HashOutput>, ProofGenerationError>>()?;
63
64 path.reverse();
65 Ok(AuditProof { index, path })
66 }
67}
68
69fn get_audit_proof<F, O>(head: &TreeHead, index: u64, get: F) -> Vec<O>
70where
71 F: Fn(NodeKey) -> O,
72{
73 let mut n = NodeKey::full_range(head.tree_size);
74 let m = index;
75
76 let mut path = vec![];
77
78 while !n.is_leaf() {
79 let (left, right) = n.split();
80 if m < right.start {
81 let elem = get(right);
82 path.push(elem);
83
84 n = left;
85 } else {
86 let elem = get(left);
87 path.push(elem);
88
89 n = right;
90 }
91 }
92
93 path
94}
95
96#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
97pub struct AuditProof {
98 pub(crate) index: u64,
99 pub(crate) path: Vec<HashOutput>,
100}
101
102impl AuditProof {
103 pub fn index(&self) -> u64 {
104 self.index
105 }
106
107 pub fn validate(
108 &self,
109 head: &TreeHead,
110 leaf: &impl Hashable,
111 ) -> Result<(), ProofValidationError> {
112 if head.tree_size < self.index {
113 return Err(ProofValidationError::InvalidIndex {
114 tree_size: head.tree_size,
115 index: self.index,
116 });
117 }
118
119 let mut f_n = self.index;
120 let mut s_n = head.tree_size - 1;
121 let mut r = leaf.hash();
122
123 for p in &self.path {
124 if s_n == 0 {
125 return Err(ProofValidationError::PathTooShort);
126 }
127
128 if f_n & 1 == 1 || f_n == s_n {
129 r = Node { left: *p, right: r }.hash();
130
131 while f_n & 1 != 1 && f_n != 0 {
132 f_n >>= 1;
133 s_n >>= 1;
134 }
135 } else {
136 r = Node { left: r, right: *p }.hash();
137 }
138
139 f_n >>= 1;
140 s_n >>= 1;
141 }
142
143 if s_n != 0 {
144 return Err(ProofValidationError::PathTooLong);
145 }
146 if r != head.head {
147 return Err(ProofValidationError::HashMismatch);
148 }
149
150 Ok(())
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::store::MemoryStore;
158
159 #[test]
160 fn compute_audit_proofs() {
161 let tree = Tree::<MemoryStore<NodeKey, HashOutput>, MemoryStore<u64, String>>::new(
162 MemoryStore::default(),
163 MemoryStore::default(),
164 );
165
166 tree.insert_entry("A".to_string());
167 tree.insert_entry("B".to_string());
168 tree.insert_entry("C".to_string());
169 tree.insert_entry("D".to_string());
170 tree.insert_entry("E".to_string());
171 tree.insert_entry("F".to_string());
172 tree.insert_entry("G".to_string());
173
174 let head = tree.recompute_tree_head();
175
176 let proof1 = tree.get_audit_proof(&head, 0).unwrap();
177 assert_eq!(proof1.path.len(), 3);
178 proof1.validate(&head, &"A".to_string()).unwrap();
179
180 let proof2 = tree.get_audit_proof(&head, 3).unwrap();
181 assert_eq!(proof2.path.len(), 3);
182 proof2.validate(&head, &"D".to_string()).unwrap();
183
184 let proof3 = tree.get_audit_proof(&head, 4).unwrap();
185 assert_eq!(proof3.path.len(), 3);
186 proof3.validate(&head, &"E".to_string()).unwrap();
187
188 let proof4 = tree.get_audit_proof(&head, 6).unwrap();
189 assert_eq!(proof4.path.len(), 2);
190 proof4.validate(&head, &"G".to_string()).unwrap();
191 }
192}