1#![deny(rustdoc::broken_intra_doc_links)]
16#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
17
18extern crate alloc;
19
20use alloc::vec::Vec;
21use core::borrow::Borrow;
22
23use alloy_primitives::{uint, Keccak256, U256};
24use risc0_zkvm::{
25 sha::{Digest, DIGEST_BYTES},
26 ReceiptClaim,
27};
28use serde::{Deserialize, Serialize};
29
30#[cfg(feature = "verify")]
31mod receipt;
32
33#[cfg(feature = "verify")]
34pub use receipt::{
35 decode_set_inclusion_seal, RecursionVerifierParameters, SetInclusionDecodingError,
36 SetInclusionEncodingError, SetInclusionReceipt, SetInclusionReceiptVerifierParameters,
40 VerificationError,
41};
42
43alloy_sol_types::sol! {
44 #[sol(all_derives)]
46 struct Seal {
47 bytes32[] path;
49 bytes root_seal;
51 }
52}
53
54#[derive(Clone, Debug, Deserialize, Serialize)]
56pub struct GuestInput {
57 pub state: GuestState,
62 pub claims: Vec<ReceiptClaim>,
65 pub finalize: bool,
71}
72
73#[derive(Clone, Debug, Deserialize, Serialize)]
74pub struct GuestState {
75 pub self_image_id: Digest,
81 pub mmr: MerkleMountainRange,
84}
85
86impl GuestState {
87 pub fn initial(self_image_id: impl Into<Digest>) -> Self {
89 Self {
90 self_image_id: self_image_id.into(),
91 mmr: MerkleMountainRange::empty(),
92 }
93 }
94
95 pub fn is_initial(&self) -> bool {
97 self.mmr.is_empty()
98 }
99
100 pub fn encode(&self) -> Vec<u8> {
103 [self.self_image_id.as_bytes(), &self.mmr.encode()].concat()
104 }
105
106 pub fn decode(bytes: impl AsRef<[u8]>) -> Result<Self, DecodingError> {
109 let (chunk, bytes) = bytes
111 .as_ref()
112 .split_at_checked(U256::BYTES)
113 .ok_or(DecodingError::UnexpectedEnd)?;
114 let self_image_id = Digest::try_from(chunk).unwrap();
115 let mmr = MerkleMountainRange::decode(bytes)?;
116 Ok(Self { self_image_id, mmr })
117 }
118
119 pub fn into_input(
126 self,
127 claims: Vec<ReceiptClaim>,
128 finalize: bool,
129 ) -> Result<GuestInput, Error> {
130 if self.mmr.is_finalized() {
131 return Err(Error::FinalizedError);
132 }
133 Ok(GuestInput {
134 state: self,
135 claims,
136 finalize,
137 })
138 }
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
150#[cfg_attr(test, derive(PartialEq, Eq))]
151pub struct MerkleMountainRange(Vec<Peak>);
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
154#[cfg_attr(test, derive(PartialEq, Eq))]
155struct Peak {
156 digest: Digest,
158 max_depth: u8,
165}
166
167#[derive(Debug, thiserror::Error)]
168#[non_exhaustive]
169pub enum Error {
170 #[error("Merkle mountain range is finalized")]
171 FinalizedError,
172 #[error("Merkle mountain range is empty")]
173 EmptyError,
174 #[error("decoding error: {0}")]
175 DecodingError(#[from] DecodingError),
176}
177
178#[derive(Debug, thiserror::Error)]
179#[non_exhaustive]
180pub enum DecodingError {
181 #[error("invalid bitmap")]
182 InvalidBitmap,
183 #[error("unexpected end of byte stream")]
184 UnexpectedEnd,
185 #[error("trailing bytes")]
186 TrailingBytes,
187}
188
189impl MerkleMountainRange {
190 pub fn empty() -> Self {
192 Self(Vec::new())
193 }
194
195 pub fn new_finalized(root: Digest) -> Self {
197 Self(vec![Peak {
198 max_depth: u8::MAX,
199 digest: root,
200 }])
201 }
202
203 pub fn push(&mut self, leaf: Digest) -> Result<(), Error> {
205 self.push_peak(Peak {
206 digest: leaf,
207 max_depth: 0,
208 })
209 }
210
211 fn push_peak(&mut self, new_peak: Peak) -> Result<(), Error> {
212 if self.is_finalized() {
218 return Err(Error::FinalizedError);
219 }
220 match self.0.last() {
221 None => self.0.push(new_peak),
223 Some(peak) if peak.max_depth > new_peak.max_depth => {
225 self.0.push(new_peak);
226 }
227 Some(peak) if peak.max_depth == new_peak.max_depth => {
229 let peak = self.0.pop().unwrap();
231 self.push_peak(Peak {
232 digest: commutative_keccak256(&peak.digest, &new_peak.digest),
233 max_depth: peak.max_depth.checked_add(1).expect(
234 "violation of invariant on the finalization of the Merkle mountain range",
235 ),
236 })?;
237 }
238 Some(_) => {
239 unreachable!("violation of ordering invariant in Merkle mountain range builder")
240 }
241 };
242 Ok(())
243 }
244
245 pub fn finalize(&mut self) -> Result<(), Error> {
248 let root = self.0.iter().rev().fold(None, |root, peak| {
249 Some(match root {
250 Some(root) => commutative_keccak256(&root, &peak.digest),
251 None => peak.digest,
252 })
253 });
254 let Some(root) = root else {
255 return Err(Error::EmptyError);
256 };
257 self.0.clear();
258 self.0.push(Peak {
259 digest: root,
260 max_depth: u8::MAX,
261 });
262 Ok(())
263 }
264
265 pub fn finalized_root(mut self) -> Option<Digest> {
268 match self.is_empty() {
269 true => None,
270 false => {
271 self.finalize().unwrap();
273 Some(self.0[0].digest)
274 }
275 }
276 }
277
278 pub fn is_finalized(&self) -> bool {
281 self.0
287 .first()
288 .map_or(false, |peak| peak.max_depth == u8::MAX)
289 }
290
291 pub fn is_empty(&self) -> bool {
293 self.0.is_empty()
294 }
295
296 pub fn encode(&self) -> Vec<u8> {
300 let mut bitmap = U256::ZERO;
303 let mut peaks = Vec::<Digest>::with_capacity(self.0.len());
304 for peak in self.0.iter() {
306 bitmap.set_bit(peak.max_depth as usize, true);
307 peaks.push(peak.digest);
308 }
309 [
310 &bitmap.to_be_bytes::<{ U256::BYTES }>(),
311 bytemuck::cast_slice(&peaks),
312 ]
313 .concat()
314 }
315
316 pub fn decode(bytes: impl AsRef<[u8]>) -> Result<Self, DecodingError> {
318 let (mut chunk, mut bytes) = bytes
320 .as_ref()
321 .split_at_checked(U256::BYTES)
322 .ok_or(DecodingError::UnexpectedEnd)?;
323 let bitmap = U256::from_be_slice(chunk);
324 if bitmap > (uint!(1_U256 << u8::MAX)) {
325 return Err(DecodingError::InvalidBitmap);
327 }
328
329 let mut peaks = Vec::<Peak>::with_capacity(bitmap.count_ones());
331 for i in (0..=u8::MAX).rev() {
332 if !bitmap.bit(i as usize) {
333 continue;
334 }
335 (chunk, bytes) = bytes
336 .split_at_checked(DIGEST_BYTES)
337 .ok_or(DecodingError::UnexpectedEnd)?;
338 peaks.push(Peak {
339 digest: Digest::try_from(chunk).unwrap(),
340 max_depth: i,
341 });
342 }
343 if !bytes.is_empty() {
344 return Err(DecodingError::TrailingBytes);
345 }
346
347 Ok(Self(peaks))
348 }
349}
350
351impl<D: Borrow<Digest>> Extend<D> for MerkleMountainRange {
352 fn extend<T: IntoIterator<Item = D>>(&mut self, leaves: T) {
354 for leaf in leaves {
355 self.push(*leaf.borrow())
356 .expect("attempted to extend a finalized MerkleMountainRange");
357 }
358 }
359}
360
361impl<D: Borrow<Digest>> FromIterator<D> for MerkleMountainRange {
362 fn from_iter<T: IntoIterator<Item = D>>(leaves: T) -> Self {
364 let mut mmr = Self::empty();
365 mmr.extend(leaves);
366 mmr
367 }
368}
369
370pub fn merkle_root(leaves: &[Digest]) -> Digest {
374 match leaves {
375 [] => panic!("digest list is empty, cannot compute Merkle root"),
376 _ => MerkleMountainRange::from_iter(leaves)
377 .finalized_root()
378 .unwrap(),
379 }
380}
381
382pub fn merkle_path(leaves: &[Digest], index: usize) -> Vec<Digest> {
389 assert!(
390 index < leaves.len(),
391 "no leaf with index {index} in tree of size {}",
392 leaves.len()
393 );
394
395 if leaves.len() == 1 {
396 return Vec::new(); }
398
399 let mut path = Vec::new();
400 let mut current_leaves = leaves;
401 let mut current_index = index;
402
403 while current_leaves.len() > 1 {
404 let mid = current_leaves.len().next_power_of_two() / 2;
406 let (left, right) = current_leaves.split_at(mid);
407
408 if current_index < mid {
410 path.push(merkle_root(right));
411 current_leaves = left;
412 } else {
413 path.push(merkle_root(left));
414 current_leaves = right;
415 current_index -= mid;
416 }
417 }
418
419 path.reverse();
420 path
421}
422
423pub fn merkle_path_root(
427 leaf: &Digest,
428 path: impl IntoIterator<Item = impl Borrow<Digest>>,
429) -> Digest {
430 path.into_iter()
431 .fold(*leaf, |a, b| commutative_keccak256(a.borrow(), b.borrow()))
432}
433
434fn commutative_keccak256(a: &Digest, b: &Digest) -> Digest {
436 let mut hasher = Keccak256::new();
437 if a.as_bytes() < b.as_bytes() {
438 hasher.update(a.as_bytes());
439 hasher.update(b.as_bytes());
440 } else {
441 hasher.update(b.as_bytes());
442 hasher.update(a.as_bytes());
443 }
444 hasher.finalize().0.into()
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450 use hex::FromHex;
451
452 fn assert_merkle_root(digests: &[Digest], expected_root: Digest) {
453 let root = merkle_root(digests);
454 assert_eq!(root, expected_root);
455 }
456
457 #[test]
458 fn test_root_manual() {
459 let digests = vec![
460 Digest::from_hex("6a428060b5d51f04583182f2ff1b565f9db661da12ee7bdc003e9ab6d5d91ba9")
461 .unwrap(),
462 Digest::from_hex("6a428060b5d51f04583182f2ff1b565f9db661da12ee7bdc003e9ab6d5d91ba9")
463 .unwrap(),
464 Digest::from_hex("6a428060b5d51f04583182f2ff1b565f9db661da12ee7bdc003e9ab6d5d91ba9")
465 .unwrap(),
466 ];
467
468 assert_merkle_root(
469 &digests,
470 Digest::from_hex("e004c72e4cb697fa97669508df099edbc053309343772a25e56412fc7db8ebef")
471 .unwrap(),
472 );
473 }
474
475 #[test]
476 fn test_merkle_root() {
477 let digests = vec![Digest::from([0u8; 32])];
478 assert_merkle_root(&digests, digests[0]);
479
480 let digests = vec![
481 Digest::from([0u8; 32]),
482 Digest::from([1u8; 32]),
483 Digest::from([2u8; 32]),
484 ];
485 assert_merkle_root(
486 &digests,
487 commutative_keccak256(
488 &commutative_keccak256(&digests[0], &digests[1]),
489 &digests[2],
490 ),
491 );
492
493 let digests = vec![
494 Digest::from([0u8; 32]),
495 Digest::from([1u8; 32]),
496 Digest::from([2u8; 32]),
497 Digest::from([3u8; 32]),
498 ];
499 assert_merkle_root(
500 &digests,
501 commutative_keccak256(
502 &commutative_keccak256(&digests[0], &digests[1]),
503 &commutative_keccak256(&digests[2], &digests[3]),
504 ),
505 );
506 }
507
508 #[test]
509 fn test_consistency() {
510 for length in 1..=128 {
511 let digests: Vec<Digest> = (0..length)
512 .map(|_| rand::random::<[u8; 32]>().into())
513 .collect();
514 let root = merkle_root(&digests);
515
516 for i in 0..length {
517 let path = merkle_path(&digests, i);
518 assert_eq!(merkle_path_root(&digests[i], &path), root);
519 }
520 }
521 }
522
523 #[test]
524 fn test_encode_decode() {
525 for length in 0..=128 {
526 let digests: Vec<Digest> = (0..length)
527 .map(|_| rand::random::<[u8; 32]>().into())
528 .collect();
529 let mmr = MerkleMountainRange::from_iter(digests);
530
531 assert_eq!(mmr, MerkleMountainRange::decode(mmr.encode()).unwrap());
532 }
533 }
534}