proof_cat/commit/
merkle.rs1use field_cat::FieldBytes;
9use sha2::{Digest, Sha256};
10
11use crate::error::Error;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub struct MerkleRoot([u8; 32]);
16
17impl MerkleRoot {
18 #[must_use]
20 pub fn as_bytes(&self) -> &[u8; 32] {
21 &self.0
22 }
23}
24
25#[derive(Debug, Clone)]
27pub struct MerkleProof {
28 leaf_index: usize,
29 siblings: Vec<[u8; 32]>,
30}
31
32impl MerkleProof {
33 #[must_use]
35 pub fn leaf_index(&self) -> usize {
36 self.leaf_index
37 }
38
39 #[must_use]
41 pub fn siblings(&self) -> &[[u8; 32]] {
42 &self.siblings
43 }
44}
45
46#[derive(Debug, Clone)]
74pub struct MerkleTree {
75 nodes: Vec<[u8; 32]>,
77 depth: usize,
79 leaf_count: usize,
81}
82
83fn hash_leaf(index: usize, value_bytes: &[u8]) -> [u8; 32] {
85 let mut hasher = Sha256::new();
86 hasher.update(b"leaf:");
87 hasher.update(index.to_le_bytes());
88 hasher.update(value_bytes);
89 hasher.finalize().into()
90}
91
92fn hash_padding(index: usize) -> [u8; 32] {
94 let mut hasher = Sha256::new();
95 hasher.update(b"padding:");
96 hasher.update(index.to_le_bytes());
97 hasher.finalize().into()
98}
99
100fn hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
102 let mut hasher = Sha256::new();
103 hasher.update(left);
104 hasher.update(right);
105 hasher.finalize().into()
106}
107
108fn next_power_of_two(n: usize) -> usize {
110 if n <= 1 { 1 } else { n.next_power_of_two() }
111}
112
113impl MerkleTree {
114 #[must_use]
118 pub fn from_field_values<F: FieldBytes>(values: &[F]) -> Self {
119 let leaf_count = values.len();
120 let n = next_power_of_two(leaf_count);
121 let depth = usize::try_from(n.trailing_zeros()).unwrap_or(0);
123
124 let leaf_hashes: Vec<[u8; 32]> = (0..n)
127 .map(|i| {
128 if i < leaf_count {
129 hash_leaf(i, &values[i].to_le_bytes())
130 } else {
131 hash_padding(i)
132 }
133 })
134 .collect();
135
136 let nodes_len = 2 * n;
139 let zeroed: Vec<[u8; 32]> = (0..nodes_len).map(|_| [0u8; 32]).collect();
140
141 let with_leaves: Vec<[u8; 32]> = zeroed
143 .iter()
144 .enumerate()
145 .map(|(idx, zero)| {
146 if idx >= n && idx < 2 * n {
147 leaf_hashes[idx - n]
148 } else {
149 *zero
150 }
151 })
152 .collect();
153
154 let nodes = (1..=depth).fold(with_leaves, |acc, level_from_bottom| {
158 let start = n >> level_from_bottom;
160 let end = n >> (level_from_bottom - 1);
161 (0..acc.len())
162 .map(|idx| {
163 if idx >= start && idx < end {
164 hash_pair(&acc[idx * 2], &acc[idx * 2 + 1])
165 } else {
166 acc[idx]
167 }
168 })
169 .collect()
170 });
171
172 Self {
173 nodes,
174 depth,
175 leaf_count,
176 }
177 }
178
179 #[must_use]
181 pub fn root(&self) -> MerkleRoot {
182 if self.nodes.len() > 1 {
183 MerkleRoot(self.nodes[1])
184 } else {
185 MerkleRoot([0u8; 32])
186 }
187 }
188
189 #[must_use]
191 pub fn leaf_count(&self) -> usize {
192 self.leaf_count
193 }
194
195 pub fn open(&self, index: usize) -> Result<MerkleProof, Error> {
201 if index >= self.leaf_count {
202 Err(Error::LeafIndexOutOfBounds {
203 index,
204 leaf_count: self.leaf_count,
205 })
206 } else {
207 let n = 1usize << self.depth;
208 let siblings = (0..self.depth)
210 .scan(n + index, |pos, _| {
211 let sibling_pos = *pos ^ 1;
212 let sibling = self.nodes[sibling_pos];
213 *pos /= 2;
214 Some(sibling)
215 })
216 .collect();
217 Ok(MerkleProof {
218 leaf_index: index,
219 siblings,
220 })
221 }
222 }
223
224 #[must_use]
229 pub fn verify_opening<F: FieldBytes>(
230 root: &MerkleRoot,
231 index: usize,
232 value: &F,
233 proof: &MerkleProof,
234 ) -> bool {
235 let leaf_hash = hash_leaf(index, &value.to_le_bytes());
236 let n = 1usize << proof.siblings.len();
237 let computed_root = proof
238 .siblings
239 .iter()
240 .enumerate()
241 .fold((leaf_hash, n + index), |(current, pos), (_, sibling)| {
242 let parent = if pos % 2 == 0 {
243 hash_pair(¤t, sibling)
244 } else {
245 hash_pair(sibling, ¤t)
246 };
247 (parent, pos / 2)
248 })
249 .0;
250 computed_root == root.0
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use field_cat::{BabyBear, F101};
258
259 #[test]
260 fn single_leaf_roundtrip() -> Result<(), Error> {
261 let tree = MerkleTree::from_field_values(&[F101::new(42)]);
262 let proof = tree.open(0)?;
263 assert!(MerkleTree::verify_opening(
264 &tree.root(),
265 0,
266 &F101::new(42),
267 &proof
268 ));
269 Ok(())
270 }
271
272 #[test]
273 fn two_leaf_roundtrip() -> Result<(), Error> {
274 let values = [BabyBear::new(10), BabyBear::new(20)];
275 let tree = MerkleTree::from_field_values(&values);
276 let proof0 = tree.open(0)?;
277 let proof1 = tree.open(1)?;
278 assert!(MerkleTree::verify_opening(
279 &tree.root(),
280 0,
281 &BabyBear::new(10),
282 &proof0
283 ));
284 assert!(MerkleTree::verify_opening(
285 &tree.root(),
286 1,
287 &BabyBear::new(20),
288 &proof1
289 ));
290 Ok(())
291 }
292
293 #[test]
294 fn tampered_value_fails() -> Result<(), Error> {
295 let tree = MerkleTree::from_field_values(&[F101::new(42)]);
296 let proof = tree.open(0)?;
297 assert!(!MerkleTree::verify_opening(
299 &tree.root(),
300 0,
301 &F101::new(99),
302 &proof
303 ));
304 Ok(())
305 }
306
307 #[test]
308 fn out_of_bounds_open_fails() {
309 let tree = MerkleTree::from_field_values(&[F101::new(1), F101::new(2)]);
310 assert!(tree.open(2).is_err());
311 }
312
313 #[test]
314 fn four_leaves() -> Result<(), Error> {
315 let values = [
316 BabyBear::new(1),
317 BabyBear::new(2),
318 BabyBear::new(3),
319 BabyBear::new(4),
320 ];
321 let tree = MerkleTree::from_field_values(&values);
322 (0..4).try_for_each(|i| {
323 let proof = tree.open(i)?;
324 assert!(
325 MerkleTree::verify_opening(&tree.root(), i, &values[i], &proof),
326 "failed at leaf {i}"
327 );
328 Ok(())
329 })
330 }
331}