risc0_zkvm/receipt/
merkle.rs

1// Copyright 2024 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Minimal Merkle tree implementation used in the recursion system for
16//! committing to a group of control IDs.
17
18use alloc::vec::Vec;
19
20use anyhow::{ensure, Result};
21use borsh::{BorshDeserialize, BorshSerialize};
22use risc0_core::field::baby_bear::BabyBear;
23use risc0_zkp::core::{digest::Digest, hash::HashFn};
24use serde::{Deserialize, Serialize};
25
26/// Depth of the Merkle tree to use for encoding the set of allowed control IDs.
27// NOTE: Changing this constant must be coordinated with the circuit. In order
28// to avoid needing to change the circuit later, this is set to 8 which allows
29// for enough control IDs to be encoded that we are unlikely to need more.
30pub const ALLOWED_CODE_MERKLE_DEPTH: usize = 8;
31
32/// Merkle tree implementation used in the recursion system to commit to a set
33/// of recursion programs, and to verify the inclusion of a given program in the
34/// set.
35#[non_exhaustive]
36pub struct MerkleGroup {
37    /// Depth of the Merkle tree.
38    pub depth: u32,
39
40    /// Ordered list of Merkle tree leaves, as Digests. It is expected that
41    /// these will be the control IDs for the committed set of recursion
42    /// programs.
43    pub leaves: Vec<Digest>,
44}
45
46/// An inclusion proof for the [MerkleGroup]. Used to verify inclusion of a
47/// given recursion program in the committed set.
48#[non_exhaustive]
49#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, BorshSerialize, BorshDeserialize)]
50pub struct MerkleProof {
51    /// Index of the leaf for which inclusion is being proven.
52    pub index: u32,
53
54    /// Sibling digests on the path from the root to the leaf.
55    /// Does not include the root of the leaf.
56    pub digests: Vec<Digest>,
57}
58
59impl MerkleGroup {
60    /// Create a new [MerkleGroup] from the given leaves.
61    /// Will fail if too many leaves are given for the default depth.
62    pub fn new(leaves: Vec<Digest>) -> Result<Self> {
63        let max_len = 1 << ALLOWED_CODE_MERKLE_DEPTH;
64        ensure!(
65            leaves.len() < max_len,
66            "a maximum of {max_len} leaves can be added to a MerkleGroup"
67        );
68        Ok(Self {
69            depth: ALLOWED_CODE_MERKLE_DEPTH as u32,
70            leaves,
71        })
72    }
73
74    /// Calculate the root of the [MerkleGroup].
75    pub fn calc_root(&self, hashfn: &dyn HashFn<BabyBear>) -> Digest {
76        self.calc_range_root(0, 1 << self.depth, hashfn)
77    }
78
79    fn leaf_or_empty(&self, index: u32) -> &Digest {
80        self.leaves.get(index as usize).unwrap_or(&Digest::ZERO)
81    }
82
83    fn calc_range_root(&self, start: u32, end: u32, hashfn: &dyn HashFn<BabyBear>) -> Digest {
84        assert!(start < end);
85        let res = if start + 1 == end {
86            *self.leaf_or_empty(start)
87        } else {
88            let mid = (start + end) / 2;
89            assert_eq!(mid - start, end - mid);
90
91            let left = self.calc_range_root(start, mid, hashfn);
92            let right = self.calc_range_root(mid, end, hashfn);
93            *hashfn.hash_pair(&left, &right)
94        };
95        res
96    }
97
98    /// Calculate and return a [MerkleProof] for the given leaf.
99    /// Will return an error if the given leaf is not in the tree.
100    #[cfg(feature = "prove")]
101    pub fn get_proof(
102        &self,
103        control_id: &Digest,
104        hashfn: &dyn HashFn<BabyBear>,
105    ) -> Result<MerkleProof> {
106        let Some(index) = self.leaves.iter().position(|elem| elem == control_id) else {
107            anyhow::bail!("Unable to find {control_id:?} in merkle group");
108        };
109        Ok(self.get_proof_by_index(index as u32, hashfn))
110    }
111
112    /// Calculate and return a [MerkleProof] for the given leaf.
113    /// Will panic if the given index is out of the range of leaves.
114    #[cfg(feature = "prove")]
115    pub fn get_proof_by_index(&self, index: u32, hashfn: &dyn HashFn<BabyBear>) -> MerkleProof {
116        let mut digests: Vec<Digest> = Vec::with_capacity(self.depth as usize);
117
118        let mut cur: Digest = self.leaves[index as usize];
119        let mut cur_index = index;
120        for i in 0..self.depth {
121            let sibling_start = (cur_index ^ 1) << i;
122            let sibling_end = sibling_start + (1 << i);
123            let sibling = self.calc_range_root(sibling_start, sibling_end, hashfn);
124            cur = if cur_index & 1 == 0 {
125                *hashfn.hash_pair(&cur, &sibling)
126            } else {
127                *hashfn.hash_pair(&sibling, &cur)
128            };
129            digests.push(sibling);
130            cur_index >>= 1;
131        }
132
133        MerkleProof { digests, index }
134    }
135}
136
137impl MerkleProof {
138    /// Verify the Merkle inclusion proof against the given leaf and root.
139    pub fn verify(
140        &self,
141        leaf: &Digest,
142        root: &Digest,
143        hashfn: &dyn HashFn<BabyBear>,
144    ) -> Result<()> {
145        ensure!(
146            self.root(leaf, hashfn) == *root,
147            "merkle proof verify failed"
148        );
149        Ok(())
150    }
151
152    /// Calculate the root of this branch by iteratively hashing, starting from the leaf.
153    pub fn root(&self, leaf: &Digest, hashfn: &dyn HashFn<BabyBear>) -> Digest {
154        let mut cur = *leaf;
155        let mut cur_index = self.index;
156        for sibling in &self.digests {
157            cur = if cur_index & 1 == 0 {
158                *hashfn.hash_pair(&cur, sibling)
159            } else {
160                *hashfn.hash_pair(sibling, &cur)
161            };
162            cur_index >>= 1;
163        }
164        cur
165    }
166}
167
168#[cfg(test)]
169#[cfg(feature = "prove")]
170mod tests {
171    use risc0_zkp::core::hash::poseidon2::Poseidon2HashSuite;
172
173    use super::*;
174
175    fn shared_levels(a: &MerkleProof, b: &MerkleProof) -> usize {
176        a.digests
177            .iter()
178            .rev()
179            .zip(b.digests.iter().rev())
180            .position(|(a_elem, b_elem)| a_elem != b_elem)
181            .unwrap_or(std::cmp::min(a.digests.len(), b.digests.len()))
182    }
183
184    #[test]
185    fn basics() {
186        let digest1 = Digest::new([1, 2, 3, 4, 5, 6, 7, 8]);
187        let digest2 = Digest::new([9, 10, 11, 12, 13, 14, 15, 16]);
188        let digest3 = Digest::new([17, 18, 19, 20, 21, 22, 23, 24]);
189
190        let suite = Poseidon2HashSuite::new_suite();
191        let hashfn = suite.hashfn.as_ref();
192
193        let grp = MerkleGroup {
194            depth: 4,
195            leaves: Vec::from([digest1, digest2, digest3]),
196        };
197        let root = grp.calc_root(hashfn);
198        tracing::trace!("Root: {root:?}");
199        let proof1 = grp.get_proof_by_index(0, hashfn);
200        tracing::trace!("Proof1: {proof1:?}");
201        let proof2 = grp.get_proof_by_index(1, hashfn);
202        tracing::trace!("Proof2: {proof2:?}");
203        let proof3 = grp.get_proof_by_index(2, hashfn);
204        tracing::trace!("Proof3: {proof3:?}");
205
206        proof1.verify(&digest1, &root, hashfn).unwrap();
207        proof1.verify(&digest1, &root, hashfn).unwrap();
208        proof1.verify(&digest1, &root, hashfn).unwrap();
209
210        // Digest1 and digest2 should share 3 levels of proof, whereas proof2 and proof3
211        // should only share 2
212        assert_eq!(shared_levels(&proof1, &proof2), 3);
213        assert_eq!(shared_levels(&proof2, &proof3), 2);
214        assert_eq!(shared_levels(&proof1, &proof3), 2);
215    }
216}