use alloc::{
collections::{BTreeMap, BTreeSet},
sync::Arc,
vec::Vec,
};
use core::ops::{Index, IndexMut};
use miden_core::{
Felt, Word,
advice::AdviceMap,
mast::{
AsmOpId, BasicBlockNode, BasicBlockNodeBuilder, CallNodeBuilder, DecoratorFingerprint,
DecoratorId, DynNodeBuilder, ExternalNodeBuilder, JoinNodeBuilder, MastForest,
MastForestContributor, MastForestError, MastNode, MastNodeBuilder, MastNodeExt,
MastNodeFingerprint, MastNodeId, Remapping, SubtreeIterator,
},
operations::{AssemblyOp, Decorator, DecoratorList, Operation},
};
use super::{GlobalItemIndex, LinkerError, Procedure};
use crate::{
diagnostics::{IntoDiagnostic, Report, WrapErr},
report,
};
const PROCEDURE_INLINING_THRESHOLD: usize = 32;
#[derive(Clone, Debug, Default)]
pub struct MastForestBuilder {
pub(crate) mast_forest: MastForest,
procedures: BTreeMap<GlobalItemIndex, Procedure>,
proc_gid_by_mast_root: BTreeMap<Word, GlobalItemIndex>,
node_id_by_fingerprint: BTreeMap<MastNodeFingerprint, MastNodeId>,
hash_by_node_id: BTreeMap<MastNodeId, MastNodeFingerprint>,
decorator_id_by_fingerprint: BTreeMap<DecoratorFingerprint, DecoratorId>,
merged_basic_block_ids: BTreeSet<MastNodeId>,
statically_linked_mast: Arc<MastForest>,
statically_linked_mast_remapping: Remapping,
statically_linked_decorator_remapping: BTreeMap<DecoratorId, DecoratorId>,
pending_asm_op_mappings: Vec<(MastNodeId, Vec<(usize, AsmOpId)>)>,
}
impl MastForestBuilder {
pub fn new<'a>(
static_libraries: impl IntoIterator<Item = &'a MastForest>,
) -> Result<Self, Report> {
let forests = static_libraries.into_iter();
let (statically_linked_mast, _remapping) = MastForest::merge(forests).into_diagnostic()?;
let mut mast_forest = MastForest::default();
*mast_forest.advice_map_mut() = statically_linked_mast.advice_map().clone();
Ok(MastForestBuilder {
mast_forest,
statically_linked_mast: Arc::new(statically_linked_mast),
..Self::default()
})
}
pub fn mast_forest(&self) -> &MastForest {
&self.mast_forest
}
pub fn build(mut self) -> (MastForest, BTreeMap<MastNodeId, MastNodeId>) {
let deduped_mappings =
deduplicate_asm_op_mappings(core::mem::take(&mut self.pending_asm_op_mappings));
for (node_id, asm_op_mappings) in deduped_mappings {
let (num_operations, adjusted_mappings) =
compute_operations_and_adjust_mappings(&self.mast_forest[node_id], asm_op_mappings);
self.mast_forest
.debug_info_mut()
.register_asm_ops(node_id, num_operations, adjusted_mappings)
.expect("failed to register AssemblyOps - internal ordering error");
}
let nodes_to_remove = get_nodes_to_remove(self.merged_basic_block_ids, &self.mast_forest);
let id_remappings = self.mast_forest.remove_nodes(&nodes_to_remove);
(self.mast_forest, id_remappings)
}
}
fn compute_operations_and_adjust_mappings(
node: &MastNode,
asm_op_mappings: Vec<(usize, AsmOpId)>,
) -> (usize, Vec<(usize, AsmOpId)>) {
match node {
MastNode::Block(block) => (
block.num_operations() as usize,
BasicBlockNode::adjust_asm_op_indices(asm_op_mappings, block.op_batches()),
),
_ => {
let num_ops = asm_op_mappings.iter().map(|(idx, _)| idx + 1).max().unwrap_or(0);
(num_ops, asm_op_mappings)
},
}
}
fn deduplicate_asm_op_mappings(
mut mappings: Vec<(MastNodeId, Vec<(usize, AsmOpId)>)>,
) -> Vec<(MastNodeId, Vec<(usize, AsmOpId)>)> {
mappings.sort_by_key(|(node_id, _)| *node_id);
let mut seen_node_ids = BTreeSet::new();
mappings
.into_iter()
.filter(|(node_id, _)| seen_node_ids.insert(*node_id))
.collect()
}
fn get_nodes_to_remove(
merged_node_ids: BTreeSet<MastNodeId>,
mast_forest: &MastForest,
) -> BTreeSet<MastNodeId> {
let mut nodes_to_remove: BTreeSet<MastNodeId> = merged_node_ids
.iter()
.filter(|&&mast_node_id| !mast_forest.is_procedure_root(mast_node_id))
.copied()
.collect();
for node in mast_forest.nodes() {
node.for_each_child(|child_id| {
if nodes_to_remove.contains(&child_id) {
nodes_to_remove.remove(&child_id);
}
});
}
nodes_to_remove
}
impl MastForestBuilder {
#[inline(always)]
pub fn get_procedure(&self, gid: GlobalItemIndex) -> Option<&Procedure> {
self.procedures.get(&gid)
}
#[inline(always)]
pub fn find_procedure_by_mast_root(&self, mast_root: &Word) -> Option<&Procedure> {
self.proc_gid_by_mast_root
.get(mast_root)
.and_then(|gid| self.get_procedure(*gid))
}
pub fn get_mast_node(&self, id: MastNodeId) -> Option<&MastNode> {
self.mast_forest.get_node_by_id(id)
}
}
impl MastForestBuilder {
pub fn insert_procedure(
&mut self,
gid: GlobalItemIndex,
procedure: Procedure,
) -> Result<(), Report> {
if self.procedures.contains_key(&gid) {
return Ok(());
}
if let Some(cached) = self.find_procedure_by_mast_root(&procedure.mast_root()) {
let cached_locals = cached.num_locals();
let procedure_locals = procedure.num_locals();
let mismatched_locals = cached_locals != procedure_locals;
let is_valid =
!mismatched_locals || core::cmp::min(cached_locals, procedure_locals) == 0;
if !is_valid {
let first = cached.path();
let second = procedure.path();
return Err(report!(
"two procedures found with same mast root, but conflicting definitions ('{}' and '{}')",
first,
second
));
}
}
self.mast_forest.make_root(procedure.body_node_id());
self.proc_gid_by_mast_root.insert(procedure.mast_root(), gid);
self.procedures.insert(gid, procedure);
Ok(())
}
}
impl MastForestBuilder {
pub fn join_nodes(
&mut self,
node_ids: Vec<MastNodeId>,
asm_op: Option<AssemblyOp>,
) -> Result<MastNodeId, Report> {
debug_assert!(!node_ids.is_empty(), "cannot combine empty MAST node id list");
let mut node_ids = self.merge_contiguous_basic_blocks(node_ids)?;
while node_ids.len() > 1 {
let last_mast_node_id = if node_ids.len().is_multiple_of(2) {
None
} else {
node_ids.pop()
};
let mut source_node_ids = Vec::new();
core::mem::swap(&mut node_ids, &mut source_node_ids);
let mut source_mast_node_iter = source_node_ids.drain(0..);
while let (Some(left), Some(right)) =
(source_mast_node_iter.next(), source_mast_node_iter.next())
{
let join_mast_node_id = self.ensure_join(left, right, vec![], vec![])?;
if let Some(ref asm_op) = asm_op {
self.register_node_asm_op(join_mast_node_id, asm_op.clone())?;
}
node_ids.push(join_mast_node_id);
}
if let Some(mast_node_id) = last_mast_node_id {
node_ids.push(mast_node_id);
}
}
Ok(node_ids.remove(0))
}
fn merge_contiguous_basic_blocks(
&mut self,
node_ids: Vec<MastNodeId>,
) -> Result<Vec<MastNodeId>, Report> {
let mut merged_node_ids = Vec::with_capacity(node_ids.len());
let mut contiguous_basic_block_ids: Vec<MastNodeId> = Vec::new();
for mast_node_id in node_ids {
if self.mast_forest[mast_node_id].is_basic_block() {
contiguous_basic_block_ids.push(mast_node_id);
} else {
merged_node_ids.extend(self.merge_basic_blocks(&contiguous_basic_block_ids)?);
contiguous_basic_block_ids.clear();
merged_node_ids.push(mast_node_id);
}
}
merged_node_ids.extend(self.merge_basic_blocks(&contiguous_basic_block_ids)?);
Ok(merged_node_ids)
}
fn merge_basic_blocks(
&mut self,
contiguous_basic_block_ids: &[MastNodeId],
) -> Result<Vec<MastNodeId>, Report> {
if contiguous_basic_block_ids.is_empty() {
return Ok(Vec::new());
}
if contiguous_basic_block_ids.len() == 1 {
return Ok(contiguous_basic_block_ids.to_vec());
}
let mut operations: Vec<Operation> = Vec::new();
let mut decorators = DecoratorList::new();
let mut merged_asm_ops: Vec<(usize, AsmOpId)> = Vec::new();
let mut merged_basic_blocks: Vec<MastNodeId> = Vec::new();
for &basic_block_id in contiguous_basic_block_ids {
if should_merge(
self.mast_forest.is_procedure_root(basic_block_id),
self.mast_forest[basic_block_id]
.get_basic_block()
.expect("merge_basic_blocks: expected BasicBlockNode")
.num_op_batches(),
) {
let (block_decorators, block_ops) = {
let basic_block_node =
self.mast_forest[basic_block_id].get_basic_block().unwrap();
let block_decorators: Vec<_> =
basic_block_node.raw_decorator_iter(&self.mast_forest).collect();
let block_ops: Vec<Operation> = basic_block_node
.op_batches()
.iter()
.flat_map(|b| b.raw_ops().copied())
.collect();
(block_decorators, block_ops)
};
let ops_offset = operations.len();
self.transfer_asm_ops_for_merge(basic_block_id, ops_offset, &mut merged_asm_ops);
for (op_idx, decorator) in block_decorators {
decorators.push((op_idx + ops_offset, decorator));
}
operations.extend(block_ops);
} else {
if !operations.is_empty() {
let block_ops = core::mem::take(&mut operations);
let block_decorators = core::mem::take(&mut decorators);
let block_asm_ops = core::mem::take(&mut merged_asm_ops);
let merged_basic_block_id = self.ensure_block_with_asm_op_ids(
block_ops,
block_decorators,
block_asm_ops,
vec![],
vec![],
)?;
merged_basic_blocks.push(merged_basic_block_id);
}
merged_basic_blocks.push(basic_block_id);
}
}
self.merged_basic_block_ids.extend(contiguous_basic_block_ids.iter());
if !operations.is_empty() || !decorators.is_empty() {
let merged_basic_block = self.ensure_block_with_asm_op_ids(
operations,
decorators,
merged_asm_ops,
vec![],
vec![],
)?;
merged_basic_blocks.push(merged_basic_block);
}
Ok(merged_basic_blocks)
}
fn transfer_asm_ops_for_merge(
&mut self,
source_block_id: MastNodeId,
ops_offset: usize,
merged_asm_ops: &mut Vec<(usize, AsmOpId)>,
) {
let (matched, rest): (Vec<_>, Vec<_>) = core::mem::take(&mut self.pending_asm_op_mappings)
.into_iter()
.partition(|(node_id, _)| *node_id == source_block_id);
self.pending_asm_op_mappings = rest;
for (_, asm_ops) in matched {
merged_asm_ops.extend(
asm_ops.into_iter().map(|(op_idx, asm_op_id)| (op_idx + ops_offset, asm_op_id)),
);
}
}
fn ensure_block_with_asm_op_ids(
&mut self,
operations: Vec<Operation>,
decorators: DecoratorList,
asm_op_ids: Vec<(usize, AsmOpId)>,
before_enter: Vec<DecoratorId>,
after_exit: Vec<DecoratorId>,
) -> Result<MastNodeId, Report> {
let block = BasicBlockNodeBuilder::new(operations, decorators)
.with_before_enter(before_enter)
.with_after_exit(after_exit);
let (node_id, is_new) = self.ensure_node_exists(block)?;
if is_new && !asm_op_ids.is_empty() {
self.pending_asm_op_mappings.push((node_id, asm_op_ids));
}
Ok(node_id)
}
}
impl MastForestBuilder {
pub fn ensure_decorator(&mut self, decorator: Decorator) -> Result<DecoratorId, Report> {
let decorator_hash = decorator.fingerprint();
if let Some(decorator_id) = self.decorator_id_by_fingerprint.get(&decorator_hash) {
Ok(*decorator_id)
} else {
let new_decorator_id = self
.mast_forest
.add_decorator(decorator)
.into_diagnostic()
.wrap_err("assembler failed to add new decorator")?;
self.decorator_id_by_fingerprint.insert(decorator_hash, new_decorator_id);
Ok(new_decorator_id)
}
}
pub fn add_debug_var(
&mut self,
debug_var: miden_core::operations::DebugVarInfo,
) -> Result<miden_core::mast::DebugVarId, Report> {
self.mast_forest
.add_debug_var(debug_var)
.into_diagnostic()
.wrap_err("assembler failed to add debug variable")
}
pub(crate) fn ensure_node(
&mut self,
builder: impl MastForestContributor,
) -> Result<MastNodeId, Report> {
let (node_id, _is_new) = self.ensure_node_exists(builder)?;
Ok(node_id)
}
fn ensure_node_exists(
&mut self,
builder: impl MastForestContributor,
) -> Result<(MastNodeId, bool), Report> {
let node_fingerprint = builder
.fingerprint_for_node(&self.mast_forest, &self.hash_by_node_id)
.expect("hash_by_node_id should contain the fingerprints of all children of `node`");
if let Some(node_id) = self.node_id_by_fingerprint.get(&node_fingerprint) {
Ok((*node_id, false))
} else {
let new_node_id = builder
.add_to_forest(&mut self.mast_forest)
.into_diagnostic()
.wrap_err("assembler failed to add new node")?;
self.node_id_by_fingerprint.insert(node_fingerprint, new_node_id);
self.hash_by_node_id.insert(new_node_id, node_fingerprint);
Ok((new_node_id, true))
}
}
pub fn ensure_block(
&mut self,
operations: Vec<Operation>,
decorators: DecoratorList,
asm_ops: Vec<(usize, AssemblyOp)>,
before_enter: Vec<DecoratorId>,
after_exit: Vec<DecoratorId>,
) -> Result<MastNodeId, Report> {
let block = BasicBlockNodeBuilder::new(operations, decorators)
.with_before_enter(before_enter)
.with_after_exit(after_exit);
let (node_id, is_new) = self.ensure_node_exists(block)?;
if is_new && !asm_ops.is_empty() {
let mut asm_op_mappings = Vec::with_capacity(asm_ops.len());
for (op_idx, asm_op) in asm_ops {
let asm_op_id = self
.mast_forest
.debug_info_mut()
.add_asm_op(asm_op)
.into_diagnostic()
.wrap_err("failed to add AssemblyOp")?;
asm_op_mappings.push((op_idx, asm_op_id));
}
self.pending_asm_op_mappings.push((node_id, asm_op_mappings));
}
Ok(node_id)
}
pub fn ensure_join(
&mut self,
left_child: MastNodeId,
right_child: MastNodeId,
before_enter: Vec<DecoratorId>,
after_exit: Vec<DecoratorId>,
) -> Result<MastNodeId, Report> {
let join = JoinNodeBuilder::new([left_child, right_child])
.with_before_enter(before_enter)
.with_after_exit(after_exit);
self.ensure_node(join)
}
pub fn ensure_call(
&mut self,
callee: MastNodeId,
before_enter: Vec<DecoratorId>,
after_exit: Vec<DecoratorId>,
) -> Result<MastNodeId, Report> {
let call = CallNodeBuilder::new(callee)
.with_before_enter(before_enter)
.with_after_exit(after_exit);
self.ensure_node(call)
}
#[cfg(all(test, feature = "std"))]
pub fn ensure_split(
&mut self,
left_child: MastNodeId,
right_child: MastNodeId,
before_enter: Vec<DecoratorId>,
after_exit: Vec<DecoratorId>,
) -> Result<MastNodeId, Report> {
use miden_core::mast::SplitNodeBuilder;
let split = SplitNodeBuilder::new([left_child, right_child])
.with_before_enter(before_enter)
.with_after_exit(after_exit);
self.ensure_node(split)
}
#[cfg(all(test, feature = "std"))]
pub fn ensure_loop(
&mut self,
body: MastNodeId,
before_enter: Vec<DecoratorId>,
after_exit: Vec<DecoratorId>,
) -> Result<MastNodeId, Report> {
use miden_core::mast::LoopNodeBuilder;
let loop_node = LoopNodeBuilder::new(body)
.with_before_enter(before_enter)
.with_after_exit(after_exit);
self.ensure_node(loop_node)
}
pub fn ensure_syscall(
&mut self,
callee: MastNodeId,
before_enter: Vec<DecoratorId>,
after_exit: Vec<DecoratorId>,
) -> Result<MastNodeId, Report> {
let syscall = CallNodeBuilder::new_syscall(callee)
.with_after_exit(after_exit)
.with_before_enter(before_enter);
self.ensure_node(syscall)
}
pub fn ensure_dyn(
&mut self,
before_enter: Vec<DecoratorId>,
after_exit: Vec<DecoratorId>,
) -> Result<MastNodeId, Report> {
self.ensure_node(
DynNodeBuilder::new_dyn()
.with_after_exit(after_exit)
.with_before_enter(before_enter),
)
}
pub fn ensure_dyncall(
&mut self,
before_enter: Vec<DecoratorId>,
after_exit: Vec<DecoratorId>,
) -> Result<MastNodeId, Report> {
self.ensure_node(
DynNodeBuilder::new_dyncall()
.with_after_exit(after_exit)
.with_before_enter(before_enter),
)
}
fn collect_decorators_from_subtree(&mut self, root_id: &MastNodeId) -> Result<(), Report> {
self.statically_linked_decorator_remapping.clear();
for node_id in SubtreeIterator::new(root_id, &self.statically_linked_mast.clone()) {
let decorator_ids: Vec<DecoratorId> = {
let mut ids = Vec::new();
ids.extend(self.statically_linked_mast.before_enter_decorators(node_id));
ids.extend(self.statically_linked_mast.after_exit_decorators(node_id));
if let MastNode::Block(block_node) = &self.statically_linked_mast[node_id] {
for (_idx, decorator_id) in
block_node.indexed_decorator_iter(&self.statically_linked_mast)
{
ids.push(decorator_id);
}
}
ids
};
for old_decorator_id in decorator_ids {
if !self.statically_linked_decorator_remapping.contains_key(&old_decorator_id) {
let decorator = self.statically_linked_mast[old_decorator_id].clone();
let new_decorator_id = self.ensure_decorator(decorator)?;
self.statically_linked_decorator_remapping
.insert(old_decorator_id, new_decorator_id);
}
}
}
Ok(())
}
fn build_with_remapped_ids(
&self,
node_id: MastNodeId,
node: MastNode,
) -> Result<MastNodeBuilder, Report> {
miden_core::mast::build_node_with_remapped_ids(
node_id,
node,
&self.statically_linked_mast,
&self.statically_linked_mast_remapping,
&self.statically_linked_decorator_remapping,
)
.into_diagnostic()
}
pub fn ensure_external_link(&mut self, mast_root: Word) -> Result<MastNodeId, Report> {
if let Some(root_id) = self.statically_linked_mast.find_procedure_root(mast_root) {
self.collect_decorators_from_subtree(&root_id)?;
for old_id in SubtreeIterator::new(&root_id, &self.statically_linked_mast.clone()) {
let node = self.statically_linked_mast[old_id].clone();
let builder = self.build_with_remapped_ids(old_id, node)?;
let new_id = self.ensure_node(builder)?;
self.statically_linked_mast_remapping.insert(old_id, new_id);
}
Ok(root_id.remap(&self.statically_linked_mast_remapping))
} else {
self.ensure_node(ExternalNodeBuilder::new(mast_root))
}
}
pub fn append_before_enter(
&mut self,
node_id: MastNodeId,
decorator_ids: Vec<DecoratorId>,
) -> Result<(), MastForestError> {
let mut decorated_builder = self.mast_forest[node_id].clone().to_builder(&self.mast_forest);
decorated_builder.append_before_enter(decorator_ids);
let new_node_fingerprint =
decorated_builder.fingerprint_for_node(&self.mast_forest, &self.hash_by_node_id)?;
self.mast_forest[node_id] = decorated_builder.build(&self.mast_forest)?;
self.hash_by_node_id.insert(node_id, new_node_fingerprint);
self.node_id_by_fingerprint.insert(new_node_fingerprint, node_id);
Ok(())
}
pub fn append_after_exit(
&mut self,
node_id: MastNodeId,
decorator_ids: Vec<DecoratorId>,
) -> Result<(), MastForestError> {
let mut decorated_builder = self.mast_forest[node_id].clone().to_builder(&self.mast_forest);
decorated_builder.append_after_exit(decorator_ids);
let new_node_fingerprint =
decorated_builder.fingerprint_for_node(&self.mast_forest, &self.hash_by_node_id)?;
self.mast_forest[node_id] = decorated_builder.build(&self.mast_forest)?;
self.hash_by_node_id.insert(node_id, new_node_fingerprint);
self.node_id_by_fingerprint.insert(new_node_fingerprint, node_id);
Ok(())
}
}
impl MastForestBuilder {
pub fn register_error(&mut self, msg: Arc<str>) -> Felt {
self.mast_forest.register_error(msg)
}
pub fn register_node_asm_op(
&mut self,
node_id: MastNodeId,
asm_op: AssemblyOp,
) -> Result<(), Report> {
let asm_op_id = self
.mast_forest
.debug_info_mut()
.add_asm_op(asm_op)
.into_diagnostic()
.wrap_err("failed to add AssemblyOp for control flow node")?;
self.pending_asm_op_mappings.push((node_id, vec![(0, asm_op_id)]));
Ok(())
}
pub fn register_debug_vars_for_node(
&mut self,
node_id: MastNodeId,
debug_vars: Vec<(usize, miden_core::mast::DebugVarId)>,
) -> Result<(), Report> {
self.mast_forest
.debug_info_mut()
.register_op_indexed_debug_vars(node_id, debug_vars)
.into_diagnostic()
.wrap_err("failed to register debug variables for node")
}
}
impl Index<MastNodeId> for MastForestBuilder {
type Output = MastNode;
#[inline(always)]
fn index(&self, node_id: MastNodeId) -> &Self::Output {
&self.mast_forest[node_id]
}
}
impl Index<DecoratorId> for MastForestBuilder {
type Output = Decorator;
#[inline(always)]
fn index(&self, decorator_id: DecoratorId) -> &Self::Output {
&self.mast_forest[decorator_id]
}
}
impl IndexMut<DecoratorId> for MastForestBuilder {
#[inline(always)]
fn index_mut(&mut self, decorator_id: DecoratorId) -> &mut Self::Output {
&mut self.mast_forest[decorator_id]
}
}
impl MastForestBuilder {
pub fn merge_advice_map(&mut self, other: &AdviceMap) -> Result<(), Report> {
self.mast_forest
.advice_map_mut()
.merge(other)
.map_err(|((key, prev_values), new_values)| LinkerError::AdviceMapKeyAlreadyPresent {
key,
prev_values: prev_values.to_vec(),
new_values: new_values.to_vec(),
})
.into_diagnostic()
}
}
fn should_merge(is_procedure: bool, num_op_batches: usize) -> bool {
if is_procedure {
num_op_batches < PROCEDURE_INLINING_THRESHOLD
} else {
true
}
}
#[cfg(test)]
mod tests {
use miden_core::operations::Operation;
use super::*;
#[test]
fn test_merge_basic_blocks_preserves_decorator_links_with_padding() {
let mut builder = MastForestBuilder::new(&[]).unwrap();
let block1_ops = vec![
Operation::Push(Felt::new(1)),
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Push(Felt::new(2)),
Operation::Push(Felt::new(3)),
]; let block1_raw_ops_len = block1_ops.len();
let block1_decorator1 = builder.ensure_decorator(Decorator::Trace(1)).unwrap();
let block1_decorator2 = builder.ensure_decorator(Decorator::Trace(2)).unwrap();
let block1_decorator3 = builder.ensure_decorator(Decorator::Trace(3)).unwrap();
let block1_decorators = vec![
(0, block1_decorator1), (7, block1_decorator2), (8, block1_decorator3), ];
let block1_id = builder
.ensure_block(block1_ops.clone(), block1_decorators, vec![], vec![], vec![])
.unwrap();
let block1 = builder.mast_forest[block1_id].get_basic_block().unwrap().clone();
assert!(block1.operations().count() > block1_raw_ops_len); assert_eq!(block1.raw_operations().count(), block1_raw_ops_len);
let block2_ops = vec![Operation::Push(Felt::new(4)), Operation::Mul];
let block2_decorator1 = builder.ensure_decorator(Decorator::Trace(4)).unwrap();
let block2_decorator2 = builder.ensure_decorator(Decorator::Trace(5)).unwrap();
let block2_decorators = vec![
(0, block2_decorator1), (1, block2_decorator2), ];
let block2_id = builder
.ensure_block(block2_ops.clone(), block2_decorators, vec![], vec![], vec![])
.unwrap();
let merged_blocks = builder.merge_basic_blocks(&[block1_id, block2_id]).unwrap();
assert_eq!(merged_blocks.len(), 1);
let merged_block_id = merged_blocks[0];
let merged_block = builder.mast_forest[merged_block_id].get_basic_block().unwrap();
let decorators = merged_block.indexed_decorator_iter(&builder.mast_forest);
let decorator_count = merged_block.indexed_decorator_iter(&builder.mast_forest).count();
assert_eq!(decorator_count, 5);
let mut found_traces = std::collections::HashSet::new();
for (op_idx, decorator_id) in decorators {
let decorator = &builder.mast_forest[decorator_id];
match decorator {
Decorator::Trace(trace_value) => {
found_traces.insert(*trace_value);
let merged_ops: Vec<Operation> = merged_block.operations().cloned().collect();
if op_idx < merged_ops.len() {
match op_idx {
0 => {
match &merged_ops[op_idx] {
Operation::Push(x) if *x == Felt::new(1) => {
assert_eq!(
*trace_value, 1,
"Decorator for Push(1) should have trace value 1"
);
},
_ => panic!("Expected Push operation at index 0"),
}
},
7 => {
match &merged_ops[op_idx] {
Operation::Push(x) if *x == Felt::new(2) => {
assert_eq!(
*trace_value, 2,
"Decorator for Push(2) should have trace value 2"
);
},
_ => panic!("Expected Push operation at index 7"),
}
},
9 => {
match &merged_ops[op_idx] {
Operation::Push(x) if *x == Felt::new(3) => {
assert_eq!(
*trace_value, 3,
"Decorator for Push(3) should have trace value 3"
);
},
_ => panic!("Expected Push operation at index 9"),
}
},
10 => {
match &merged_ops[op_idx] {
Operation::Push(x) if *x == Felt::new(4) => {
assert_eq!(
*trace_value, 4,
"Decorator for Push(4) should have trace value 4"
);
},
_ => panic!("Expected Push operation at index 10"),
}
},
11 => {
match &merged_ops[op_idx] {
Operation::Mul => {
assert_eq!(
*trace_value, 5,
"Decorator for Mul should have trace value 5"
);
},
_ => panic!("Expected Mul operation at index 11"),
}
},
_ => panic!(
"Unexpected operation index {} for {:?} pointing at {:?}",
op_idx, trace_value, merged_ops[op_idx]
),
}
} else {
panic!("Operation index {} is out of bounds", op_idx);
}
},
_ => panic!("Expected Trace decorator"),
}
}
let expected_traces = [1, 2, 3, 4, 5];
for expected_trace in expected_traces {
assert!(
found_traces.contains(&expected_trace),
"Missing trace value: {}",
expected_trace
);
}
assert_eq!(found_traces.len(), 5, "Should have found exactly 5 trace values");
}
#[test]
fn test_merge_basic_blocks_keeps_non_mergeable_block_standalone() {
let mut builder = MastForestBuilder::new(&[]).unwrap();
let num_ops = PROCEDURE_INLINING_THRESHOLD * 1024;
let large_ops = vec![Operation::Add; num_ops];
let large_block_id =
builder.ensure_block(large_ops, Vec::new(), vec![], vec![], vec![]).unwrap();
builder.mast_forest.make_root(large_block_id);
let small_block_id = builder
.ensure_block(vec![Operation::Add], Vec::new(), vec![], vec![], vec![])
.unwrap();
let merged_blocks = builder.merge_basic_blocks(&[large_block_id, small_block_id]).unwrap();
assert_eq!(merged_blocks.len(), 2);
assert_eq!(merged_blocks[0], large_block_id);
assert_eq!(merged_blocks[1], small_block_id);
}
}