1#![deny(rustdoc::broken_intra_doc_links)]
16#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
17
18extern crate alloc;
19
20use alloc::vec::Vec;
21use core::borrow::Borrow;
22
23use alloy_primitives::{uint, Keccak256, U256};
24use risc0_zkvm::{
25 sha::{Digest, DIGEST_BYTES},
26 ReceiptClaim,
27};
28use serde::{Deserialize, Serialize};
29
30#[cfg(feature = "verify")]
31mod receipt;
32
33#[cfg(feature = "verify")]
34pub use receipt::{
35 decode_set_inclusion_seal, RecursionVerifierParameters, SetInclusionDecodingError,
36 SetInclusionEncodingError, SetInclusionReceipt, SetInclusionReceiptVerifierParameters,
40 VerificationError,
41};
42
43alloy_sol_types::sol! {
44 #[sol(all_derives)]
46 struct Seal {
47 bytes32[] path;
49 bytes root_seal;
51 }
52}
53
54#[derive(Clone, Debug, Deserialize, Serialize)]
56pub struct GuestInput {
57 pub state: GuestState,
62 pub claims: Vec<ReceiptClaim>,
65 pub finalize: bool,
71}
72
73#[derive(Clone, Debug, Deserialize, Serialize)]
74pub struct GuestState {
75 pub self_image_id: Digest,
81 pub mmr: MerkleMountainRange,
84}
85
86impl GuestState {
87 pub fn initial(self_image_id: impl Into<Digest>) -> Self {
89 Self {
90 self_image_id: self_image_id.into(),
91 mmr: MerkleMountainRange::empty(),
92 }
93 }
94
95 pub fn is_initial(&self) -> bool {
97 self.mmr.is_empty()
98 }
99
100 pub fn encode(&self) -> Vec<u8> {
103 [self.self_image_id.as_bytes(), &self.mmr.encode()].concat()
104 }
105
106 pub fn decode(bytes: impl AsRef<[u8]>) -> Result<Self, DecodingError> {
109 let (chunk, bytes) = bytes
111 .as_ref()
112 .split_at_checked(U256::BYTES)
113 .ok_or(DecodingError::UnexpectedEnd)?;
114 let self_image_id = Digest::try_from(chunk).unwrap();
115 let mmr = MerkleMountainRange::decode(bytes)?;
116 Ok(Self { self_image_id, mmr })
117 }
118
119 pub fn into_input(
126 self,
127 claims: Vec<ReceiptClaim>,
128 finalize: bool,
129 ) -> Result<GuestInput, Error> {
130 if self.mmr.is_finalized() {
131 return Err(Error::FinalizedError);
132 }
133 Ok(GuestInput {
134 state: self,
135 claims,
136 finalize,
137 })
138 }
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
150#[cfg_attr(test, derive(PartialEq, Eq))]
151pub struct MerkleMountainRange(Vec<Peak>);
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
154#[cfg_attr(test, derive(PartialEq, Eq))]
155struct Peak {
156 digest: Digest,
158 max_depth: u8,
165}
166
167#[derive(Debug, thiserror::Error)]
168#[non_exhaustive]
169pub enum Error {
170 #[error("Merkle mountain range is finalized")]
171 FinalizedError,
172 #[error("Merkle mountain range is empty")]
173 EmptyError,
174 #[error("decoding error: {0}")]
175 DecodingError(#[from] DecodingError),
176}
177
178#[derive(Debug, thiserror::Error)]
179#[non_exhaustive]
180pub enum DecodingError {
181 #[error("invalid bitmap")]
182 InvalidBitmap,
183 #[error("unexpected end of byte stream")]
184 UnexpectedEnd,
185 #[error("trailing bytes")]
186 TrailingBytes,
187}
188
189impl MerkleMountainRange {
190 pub fn empty() -> Self {
192 Self(Vec::new())
193 }
194
195 pub fn new_finalized(root: Digest) -> Self {
197 Self(vec![Peak {
198 max_depth: u8::MAX,
199 digest: root,
200 }])
201 }
202
203 pub fn push(&mut self, value: impl Borrow<Digest>) -> Result<(), Error> {
205 self.push_peak(Peak {
206 digest: hash_leaf(value.borrow()),
207 max_depth: 0,
208 })
209 }
210
211 fn push_peak(&mut self, new_peak: Peak) -> Result<(), Error> {
212 if self.is_finalized() {
218 return Err(Error::FinalizedError);
219 }
220 match self.0.last() {
221 None => self.0.push(new_peak),
223 Some(peak) if peak.max_depth > new_peak.max_depth => {
225 self.0.push(new_peak);
226 }
227 Some(peak) if peak.max_depth == new_peak.max_depth => {
229 let peak = self.0.pop().unwrap();
231 self.push_peak(Peak {
232 digest: commutative_keccak256(&peak.digest, &new_peak.digest),
233 max_depth: peak.max_depth.checked_add(1).expect(
234 "violation of invariant on the finalization of the Merkle mountain range",
235 ),
236 })?;
237 }
238 Some(_) => {
239 unreachable!("violation of ordering invariant in Merkle mountain range builder")
240 }
241 };
242 Ok(())
243 }
244
245 pub fn finalize(&mut self) -> Result<(), Error> {
248 let root = self.0.iter().rev().fold(None, |root, peak| {
249 Some(match root {
250 Some(root) => commutative_keccak256(&root, &peak.digest),
251 None => peak.digest,
252 })
253 });
254 let Some(root) = root else {
255 return Err(Error::EmptyError);
256 };
257 self.0.clear();
258 self.0.push(Peak {
259 digest: root,
260 max_depth: u8::MAX,
261 });
262 Ok(())
263 }
264
265 pub fn finalized_root(mut self) -> Option<Digest> {
268 match self.is_empty() {
269 true => None,
270 false => {
271 self.finalize().unwrap();
273 Some(self.0[0].digest)
274 }
275 }
276 }
277
278 pub fn is_finalized(&self) -> bool {
281 self.0.first().is_some_and(|peak| peak.max_depth == u8::MAX)
287 }
288
289 pub fn is_empty(&self) -> bool {
291 self.0.is_empty()
292 }
293
294 pub fn encode(&self) -> Vec<u8> {
298 let mut bitmap = U256::ZERO;
301 let mut peaks = Vec::<Digest>::with_capacity(self.0.len());
302 for peak in self.0.iter() {
304 bitmap.set_bit(peak.max_depth as usize, true);
305 peaks.push(peak.digest);
306 }
307 [
308 &bitmap.to_be_bytes::<{ U256::BYTES }>(),
309 bytemuck::cast_slice(&peaks),
310 ]
311 .concat()
312 }
313
314 pub fn decode(bytes: impl AsRef<[u8]>) -> Result<Self, DecodingError> {
316 let (mut chunk, mut bytes) = bytes
318 .as_ref()
319 .split_at_checked(U256::BYTES)
320 .ok_or(DecodingError::UnexpectedEnd)?;
321 let bitmap = U256::from_be_slice(chunk);
322 if bitmap > (uint!(1_U256 << u8::MAX)) {
323 return Err(DecodingError::InvalidBitmap);
325 }
326
327 let mut peaks = Vec::<Peak>::with_capacity(bitmap.count_ones());
329 for i in (0..=u8::MAX).rev() {
330 if !bitmap.bit(i as usize) {
331 continue;
332 }
333 (chunk, bytes) = bytes
334 .split_at_checked(DIGEST_BYTES)
335 .ok_or(DecodingError::UnexpectedEnd)?;
336 peaks.push(Peak {
337 digest: Digest::try_from(chunk).unwrap(),
338 max_depth: i,
339 });
340 }
341 if !bytes.is_empty() {
342 return Err(DecodingError::TrailingBytes);
343 }
344
345 Ok(Self(peaks))
346 }
347}
348
349impl<D: Borrow<Digest>> Extend<D> for MerkleMountainRange {
350 fn extend<T: IntoIterator<Item = D>>(&mut self, values: T) {
352 for value in values {
353 self.push(value)
354 .expect("attempted to extend a finalized MerkleMountainRange");
355 }
356 }
357}
358
359impl<D: Borrow<Digest>> FromIterator<D> for MerkleMountainRange {
360 fn from_iter<T: IntoIterator<Item = D>>(values: T) -> Self {
362 let mut mmr = Self::empty();
363 mmr.extend(values);
364 mmr
365 }
366}
367
368pub fn merkle_root(leaves: &[Digest]) -> Digest {
372 match leaves {
373 [] => panic!("digest list is empty, cannot compute Merkle root"),
374 _ => MerkleMountainRange::from_iter(leaves)
375 .finalized_root()
376 .unwrap(),
377 }
378}
379
380pub fn merkle_path(leaves: &[Digest], index: usize) -> Vec<Digest> {
387 assert!(
388 index < leaves.len(),
389 "no leaf with index {index} in tree of size {}",
390 leaves.len()
391 );
392
393 if leaves.len() == 1 {
394 return Vec::new(); }
396
397 let mut path = Vec::new();
398 let mut current_leaves = leaves;
399 let mut current_index = index;
400
401 while current_leaves.len() > 1 {
402 let mid = current_leaves.len().next_power_of_two() / 2;
404 let (left, right) = current_leaves.split_at(mid);
405
406 if current_index < mid {
408 path.push(merkle_root(right));
409 current_leaves = left;
410 } else {
411 path.push(merkle_root(left));
412 current_leaves = right;
413 current_index -= mid;
414 }
415 }
416
417 path.reverse();
418 path
419}
420
421pub fn merkle_path_root(
425 leaf_value: impl Borrow<Digest>,
426 path: impl IntoIterator<Item = impl Borrow<Digest>>,
427) -> Digest {
428 let leaf = hash_leaf(leaf_value.borrow());
429 path.into_iter()
430 .fold(leaf, |a, b| commutative_keccak256(a.borrow(), b.borrow()))
431}
432
433const LEAF_TAG: &[u8; 8] = b"LEAF_TAG";
437
438fn hash_leaf(value: &Digest) -> Digest {
443 let mut hasher = Keccak256::new();
444 hasher.update(LEAF_TAG);
445 hasher.update(value.as_bytes());
446 hasher.finalize().0.into()
447}
448
449fn commutative_keccak256(a: &Digest, b: &Digest) -> Digest {
451 let mut hasher = Keccak256::new();
452 if a.as_bytes() < b.as_bytes() {
453 hasher.update(a.as_bytes());
454 hasher.update(b.as_bytes());
455 } else {
456 hasher.update(b.as_bytes());
457 hasher.update(a.as_bytes());
458 }
459 hasher.finalize().0.into()
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use hex::FromHex;
466
467 fn assert_merkle_root(digests: &[Digest], expected_root: Digest) {
468 let root = merkle_root(digests);
469 assert_eq!(root, expected_root);
470 }
471
472 #[test]
473 fn test_root_manual() {
474 let digests = vec![
475 Digest::from_hex("6a428060b5d51f04583182f2ff1b565f9db661da12ee7bdc003e9ab6d5d91ba9")
476 .unwrap(),
477 Digest::from_hex("6a428060b5d51f04583182f2ff1b565f9db661da12ee7bdc003e9ab6d5d91ba9")
478 .unwrap(),
479 Digest::from_hex("6a428060b5d51f04583182f2ff1b565f9db661da12ee7bdc003e9ab6d5d91ba9")
480 .unwrap(),
481 ];
482
483 assert_merkle_root(
484 &digests,
485 Digest::from_hex("bd792a6858270b233a6b399c1cbc60c5b1046a5b43758b9abc46ba32d23c7352")
486 .unwrap(),
487 );
488 }
489
490 #[test]
491 fn test_merkle_root() {
492 let digests = vec![Digest::from([0u8; 32])];
493 assert_merkle_root(&digests, hash_leaf(&digests[0]));
494
495 let digests = vec![
496 Digest::from([0u8; 32]),
497 Digest::from([1u8; 32]),
498 Digest::from([2u8; 32]),
499 ];
500 assert_merkle_root(
501 &digests,
502 commutative_keccak256(
503 &commutative_keccak256(&hash_leaf(&digests[0]), &hash_leaf(&digests[1])),
504 &hash_leaf(&digests[2]),
505 ),
506 );
507
508 let digests = vec![
509 Digest::from([0u8; 32]),
510 Digest::from([1u8; 32]),
511 Digest::from([2u8; 32]),
512 Digest::from([3u8; 32]),
513 ];
514 assert_merkle_root(
515 &digests,
516 commutative_keccak256(
517 &commutative_keccak256(&hash_leaf(&digests[0]), &hash_leaf(&digests[1])),
518 &commutative_keccak256(&hash_leaf(&digests[2]), &hash_leaf(&digests[3])),
519 ),
520 );
521 }
522
523 #[test]
524 fn test_consistency() {
525 for length in 1..=128 {
526 let digests: Vec<Digest> = (0..length)
527 .map(|_| rand::random::<[u8; 32]>().into())
528 .collect();
529 let root = merkle_root(&digests);
530
531 for i in 0..length {
532 let path = merkle_path(&digests, i);
533 assert_eq!(merkle_path_root(digests[i], &path), root);
534 }
535 }
536 }
537
538 #[test]
539 fn test_encode_decode() {
540 for length in 0..=128 {
541 let digests: Vec<Digest> = (0..length)
542 .map(|_| rand::random::<[u8; 32]>().into())
543 .collect();
544 let mmr = MerkleMountainRange::from_iter(digests);
545
546 assert_eq!(mmr, MerkleMountainRange::decode(mmr.encode()).unwrap());
547 }
548 }
549}