1use alloc::{
2 collections::{BTreeMap, BTreeSet},
3 sync::Arc,
4 vec::Vec,
5};
6use core::{
7 fmt, mem,
8 ops::{Index, IndexMut},
9};
10mod node;
11pub use node::{
12 BasicBlockNode, CallNode, DynNode, ExternalNode, JoinNode, LoopNode, MastNode, MastNodeExt,
13 OP_BATCH_SIZE, OP_GROUP_SIZE, OpBatch, OperationOrDecorator, SplitNode,
14};
15
16use crate::{
17 AdviceMap, Decorator, DecoratorList, Felt, LexicographicWord, Operation, Word,
18 crypto::hash::{Blake3_256, Blake3Digest, Digest, Hasher},
19 utils::{ByteWriter, DeserializationError, Serializable},
20};
21
22mod serialization;
23
24mod merger;
25pub(crate) use merger::MastForestMerger;
26pub use merger::MastForestRootMap;
27
28mod multi_forest_node_iterator;
29pub(crate) use multi_forest_node_iterator::*;
30
31mod node_fingerprint;
32pub use node_fingerprint::{DecoratorFingerprint, MastNodeFingerprint};
33
34#[cfg(test)]
35mod tests;
36
37#[derive(Clone, Debug, Default, PartialEq, Eq)]
45pub struct MastForest {
46 nodes: Vec<MastNode>,
48
49 roots: Vec<MastNodeId>,
51
52 decorators: Vec<Decorator>,
54
55 advice_map: AdviceMap,
57
58 error_codes: BTreeMap<u64, Arc<str>>,
62}
63
64impl MastForest {
67 pub fn new() -> Self {
69 Self::default()
70 }
71}
72
73impl MastForest {
76 const MAX_NODES: usize = (1 << 30) - 1;
78
79 pub fn add_decorator(&mut self, decorator: Decorator) -> Result<DecoratorId, MastForestError> {
81 if self.decorators.len() >= u32::MAX as usize {
82 return Err(MastForestError::TooManyDecorators);
83 }
84
85 let new_decorator_id = DecoratorId(self.decorators.len() as u32);
86 self.decorators.push(decorator);
87
88 Ok(new_decorator_id)
89 }
90
91 pub fn add_node(&mut self, node: MastNode) -> Result<MastNodeId, MastForestError> {
95 if self.nodes.len() == Self::MAX_NODES {
96 return Err(MastForestError::TooManyNodes);
97 }
98
99 let new_node_id = MastNodeId(self.nodes.len() as u32);
100 self.nodes.push(node);
101
102 Ok(new_node_id)
103 }
104
105 pub fn add_block(
107 &mut self,
108 operations: Vec<Operation>,
109 decorators: Option<DecoratorList>,
110 ) -> Result<MastNodeId, MastForestError> {
111 let block = MastNode::new_basic_block(operations, decorators)?;
112 self.add_node(block)
113 }
114
115 pub fn add_join(
117 &mut self,
118 left_child: MastNodeId,
119 right_child: MastNodeId,
120 ) -> Result<MastNodeId, MastForestError> {
121 let join = MastNode::new_join(left_child, right_child, self)?;
122 self.add_node(join)
123 }
124
125 pub fn add_split(
127 &mut self,
128 if_branch: MastNodeId,
129 else_branch: MastNodeId,
130 ) -> Result<MastNodeId, MastForestError> {
131 let split = MastNode::new_split(if_branch, else_branch, self)?;
132 self.add_node(split)
133 }
134
135 pub fn add_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, MastForestError> {
137 let loop_node = MastNode::new_loop(body, self)?;
138 self.add_node(loop_node)
139 }
140
141 pub fn add_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
143 let call = MastNode::new_call(callee, self)?;
144 self.add_node(call)
145 }
146
147 pub fn add_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
149 let syscall = MastNode::new_syscall(callee, self)?;
150 self.add_node(syscall)
151 }
152
153 pub fn add_dyn(&mut self) -> Result<MastNodeId, MastForestError> {
155 self.add_node(MastNode::new_dyn())
156 }
157
158 pub fn add_dyncall(&mut self) -> Result<MastNodeId, MastForestError> {
160 self.add_node(MastNode::new_dyncall())
161 }
162
163 pub fn add_external(&mut self, mast_root: Word) -> Result<MastNodeId, MastForestError> {
165 self.add_node(MastNode::new_external(mast_root))
166 }
167
168 pub fn make_root(&mut self, new_root_id: MastNodeId) {
176 assert!((new_root_id.0 as usize) < self.nodes.len());
177
178 if !self.roots.contains(&new_root_id) {
179 self.roots.push(new_root_id);
180 }
181 }
182
183 pub fn remove_nodes(
191 &mut self,
192 nodes_to_remove: &BTreeSet<MastNodeId>,
193 ) -> BTreeMap<MastNodeId, MastNodeId> {
194 if nodes_to_remove.is_empty() {
195 return BTreeMap::new();
196 }
197
198 let old_nodes = mem::take(&mut self.nodes);
199 let old_root_ids = mem::take(&mut self.roots);
200 let (retained_nodes, id_remappings) = remove_nodes(old_nodes, nodes_to_remove);
201
202 self.remap_and_add_nodes(retained_nodes, &id_remappings);
203 self.remap_and_add_roots(old_root_ids, &id_remappings);
204 id_remappings
205 }
206
207 pub fn append_before_enter(&mut self, node_id: MastNodeId, decorator_ids: &[DecoratorId]) {
208 self[node_id].append_before_enter(decorator_ids)
209 }
210
211 pub fn append_after_exit(&mut self, node_id: MastNodeId, decorator_ids: &[DecoratorId]) {
212 self[node_id].append_after_exit(decorator_ids)
213 }
214
215 pub fn merge<'forest>(
265 forests: impl IntoIterator<Item = &'forest MastForest>,
266 ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
267 MastForestMerger::merge(forests)
268 }
269
270 #[cfg(test)]
275 pub fn add_block_with_raw_decorators(
276 &mut self,
277 operations: Vec<Operation>,
278 decorators: Vec<(usize, Decorator)>,
279 ) -> Result<MastNodeId, MastForestError> {
280 let block = MastNode::new_basic_block_with_raw_decorators(operations, decorators, self)?;
281 self.add_node(block)
282 }
283}
284
285impl MastForest {
287 fn remap_and_add_nodes(
293 &mut self,
294 nodes_to_add: Vec<MastNode>,
295 id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
296 ) {
297 assert!(self.nodes.is_empty());
298
299 for live_node in nodes_to_add {
302 match &live_node {
303 MastNode::Join(join_node) => {
304 let first_child =
305 id_remappings.get(&join_node.first()).copied().unwrap_or(join_node.first());
306 let second_child = id_remappings
307 .get(&join_node.second())
308 .copied()
309 .unwrap_or(join_node.second());
310
311 self.add_join(first_child, second_child).unwrap();
312 },
313 MastNode::Split(split_node) => {
314 let on_true_child = id_remappings
315 .get(&split_node.on_true())
316 .copied()
317 .unwrap_or(split_node.on_true());
318 let on_false_child = id_remappings
319 .get(&split_node.on_false())
320 .copied()
321 .unwrap_or(split_node.on_false());
322
323 self.add_split(on_true_child, on_false_child).unwrap();
324 },
325 MastNode::Loop(loop_node) => {
326 let body_id =
327 id_remappings.get(&loop_node.body()).copied().unwrap_or(loop_node.body());
328
329 self.add_loop(body_id).unwrap();
330 },
331 MastNode::Call(call_node) => {
332 let callee_id = id_remappings
333 .get(&call_node.callee())
334 .copied()
335 .unwrap_or(call_node.callee());
336
337 if call_node.is_syscall() {
338 self.add_syscall(callee_id).unwrap();
339 } else {
340 self.add_call(callee_id).unwrap();
341 }
342 },
343 MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => {
344 self.add_node(live_node).unwrap();
345 },
346 }
347 }
348 }
349
350 fn remap_and_add_roots(
355 &mut self,
356 old_root_ids: Vec<MastNodeId>,
357 id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
358 ) {
359 assert!(self.roots.is_empty());
360
361 for old_root_id in old_root_ids {
362 let new_root_id = id_remappings.get(&old_root_id).copied().unwrap_or(old_root_id);
363 self.make_root(new_root_id);
364 }
365 }
366}
367
368fn remove_nodes(
371 mast_nodes: Vec<MastNode>,
372 nodes_to_remove: &BTreeSet<MastNodeId>,
373) -> (Vec<MastNode>, BTreeMap<MastNodeId, MastNodeId>) {
374 assert!(mast_nodes.len() < u32::MAX as usize);
376
377 let mut retained_nodes = Vec::with_capacity(mast_nodes.len());
378 let mut id_remappings = BTreeMap::new();
379
380 for (old_node_index, old_node) in mast_nodes.into_iter().enumerate() {
381 let old_node_id: MastNodeId = MastNodeId(old_node_index as u32);
382
383 if !nodes_to_remove.contains(&old_node_id) {
384 let new_node_id: MastNodeId = MastNodeId(retained_nodes.len() as u32);
385 id_remappings.insert(old_node_id, new_node_id);
386
387 retained_nodes.push(old_node);
388 }
389 }
390
391 (retained_nodes, id_remappings)
392}
393
394impl MastForest {
398 #[inline(always)]
403 pub fn get_decorator_by_id(&self, decorator_id: DecoratorId) -> Option<&Decorator> {
404 let idx = decorator_id.0 as usize;
405
406 self.decorators.get(idx)
407 }
408
409 #[inline(always)]
414 pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
415 let idx = node_id.0 as usize;
416
417 self.nodes.get(idx)
418 }
419
420 #[inline(always)]
422 pub fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId> {
423 self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied()
424 }
425
426 pub fn is_procedure_root(&self, node_id: MastNodeId) -> bool {
428 self.roots.contains(&node_id)
429 }
430
431 pub fn procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
433 self.roots.iter().map(|&root_id| self[root_id].digest())
434 }
435
436 pub fn local_procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
440 self.roots.iter().filter_map(|&root_id| {
441 let node = &self[root_id];
442 if node.is_external() { None } else { Some(node.digest()) }
443 })
444 }
445
446 pub fn procedure_roots(&self) -> &[MastNodeId] {
448 &self.roots
449 }
450
451 pub fn num_procedures(&self) -> u32 {
453 self.roots
454 .len()
455 .try_into()
456 .expect("MAST forest contains more than 2^32 procedures.")
457 }
458
459 pub fn compute_nodes_commitment<'a>(
464 &self,
465 node_ids: impl IntoIterator<Item = &'a MastNodeId>,
466 ) -> Word {
467 let mut digests: Vec<Word> = node_ids.into_iter().map(|&id| self[id].digest()).collect();
468 digests.sort_unstable_by_key(|word| LexicographicWord::from(*word));
469 miden_crypto::hash::rpo::Rpo256::merge_many(&digests)
470 }
471
472 pub fn num_nodes(&self) -> u32 {
474 self.nodes.len() as u32
475 }
476
477 pub fn nodes(&self) -> &[MastNode] {
479 &self.nodes
480 }
481
482 pub fn decorators(&self) -> &[Decorator] {
483 &self.decorators
484 }
485
486 pub fn advice_map(&self) -> &AdviceMap {
487 &self.advice_map
488 }
489
490 pub fn advice_map_mut(&mut self) -> &mut AdviceMap {
491 &mut self.advice_map
492 }
493
494 pub fn register_error(&mut self, msg: Arc<str>) -> Felt {
497 let code: Felt = error_code_from_msg(&msg);
498 self.error_codes.insert(code.as_int(), msg);
500 code
501 }
502
503 pub fn resolve_error_message(&self, code: Felt) -> Option<Arc<str>> {
505 let key = u64::from(code);
506 self.error_codes.get(&key).cloned()
507 }
508}
509
510impl Index<MastNodeId> for MastForest {
511 type Output = MastNode;
512
513 #[inline(always)]
514 fn index(&self, node_id: MastNodeId) -> &Self::Output {
515 let idx = node_id.0 as usize;
516
517 &self.nodes[idx]
518 }
519}
520
521impl IndexMut<MastNodeId> for MastForest {
522 #[inline(always)]
523 fn index_mut(&mut self, node_id: MastNodeId) -> &mut Self::Output {
524 let idx = node_id.0 as usize;
525
526 &mut self.nodes[idx]
527 }
528}
529
530impl Index<DecoratorId> for MastForest {
531 type Output = Decorator;
532
533 #[inline(always)]
534 fn index(&self, decorator_id: DecoratorId) -> &Self::Output {
535 let idx = decorator_id.0 as usize;
536
537 &self.decorators[idx]
538 }
539}
540
541impl IndexMut<DecoratorId> for MastForest {
542 #[inline(always)]
543 fn index_mut(&mut self, decorator_id: DecoratorId) -> &mut Self::Output {
544 let idx = decorator_id.0 as usize;
545 &mut self.decorators[idx]
546 }
547}
548
549#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
559pub struct MastNodeId(u32);
560
561pub type Remapping = BTreeMap<MastNodeId, MastNodeId>;
563
564impl MastNodeId {
565 pub fn from_u32_safe(
570 value: u32,
571 mast_forest: &MastForest,
572 ) -> Result<Self, DeserializationError> {
573 Self::from_u32_with_node_count(value, mast_forest.nodes.len())
574 }
575
576 pub fn from_usize_safe(
580 node_id: usize,
581 mast_forest: &MastForest,
582 ) -> Result<Self, DeserializationError> {
583 let node_id: u32 = node_id.try_into().map_err(|_| {
584 DeserializationError::InvalidValue(format!(
585 "node id '{node_id}' does not fit into a u32"
586 ))
587 })?;
588 MastNodeId::from_u32_safe(node_id, mast_forest)
589 }
590
591 pub(crate) fn new_unchecked(value: u32) -> Self {
593 Self(value)
594 }
595
596 pub(super) fn from_u32_with_node_count(
610 id: u32,
611 node_count: usize,
612 ) -> Result<Self, DeserializationError> {
613 if (id as usize) < node_count {
614 Ok(Self(id))
615 } else {
616 Err(DeserializationError::InvalidValue(format!(
617 "Invalid deserialized MAST node ID '{id}', but {node_count} is the number of nodes in the forest",
618 )))
619 }
620 }
621
622 pub fn as_usize(&self) -> usize {
623 self.0 as usize
624 }
625
626 pub fn as_u32(&self) -> u32 {
627 self.0
628 }
629
630 pub fn remap(&self, remapping: &Remapping) -> Self {
632 *remapping.get(self).unwrap_or(self)
633 }
634}
635
636impl From<MastNodeId> for usize {
637 fn from(value: MastNodeId) -> Self {
638 value.0 as usize
639 }
640}
641
642impl From<MastNodeId> for u32 {
643 fn from(value: MastNodeId) -> Self {
644 value.0
645 }
646}
647
648impl From<&MastNodeId> for u32 {
649 fn from(value: &MastNodeId) -> Self {
650 value.0
651 }
652}
653
654impl fmt::Display for MastNodeId {
655 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
656 write!(f, "MastNodeId({})", self.0)
657 }
658}
659
660pub struct SubtreeIterator<'a> {
665 forest: &'a MastForest,
666 discovered: Vec<MastNodeId>,
667 unvisited: Vec<MastNodeId>,
668}
669impl<'a> SubtreeIterator<'a> {
670 pub fn new(root: &MastNodeId, forest: &'a MastForest) -> Self {
671 let discovered = vec![];
672 let unvisited = vec![*root];
673 SubtreeIterator { forest, discovered, unvisited }
674 }
675}
676impl Iterator for SubtreeIterator<'_> {
677 type Item = MastNodeId;
678 fn next(&mut self) -> Option<MastNodeId> {
679 while let Some(id) = self.unvisited.pop() {
680 let node = &self.forest[id];
681 if !node.has_children() {
682 return Some(id);
683 } else {
684 self.discovered.push(id);
685 node.append_children_to(&mut self.unvisited);
686 }
687 }
688 self.discovered.pop()
689 }
690}
691
692#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
698pub struct DecoratorId(u32);
699
700impl DecoratorId {
701 pub fn from_u32_safe(
706 value: u32,
707 mast_forest: &MastForest,
708 ) -> Result<Self, DeserializationError> {
709 if (value as usize) < mast_forest.decorators.len() {
710 Ok(Self(value))
711 } else {
712 Err(DeserializationError::InvalidValue(format!(
713 "Invalid deserialized MAST decorator id '{}', but only {} decorators in the forest",
714 value,
715 mast_forest.nodes.len(),
716 )))
717 }
718 }
719
720 pub(crate) fn new_unchecked(value: u32) -> Self {
722 Self(value)
723 }
724
725 pub fn as_usize(&self) -> usize {
726 self.0 as usize
727 }
728
729 pub fn as_u32(&self) -> u32 {
730 self.0
731 }
732}
733
734impl From<DecoratorId> for usize {
735 fn from(value: DecoratorId) -> Self {
736 value.0 as usize
737 }
738}
739
740impl From<DecoratorId> for u32 {
741 fn from(value: DecoratorId) -> Self {
742 value.0
743 }
744}
745
746impl From<&DecoratorId> for u32 {
747 fn from(value: &DecoratorId) -> Self {
748 value.0
749 }
750}
751
752impl fmt::Display for DecoratorId {
753 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
754 write!(f, "DecoratorId({})", self.0)
755 }
756}
757
758impl Serializable for DecoratorId {
759 fn write_into<W: ByteWriter>(&self, target: &mut W) {
760 self.0.write_into(target)
761 }
762}
763
764pub fn error_code_from_msg(msg: impl AsRef<str>) -> Felt {
767 let digest: Blake3Digest<32> = Blake3_256::hash(msg.as_ref().as_bytes());
768 let mut digest_bytes: [u8; 8] = [0; 8];
769 digest_bytes.copy_from_slice(&digest.as_bytes()[0..8]);
770 let code = u64::from_le_bytes(digest_bytes);
771 Felt::new(code)
772}
773
774#[derive(Debug, thiserror::Error, PartialEq)]
779pub enum MastForestError {
780 #[error("MAST forest decorator count exceeds the maximum of {} decorators", u32::MAX)]
781 TooManyDecorators,
782 #[error("MAST forest node count exceeds the maximum of {} nodes", MastForest::MAX_NODES)]
783 TooManyNodes,
784 #[error("node id {0} is greater than or equal to forest length {1}")]
785 NodeIdOverflow(MastNodeId, usize),
786 #[error("decorator id {0} is greater than or equal to decorator count {1}")]
787 DecoratorIdOverflow(DecoratorId, usize),
788 #[error("basic block cannot be created from an empty list of operations")]
789 EmptyBasicBlock,
790 #[error(
791 "decorator root of child with node id {0} is missing but is required for fingerprint computation"
792 )]
793 ChildFingerprintMissing(MastNodeId),
794 #[error("advice map key {0} already exists when merging forests")]
795 AdviceMapKeyCollisionOnMerge(Word),
796}