1use crate::{
4 hasher::Hasher,
5 node::{Node, NodeChildType},
6};
7use rayon::prelude::*;
8
9#[derive(Debug, Clone)]
11pub struct ProofNode {
12 pub hash: String,
14 pub child_type: NodeChildType,
16}
17
18#[derive(Debug)]
20pub struct MerkleProof {
21 pub path: Vec<ProofNode>,
23 pub leaf_index: usize,
25}
26
27pub trait Proofer {
28 fn generate(&self, index: usize) -> Option<MerkleProof>;
38
39 fn verify<T>(&self, proof: &MerkleProof, data: T, root_hash: &str) -> bool
51 where
52 T: AsRef<[u8]>;
53}
54
55pub struct DefaultProofer<H: Hasher> {
56 hasher: H,
57 levels: Vec<Vec<Node>>,
58}
59
60impl<H> DefaultProofer<H>
61where
62 H: Hasher,
63{
64 pub fn new(hasher: H, leaves: Vec<Node>) -> Self {
65 let mut levels = Vec::new();
66 levels.push(leaves.clone());
67
68 let mut current_level = leaves;
69 while current_level.len() > 1 {
70 if current_level.len() % 2 != 0 {
71 current_level.push(current_level.last().unwrap().clone());
72 }
73 let next_level: Vec<Node> = current_level
74 .par_chunks(2)
75 .map(|pair| {
76 let (left, right) = (&pair[0], &pair[1]);
77 let combined = [left.hash().as_bytes(), right.hash().as_bytes()].concat();
78 let hash = hasher.hash(&combined);
79 Node::new_internal(hash, left.clone(), right.clone())
80 })
81 .collect();
82
83 levels.push(next_level.clone());
84 current_level = next_level;
85 }
86
87 Self { hasher, levels }
88 }
89
90 pub fn verify_hash(&self, proof: &MerkleProof, hash: String, root_hash: &str) -> bool {
91 let mut current_hash = hash;
92 for proof_node in &proof.path {
94 let combined: String = match proof_node.child_type {
95 NodeChildType::Left => format!("{}{}", proof_node.hash, current_hash),
96 NodeChildType::Right => format!("{}{}", current_hash, proof_node.hash),
97 };
98 current_hash = self.hasher.hash(combined.as_bytes());
99 }
100
101 current_hash == root_hash
103 }
104}
105
106impl<H> Proofer for DefaultProofer<H>
107where
108 H: Hasher,
109{
110 fn generate(&self, index: usize) -> Option<MerkleProof> {
111 if index >= self.levels[0].len() {
112 return None;
113 }
114
115 let mut path = Vec::new();
116 let mut current_index = index;
117
118 for level in &self.levels[..self.levels.len() - 1] {
119 let sibling_index = (current_index ^ 1).min(level.len() - 1);
121
122 let sibling = &level[sibling_index];
123
124 let child_type = if sibling_index < current_index {
125 NodeChildType::Left
126 } else {
127 NodeChildType::Right
128 };
129
130 path.push(ProofNode {
131 hash: sibling.hash().to_string(),
132 child_type,
133 });
134
135 current_index >>= 1;
136 }
137
138 Some(MerkleProof {
139 path,
140 leaf_index: index,
141 })
142 }
143
144 fn verify<T>(&self, proof: &MerkleProof, data: T, root_hash: &str) -> bool
145 where
146 T: AsRef<[u8]>,
147 {
148 let hash: String = self.hasher.hash(data.as_ref());
150 self.verify_hash(proof, hash, root_hash)
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use crate::{hasher::*, merkletree::MerkleTree};
157
158 use super::*;
159
160 #[test]
161 fn test_proof_generation_and_verification_dummy() {
162 let hasher = DummyHasher;
163 let data = vec!["a", "b", "c", "d"];
164 let tree = MerkleTree::new(hasher.clone(), data.clone());
165 let proofer = DefaultProofer::new(hasher, tree.leaves());
166
167 for (index, item) in data.iter().enumerate() {
168 let proof = proofer.generate(index).unwrap();
169
170 assert!(proofer.verify(&proof, item, tree.root().hash()));
171 }
172 }
173
174 #[test]
175 fn test_proof_generation_and_verification_sha256() {
176 let hasher = SHA256Hasher::new();
177 let data = vec!["a", "b", "c", "d"];
178 let tree = MerkleTree::new(hasher.clone(), data.clone());
179 let proofer = DefaultProofer::new(hasher, tree.leaves().clone());
180
181 for (index, item) in data.iter().enumerate() {
182 let proof = proofer.generate(index).unwrap();
183
184 assert!(proofer.verify(&proof, item, tree.root().hash()));
185 }
186 }
187
188 #[test]
189 fn test_proof_not_valid() {
190 let hasher = SHA256Hasher::new();
191 let data = vec!["a", "b", "c", "d"];
192 let tree = MerkleTree::new(hasher.clone(), data.clone());
193 let proofer = DefaultProofer::new(hasher, tree.leaves().clone());
194
195 let proof = proofer.generate(0).unwrap();
196
197 assert!(proofer.verify(&proof, b"a", tree.root().hash()));
198 assert!(!proofer.verify(&proof, b"b", tree.root().hash()));
199 assert!(!proofer.verify(&proof, b"c", tree.root().hash()));
200 assert!(!proofer.verify(&proof, b"d", tree.root().hash()));
201
202 assert!(!proofer.verify(&proof, b"e", tree.root().hash()));
203 }
204}