1use alloc::vec::Vec;
19
20use anyhow::{ensure, Result};
21use borsh::{BorshDeserialize, BorshSerialize};
22use risc0_core::field::baby_bear::BabyBear;
23use risc0_zkp::core::{digest::Digest, hash::HashFn};
24use serde::{Deserialize, Serialize};
25
26pub const ALLOWED_CODE_MERKLE_DEPTH: usize = 8;
31
32#[non_exhaustive]
36pub struct MerkleGroup {
37 pub depth: u32,
39
40 pub leaves: Vec<Digest>,
44}
45
46#[non_exhaustive]
49#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, BorshSerialize, BorshDeserialize)]
50pub struct MerkleProof {
51 pub index: u32,
53
54 pub digests: Vec<Digest>,
57}
58
59impl MerkleGroup {
60 pub fn new(leaves: Vec<Digest>) -> Result<Self> {
63 let max_len = 1 << ALLOWED_CODE_MERKLE_DEPTH;
64 ensure!(
65 leaves.len() < max_len,
66 "a maximum of {max_len} leaves can be added to a MerkleGroup"
67 );
68 Ok(Self {
69 depth: ALLOWED_CODE_MERKLE_DEPTH as u32,
70 leaves,
71 })
72 }
73
74 pub fn calc_root(&self, hashfn: &dyn HashFn<BabyBear>) -> Digest {
76 self.calc_range_root(0, 1 << self.depth, hashfn)
77 }
78
79 fn leaf_or_empty(&self, index: u32) -> &Digest {
80 self.leaves.get(index as usize).unwrap_or(&Digest::ZERO)
81 }
82
83 fn calc_range_root(&self, start: u32, end: u32, hashfn: &dyn HashFn<BabyBear>) -> Digest {
84 assert!(start < end);
85 let res = if start + 1 == end {
86 *self.leaf_or_empty(start)
87 } else {
88 let mid = (start + end) / 2;
89 assert_eq!(mid - start, end - mid);
90
91 let left = self.calc_range_root(start, mid, hashfn);
92 let right = self.calc_range_root(mid, end, hashfn);
93 *hashfn.hash_pair(&left, &right)
94 };
95 res
96 }
97
98 #[cfg(feature = "prove")]
101 pub fn get_proof(
102 &self,
103 control_id: &Digest,
104 hashfn: &dyn HashFn<BabyBear>,
105 ) -> Result<MerkleProof> {
106 let Some(index) = self.leaves.iter().position(|elem| elem == control_id) else {
107 anyhow::bail!("Unable to find {control_id:?} in merkle group");
108 };
109 Ok(self.get_proof_by_index(index as u32, hashfn))
110 }
111
112 #[cfg(feature = "prove")]
115 pub fn get_proof_by_index(&self, index: u32, hashfn: &dyn HashFn<BabyBear>) -> MerkleProof {
116 let mut digests: Vec<Digest> = Vec::with_capacity(self.depth as usize);
117
118 let mut cur: Digest = self.leaves[index as usize];
119 let mut cur_index = index;
120 for i in 0..self.depth {
121 let sibling_start = (cur_index ^ 1) << i;
122 let sibling_end = sibling_start + (1 << i);
123 let sibling = self.calc_range_root(sibling_start, sibling_end, hashfn);
124 cur = if cur_index & 1 == 0 {
125 *hashfn.hash_pair(&cur, &sibling)
126 } else {
127 *hashfn.hash_pair(&sibling, &cur)
128 };
129 digests.push(sibling);
130 cur_index >>= 1;
131 }
132
133 MerkleProof { digests, index }
134 }
135}
136
137impl MerkleProof {
138 pub fn verify(
140 &self,
141 leaf: &Digest,
142 root: &Digest,
143 hashfn: &dyn HashFn<BabyBear>,
144 ) -> Result<()> {
145 ensure!(
146 self.root(leaf, hashfn) == *root,
147 "merkle proof verify failed"
148 );
149 Ok(())
150 }
151
152 pub fn root(&self, leaf: &Digest, hashfn: &dyn HashFn<BabyBear>) -> Digest {
154 let mut cur = *leaf;
155 let mut cur_index = self.index;
156 for sibling in &self.digests {
157 cur = if cur_index & 1 == 0 {
158 *hashfn.hash_pair(&cur, sibling)
159 } else {
160 *hashfn.hash_pair(sibling, &cur)
161 };
162 cur_index >>= 1;
163 }
164 cur
165 }
166}
167
168#[cfg(test)]
169#[cfg(feature = "prove")]
170mod tests {
171 use risc0_zkp::core::hash::poseidon2::Poseidon2HashSuite;
172
173 use super::*;
174
175 fn shared_levels(a: &MerkleProof, b: &MerkleProof) -> usize {
176 a.digests
177 .iter()
178 .rev()
179 .zip(b.digests.iter().rev())
180 .position(|(a_elem, b_elem)| a_elem != b_elem)
181 .unwrap_or(std::cmp::min(a.digests.len(), b.digests.len()))
182 }
183
184 #[test]
185 fn basics() {
186 let digest1 = Digest::new([1, 2, 3, 4, 5, 6, 7, 8]);
187 let digest2 = Digest::new([9, 10, 11, 12, 13, 14, 15, 16]);
188 let digest3 = Digest::new([17, 18, 19, 20, 21, 22, 23, 24]);
189
190 let suite = Poseidon2HashSuite::new_suite();
191 let hashfn = suite.hashfn.as_ref();
192
193 let grp = MerkleGroup {
194 depth: 4,
195 leaves: Vec::from([digest1, digest2, digest3]),
196 };
197 let root = grp.calc_root(hashfn);
198 tracing::trace!("Root: {root:?}");
199 let proof1 = grp.get_proof_by_index(0, hashfn);
200 tracing::trace!("Proof1: {proof1:?}");
201 let proof2 = grp.get_proof_by_index(1, hashfn);
202 tracing::trace!("Proof2: {proof2:?}");
203 let proof3 = grp.get_proof_by_index(2, hashfn);
204 tracing::trace!("Proof3: {proof3:?}");
205
206 proof1.verify(&digest1, &root, hashfn).unwrap();
207 proof1.verify(&digest1, &root, hashfn).unwrap();
208 proof1.verify(&digest1, &root, hashfn).unwrap();
209
210 assert_eq!(shared_levels(&proof1, &proof2), 3);
213 assert_eq!(shared_levels(&proof2, &proof3), 2);
214 assert_eq!(shared_levels(&proof1, &proof3), 2);
215 }
216}