#[cfg(test)]
use alloc::collections::BTreeSet;
use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec};
use core::{fmt, ops::Index};
#[cfg(any(test, feature = "arbitrary"))]
use proptest::prelude::*;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
mod node;
#[cfg(any(test, feature = "arbitrary"))]
pub use node::arbitrary;
pub(crate) use node::collect_immediate_placements;
pub use node::{
BasicBlockNode, BasicBlockNodeBuilder, CallNode, CallNodeBuilder, DynNode, DynNodeBuilder,
ExternalNode, ExternalNodeBuilder, JoinNode, JoinNodeBuilder, LoopNode, LoopNodeBuilder,
MastForestContributor, MastNode, MastNodeBuilder, MastNodeExt, OP_BATCH_SIZE, OP_GROUP_SIZE,
OpBatch, SplitNode, SplitNodeBuilder,
};
#[cfg(feature = "serde")]
use crate::serde::{Deserializable, Serializable, SliceReader};
use crate::{
Felt, Word,
advice::AdviceMap,
serde::{ByteWriter, DeserializationError},
utils::{Idx, IndexVec, hash_string_to_word},
};
mod serialization;
pub use serialization::{
AdviceMapView, AdviceValueView, MastForestReadMode, MastForestReadView, MastForestView,
MastForestWireView, MastNodeEntry, MastNodeInfo,
};
mod untrusted;
pub use untrusted::{UntrustedMastForest, UntrustedMastForestReadOptions};
mod merger;
pub(crate) use merger::MastForestMerger;
pub use merger::MastForestRootMap;
mod multi_forest_node_iterator;
pub(crate) use multi_forest_node_iterator::*;
mod node_builder_utils;
pub use node_builder_utils::build_node_with_remapped_ids;
mod sparse;
pub use sparse::{MastForestId, SparseMastForest, SparseMastForestBuilder, VisitKind};
#[cfg(test)]
mod tests;
#[derive(Clone, Debug, Default)]
#[cfg_attr(
all(feature = "arbitrary", test),
miden_test_serde_macros::serde_test(binary_serde(true))
)]
pub struct MastForest {
nodes: IndexVec<MastNodeId, MastNode>,
roots: Vec<MastNodeId>,
advice_map: AdviceMap,
commitment: Word,
}
pub(crate) struct MastForestParts {
pub nodes: IndexVec<MastNodeId, MastNode>,
pub roots: Vec<MastNodeId>,
pub advice_map: AdviceMap,
}
impl MastForest {
pub fn new() -> Self {
Self {
nodes: IndexVec::new(),
roots: Vec::new(),
advice_map: AdviceMap::default(),
commitment: empty_mast_forest_commitment(),
}
}
#[doc(hidden)]
pub fn from_raw_parts(
nodes: IndexVec<MastNodeId, MastNode>,
roots: Vec<MastNodeId>,
advice_map: AdviceMap,
) -> Result<Self, MastForestError> {
Self::from_parts(MastForestParts { nodes, roots, advice_map })
}
pub(crate) fn from_parts(parts: MastForestParts) -> Result<Self, MastForestError> {
if parts.nodes.len() > Self::MAX_NODES {
return Err(MastForestError::TooManyNodes);
}
let node_count = parts.nodes.len();
for &root_id in &parts.roots {
if root_id.to_usize() >= node_count {
return Err(MastForestError::NodeIdOverflow(root_id, node_count));
}
}
let forest = Self {
commitment: compute_nodes_commitment(&parts.nodes, &parts.roots),
nodes: parts.nodes,
roots: parts.roots,
advice_map: parts.advice_map,
};
forest.validate()?;
forest.validate_node_hashes()?;
Ok(forest)
}
pub(in crate::mast) fn from_trusted_deserialization_parts(
parts: MastForestParts,
) -> Result<Self, MastForestError> {
if parts.nodes.len() > Self::MAX_NODES {
return Err(MastForestError::TooManyNodes);
}
let node_count = parts.nodes.len();
for &root_id in &parts.roots {
if root_id.to_usize() >= node_count {
return Err(MastForestError::NodeIdOverflow(root_id, node_count));
}
}
Ok(Self {
commitment: compute_nodes_commitment(&parts.nodes, &parts.roots),
nodes: parts.nodes,
roots: parts.roots,
advice_map: parts.advice_map,
})
}
}
impl PartialEq for MastForest {
fn eq(&self, other: &Self) -> bool {
self.nodes == other.nodes
&& self.roots == other.roots
&& self.advice_map == other.advice_map
}
}
impl Eq for MastForest {}
impl MastForest {
const MAX_NODES: usize = (1 << 30) - 1;
fn mark_root(&mut self, new_root_id: MastNodeId) {
assert!(new_root_id.to_usize() < self.nodes.len());
if !self.roots.contains(&new_root_id) {
self.roots.push(new_root_id);
self.commitment = self.compute_nodes_commitment(&self.roots);
}
}
#[cfg(any(test, feature = "arbitrary"))]
pub fn make_root(&mut self, new_root_id: MastNodeId) {
self.mark_root(new_root_id);
}
#[cfg(test)]
pub fn remove_nodes(
&mut self,
nodes_to_remove: &BTreeSet<MastNodeId>,
) -> BTreeMap<MastNodeId, MastNodeId> {
if nodes_to_remove.is_empty() {
return BTreeMap::new();
}
self.assert_nodes_to_remove_are_orphaned(nodes_to_remove);
let old_nodes = core::mem::replace(&mut self.nodes, IndexVec::new());
let old_root_ids = core::mem::take(&mut self.roots);
let (retained_nodes, id_remappings) = remove_nodes(old_nodes.into_inner(), nodes_to_remove);
self.remap_and_add_nodes(retained_nodes, &id_remappings);
self.remap_and_add_roots(old_root_ids, &id_remappings);
self.commitment = self.compute_nodes_commitment(&self.roots);
id_remappings
}
pub fn compact(self) -> (MastForest, MastForestRootMap) {
MastForest::merge([&self])
.expect("Failed to compact MastForest: this should never happen during self-merge")
}
pub fn merge<'forest>(
forests: impl IntoIterator<Item = &'forest MastForest>,
) -> Result<(MastForest, MastForestRootMap), MastForestError> {
MastForestMerger::merge(forests)
}
}
impl MastForest {
#[cfg(test)]
fn assert_nodes_to_remove_are_orphaned(&self, nodes_to_remove: &BTreeSet<MastNodeId>) {
for (node_idx, node) in self.nodes.iter().enumerate() {
let node_id = MastNodeId::new_unchecked(node_idx.try_into().expect("too many nodes"));
if nodes_to_remove.contains(&node_id) {
continue;
}
node.for_each_child(|child_id| {
assert!(
!nodes_to_remove.contains(&child_id),
"cannot remove node {child_id:?}; retained node {node_id:?} references it"
);
});
}
}
#[cfg(test)]
fn remap_and_add_nodes(
&mut self,
nodes_to_add: Vec<MastNode>,
id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
) {
assert!(self.nodes.is_empty());
let node_builders =
nodes_to_add.into_iter().map(|node| node.to_builder(self)).collect::<Vec<_>>();
for live_node_builder in node_builders {
live_node_builder.remap_children(id_remappings).add_to_forest(self).unwrap();
}
}
#[cfg(test)]
fn remap_and_add_roots(
&mut self,
old_root_ids: Vec<MastNodeId>,
id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
) {
assert!(self.roots.is_empty());
for old_root_id in old_root_ids {
if let Some(new_root_id) = id_remappings.get(&old_root_id).copied() {
self.mark_root(new_root_id);
}
}
}
}
#[cfg(test)]
fn remove_nodes(
mast_nodes: Vec<MastNode>,
nodes_to_remove: &BTreeSet<MastNodeId>,
) -> (Vec<MastNode>, BTreeMap<MastNodeId, MastNodeId>) {
assert!(mast_nodes.len() < u32::MAX as usize);
let mut retained_nodes = Vec::with_capacity(mast_nodes.len());
let mut id_remappings = BTreeMap::new();
for (old_node_index, old_node) in mast_nodes.into_iter().enumerate() {
let old_node_id: MastNodeId = MastNodeId(old_node_index as u32);
if !nodes_to_remove.contains(&old_node_id) {
let new_node_id: MastNodeId = MastNodeId(retained_nodes.len() as u32);
id_remappings.insert(old_node_id, new_node_id);
retained_nodes.push(old_node);
}
}
(retained_nodes, id_remappings)
}
fn empty_mast_forest_commitment() -> Word {
miden_crypto::hash::poseidon2::Poseidon2::merge_many(&[])
}
fn compute_nodes_commitment(
nodes: &IndexVec<MastNodeId, MastNode>,
node_ids: &[MastNodeId],
) -> Word {
let mut digests: Vec<Word> = node_ids.iter().map(|&id| nodes[id].digest()).collect();
digests.sort_unstable();
miden_crypto::hash::poseidon2::Poseidon2::merge_many(&digests)
}
impl MastForest {
#[inline(always)]
pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
self.nodes.get(node_id)
}
#[inline(always)]
pub fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId> {
self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied()
}
pub fn is_procedure_root(&self, node_id: MastNodeId) -> bool {
self.roots.contains(&node_id)
}
pub fn is_procedure_root_with_exact_digest(&self, node_id: MastNodeId, digest: Word) -> bool {
self.is_procedure_root(node_id) && self[node_id].digest() == digest
}
pub fn procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
self.roots.iter().map(|&root_id| self[root_id].digest())
}
pub fn local_procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
self.roots.iter().filter_map(|&root_id| {
let node = &self[root_id];
if node.is_external() { None } else { Some(node.digest()) }
})
}
pub fn procedure_roots(&self) -> &[MastNodeId] {
&self.roots
}
pub fn num_procedures(&self) -> u32 {
self.roots
.len()
.try_into()
.expect("MAST forest contains more than 2^32 procedures.")
}
pub fn compute_nodes_commitment<'a>(
&self,
node_ids: impl IntoIterator<Item = &'a MastNodeId>,
) -> Word {
let node_ids = node_ids.into_iter().copied().collect::<Vec<_>>();
compute_nodes_commitment(&self.nodes, &node_ids)
}
pub fn commitment(&self) -> Word {
self.commitment
}
pub fn num_nodes(&self) -> u32 {
self.nodes.len() as u32
}
pub fn nodes(&self) -> &[MastNode] {
self.nodes.as_slice()
}
pub fn advice_map(&self) -> &AdviceMap {
&self.advice_map
}
pub fn with_advice_map(mut self, advice_map: AdviceMap) -> Self {
self.advice_map.extend(advice_map);
self
}
#[cfg(test)]
pub(crate) fn advice_map_mut(&mut self) -> &mut AdviceMap {
&mut self.advice_map
}
pub fn write_hashless<W: ByteWriter>(&self, target: &mut W) {
serialization::write_hashless_into(self, target);
}
}
impl MastForest {
fn validate_basic_block_invariants(&self) -> Result<(), MastForestError> {
for (node_id_idx, node) in self.nodes.iter().enumerate() {
let node_id =
MastNodeId::new_unchecked(node_id_idx.try_into().expect("too many nodes"));
if let MastNode::Block(basic_block) = node {
basic_block.validate_batch_invariants().map_err(|error_msg| {
MastForestError::InvalidBatchPadding(node_id, error_msg)
})?;
}
}
Ok(())
}
pub fn validate(&self) -> Result<(), MastForestError> {
self.validate_basic_block_invariants()?;
Ok(())
}
fn validate_node_hashes(&self) -> Result<(), MastForestError> {
let computed_hashes = self.compute_node_hashes()?;
for (node_idx, (node, computed_digest)) in
self.nodes.iter().zip(computed_hashes).enumerate()
{
let expected_digest = node.digest();
if expected_digest != computed_digest {
return Err(MastForestError::HashMismatch {
node_id: MastNodeId::new_unchecked(node_idx as u32),
expected: expected_digest,
computed: computed_digest,
});
}
}
Ok(())
}
fn compute_node_hashes(&self) -> Result<Vec<Word>, MastForestError> {
use crate::chiplets::hasher;
fn check_no_forward_ref(
node_id: MastNodeId,
child_id: MastNodeId,
) -> Result<(), MastForestError> {
if child_id.0 >= node_id.0 {
return Err(MastForestError::ForwardReference(node_id, child_id));
}
Ok(())
}
let mut computed_hashes = Vec::with_capacity(self.nodes.len());
for (node_idx, node) in self.nodes.iter().enumerate() {
let node_id = MastNodeId::new_unchecked(node_idx as u32);
let computed_digest = match node {
MastNode::Block(block) => {
let op_groups: Vec<Felt> =
block.op_batches().iter().flat_map(|batch| *batch.groups()).collect();
hasher::hash_elements(&op_groups)
},
MastNode::Join(join) => {
let left_id = join.first();
let right_id = join.second();
check_no_forward_ref(node_id, left_id)?;
check_no_forward_ref(node_id, right_id)?;
let left_digest = computed_hashes[left_id.0 as usize];
let right_digest = computed_hashes[right_id.0 as usize];
hasher::merge_in_domain(&[left_digest, right_digest], JoinNode::DOMAIN)
},
MastNode::Split(split) => {
let true_id = split.on_true();
let false_id = split.on_false();
check_no_forward_ref(node_id, true_id)?;
check_no_forward_ref(node_id, false_id)?;
let true_digest = computed_hashes[true_id.0 as usize];
let false_digest = computed_hashes[false_id.0 as usize];
hasher::merge_in_domain(&[true_digest, false_digest], SplitNode::DOMAIN)
},
MastNode::Loop(loop_node) => {
let body_id = loop_node.body();
check_no_forward_ref(node_id, body_id)?;
let body_digest = computed_hashes[body_id.0 as usize];
hasher::merge_in_domain(&[body_digest, Word::default()], LoopNode::DOMAIN)
},
MastNode::Call(call) => {
let callee_id = call.callee();
check_no_forward_ref(node_id, callee_id)?;
let callee_digest = computed_hashes[callee_id.0 as usize];
let domain = if call.is_syscall() {
CallNode::SYSCALL_DOMAIN
} else {
CallNode::CALL_DOMAIN
};
hasher::merge_in_domain(&[callee_digest, Word::default()], domain)
},
MastNode::Dyn(dyn_node) => {
if dyn_node.is_dyncall() {
DynNode::DYNCALL_DEFAULT_DIGEST
} else {
DynNode::DYN_DEFAULT_DIGEST
}
},
MastNode::External(_) => {
node.digest()
},
};
computed_hashes.push(computed_digest);
}
Ok(computed_hashes)
}
}
impl Index<MastNodeId> for MastForest {
type Output = MastNode;
#[inline(always)]
fn index(&self, node_id: MastNodeId) -> &Self::Output {
&self.nodes[node_id]
}
}
pub trait ExecutableMastForest {
fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode>;
fn get_digest_by_id(&self, node_id: MastNodeId) -> Option<Word>;
fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId>;
fn advice_map(&self) -> &AdviceMap;
}
impl ExecutableMastForest for MastForest {
#[inline(always)]
fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
MastForest::get_node_by_id(self, node_id)
}
#[inline(always)]
fn get_digest_by_id(&self, node_id: MastNodeId) -> Option<Word> {
MastForest::get_node_by_id(self, node_id).map(MastNodeExt::digest)
}
#[inline(always)]
fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId> {
MastForest::find_procedure_root(self, digest)
}
#[inline(always)]
fn advice_map(&self) -> &AdviceMap {
MastForest::advice_map(self)
}
}
impl<T> Index<MastNodeId> for Arc<T>
where
T: Index<MastNodeId, Output = MastNode> + ?Sized,
{
type Output = MastNode;
#[inline(always)]
fn index(&self, node_id: MastNodeId) -> &Self::Output {
&(**self)[node_id]
}
}
impl<T: ExecutableMastForest + ?Sized> ExecutableMastForest for Arc<T> {
#[inline(always)]
fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
T::get_node_by_id(self, node_id)
}
#[inline(always)]
fn get_digest_by_id(&self, node_id: MastNodeId) -> Option<Word> {
T::get_digest_by_id(self, node_id)
}
#[inline(always)]
fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId> {
T::find_procedure_root(self, digest)
}
#[inline(always)]
fn advice_map(&self) -> &AdviceMap {
T::advice_map(self)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
pub struct MastNodeId(u32);
pub type Remapping = BTreeMap<MastNodeId, MastNodeId>;
impl MastNodeId {
pub fn from_u32_safe(
value: u32,
mast_forest: &MastForest,
) -> Result<Self, DeserializationError> {
Self::from_u32_with_node_count(value, mast_forest.nodes.len())
}
pub fn new_unchecked(value: u32) -> Self {
Self(value)
}
pub(super) fn from_u32_with_node_count(
id: u32,
node_count: usize,
) -> Result<Self, DeserializationError> {
if (id as usize) < node_count {
Ok(Self(id))
} else {
Err(DeserializationError::InvalidValue(format!(
"Invalid deserialized MAST node ID '{id}', but {node_count} is the number of nodes in the forest",
)))
}
}
pub fn remap(&self, remapping: &Remapping) -> Self {
*remapping.get(self).unwrap_or(self)
}
}
impl From<u32> for MastNodeId {
fn from(value: u32) -> Self {
MastNodeId::new_unchecked(value)
}
}
impl Idx for MastNodeId {}
impl From<MastNodeId> for u32 {
fn from(value: MastNodeId) -> Self {
value.0
}
}
impl fmt::Display for MastNodeId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "MastNodeId({})", self.0)
}
}
#[cfg(any(test, feature = "arbitrary"))]
impl Arbitrary for MastNodeId {
type Parameters = ();
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
use proptest::prelude::*;
any::<u32>().prop_map(MastNodeId).boxed()
}
type Strategy = BoxedStrategy<Self>;
}
pub struct SubtreeIterator<'a> {
forest: &'a MastForest,
discovered: Vec<MastNodeId>,
unvisited: Vec<MastNodeId>,
}
impl<'a> SubtreeIterator<'a> {
pub fn new(root: &MastNodeId, forest: &'a MastForest) -> Self {
let discovered = vec![];
let unvisited = vec![*root];
SubtreeIterator { forest, discovered, unvisited }
}
}
impl Iterator for SubtreeIterator<'_> {
type Item = MastNodeId;
fn next(&mut self) -> Option<MastNodeId> {
while let Some(id) = self.unvisited.pop() {
let node = &self.forest[id];
if !node.has_children() {
return Some(id);
} else {
self.discovered.push(id);
node.append_children_to(&mut self.unvisited);
}
}
self.discovered.pop()
}
}
pub fn error_code_from_msg(msg: impl AsRef<str>) -> Felt {
hash_string_to_word(msg.as_ref())[0]
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum MastForestError {
#[error("MAST forest node count exceeds the maximum of {} nodes", MastForest::MAX_NODES)]
TooManyNodes,
#[error("node id {0} is greater than or equal to forest length {1}")]
NodeIdOverflow(MastNodeId, usize),
#[error("basic block cannot be created from an empty list of operations")]
EmptyBasicBlock,
#[error("advice map key {0} already exists when merging forests")]
AdviceMapKeyCollisionOnMerge(Word),
#[error("digest is required for deserialization")]
DigestRequiredForDeserialization,
#[error("invalid batch in basic block node {0:?}: {1}")]
InvalidBatchPadding(MastNodeId, String),
#[error(
"node {0:?} references child {1:?} which comes after it in the forest (forward reference)"
)]
ForwardReference(MastNodeId, MastNodeId),
#[error("hash mismatch for node {node_id:?}: expected {expected:?}, computed {computed:?}")]
HashMismatch {
node_id: MastNodeId,
expected: Word,
computed: Word,
},
#[error("deserialization failed: {0}")]
Deserialization(DeserializationError),
}
#[cfg(feature = "serde")]
impl Serialize for MastForest {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let bytes = Serializable::to_bytes(self);
serializer.serialize_bytes(&bytes)
}
}
#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for MastForest {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let bytes = Vec::<u8>::deserialize(deserializer)?;
let mut slice_reader = SliceReader::new(&bytes);
Deserializable::read_from(&mut slice_reader).map_err(serde::de::Error::custom)
}
}