use alloc::{
collections::{BTreeMap, BTreeSet},
sync::Arc,
vec::Vec,
};
use core::{
fmt,
ops::{Index, IndexMut},
};
pub use miden_utils_indexing::{IndexVec, IndexedVecError};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
mod node;
#[cfg(any(test, feature = "arbitrary"))]
pub use node::arbitrary;
pub use node::{
BasicBlockNode, CallNode, DecoratedOpLink, DecoratorOpLinkIterator, DynNode, ExternalNode,
JoinNode, LoopNode, MastNode, MastNodeErrorContext, MastNodeExt, OP_BATCH_SIZE, OP_GROUP_SIZE,
OpBatch, OperationOrDecorator, SplitNode,
};
use crate::{
AdviceMap, Decorator, DecoratorList, Felt, Idx, LexicographicWord, Operation, Word,
crypto::hash::Hasher,
utils::{ByteWriter, DeserializationError, Serializable, hash_string_to_word},
};
mod serialization;
mod merger;
pub(crate) use merger::MastForestMerger;
pub use merger::MastForestRootMap;
mod multi_forest_node_iterator;
pub(crate) use multi_forest_node_iterator::*;
mod node_fingerprint;
pub use node_fingerprint::{DecoratorFingerprint, MastNodeFingerprint};
#[cfg(test)]
mod tests;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
pub struct MastForest {
nodes: IndexVec<MastNodeId, MastNode>,
roots: Vec<MastNodeId>,
decorators: IndexVec<DecoratorId, Decorator>,
advice_map: AdviceMap,
error_codes: BTreeMap<u64, Arc<str>>,
}
impl MastForest {
pub fn new() -> Self {
Self {
nodes: IndexVec::new(),
roots: Vec::new(),
decorators: IndexVec::new(),
advice_map: AdviceMap::default(),
error_codes: BTreeMap::new(),
}
}
}
impl MastForest {
const MAX_NODES: usize = (1 << 30) - 1;
pub fn add_decorator(&mut self, decorator: Decorator) -> Result<DecoratorId, MastForestError> {
self.decorators.push(decorator).map_err(|_| MastForestError::TooManyDecorators)
}
pub fn add_node(&mut self, node: impl Into<MastNode>) -> Result<MastNodeId, MastForestError> {
self.nodes.push(node.into()).map_err(|_| MastForestError::TooManyNodes)
}
pub fn add_block(
&mut self,
operations: Vec<Operation>,
decorators: DecoratorList,
) -> Result<MastNodeId, MastForestError> {
let block = BasicBlockNode::new(operations, decorators)?;
self.add_node(block)
}
pub fn add_join(
&mut self,
left_child: MastNodeId,
right_child: MastNodeId,
) -> Result<MastNodeId, MastForestError> {
let join = JoinNode::new([left_child, right_child], self)?;
self.add_node(join)
}
pub fn add_split(
&mut self,
if_branch: MastNodeId,
else_branch: MastNodeId,
) -> Result<MastNodeId, MastForestError> {
let split = SplitNode::new([if_branch, else_branch], self)?;
self.add_node(split)
}
pub fn add_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, MastForestError> {
let loop_node = LoopNode::new(body, self)?;
self.add_node(loop_node)
}
pub fn add_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
let call = CallNode::new(callee, self)?;
self.add_node(call)
}
pub fn add_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
let syscall = CallNode::new_syscall(callee, self)?;
self.add_node(syscall)
}
pub fn add_dyn(&mut self) -> Result<MastNodeId, MastForestError> {
self.add_node(DynNode::new_dyn())
}
pub fn add_dyncall(&mut self) -> Result<MastNodeId, MastForestError> {
self.add_node(DynNode::new_dyncall())
}
pub fn add_external(&mut self, mast_root: Word) -> Result<MastNodeId, MastForestError> {
self.add_node(ExternalNode::new(mast_root))
}
pub fn make_root(&mut self, new_root_id: MastNodeId) {
assert!(new_root_id.to_usize() < self.nodes.len());
if !self.roots.contains(&new_root_id) {
self.roots.push(new_root_id);
}
}
pub fn remove_nodes(
&mut self,
nodes_to_remove: &BTreeSet<MastNodeId>,
) -> BTreeMap<MastNodeId, MastNodeId> {
if nodes_to_remove.is_empty() {
return BTreeMap::new();
}
let old_nodes = core::mem::replace(&mut self.nodes, IndexVec::new());
let old_root_ids = core::mem::take(&mut self.roots);
let (retained_nodes, id_remappings) = remove_nodes(old_nodes.into_inner(), nodes_to_remove);
self.remap_and_add_nodes(retained_nodes, &id_remappings);
self.remap_and_add_roots(old_root_ids, &id_remappings);
id_remappings
}
pub fn append_before_enter(&mut self, node_id: MastNodeId, decorator_ids: &[DecoratorId]) {
self[node_id].append_before_enter(decorator_ids)
}
pub fn append_after_exit(&mut self, node_id: MastNodeId, decorator_ids: &[DecoratorId]) {
self[node_id].append_after_exit(decorator_ids)
}
pub fn strip_decorators(&mut self) {
for node in self.nodes.iter_mut() {
node.remove_decorators();
}
self.decorators = IndexVec::new();
}
pub fn merge<'forest>(
forests: impl IntoIterator<Item = &'forest MastForest>,
) -> Result<(MastForest, MastForestRootMap), MastForestError> {
MastForestMerger::merge(forests)
}
#[cfg(test)]
pub fn add_block_with_raw_decorators(
&mut self,
operations: Vec<Operation>,
decorators: Vec<(usize, Decorator)>,
) -> Result<MastNodeId, MastForestError> {
let block = BasicBlockNode::new_with_raw_decorators(operations, decorators, self)?;
self.add_node(block)
}
}
impl MastForest {
fn remap_and_add_nodes(
&mut self,
nodes_to_add: Vec<MastNode>,
id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
) {
assert!(self.nodes.is_empty());
for live_node in nodes_to_add {
match &live_node {
MastNode::Join(join_node) => {
let first_child =
id_remappings.get(&join_node.first()).copied().unwrap_or(join_node.first());
let second_child = id_remappings
.get(&join_node.second())
.copied()
.unwrap_or(join_node.second());
self.add_join(first_child, second_child).unwrap();
},
MastNode::Split(split_node) => {
let on_true_child = id_remappings
.get(&split_node.on_true())
.copied()
.unwrap_or(split_node.on_true());
let on_false_child = id_remappings
.get(&split_node.on_false())
.copied()
.unwrap_or(split_node.on_false());
self.add_split(on_true_child, on_false_child).unwrap();
},
MastNode::Loop(loop_node) => {
let body_id =
id_remappings.get(&loop_node.body()).copied().unwrap_or(loop_node.body());
self.add_loop(body_id).unwrap();
},
MastNode::Call(call_node) => {
let callee_id = id_remappings
.get(&call_node.callee())
.copied()
.unwrap_or(call_node.callee());
if call_node.is_syscall() {
self.add_syscall(callee_id).unwrap();
} else {
self.add_call(callee_id).unwrap();
}
},
MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => {
self.add_node(live_node).unwrap();
},
}
}
}
fn remap_and_add_roots(
&mut self,
old_root_ids: Vec<MastNodeId>,
id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
) {
assert!(self.roots.is_empty());
for old_root_id in old_root_ids {
let new_root_id = id_remappings.get(&old_root_id).copied().unwrap_or(old_root_id);
self.make_root(new_root_id);
}
}
}
fn remove_nodes(
mast_nodes: Vec<MastNode>,
nodes_to_remove: &BTreeSet<MastNodeId>,
) -> (Vec<MastNode>, BTreeMap<MastNodeId, MastNodeId>) {
assert!(mast_nodes.len() < u32::MAX as usize);
let mut retained_nodes = Vec::with_capacity(mast_nodes.len());
let mut id_remappings = BTreeMap::new();
for (old_node_index, old_node) in mast_nodes.into_iter().enumerate() {
let old_node_id: MastNodeId = MastNodeId(old_node_index as u32);
if !nodes_to_remove.contains(&old_node_id) {
let new_node_id: MastNodeId = MastNodeId(retained_nodes.len() as u32);
id_remappings.insert(old_node_id, new_node_id);
retained_nodes.push(old_node);
}
}
(retained_nodes, id_remappings)
}
impl MastForest {
#[inline(always)]
pub fn get_decorator_by_id(&self, decorator_id: DecoratorId) -> Option<&Decorator> {
self.decorators.get(decorator_id)
}
#[inline(always)]
pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
self.nodes.get(node_id)
}
#[inline(always)]
pub fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId> {
self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied()
}
pub fn is_procedure_root(&self, node_id: MastNodeId) -> bool {
self.roots.contains(&node_id)
}
pub fn procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
self.roots.iter().map(|&root_id| self[root_id].digest())
}
pub fn local_procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
self.roots.iter().filter_map(|&root_id| {
let node = &self[root_id];
if node.is_external() { None } else { Some(node.digest()) }
})
}
pub fn procedure_roots(&self) -> &[MastNodeId] {
&self.roots
}
pub fn num_procedures(&self) -> u32 {
self.roots
.len()
.try_into()
.expect("MAST forest contains more than 2^32 procedures.")
}
pub fn compute_nodes_commitment<'a>(
&self,
node_ids: impl IntoIterator<Item = &'a MastNodeId>,
) -> Word {
let mut digests: Vec<Word> = node_ids.into_iter().map(|&id| self[id].digest()).collect();
digests.sort_unstable_by_key(|word| LexicographicWord::from(*word));
miden_crypto::hash::rpo::Rpo256::merge_many(&digests)
}
pub fn num_nodes(&self) -> u32 {
self.nodes.len() as u32
}
pub fn nodes(&self) -> &[MastNode] {
self.nodes.as_slice()
}
pub fn decorators(&self) -> &[Decorator] {
self.decorators.as_slice()
}
pub fn advice_map(&self) -> &AdviceMap {
&self.advice_map
}
pub fn advice_map_mut(&mut self) -> &mut AdviceMap {
&mut self.advice_map
}
pub fn register_error(&mut self, msg: Arc<str>) -> Felt {
let code: Felt = error_code_from_msg(&msg);
self.error_codes.insert(code.as_int(), msg);
code
}
pub fn resolve_error_message(&self, code: Felt) -> Option<Arc<str>> {
let key = u64::from(code);
self.error_codes.get(&key).cloned()
}
}
impl Index<MastNodeId> for MastForest {
type Output = MastNode;
#[inline(always)]
fn index(&self, node_id: MastNodeId) -> &Self::Output {
&self.nodes[node_id]
}
}
impl IndexMut<MastNodeId> for MastForest {
#[inline(always)]
fn index_mut(&mut self, node_id: MastNodeId) -> &mut Self::Output {
&mut self.nodes[node_id]
}
}
impl Index<DecoratorId> for MastForest {
type Output = Decorator;
#[inline(always)]
fn index(&self, decorator_id: DecoratorId) -> &Self::Output {
&self.decorators[decorator_id]
}
}
impl IndexMut<DecoratorId> for MastForest {
#[inline(always)]
fn index_mut(&mut self, decorator_id: DecoratorId) -> &mut Self::Output {
&mut self.decorators[decorator_id]
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
pub struct MastNodeId(u32);
pub type Remapping = BTreeMap<MastNodeId, MastNodeId>;
impl MastNodeId {
pub fn from_u32_safe(
value: u32,
mast_forest: &MastForest,
) -> Result<Self, DeserializationError> {
Self::from_u32_with_node_count(value, mast_forest.nodes.len())
}
pub fn from_usize_safe(
node_id: usize,
mast_forest: &MastForest,
) -> Result<Self, DeserializationError> {
let node_id: u32 = node_id.try_into().map_err(|_| {
DeserializationError::InvalidValue(format!(
"node id '{node_id}' does not fit into a u32"
))
})?;
MastNodeId::from_u32_safe(node_id, mast_forest)
}
pub(crate) fn new_unchecked(value: u32) -> Self {
Self(value)
}
pub(super) fn from_u32_with_node_count(
id: u32,
node_count: usize,
) -> Result<Self, DeserializationError> {
if (id as usize) < node_count {
Ok(Self(id))
} else {
Err(DeserializationError::InvalidValue(format!(
"Invalid deserialized MAST node ID '{id}', but {node_count} is the number of nodes in the forest",
)))
}
}
pub fn remap(&self, remapping: &Remapping) -> Self {
*remapping.get(self).unwrap_or(self)
}
}
impl From<u32> for MastNodeId {
fn from(value: u32) -> Self {
MastNodeId::new_unchecked(value)
}
}
impl Idx for MastNodeId {}
impl From<MastNodeId> for u32 {
fn from(value: MastNodeId) -> Self {
value.0
}
}
impl fmt::Display for MastNodeId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "MastNodeId({})", self.0)
}
}
#[cfg(any(test, feature = "arbitrary"))]
impl proptest::prelude::Arbitrary for MastNodeId {
type Parameters = ();
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
use proptest::prelude::*;
any::<u32>().prop_map(MastNodeId).boxed()
}
type Strategy = proptest::prelude::BoxedStrategy<Self>;
}
pub struct SubtreeIterator<'a> {
forest: &'a MastForest,
discovered: Vec<MastNodeId>,
unvisited: Vec<MastNodeId>,
}
impl<'a> SubtreeIterator<'a> {
pub fn new(root: &MastNodeId, forest: &'a MastForest) -> Self {
let discovered = vec![];
let unvisited = vec![*root];
SubtreeIterator { forest, discovered, unvisited }
}
}
impl Iterator for SubtreeIterator<'_> {
type Item = MastNodeId;
fn next(&mut self) -> Option<MastNodeId> {
while let Some(id) = self.unvisited.pop() {
let node = &self.forest[id];
if !node.has_children() {
return Some(id);
} else {
self.discovered.push(id);
node.append_children_to(&mut self.unvisited);
}
}
self.discovered.pop()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
pub struct DecoratorId(u32);
impl DecoratorId {
pub fn from_u32_safe(
value: u32,
mast_forest: &MastForest,
) -> Result<Self, DeserializationError> {
if (value as usize) < mast_forest.decorators.len() {
Ok(Self(value))
} else {
Err(DeserializationError::InvalidValue(format!(
"Invalid deserialized MAST decorator id '{}', but only {} decorators in the forest",
value,
mast_forest.decorators.len(),
)))
}
}
pub(crate) fn new_unchecked(value: u32) -> Self {
Self(value)
}
}
impl From<u32> for DecoratorId {
fn from(value: u32) -> Self {
DecoratorId::new_unchecked(value)
}
}
impl Idx for DecoratorId {}
impl From<DecoratorId> for u32 {
fn from(value: DecoratorId) -> Self {
value.0
}
}
impl fmt::Display for DecoratorId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "DecoratorId({})", self.0)
}
}
impl Serializable for DecoratorId {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.0.write_into(target)
}
}
pub fn error_code_from_msg(msg: impl AsRef<str>) -> Felt {
hash_string_to_word(msg.as_ref())[0]
}
#[derive(Debug, thiserror::Error, PartialEq)]
pub enum MastForestError {
#[error("MAST forest decorator count exceeds the maximum of {} decorators", u32::MAX)]
TooManyDecorators,
#[error("MAST forest node count exceeds the maximum of {} nodes", MastForest::MAX_NODES)]
TooManyNodes,
#[error("node id {0} is greater than or equal to forest length {1}")]
NodeIdOverflow(MastNodeId, usize),
#[error("decorator id {0} is greater than or equal to decorator count {1}")]
DecoratorIdOverflow(DecoratorId, usize),
#[error("basic block cannot be created from an empty list of operations")]
EmptyBasicBlock,
#[error(
"decorator root of child with node id {0} is missing but is required for fingerprint computation"
)]
ChildFingerprintMissing(MastNodeId),
#[error("advice map key {0} already exists when merging forests")]
AdviceMapKeyCollisionOnMerge(Word),
}