Skip to main content

proof_cat/commit/
merkle.rs

1//! Merkle tree commitment scheme.
2//!
3//! Commits to a vector of field elements by hashing each as a
4//! leaf, building a binary hash tree, and exposing the root as
5//! the binding commitment.  Opening proofs are sibling paths
6//! from leaf to root.
7
8use field_cat::FieldBytes;
9use sha2::{Digest, Sha256};
10
11use crate::error::Error;
12
13/// A Merkle tree root hash: the binding commitment.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub struct MerkleRoot([u8; 32]);
16
17impl MerkleRoot {
18    /// The raw 32-byte root hash.
19    #[must_use]
20    pub fn as_bytes(&self) -> &[u8; 32] {
21        &self.0
22    }
23}
24
25/// A Merkle opening proof: sibling hashes from leaf to root.
26#[derive(Debug, Clone)]
27pub struct MerkleProof {
28    leaf_index: usize,
29    siblings: Vec<[u8; 32]>,
30}
31
32impl MerkleProof {
33    /// The index of the opened leaf.
34    #[must_use]
35    pub fn leaf_index(&self) -> usize {
36        self.leaf_index
37    }
38
39    /// The sibling hashes along the path to the root.
40    #[must_use]
41    pub fn siblings(&self) -> &[[u8; 32]] {
42        &self.siblings
43    }
44}
45
46/// A Merkle tree over field element leaves.
47///
48/// The tree is stored as a flat array where `nodes[1]` is the
49/// root, `nodes[2i]` and `nodes[2i+1]` are children of `nodes[i]`,
50/// and leaves occupy `nodes[n..2n]` where `n = 2^depth`.
51///
52/// # Examples
53///
54/// ```
55/// use field_cat::F101;
56/// use proof_cat::commit::merkle::MerkleTree;
57///
58/// let values = [F101::new(10), F101::new(20), F101::new(30), F101::new(40)];
59/// let tree = MerkleTree::from_field_values(&values);
60///
61/// // Open leaf 2 and verify the opening.
62/// let proof = tree.open(2)?;
63/// assert!(MerkleTree::verify_opening(
64///     &tree.root(), 2, &F101::new(30), &proof,
65/// ));
66///
67/// // A wrong value fails verification.
68/// assert!(!MerkleTree::verify_opening(
69///     &tree.root(), 2, &F101::new(99), &proof,
70/// ));
71/// # Ok::<(), proof_cat::Error>(())
72/// ```
73#[derive(Debug, Clone)]
74pub struct MerkleTree {
75    /// Flat array: index 0 unused, index 1 = root, leaves at [n..2n).
76    nodes: Vec<[u8; 32]>,
77    /// Tree depth (number of levels below root).
78    depth: usize,
79    /// Number of actual (non-padding) leaves.
80    leaf_count: usize,
81}
82
83/// Hash a leaf value with its index for domain separation.
84fn hash_leaf(index: usize, value_bytes: &[u8]) -> [u8; 32] {
85    let mut hasher = Sha256::new();
86    hasher.update(b"leaf:");
87    hasher.update(index.to_le_bytes());
88    hasher.update(value_bytes);
89    hasher.finalize().into()
90}
91
92/// Hash a padding leaf.
93fn hash_padding(index: usize) -> [u8; 32] {
94    let mut hasher = Sha256::new();
95    hasher.update(b"padding:");
96    hasher.update(index.to_le_bytes());
97    hasher.finalize().into()
98}
99
100/// Hash two children into a parent node.
101fn hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
102    let mut hasher = Sha256::new();
103    hasher.update(left);
104    hasher.update(right);
105    hasher.finalize().into()
106}
107
108/// Compute the next power of two >= n (minimum 1).
109fn next_power_of_two(n: usize) -> usize {
110    if n <= 1 { 1 } else { n.next_power_of_two() }
111}
112
113impl MerkleTree {
114    /// Build a Merkle tree from a slice of field elements.
115    ///
116    /// Pads to the next power of two with distinct padding leaves.
117    #[must_use]
118    pub fn from_field_values<F: FieldBytes>(values: &[F]) -> Self {
119        let leaf_count = values.len();
120        let n = next_power_of_two(leaf_count);
121        // Safety: trailing_zeros of a usize fits in usize on all platforms.
122        let depth = usize::try_from(n.trailing_zeros()).unwrap_or(0);
123
124        // Allocate flat array: 2*n entries, index 0 unused.
125        // Leaves at indices [n..2n).
126        let leaf_hashes: Vec<[u8; 32]> = (0..n)
127            .map(|i| {
128                if i < leaf_count {
129                    hash_leaf(i, &values[i].to_le_bytes())
130                } else {
131                    hash_padding(i)
132                }
133            })
134            .collect();
135
136        // Build the flat node array.
137        // Start with 2*n zeros, fill leaves, then compute parents.
138        let nodes_len = 2 * n;
139        let zeroed: Vec<[u8; 32]> = (0..nodes_len).map(|_| [0u8; 32]).collect();
140
141        // Place leaves at positions [n..2n).
142        let with_leaves: Vec<[u8; 32]> = zeroed
143            .iter()
144            .enumerate()
145            .map(|(idx, zero)| {
146                if idx >= n && idx < 2 * n {
147                    leaf_hashes[idx - n]
148                } else {
149                    *zero
150                }
151            })
152            .collect();
153
154        // Build internal nodes from bottom up: levels from (n-1) down to 1.
155        // Each parent[i] = hash(child[2i], child[2i+1]).
156        // We fold from the deepest internal level upward.
157        let nodes = (1..=depth).fold(with_leaves, |acc, level_from_bottom| {
158            // Nodes at this level: indices [start, end).
159            let start = n >> level_from_bottom;
160            let end = n >> (level_from_bottom - 1);
161            (0..acc.len())
162                .map(|idx| {
163                    if idx >= start && idx < end {
164                        hash_pair(&acc[idx * 2], &acc[idx * 2 + 1])
165                    } else {
166                        acc[idx]
167                    }
168                })
169                .collect()
170        });
171
172        Self {
173            nodes,
174            depth,
175            leaf_count,
176        }
177    }
178
179    /// The root commitment.
180    #[must_use]
181    pub fn root(&self) -> MerkleRoot {
182        if self.nodes.len() > 1 {
183            MerkleRoot(self.nodes[1])
184        } else {
185            MerkleRoot([0u8; 32])
186        }
187    }
188
189    /// The number of actual (non-padding) leaves.
190    #[must_use]
191    pub fn leaf_count(&self) -> usize {
192        self.leaf_count
193    }
194
195    /// Generate an opening proof for the leaf at `index`.
196    ///
197    /// # Errors
198    ///
199    /// Returns [`Error::LeafIndexOutOfBounds`] if `index >= leaf_count`.
200    pub fn open(&self, index: usize) -> Result<MerkleProof, Error> {
201        if index >= self.leaf_count {
202            Err(Error::LeafIndexOutOfBounds {
203                index,
204                leaf_count: self.leaf_count,
205            })
206        } else {
207            let n = 1usize << self.depth;
208            // Collect siblings from leaf position up to the root.
209            let siblings = (0..self.depth)
210                .scan(n + index, |pos, _| {
211                    let sibling_pos = *pos ^ 1;
212                    let sibling = self.nodes[sibling_pos];
213                    *pos /= 2;
214                    Some(sibling)
215                })
216                .collect();
217            Ok(MerkleProof {
218                leaf_index: index,
219                siblings,
220            })
221        }
222    }
223
224    /// Verify an opening proof against a root and leaf value.
225    ///
226    /// Recomputes the root from the leaf hash and sibling path,
227    /// then checks it matches the expected root.
228    #[must_use]
229    pub fn verify_opening<F: FieldBytes>(
230        root: &MerkleRoot,
231        index: usize,
232        value: &F,
233        proof: &MerkleProof,
234    ) -> bool {
235        let leaf_hash = hash_leaf(index, &value.to_le_bytes());
236        let n = 1usize << proof.siblings.len();
237        let computed_root = proof
238            .siblings
239            .iter()
240            .enumerate()
241            .fold((leaf_hash, n + index), |(current, pos), (_, sibling)| {
242                let parent = if pos % 2 == 0 {
243                    hash_pair(&current, sibling)
244                } else {
245                    hash_pair(sibling, &current)
246                };
247                (parent, pos / 2)
248            })
249            .0;
250        computed_root == root.0
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use field_cat::{BabyBear, F101};
258
259    #[test]
260    fn single_leaf_roundtrip() -> Result<(), Error> {
261        let tree = MerkleTree::from_field_values(&[F101::new(42)]);
262        let proof = tree.open(0)?;
263        assert!(MerkleTree::verify_opening(
264            &tree.root(),
265            0,
266            &F101::new(42),
267            &proof
268        ));
269        Ok(())
270    }
271
272    #[test]
273    fn two_leaf_roundtrip() -> Result<(), Error> {
274        let values = [BabyBear::new(10), BabyBear::new(20)];
275        let tree = MerkleTree::from_field_values(&values);
276        let proof0 = tree.open(0)?;
277        let proof1 = tree.open(1)?;
278        assert!(MerkleTree::verify_opening(
279            &tree.root(),
280            0,
281            &BabyBear::new(10),
282            &proof0
283        ));
284        assert!(MerkleTree::verify_opening(
285            &tree.root(),
286            1,
287            &BabyBear::new(20),
288            &proof1
289        ));
290        Ok(())
291    }
292
293    #[test]
294    fn tampered_value_fails() -> Result<(), Error> {
295        let tree = MerkleTree::from_field_values(&[F101::new(42)]);
296        let proof = tree.open(0)?;
297        // Wrong value:
298        assert!(!MerkleTree::verify_opening(
299            &tree.root(),
300            0,
301            &F101::new(99),
302            &proof
303        ));
304        Ok(())
305    }
306
307    #[test]
308    fn out_of_bounds_open_fails() {
309        let tree = MerkleTree::from_field_values(&[F101::new(1), F101::new(2)]);
310        assert!(tree.open(2).is_err());
311    }
312
313    #[test]
314    fn four_leaves() -> Result<(), Error> {
315        let values = [
316            BabyBear::new(1),
317            BabyBear::new(2),
318            BabyBear::new(3),
319            BabyBear::new(4),
320        ];
321        let tree = MerkleTree::from_field_values(&values);
322        (0..4).try_for_each(|i| {
323            let proof = tree.open(i)?;
324            assert!(
325                MerkleTree::verify_opening(&tree.root(), i, &values[i], &proof),
326                "failed at leaf {i}"
327            );
328            Ok(())
329        })
330    }
331}