#![deny(rustdoc::broken_intra_doc_links)]
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
extern crate alloc;
use alloc::vec::Vec;
use core::borrow::Borrow;
use alloy_primitives::{uint, Keccak256, U256};
use risc0_zkvm::{
sha::{Digest, DIGEST_BYTES},
ReceiptClaim,
};
use serde::{Deserialize, Serialize};
#[cfg(feature = "verify")]
mod receipt;
#[cfg(feature = "verify")]
pub use receipt::{
decode_set_inclusion_seal, RecursionVerifierParameters, SetInclusionDecodingError,
SetInclusionEncodingError, SetInclusionReceipt, SetInclusionReceiptVerifierParameters,
VerificationError,
};
alloy_sol_types::sol! {
#[sol(all_derives)]
struct Seal {
bytes32[] path;
bytes root_seal;
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct GuestInput {
pub state: GuestState,
pub claims: Vec<ReceiptClaim>,
pub finalize: bool,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct GuestState {
pub self_image_id: Digest,
pub mmr: MerkleMountainRange,
}
impl GuestState {
pub fn initial(self_image_id: impl Into<Digest>) -> Self {
Self {
self_image_id: self_image_id.into(),
mmr: MerkleMountainRange::empty(),
}
}
pub fn is_initial(&self) -> bool {
self.mmr.is_empty()
}
pub fn encode(&self) -> Vec<u8> {
[self.self_image_id.as_bytes(), &self.mmr.encode()].concat()
}
pub fn decode(bytes: impl AsRef<[u8]>) -> Result<Self, DecodingError> {
let (chunk, bytes) = bytes
.as_ref()
.split_at_checked(U256::BYTES)
.ok_or(DecodingError::UnexpectedEnd)?;
let self_image_id = Digest::try_from(chunk).unwrap();
let mmr = MerkleMountainRange::decode(bytes)?;
Ok(Self { self_image_id, mmr })
}
pub fn into_input(
self,
claims: Vec<ReceiptClaim>,
finalize: bool,
) -> Result<GuestInput, Error> {
if self.mmr.is_finalized() {
return Err(Error::FinalizedError);
}
Ok(GuestInput {
state: self,
claims,
finalize,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(test, derive(PartialEq, Eq))]
pub struct MerkleMountainRange(Vec<Peak>);
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(test, derive(PartialEq, Eq))]
struct Peak {
digest: Digest,
max_depth: u8,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
#[error("Merkle mountain range is finalized")]
FinalizedError,
#[error("Merkle mountain range is empty")]
EmptyError,
#[error("decoding error: {0}")]
DecodingError(#[from] DecodingError),
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum DecodingError {
#[error("invalid bitmap")]
InvalidBitmap,
#[error("unexpected end of byte stream")]
UnexpectedEnd,
#[error("trailing bytes")]
TrailingBytes,
}
impl MerkleMountainRange {
pub fn empty() -> Self {
Self(Vec::new())
}
pub fn new_finalized(root: Digest) -> Self {
Self(vec![Peak {
max_depth: u8::MAX,
digest: root,
}])
}
pub fn push(&mut self, value: impl Borrow<Digest>) -> Result<(), Error> {
self.push_peak(Peak {
digest: hash_leaf(value.borrow()),
max_depth: 0,
})
}
fn push_peak(&mut self, new_peak: Peak) -> Result<(), Error> {
if self.is_finalized() {
return Err(Error::FinalizedError);
}
match self.0.last() {
None => self.0.push(new_peak),
Some(peak) if peak.max_depth > new_peak.max_depth => {
self.0.push(new_peak);
}
Some(peak) if peak.max_depth == new_peak.max_depth => {
let peak = self.0.pop().unwrap();
self.push_peak(Peak {
digest: commutative_keccak256(&peak.digest, &new_peak.digest),
max_depth: peak.max_depth.checked_add(1).expect(
"violation of invariant on the finalization of the Merkle mountain range",
),
})?;
}
Some(_) => {
unreachable!("violation of ordering invariant in Merkle mountain range builder")
}
};
Ok(())
}
pub fn finalize(&mut self) -> Result<(), Error> {
let root = self.0.iter().rev().fold(None, |root, peak| {
Some(match root {
Some(root) => commutative_keccak256(&root, &peak.digest),
None => peak.digest,
})
});
let Some(root) = root else {
return Err(Error::EmptyError);
};
self.0.clear();
self.0.push(Peak {
digest: root,
max_depth: u8::MAX,
});
Ok(())
}
pub fn finalized_root(mut self) -> Option<Digest> {
match self.is_empty() {
true => None,
false => {
self.finalize().unwrap();
Some(self.0[0].digest)
}
}
}
pub fn is_finalized(&self) -> bool {
self.0.first().is_some_and(|peak| peak.max_depth == u8::MAX)
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn encode(&self) -> Vec<u8> {
let mut bitmap = U256::ZERO;
let mut peaks = Vec::<Digest>::with_capacity(self.0.len());
for peak in self.0.iter() {
bitmap.set_bit(peak.max_depth as usize, true);
peaks.push(peak.digest);
}
[
&bitmap.to_be_bytes::<{ U256::BYTES }>(),
bytemuck::cast_slice(&peaks),
]
.concat()
}
pub fn decode(bytes: impl AsRef<[u8]>) -> Result<Self, DecodingError> {
let (mut chunk, mut bytes) = bytes
.as_ref()
.split_at_checked(U256::BYTES)
.ok_or(DecodingError::UnexpectedEnd)?;
let bitmap = U256::from_be_slice(chunk);
if bitmap > (uint!(1_U256 << u8::MAX)) {
return Err(DecodingError::InvalidBitmap);
}
let mut peaks = Vec::<Peak>::with_capacity(bitmap.count_ones());
for i in (0..=u8::MAX).rev() {
if !bitmap.bit(i as usize) {
continue;
}
(chunk, bytes) = bytes
.split_at_checked(DIGEST_BYTES)
.ok_or(DecodingError::UnexpectedEnd)?;
peaks.push(Peak {
digest: Digest::try_from(chunk).unwrap(),
max_depth: i,
});
}
if !bytes.is_empty() {
return Err(DecodingError::TrailingBytes);
}
Ok(Self(peaks))
}
}
impl<D: Borrow<Digest>> Extend<D> for MerkleMountainRange {
fn extend<T: IntoIterator<Item = D>>(&mut self, values: T) {
for value in values {
self.push(value)
.expect("attempted to extend a finalized MerkleMountainRange");
}
}
}
impl<D: Borrow<Digest>> FromIterator<D> for MerkleMountainRange {
fn from_iter<T: IntoIterator<Item = D>>(values: T) -> Self {
let mut mmr = Self::empty();
mmr.extend(values);
mmr
}
}
pub fn merkle_root(leaves: &[Digest]) -> Digest {
match leaves {
[] => panic!("digest list is empty, cannot compute Merkle root"),
_ => MerkleMountainRange::from_iter(leaves)
.finalized_root()
.unwrap(),
}
}
pub fn merkle_path(leaves: &[Digest], index: usize) -> Vec<Digest> {
assert!(
index < leaves.len(),
"no leaf with index {index} in tree of size {}",
leaves.len()
);
if leaves.len() == 1 {
return Vec::new(); }
let mut path = Vec::new();
let mut current_leaves = leaves;
let mut current_index = index;
while current_leaves.len() > 1 {
let mid = current_leaves.len().next_power_of_two() / 2;
let (left, right) = current_leaves.split_at(mid);
if current_index < mid {
path.push(merkle_root(right));
current_leaves = left;
} else {
path.push(merkle_root(left));
current_leaves = right;
current_index -= mid;
}
}
path.reverse();
path
}
pub fn merkle_path_root(
leaf_value: impl Borrow<Digest>,
path: impl IntoIterator<Item = impl Borrow<Digest>>,
) -> Digest {
let leaf = hash_leaf(leaf_value.borrow());
path.into_iter()
.fold(leaf, |a, b| commutative_keccak256(a.borrow(), b.borrow()))
}
const LEAF_TAG: &[u8; 8] = b"LEAF_TAG";
fn hash_leaf(value: &Digest) -> Digest {
let mut hasher = Keccak256::new();
hasher.update(LEAF_TAG);
hasher.update(value.as_bytes());
hasher.finalize().0.into()
}
fn commutative_keccak256(a: &Digest, b: &Digest) -> Digest {
let mut hasher = Keccak256::new();
if a.as_bytes() < b.as_bytes() {
hasher.update(a.as_bytes());
hasher.update(b.as_bytes());
} else {
hasher.update(b.as_bytes());
hasher.update(a.as_bytes());
}
hasher.finalize().0.into()
}
#[cfg(test)]
mod tests {
use super::*;
use hex::FromHex;
fn assert_merkle_root(digests: &[Digest], expected_root: Digest) {
let root = merkle_root(digests);
assert_eq!(root, expected_root);
}
#[test]
fn test_root_manual() {
let digests = vec![
Digest::from_hex("6a428060b5d51f04583182f2ff1b565f9db661da12ee7bdc003e9ab6d5d91ba9")
.unwrap(),
Digest::from_hex("6a428060b5d51f04583182f2ff1b565f9db661da12ee7bdc003e9ab6d5d91ba9")
.unwrap(),
Digest::from_hex("6a428060b5d51f04583182f2ff1b565f9db661da12ee7bdc003e9ab6d5d91ba9")
.unwrap(),
];
assert_merkle_root(
&digests,
Digest::from_hex("bd792a6858270b233a6b399c1cbc60c5b1046a5b43758b9abc46ba32d23c7352")
.unwrap(),
);
}
#[test]
fn test_merkle_root() {
let digests = vec![Digest::from([0u8; 32])];
assert_merkle_root(&digests, hash_leaf(&digests[0]));
let digests = vec![
Digest::from([0u8; 32]),
Digest::from([1u8; 32]),
Digest::from([2u8; 32]),
];
assert_merkle_root(
&digests,
commutative_keccak256(
&commutative_keccak256(&hash_leaf(&digests[0]), &hash_leaf(&digests[1])),
&hash_leaf(&digests[2]),
),
);
let digests = vec![
Digest::from([0u8; 32]),
Digest::from([1u8; 32]),
Digest::from([2u8; 32]),
Digest::from([3u8; 32]),
];
assert_merkle_root(
&digests,
commutative_keccak256(
&commutative_keccak256(&hash_leaf(&digests[0]), &hash_leaf(&digests[1])),
&commutative_keccak256(&hash_leaf(&digests[2]), &hash_leaf(&digests[3])),
),
);
}
#[test]
fn test_consistency() {
for length in 1..=128 {
let digests: Vec<Digest> = (0..length)
.map(|_| rand::random::<[u8; 32]>().into())
.collect();
let root = merkle_root(&digests);
for i in 0..length {
let path = merkle_path(&digests, i);
assert_eq!(merkle_path_root(digests[i], &path), root);
}
}
}
#[test]
fn test_encode_decode() {
for length in 0..=128 {
let digests: Vec<Digest> = (0..length)
.map(|_| rand::random::<[u8; 32]>().into())
.collect();
let mmr = MerkleMountainRange::from_iter(digests);
assert_eq!(mmr, MerkleMountainRange::decode(mmr.encode()).unwrap());
}
}
}