1use alloc::{
2 collections::{BTreeMap, BTreeSet},
3 sync::Arc,
4 vec::Vec,
5};
6use core::{
7 fmt,
8 ops::{Index, IndexMut},
9};
10
11pub use miden_utils_indexing::{IndexVec, IndexedVecError};
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14
15mod node;
16#[cfg(any(test, feature = "arbitrary"))]
17pub use node::arbitrary;
18pub use node::{
19 BasicBlockNode, CallNode, DecoratedOpLink, DecoratorOpLinkIterator, DynNode, ExternalNode,
20 JoinNode, LoopNode, MastNode, MastNodeErrorContext, MastNodeExt, OP_BATCH_SIZE, OP_GROUP_SIZE,
21 OpBatch, OperationOrDecorator, SplitNode,
22};
23
24use crate::{
25 AdviceMap, Decorator, DecoratorList, Felt, Idx, LexicographicWord, Operation, Word,
26 crypto::hash::Hasher,
27 utils::{ByteWriter, DeserializationError, Serializable, hash_string_to_word},
28};
29
30mod serialization;
31
32mod merger;
33pub(crate) use merger::MastForestMerger;
34pub use merger::MastForestRootMap;
35
36mod multi_forest_node_iterator;
37pub(crate) use multi_forest_node_iterator::*;
38
39mod node_fingerprint;
40pub use node_fingerprint::{DecoratorFingerprint, MastNodeFingerprint};
41
42#[cfg(test)]
43mod tests;
44
45#[derive(Clone, Debug, Default, PartialEq, Eq)]
53#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
54#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
55pub struct MastForest {
56 nodes: IndexVec<MastNodeId, MastNode>,
58
59 roots: Vec<MastNodeId>,
61
62 decorators: IndexVec<DecoratorId, Decorator>,
64
65 advice_map: AdviceMap,
67
68 error_codes: BTreeMap<u64, Arc<str>>,
72}
73
74impl MastForest {
77 pub fn new() -> Self {
79 Self {
80 nodes: IndexVec::new(),
81 roots: Vec::new(),
82 decorators: IndexVec::new(),
83 advice_map: AdviceMap::default(),
84 error_codes: BTreeMap::new(),
85 }
86 }
87}
88
89impl MastForest {
92 const MAX_NODES: usize = (1 << 30) - 1;
94
95 pub fn add_decorator(&mut self, decorator: Decorator) -> Result<DecoratorId, MastForestError> {
97 self.decorators.push(decorator).map_err(|_| MastForestError::TooManyDecorators)
98 }
99
100 pub fn add_node(&mut self, node: impl Into<MastNode>) -> Result<MastNodeId, MastForestError> {
104 self.nodes.push(node.into()).map_err(|_| MastForestError::TooManyNodes)
105 }
106
107 pub fn add_block(
109 &mut self,
110 operations: Vec<Operation>,
111 decorators: DecoratorList,
112 ) -> Result<MastNodeId, MastForestError> {
113 let block = BasicBlockNode::new(operations, decorators)?;
114 self.add_node(block)
115 }
116
117 pub fn add_join(
119 &mut self,
120 left_child: MastNodeId,
121 right_child: MastNodeId,
122 ) -> Result<MastNodeId, MastForestError> {
123 let join = JoinNode::new([left_child, right_child], self)?;
124 self.add_node(join)
125 }
126
127 pub fn add_split(
129 &mut self,
130 if_branch: MastNodeId,
131 else_branch: MastNodeId,
132 ) -> Result<MastNodeId, MastForestError> {
133 let split = SplitNode::new([if_branch, else_branch], self)?;
134 self.add_node(split)
135 }
136
137 pub fn add_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, MastForestError> {
139 let loop_node = LoopNode::new(body, self)?;
140 self.add_node(loop_node)
141 }
142
143 pub fn add_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
145 let call = CallNode::new(callee, self)?;
146 self.add_node(call)
147 }
148
149 pub fn add_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
151 let syscall = CallNode::new_syscall(callee, self)?;
152 self.add_node(syscall)
153 }
154
155 pub fn add_dyn(&mut self) -> Result<MastNodeId, MastForestError> {
157 self.add_node(DynNode::new_dyn())
158 }
159
160 pub fn add_dyncall(&mut self) -> Result<MastNodeId, MastForestError> {
162 self.add_node(DynNode::new_dyncall())
163 }
164
165 pub fn add_external(&mut self, mast_root: Word) -> Result<MastNodeId, MastForestError> {
167 self.add_node(ExternalNode::new(mast_root))
168 }
169
170 pub fn make_root(&mut self, new_root_id: MastNodeId) {
178 assert!(new_root_id.to_usize() < self.nodes.len());
179
180 if !self.roots.contains(&new_root_id) {
181 self.roots.push(new_root_id);
182 }
183 }
184
185 pub fn remove_nodes(
193 &mut self,
194 nodes_to_remove: &BTreeSet<MastNodeId>,
195 ) -> BTreeMap<MastNodeId, MastNodeId> {
196 if nodes_to_remove.is_empty() {
197 return BTreeMap::new();
198 }
199
200 let old_nodes = core::mem::replace(&mut self.nodes, IndexVec::new());
201 let old_root_ids = core::mem::take(&mut self.roots);
202 let (retained_nodes, id_remappings) = remove_nodes(old_nodes.into_inner(), nodes_to_remove);
203
204 self.remap_and_add_nodes(retained_nodes, &id_remappings);
205 self.remap_and_add_roots(old_root_ids, &id_remappings);
206 id_remappings
207 }
208
209 pub fn append_before_enter(&mut self, node_id: MastNodeId, decorator_ids: &[DecoratorId]) {
210 self[node_id].append_before_enter(decorator_ids)
211 }
212
213 pub fn append_after_exit(&mut self, node_id: MastNodeId, decorator_ids: &[DecoratorId]) {
214 self[node_id].append_after_exit(decorator_ids)
215 }
216
217 pub fn strip_decorators(&mut self) {
219 for node in self.nodes.iter_mut() {
220 node.remove_decorators();
221 }
222 self.decorators = IndexVec::new();
223 }
224
225 pub fn merge<'forest>(
275 forests: impl IntoIterator<Item = &'forest MastForest>,
276 ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
277 MastForestMerger::merge(forests)
278 }
279
280 #[cfg(test)]
285 pub fn add_block_with_raw_decorators(
286 &mut self,
287 operations: Vec<Operation>,
288 decorators: Vec<(usize, Decorator)>,
289 ) -> Result<MastNodeId, MastForestError> {
290 let block = BasicBlockNode::new_with_raw_decorators(operations, decorators, self)?;
291 self.add_node(block)
292 }
293}
294
295impl MastForest {
297 fn remap_and_add_nodes(
303 &mut self,
304 nodes_to_add: Vec<MastNode>,
305 id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
306 ) {
307 assert!(self.nodes.is_empty());
308
309 for live_node in nodes_to_add {
312 match &live_node {
313 MastNode::Join(join_node) => {
314 let first_child =
315 id_remappings.get(&join_node.first()).copied().unwrap_or(join_node.first());
316 let second_child = id_remappings
317 .get(&join_node.second())
318 .copied()
319 .unwrap_or(join_node.second());
320
321 self.add_join(first_child, second_child).unwrap();
322 },
323 MastNode::Split(split_node) => {
324 let on_true_child = id_remappings
325 .get(&split_node.on_true())
326 .copied()
327 .unwrap_or(split_node.on_true());
328 let on_false_child = id_remappings
329 .get(&split_node.on_false())
330 .copied()
331 .unwrap_or(split_node.on_false());
332
333 self.add_split(on_true_child, on_false_child).unwrap();
334 },
335 MastNode::Loop(loop_node) => {
336 let body_id =
337 id_remappings.get(&loop_node.body()).copied().unwrap_or(loop_node.body());
338
339 self.add_loop(body_id).unwrap();
340 },
341 MastNode::Call(call_node) => {
342 let callee_id = id_remappings
343 .get(&call_node.callee())
344 .copied()
345 .unwrap_or(call_node.callee());
346
347 if call_node.is_syscall() {
348 self.add_syscall(callee_id).unwrap();
349 } else {
350 self.add_call(callee_id).unwrap();
351 }
352 },
353 MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => {
354 self.add_node(live_node).unwrap();
355 },
356 }
357 }
358 }
359
360 fn remap_and_add_roots(
365 &mut self,
366 old_root_ids: Vec<MastNodeId>,
367 id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
368 ) {
369 assert!(self.roots.is_empty());
370
371 for old_root_id in old_root_ids {
372 let new_root_id = id_remappings.get(&old_root_id).copied().unwrap_or(old_root_id);
373 self.make_root(new_root_id);
374 }
375 }
376}
377
378fn remove_nodes(
381 mast_nodes: Vec<MastNode>,
382 nodes_to_remove: &BTreeSet<MastNodeId>,
383) -> (Vec<MastNode>, BTreeMap<MastNodeId, MastNodeId>) {
384 assert!(mast_nodes.len() < u32::MAX as usize);
386
387 let mut retained_nodes = Vec::with_capacity(mast_nodes.len());
388 let mut id_remappings = BTreeMap::new();
389
390 for (old_node_index, old_node) in mast_nodes.into_iter().enumerate() {
391 let old_node_id: MastNodeId = MastNodeId(old_node_index as u32);
392
393 if !nodes_to_remove.contains(&old_node_id) {
394 let new_node_id: MastNodeId = MastNodeId(retained_nodes.len() as u32);
395 id_remappings.insert(old_node_id, new_node_id);
396
397 retained_nodes.push(old_node);
398 }
399 }
400
401 (retained_nodes, id_remappings)
402}
403
404impl MastForest {
408 #[inline(always)]
413 pub fn get_decorator_by_id(&self, decorator_id: DecoratorId) -> Option<&Decorator> {
414 self.decorators.get(decorator_id)
415 }
416
417 #[inline(always)]
422 pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
423 self.nodes.get(node_id)
424 }
425
426 #[inline(always)]
428 pub fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId> {
429 self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied()
430 }
431
432 pub fn is_procedure_root(&self, node_id: MastNodeId) -> bool {
434 self.roots.contains(&node_id)
435 }
436
437 pub fn procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
439 self.roots.iter().map(|&root_id| self[root_id].digest())
440 }
441
442 pub fn local_procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
446 self.roots.iter().filter_map(|&root_id| {
447 let node = &self[root_id];
448 if node.is_external() { None } else { Some(node.digest()) }
449 })
450 }
451
452 pub fn procedure_roots(&self) -> &[MastNodeId] {
454 &self.roots
455 }
456
457 pub fn num_procedures(&self) -> u32 {
459 self.roots
460 .len()
461 .try_into()
462 .expect("MAST forest contains more than 2^32 procedures.")
463 }
464
465 pub fn compute_nodes_commitment<'a>(
470 &self,
471 node_ids: impl IntoIterator<Item = &'a MastNodeId>,
472 ) -> Word {
473 let mut digests: Vec<Word> = node_ids.into_iter().map(|&id| self[id].digest()).collect();
474 digests.sort_unstable_by_key(|word| LexicographicWord::from(*word));
475 miden_crypto::hash::rpo::Rpo256::merge_many(&digests)
476 }
477
478 pub fn num_nodes(&self) -> u32 {
480 self.nodes.len() as u32
481 }
482
483 pub fn nodes(&self) -> &[MastNode] {
485 self.nodes.as_slice()
486 }
487
488 pub fn decorators(&self) -> &[Decorator] {
489 self.decorators.as_slice()
490 }
491
492 pub fn advice_map(&self) -> &AdviceMap {
493 &self.advice_map
494 }
495
496 pub fn advice_map_mut(&mut self) -> &mut AdviceMap {
497 &mut self.advice_map
498 }
499
500 pub fn register_error(&mut self, msg: Arc<str>) -> Felt {
503 let code: Felt = error_code_from_msg(&msg);
504 self.error_codes.insert(code.as_int(), msg);
506 code
507 }
508
509 pub fn resolve_error_message(&self, code: Felt) -> Option<Arc<str>> {
511 let key = u64::from(code);
512 self.error_codes.get(&key).cloned()
513 }
514}
515
516impl Index<MastNodeId> for MastForest {
517 type Output = MastNode;
518
519 #[inline(always)]
520 fn index(&self, node_id: MastNodeId) -> &Self::Output {
521 &self.nodes[node_id]
522 }
523}
524
525impl IndexMut<MastNodeId> for MastForest {
526 #[inline(always)]
527 fn index_mut(&mut self, node_id: MastNodeId) -> &mut Self::Output {
528 &mut self.nodes[node_id]
529 }
530}
531
532impl Index<DecoratorId> for MastForest {
533 type Output = Decorator;
534
535 #[inline(always)]
536 fn index(&self, decorator_id: DecoratorId) -> &Self::Output {
537 &self.decorators[decorator_id]
538 }
539}
540
541impl IndexMut<DecoratorId> for MastForest {
542 #[inline(always)]
543 fn index_mut(&mut self, decorator_id: DecoratorId) -> &mut Self::Output {
544 &mut self.decorators[decorator_id]
545 }
546}
547
548#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
558#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
559#[cfg_attr(feature = "serde", serde(transparent))]
560#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
561pub struct MastNodeId(u32);
562
563pub type Remapping = BTreeMap<MastNodeId, MastNodeId>;
565
566impl MastNodeId {
567 pub fn from_u32_safe(
572 value: u32,
573 mast_forest: &MastForest,
574 ) -> Result<Self, DeserializationError> {
575 Self::from_u32_with_node_count(value, mast_forest.nodes.len())
576 }
577
578 pub fn from_usize_safe(
582 node_id: usize,
583 mast_forest: &MastForest,
584 ) -> Result<Self, DeserializationError> {
585 let node_id: u32 = node_id.try_into().map_err(|_| {
586 DeserializationError::InvalidValue(format!(
587 "node id '{node_id}' does not fit into a u32"
588 ))
589 })?;
590 MastNodeId::from_u32_safe(node_id, mast_forest)
591 }
592
593 pub(crate) fn new_unchecked(value: u32) -> Self {
595 Self(value)
596 }
597
598 pub(super) fn from_u32_with_node_count(
612 id: u32,
613 node_count: usize,
614 ) -> Result<Self, DeserializationError> {
615 if (id as usize) < node_count {
616 Ok(Self(id))
617 } else {
618 Err(DeserializationError::InvalidValue(format!(
619 "Invalid deserialized MAST node ID '{id}', but {node_count} is the number of nodes in the forest",
620 )))
621 }
622 }
623
624 pub fn remap(&self, remapping: &Remapping) -> Self {
626 *remapping.get(self).unwrap_or(self)
627 }
628}
629
630impl From<u32> for MastNodeId {
631 fn from(value: u32) -> Self {
632 MastNodeId::new_unchecked(value)
633 }
634}
635
636impl Idx for MastNodeId {}
637
638impl From<MastNodeId> for u32 {
639 fn from(value: MastNodeId) -> Self {
640 value.0
641 }
642}
643
644impl fmt::Display for MastNodeId {
645 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
646 write!(f, "MastNodeId({})", self.0)
647 }
648}
649
650#[cfg(any(test, feature = "arbitrary"))]
651impl proptest::prelude::Arbitrary for MastNodeId {
652 type Parameters = ();
653
654 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
655 use proptest::prelude::*;
656 any::<u32>().prop_map(MastNodeId).boxed()
657 }
658
659 type Strategy = proptest::prelude::BoxedStrategy<Self>;
660}
661
662pub struct SubtreeIterator<'a> {
667 forest: &'a MastForest,
668 discovered: Vec<MastNodeId>,
669 unvisited: Vec<MastNodeId>,
670}
671impl<'a> SubtreeIterator<'a> {
672 pub fn new(root: &MastNodeId, forest: &'a MastForest) -> Self {
673 let discovered = vec![];
674 let unvisited = vec![*root];
675 SubtreeIterator { forest, discovered, unvisited }
676 }
677}
678impl Iterator for SubtreeIterator<'_> {
679 type Item = MastNodeId;
680 fn next(&mut self) -> Option<MastNodeId> {
681 while let Some(id) = self.unvisited.pop() {
682 let node = &self.forest[id];
683 if !node.has_children() {
684 return Some(id);
685 } else {
686 self.discovered.push(id);
687 node.append_children_to(&mut self.unvisited);
688 }
689 }
690 self.discovered.pop()
691 }
692}
693
694#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
700#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
701#[cfg_attr(feature = "serde", serde(transparent))]
702pub struct DecoratorId(u32);
703
704impl DecoratorId {
705 pub fn from_u32_safe(
710 value: u32,
711 mast_forest: &MastForest,
712 ) -> Result<Self, DeserializationError> {
713 if (value as usize) < mast_forest.decorators.len() {
714 Ok(Self(value))
715 } else {
716 Err(DeserializationError::InvalidValue(format!(
717 "Invalid deserialized MAST decorator id '{}', but only {} decorators in the forest",
718 value,
719 mast_forest.decorators.len(),
720 )))
721 }
722 }
723
724 pub(crate) fn new_unchecked(value: u32) -> Self {
726 Self(value)
727 }
728}
729
730impl From<u32> for DecoratorId {
731 fn from(value: u32) -> Self {
732 DecoratorId::new_unchecked(value)
733 }
734}
735
736impl Idx for DecoratorId {}
737
738impl From<DecoratorId> for u32 {
739 fn from(value: DecoratorId) -> Self {
740 value.0
741 }
742}
743
744impl fmt::Display for DecoratorId {
745 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
746 write!(f, "DecoratorId({})", self.0)
747 }
748}
749
750impl Serializable for DecoratorId {
751 fn write_into<W: ByteWriter>(&self, target: &mut W) {
752 self.0.write_into(target)
753 }
754}
755
756pub fn error_code_from_msg(msg: impl AsRef<str>) -> Felt {
759 hash_string_to_word(msg.as_ref())[0]
761}
762
763#[derive(Debug, thiserror::Error, PartialEq)]
768pub enum MastForestError {
769 #[error("MAST forest decorator count exceeds the maximum of {} decorators", u32::MAX)]
770 TooManyDecorators,
771 #[error("MAST forest node count exceeds the maximum of {} nodes", MastForest::MAX_NODES)]
772 TooManyNodes,
773 #[error("node id {0} is greater than or equal to forest length {1}")]
774 NodeIdOverflow(MastNodeId, usize),
775 #[error("decorator id {0} is greater than or equal to decorator count {1}")]
776 DecoratorIdOverflow(DecoratorId, usize),
777 #[error("basic block cannot be created from an empty list of operations")]
778 EmptyBasicBlock,
779 #[error(
780 "decorator root of child with node id {0} is missing but is required for fingerprint computation"
781 )]
782 ChildFingerprintMissing(MastNodeId),
783 #[error("advice map key {0} already exists when merging forests")]
784 AdviceMapKeyCollisionOnMerge(Word),
785}