1use alloc::{
2 collections::{BTreeMap, BTreeSet},
3 sync::Arc,
4 vec::Vec,
5};
6use core::{
7 fmt, mem,
8 ops::{Index, IndexMut},
9};
10
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize};
13
14mod node;
15pub use node::{
16 BasicBlockNode, CallNode, DynNode, ExternalNode, JoinNode, LoopNode, MastNode,
17 MastNodeErrorContext, MastNodeExt, OP_BATCH_SIZE, OP_GROUP_SIZE, OpBatch, OperationOrDecorator,
18 SplitNode,
19};
20
21use crate::{
22 AdviceMap, Decorator, DecoratorList, Felt, LexicographicWord, Operation, Word,
23 crypto::hash::Hasher,
24 utils::{ByteWriter, DeserializationError, Serializable, hash_string_to_word},
25};
26
27mod serialization;
28
29mod merger;
30pub(crate) use merger::MastForestMerger;
31pub use merger::MastForestRootMap;
32
33mod multi_forest_node_iterator;
34pub(crate) use multi_forest_node_iterator::*;
35
36mod node_fingerprint;
37pub use node_fingerprint::{DecoratorFingerprint, MastNodeFingerprint};
38
39#[cfg(test)]
40mod tests;
41
42#[derive(Clone, Debug, Default, PartialEq, Eq)]
50#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
51pub struct MastForest {
52 nodes: Vec<MastNode>,
54
55 roots: Vec<MastNodeId>,
57
58 decorators: Vec<Decorator>,
60
61 advice_map: AdviceMap,
63
64 error_codes: BTreeMap<u64, Arc<str>>,
68}
69
70impl MastForest {
73 pub fn new() -> Self {
75 Self::default()
76 }
77}
78
79impl MastForest {
82 const MAX_NODES: usize = (1 << 30) - 1;
84
85 pub fn add_decorator(&mut self, decorator: Decorator) -> Result<DecoratorId, MastForestError> {
87 if self.decorators.len() >= u32::MAX as usize {
88 return Err(MastForestError::TooManyDecorators);
89 }
90
91 let new_decorator_id = DecoratorId(self.decorators.len() as u32);
92 self.decorators.push(decorator);
93
94 Ok(new_decorator_id)
95 }
96
97 pub fn add_node(&mut self, node: impl Into<MastNode>) -> Result<MastNodeId, MastForestError> {
101 if self.nodes.len() == Self::MAX_NODES {
102 return Err(MastForestError::TooManyNodes);
103 }
104
105 let new_node_id = MastNodeId(self.nodes.len() as u32);
106 self.nodes.push(node.into());
107
108 Ok(new_node_id)
109 }
110
111 pub fn add_block(
113 &mut self,
114 operations: Vec<Operation>,
115 decorators: Option<DecoratorList>,
116 ) -> Result<MastNodeId, MastForestError> {
117 let block = BasicBlockNode::new(operations, decorators)?;
118 self.add_node(block)
119 }
120
121 pub fn add_join(
123 &mut self,
124 left_child: MastNodeId,
125 right_child: MastNodeId,
126 ) -> Result<MastNodeId, MastForestError> {
127 let join = JoinNode::new([left_child, right_child], self)?;
128 self.add_node(join)
129 }
130
131 pub fn add_split(
133 &mut self,
134 if_branch: MastNodeId,
135 else_branch: MastNodeId,
136 ) -> Result<MastNodeId, MastForestError> {
137 let split = SplitNode::new([if_branch, else_branch], self)?;
138 self.add_node(split)
139 }
140
141 pub fn add_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, MastForestError> {
143 let loop_node = LoopNode::new(body, self)?;
144 self.add_node(loop_node)
145 }
146
147 pub fn add_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
149 let call = CallNode::new(callee, self)?;
150 self.add_node(call)
151 }
152
153 pub fn add_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
155 let syscall = CallNode::new_syscall(callee, self)?;
156 self.add_node(syscall)
157 }
158
159 pub fn add_dyn(&mut self) -> Result<MastNodeId, MastForestError> {
161 self.add_node(DynNode::new_dyn())
162 }
163
164 pub fn add_dyncall(&mut self) -> Result<MastNodeId, MastForestError> {
166 self.add_node(DynNode::new_dyncall())
167 }
168
169 pub fn add_external(&mut self, mast_root: Word) -> Result<MastNodeId, MastForestError> {
171 self.add_node(ExternalNode::new(mast_root))
172 }
173
174 pub fn make_root(&mut self, new_root_id: MastNodeId) {
182 assert!((new_root_id.0 as usize) < self.nodes.len());
183
184 if !self.roots.contains(&new_root_id) {
185 self.roots.push(new_root_id);
186 }
187 }
188
189 pub fn remove_nodes(
197 &mut self,
198 nodes_to_remove: &BTreeSet<MastNodeId>,
199 ) -> BTreeMap<MastNodeId, MastNodeId> {
200 if nodes_to_remove.is_empty() {
201 return BTreeMap::new();
202 }
203
204 let old_nodes = mem::take(&mut self.nodes);
205 let old_root_ids = mem::take(&mut self.roots);
206 let (retained_nodes, id_remappings) = remove_nodes(old_nodes, nodes_to_remove);
207
208 self.remap_and_add_nodes(retained_nodes, &id_remappings);
209 self.remap_and_add_roots(old_root_ids, &id_remappings);
210 id_remappings
211 }
212
213 pub fn append_before_enter(&mut self, node_id: MastNodeId, decorator_ids: &[DecoratorId]) {
214 self[node_id].append_before_enter(decorator_ids)
215 }
216
217 pub fn append_after_exit(&mut self, node_id: MastNodeId, decorator_ids: &[DecoratorId]) {
218 self[node_id].append_after_exit(decorator_ids)
219 }
220
221 pub fn strip_decorators(&mut self) {
223 for node in self.nodes.iter_mut() {
224 node.remove_decorators();
225 }
226 self.decorators.truncate(0);
227 }
228
229 pub fn merge<'forest>(
279 forests: impl IntoIterator<Item = &'forest MastForest>,
280 ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
281 MastForestMerger::merge(forests)
282 }
283
284 #[cfg(test)]
289 pub fn add_block_with_raw_decorators(
290 &mut self,
291 operations: Vec<Operation>,
292 decorators: Vec<(usize, Decorator)>,
293 ) -> Result<MastNodeId, MastForestError> {
294 let block = BasicBlockNode::new_with_raw_decorators(operations, decorators, self)?;
295 self.add_node(block)
296 }
297}
298
299impl MastForest {
301 fn remap_and_add_nodes(
307 &mut self,
308 nodes_to_add: Vec<MastNode>,
309 id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
310 ) {
311 assert!(self.nodes.is_empty());
312
313 for live_node in nodes_to_add {
316 match &live_node {
317 MastNode::Join(join_node) => {
318 let first_child =
319 id_remappings.get(&join_node.first()).copied().unwrap_or(join_node.first());
320 let second_child = id_remappings
321 .get(&join_node.second())
322 .copied()
323 .unwrap_or(join_node.second());
324
325 self.add_join(first_child, second_child).unwrap();
326 },
327 MastNode::Split(split_node) => {
328 let on_true_child = id_remappings
329 .get(&split_node.on_true())
330 .copied()
331 .unwrap_or(split_node.on_true());
332 let on_false_child = id_remappings
333 .get(&split_node.on_false())
334 .copied()
335 .unwrap_or(split_node.on_false());
336
337 self.add_split(on_true_child, on_false_child).unwrap();
338 },
339 MastNode::Loop(loop_node) => {
340 let body_id =
341 id_remappings.get(&loop_node.body()).copied().unwrap_or(loop_node.body());
342
343 self.add_loop(body_id).unwrap();
344 },
345 MastNode::Call(call_node) => {
346 let callee_id = id_remappings
347 .get(&call_node.callee())
348 .copied()
349 .unwrap_or(call_node.callee());
350
351 if call_node.is_syscall() {
352 self.add_syscall(callee_id).unwrap();
353 } else {
354 self.add_call(callee_id).unwrap();
355 }
356 },
357 MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => {
358 self.add_node(live_node).unwrap();
359 },
360 }
361 }
362 }
363
364 fn remap_and_add_roots(
369 &mut self,
370 old_root_ids: Vec<MastNodeId>,
371 id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
372 ) {
373 assert!(self.roots.is_empty());
374
375 for old_root_id in old_root_ids {
376 let new_root_id = id_remappings.get(&old_root_id).copied().unwrap_or(old_root_id);
377 self.make_root(new_root_id);
378 }
379 }
380}
381
382fn remove_nodes(
385 mast_nodes: Vec<MastNode>,
386 nodes_to_remove: &BTreeSet<MastNodeId>,
387) -> (Vec<MastNode>, BTreeMap<MastNodeId, MastNodeId>) {
388 assert!(mast_nodes.len() < u32::MAX as usize);
390
391 let mut retained_nodes = Vec::with_capacity(mast_nodes.len());
392 let mut id_remappings = BTreeMap::new();
393
394 for (old_node_index, old_node) in mast_nodes.into_iter().enumerate() {
395 let old_node_id: MastNodeId = MastNodeId(old_node_index as u32);
396
397 if !nodes_to_remove.contains(&old_node_id) {
398 let new_node_id: MastNodeId = MastNodeId(retained_nodes.len() as u32);
399 id_remappings.insert(old_node_id, new_node_id);
400
401 retained_nodes.push(old_node);
402 }
403 }
404
405 (retained_nodes, id_remappings)
406}
407
408impl MastForest {
412 #[inline(always)]
417 pub fn get_decorator_by_id(&self, decorator_id: DecoratorId) -> Option<&Decorator> {
418 let idx = decorator_id.0 as usize;
419
420 self.decorators.get(idx)
421 }
422
423 #[inline(always)]
428 pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
429 let idx = node_id.0 as usize;
430
431 self.nodes.get(idx)
432 }
433
434 #[inline(always)]
436 pub fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId> {
437 self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied()
438 }
439
440 pub fn is_procedure_root(&self, node_id: MastNodeId) -> bool {
442 self.roots.contains(&node_id)
443 }
444
445 pub fn procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
447 self.roots.iter().map(|&root_id| self[root_id].digest())
448 }
449
450 pub fn local_procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
454 self.roots.iter().filter_map(|&root_id| {
455 let node = &self[root_id];
456 if node.is_external() { None } else { Some(node.digest()) }
457 })
458 }
459
460 pub fn procedure_roots(&self) -> &[MastNodeId] {
462 &self.roots
463 }
464
465 pub fn num_procedures(&self) -> u32 {
467 self.roots
468 .len()
469 .try_into()
470 .expect("MAST forest contains more than 2^32 procedures.")
471 }
472
473 pub fn compute_nodes_commitment<'a>(
478 &self,
479 node_ids: impl IntoIterator<Item = &'a MastNodeId>,
480 ) -> Word {
481 let mut digests: Vec<Word> = node_ids.into_iter().map(|&id| self[id].digest()).collect();
482 digests.sort_unstable_by_key(|word| LexicographicWord::from(*word));
483 miden_crypto::hash::rpo::Rpo256::merge_many(&digests)
484 }
485
486 pub fn num_nodes(&self) -> u32 {
488 self.nodes.len() as u32
489 }
490
491 pub fn nodes(&self) -> &[MastNode] {
493 &self.nodes
494 }
495
496 pub fn decorators(&self) -> &[Decorator] {
497 &self.decorators
498 }
499
500 pub fn advice_map(&self) -> &AdviceMap {
501 &self.advice_map
502 }
503
504 pub fn advice_map_mut(&mut self) -> &mut AdviceMap {
505 &mut self.advice_map
506 }
507
508 pub fn register_error(&mut self, msg: Arc<str>) -> Felt {
511 let code: Felt = error_code_from_msg(&msg);
512 self.error_codes.insert(code.as_int(), msg);
514 code
515 }
516
517 pub fn resolve_error_message(&self, code: Felt) -> Option<Arc<str>> {
519 let key = u64::from(code);
520 self.error_codes.get(&key).cloned()
521 }
522}
523
524impl Index<MastNodeId> for MastForest {
525 type Output = MastNode;
526
527 #[inline(always)]
528 fn index(&self, node_id: MastNodeId) -> &Self::Output {
529 let idx = node_id.0 as usize;
530
531 &self.nodes[idx]
532 }
533}
534
535impl IndexMut<MastNodeId> for MastForest {
536 #[inline(always)]
537 fn index_mut(&mut self, node_id: MastNodeId) -> &mut Self::Output {
538 let idx = node_id.0 as usize;
539
540 &mut self.nodes[idx]
541 }
542}
543
544impl Index<DecoratorId> for MastForest {
545 type Output = Decorator;
546
547 #[inline(always)]
548 fn index(&self, decorator_id: DecoratorId) -> &Self::Output {
549 let idx = decorator_id.0 as usize;
550
551 &self.decorators[idx]
552 }
553}
554
555impl IndexMut<DecoratorId> for MastForest {
556 #[inline(always)]
557 fn index_mut(&mut self, decorator_id: DecoratorId) -> &mut Self::Output {
558 let idx = decorator_id.0 as usize;
559 &mut self.decorators[idx]
560 }
561}
562
563#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
573#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
574#[cfg_attr(feature = "serde", serde(transparent))]
575#[cfg_attr(
576 all(feature = "serde", feature = "arbitrary", test),
577 miden_serde_test_macros::serde_test
578)]
579pub struct MastNodeId(u32);
580
581pub type Remapping = BTreeMap<MastNodeId, MastNodeId>;
583
584impl MastNodeId {
585 pub fn from_u32_safe(
590 value: u32,
591 mast_forest: &MastForest,
592 ) -> Result<Self, DeserializationError> {
593 Self::from_u32_with_node_count(value, mast_forest.nodes.len())
594 }
595
596 pub fn from_usize_safe(
600 node_id: usize,
601 mast_forest: &MastForest,
602 ) -> Result<Self, DeserializationError> {
603 let node_id: u32 = node_id.try_into().map_err(|_| {
604 DeserializationError::InvalidValue(format!(
605 "node id '{node_id}' does not fit into a u32"
606 ))
607 })?;
608 MastNodeId::from_u32_safe(node_id, mast_forest)
609 }
610
611 pub(crate) fn new_unchecked(value: u32) -> Self {
613 Self(value)
614 }
615
616 pub(super) fn from_u32_with_node_count(
630 id: u32,
631 node_count: usize,
632 ) -> Result<Self, DeserializationError> {
633 if (id as usize) < node_count {
634 Ok(Self(id))
635 } else {
636 Err(DeserializationError::InvalidValue(format!(
637 "Invalid deserialized MAST node ID '{id}', but {node_count} is the number of nodes in the forest",
638 )))
639 }
640 }
641
642 pub fn as_usize(&self) -> usize {
643 self.0 as usize
644 }
645
646 pub fn as_u32(&self) -> u32 {
647 self.0
648 }
649
650 pub fn remap(&self, remapping: &Remapping) -> Self {
652 *remapping.get(self).unwrap_or(self)
653 }
654}
655
656impl From<MastNodeId> for usize {
657 fn from(value: MastNodeId) -> Self {
658 value.0 as usize
659 }
660}
661
662impl From<MastNodeId> for u32 {
663 fn from(value: MastNodeId) -> Self {
664 value.0
665 }
666}
667
668impl From<&MastNodeId> for u32 {
669 fn from(value: &MastNodeId) -> Self {
670 value.0
671 }
672}
673
674impl fmt::Display for MastNodeId {
675 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
676 write!(f, "MastNodeId({})", self.0)
677 }
678}
679
680#[cfg(any(test, feature = "arbitrary"))]
681impl proptest::prelude::Arbitrary for MastNodeId {
682 type Parameters = ();
683
684 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
685 use proptest::prelude::*;
686 any::<u32>().prop_map(MastNodeId).boxed()
687 }
688
689 type Strategy = proptest::prelude::BoxedStrategy<Self>;
690}
691
692pub struct SubtreeIterator<'a> {
697 forest: &'a MastForest,
698 discovered: Vec<MastNodeId>,
699 unvisited: Vec<MastNodeId>,
700}
701impl<'a> SubtreeIterator<'a> {
702 pub fn new(root: &MastNodeId, forest: &'a MastForest) -> Self {
703 let discovered = vec![];
704 let unvisited = vec![*root];
705 SubtreeIterator { forest, discovered, unvisited }
706 }
707}
708impl Iterator for SubtreeIterator<'_> {
709 type Item = MastNodeId;
710 fn next(&mut self) -> Option<MastNodeId> {
711 while let Some(id) = self.unvisited.pop() {
712 let node = &self.forest[id];
713 if !node.has_children() {
714 return Some(id);
715 } else {
716 self.discovered.push(id);
717 node.append_children_to(&mut self.unvisited);
718 }
719 }
720 self.discovered.pop()
721 }
722}
723
724#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
730#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
731#[cfg_attr(feature = "serde", serde(transparent))]
732pub struct DecoratorId(u32);
733
734impl DecoratorId {
735 pub fn from_u32_safe(
740 value: u32,
741 mast_forest: &MastForest,
742 ) -> Result<Self, DeserializationError> {
743 if (value as usize) < mast_forest.decorators.len() {
744 Ok(Self(value))
745 } else {
746 Err(DeserializationError::InvalidValue(format!(
747 "Invalid deserialized MAST decorator id '{}', but only {} decorators in the forest",
748 value,
749 mast_forest.nodes.len(),
750 )))
751 }
752 }
753
754 pub(crate) fn new_unchecked(value: u32) -> Self {
756 Self(value)
757 }
758
759 pub fn as_usize(&self) -> usize {
760 self.0 as usize
761 }
762
763 pub fn as_u32(&self) -> u32 {
764 self.0
765 }
766}
767
768impl From<DecoratorId> for usize {
769 fn from(value: DecoratorId) -> Self {
770 value.0 as usize
771 }
772}
773
774impl From<DecoratorId> for u32 {
775 fn from(value: DecoratorId) -> Self {
776 value.0
777 }
778}
779
780impl From<&DecoratorId> for u32 {
781 fn from(value: &DecoratorId) -> Self {
782 value.0
783 }
784}
785
786impl fmt::Display for DecoratorId {
787 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
788 write!(f, "DecoratorId({})", self.0)
789 }
790}
791
792impl Serializable for DecoratorId {
793 fn write_into<W: ByteWriter>(&self, target: &mut W) {
794 self.0.write_into(target)
795 }
796}
797
798pub fn error_code_from_msg(msg: impl AsRef<str>) -> Felt {
801 hash_string_to_word(msg.as_ref())[0]
803}
804
805#[derive(Debug, thiserror::Error, PartialEq)]
810pub enum MastForestError {
811 #[error("MAST forest decorator count exceeds the maximum of {} decorators", u32::MAX)]
812 TooManyDecorators,
813 #[error("MAST forest node count exceeds the maximum of {} nodes", MastForest::MAX_NODES)]
814 TooManyNodes,
815 #[error("node id {0} is greater than or equal to forest length {1}")]
816 NodeIdOverflow(MastNodeId, usize),
817 #[error("decorator id {0} is greater than or equal to decorator count {1}")]
818 DecoratorIdOverflow(DecoratorId, usize),
819 #[error("basic block cannot be created from an empty list of operations")]
820 EmptyBasicBlock,
821 #[error(
822 "decorator root of child with node id {0} is missing but is required for fingerprint computation"
823 )]
824 ChildFingerprintMissing(MastNodeId),
825 #[error("advice map key {0} already exists when merging forests")]
826 AdviceMapKeyCollisionOnMerge(Word),
827}