proof_cat/commit/
merkle.rs1use sha2::{Digest, Sha256};
9
10use crate::error::Error;
11use crate::field::FieldBytes;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub struct MerkleRoot([u8; 32]);
16
17impl MerkleRoot {
18 #[must_use]
20 pub fn as_bytes(&self) -> &[u8; 32] {
21 &self.0
22 }
23}
24
25#[derive(Debug, Clone)]
27pub struct MerkleProof {
28 leaf_index: usize,
29 siblings: Vec<[u8; 32]>,
30}
31
32impl MerkleProof {
33 #[must_use]
35 pub fn leaf_index(&self) -> usize {
36 self.leaf_index
37 }
38
39 #[must_use]
41 pub fn siblings(&self) -> &[[u8; 32]] {
42 &self.siblings
43 }
44}
45
46#[derive(Debug, Clone)]
74pub struct MerkleTree {
75 nodes: Vec<[u8; 32]>,
77 depth: usize,
79 leaf_count: usize,
81}
82
83fn hash_leaf(index: usize, value_bytes: &[u8]) -> [u8; 32] {
85 let mut hasher = Sha256::new();
86 hasher.update(b"leaf:");
87 hasher.update(index.to_le_bytes());
88 hasher.update(value_bytes);
89 hasher.finalize().into()
90}
91
92fn hash_padding(index: usize) -> [u8; 32] {
94 let mut hasher = Sha256::new();
95 hasher.update(b"padding:");
96 hasher.update(index.to_le_bytes());
97 hasher.finalize().into()
98}
99
100fn hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
102 let mut hasher = Sha256::new();
103 hasher.update(left);
104 hasher.update(right);
105 hasher.finalize().into()
106}
107
108fn next_power_of_two(n: usize) -> usize {
110 if n <= 1 { 1 } else { n.next_power_of_two() }
111}
112
113impl MerkleTree {
114 #[must_use]
118 pub fn from_field_values<F: FieldBytes>(values: &[F]) -> Self {
119 let leaf_count = values.len();
120 let n = next_power_of_two(leaf_count);
121 let depth = usize::try_from(n.trailing_zeros()).unwrap_or(0);
123
124 let leaf_hashes: Vec<[u8; 32]> = (0..n)
127 .map(|i| {
128 if i < leaf_count {
129 hash_leaf(i, &values[i].to_le_bytes())
130 } else {
131 hash_padding(i)
132 }
133 })
134 .collect();
135
136 let nodes_len = 2 * n;
139 let zeroed: Vec<[u8; 32]> = (0..nodes_len).map(|_| [0u8; 32]).collect();
140
141 let with_leaves: Vec<[u8; 32]> = zeroed
143 .iter()
144 .enumerate()
145 .map(|(idx, zero)| {
146 if idx >= n && idx < 2 * n {
147 leaf_hashes[idx - n]
148 } else {
149 *zero
150 }
151 })
152 .collect();
153
154 let nodes = (1..=depth).fold(with_leaves, |acc, level_from_bottom| {
158 let start = n >> level_from_bottom;
160 let end = n >> (level_from_bottom - 1);
161 (0..acc.len())
162 .map(|idx| {
163 if idx >= start && idx < end {
164 hash_pair(&acc[idx * 2], &acc[idx * 2 + 1])
165 } else {
166 acc[idx]
167 }
168 })
169 .collect()
170 });
171
172 Self {
173 nodes,
174 depth,
175 leaf_count,
176 }
177 }
178
179 #[must_use]
181 pub fn root(&self) -> MerkleRoot {
182 if self.nodes.len() > 1 {
183 MerkleRoot(self.nodes[1])
184 } else {
185 MerkleRoot([0u8; 32])
186 }
187 }
188
189 #[must_use]
191 pub fn leaf_count(&self) -> usize {
192 self.leaf_count
193 }
194
195 pub fn open(&self, index: usize) -> Result<MerkleProof, Error> {
201 if index >= self.leaf_count {
202 Err(Error::LeafIndexOutOfBounds {
203 index,
204 leaf_count: self.leaf_count,
205 })
206 } else {
207 let n = 1usize << self.depth;
208 let siblings = (0..self.depth)
210 .scan(n + index, |pos, _| {
211 let sibling_pos = *pos ^ 1;
212 let sibling = self.nodes[sibling_pos];
213 *pos /= 2;
214 Some(sibling)
215 })
216 .collect();
217 Ok(MerkleProof {
218 leaf_index: index,
219 siblings,
220 })
221 }
222 }
223
224 #[must_use]
229 pub fn verify_opening<F: FieldBytes>(
230 root: &MerkleRoot,
231 index: usize,
232 value: &F,
233 proof: &MerkleProof,
234 ) -> bool {
235 let leaf_hash = hash_leaf(index, &value.to_le_bytes());
236 let n = 1usize << proof.siblings.len();
237 let computed_root = proof
238 .siblings
239 .iter()
240 .enumerate()
241 .fold((leaf_hash, n + index), |(current, pos), (_, sibling)| {
242 let parent = if pos % 2 == 0 {
243 hash_pair(¤t, sibling)
244 } else {
245 hash_pair(sibling, ¤t)
246 };
247 (parent, pos / 2)
248 })
249 .0;
250 computed_root == root.0
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use crate::field::BabyBear;
258 use plonkish_cat::F101;
259
260 #[test]
261 fn single_leaf_roundtrip() -> Result<(), Error> {
262 let tree = MerkleTree::from_field_values(&[F101::new(42)]);
263 let proof = tree.open(0)?;
264 assert!(MerkleTree::verify_opening(
265 &tree.root(),
266 0,
267 &F101::new(42),
268 &proof
269 ));
270 Ok(())
271 }
272
273 #[test]
274 fn two_leaf_roundtrip() -> Result<(), Error> {
275 let values = [BabyBear::new(10), BabyBear::new(20)];
276 let tree = MerkleTree::from_field_values(&values);
277 let proof0 = tree.open(0)?;
278 let proof1 = tree.open(1)?;
279 assert!(MerkleTree::verify_opening(
280 &tree.root(),
281 0,
282 &BabyBear::new(10),
283 &proof0
284 ));
285 assert!(MerkleTree::verify_opening(
286 &tree.root(),
287 1,
288 &BabyBear::new(20),
289 &proof1
290 ));
291 Ok(())
292 }
293
294 #[test]
295 fn tampered_value_fails() -> Result<(), Error> {
296 let tree = MerkleTree::from_field_values(&[F101::new(42)]);
297 let proof = tree.open(0)?;
298 assert!(!MerkleTree::verify_opening(
300 &tree.root(),
301 0,
302 &F101::new(99),
303 &proof
304 ));
305 Ok(())
306 }
307
308 #[test]
309 fn out_of_bounds_open_fails() {
310 let tree = MerkleTree::from_field_values(&[F101::new(1), F101::new(2)]);
311 assert!(tree.open(2).is_err());
312 }
313
314 #[test]
315 fn four_leaves() -> Result<(), Error> {
316 let values = [
317 BabyBear::new(1),
318 BabyBear::new(2),
319 BabyBear::new(3),
320 BabyBear::new(4),
321 ];
322 let tree = MerkleTree::from_field_values(&values);
323 (0..4).try_for_each(|i| {
324 let proof = tree.open(i)?;
325 assert!(
326 MerkleTree::verify_opening(&tree.root(), i, &values[i], &proof),
327 "failed at leaf {i}"
328 );
329 Ok(())
330 })
331 }
332}