1use alloc::{
2 collections::{BTreeMap, BTreeSet},
3 sync::Arc,
4 vec::Vec,
5};
6use core::{
7 fmt, mem,
8 ops::{Index, IndexMut},
9};
10mod node;
11pub use node::{
12 BasicBlockNode, CallNode, DynNode, ExternalNode, JoinNode, LoopNode, MastNode, MastNodeExt,
13 OP_BATCH_SIZE, OP_GROUP_SIZE, OpBatch, OperationOrDecorator, SplitNode,
14};
15
16use crate::{
17 AdviceMap, Decorator, DecoratorList, Felt, LexicographicWord, Operation, Word,
18 crypto::hash::{Blake3_256, Blake3Digest, Digest, Hasher},
19 utils::{ByteWriter, DeserializationError, Serializable},
20};
21
22mod serialization;
23
24mod merger;
25pub(crate) use merger::MastForestMerger;
26pub use merger::MastForestRootMap;
27
28mod multi_forest_node_iterator;
29pub(crate) use multi_forest_node_iterator::*;
30
31mod node_fingerprint;
32pub use node_fingerprint::{DecoratorFingerprint, MastNodeFingerprint};
33
34#[cfg(test)]
35mod tests;
36
37#[derive(Clone, Debug, Default, PartialEq, Eq)]
45pub struct MastForest {
46 nodes: Vec<MastNode>,
48
49 roots: Vec<MastNodeId>,
51
52 decorators: Vec<Decorator>,
54
55 advice_map: AdviceMap,
57
58 error_codes: BTreeMap<u64, Arc<str>>,
62}
63
64impl MastForest {
67 pub fn new() -> Self {
69 Self::default()
70 }
71}
72
73impl MastForest {
76 const MAX_NODES: usize = (1 << 30) - 1;
78
79 pub fn add_decorator(&mut self, decorator: Decorator) -> Result<DecoratorId, MastForestError> {
81 if self.decorators.len() >= u32::MAX as usize {
82 return Err(MastForestError::TooManyDecorators);
83 }
84
85 let new_decorator_id = DecoratorId(self.decorators.len() as u32);
86 self.decorators.push(decorator);
87
88 Ok(new_decorator_id)
89 }
90
91 pub fn add_node(&mut self, node: MastNode) -> Result<MastNodeId, MastForestError> {
95 if self.nodes.len() == Self::MAX_NODES {
96 return Err(MastForestError::TooManyNodes);
97 }
98
99 let new_node_id = MastNodeId(self.nodes.len() as u32);
100 self.nodes.push(node);
101
102 Ok(new_node_id)
103 }
104
105 pub fn add_block(
107 &mut self,
108 operations: Vec<Operation>,
109 decorators: Option<DecoratorList>,
110 ) -> Result<MastNodeId, MastForestError> {
111 let block = MastNode::new_basic_block(operations, decorators)?;
112 self.add_node(block)
113 }
114
115 pub fn add_join(
117 &mut self,
118 left_child: MastNodeId,
119 right_child: MastNodeId,
120 ) -> Result<MastNodeId, MastForestError> {
121 let join = MastNode::new_join(left_child, right_child, self)?;
122 self.add_node(join)
123 }
124
125 pub fn add_split(
127 &mut self,
128 if_branch: MastNodeId,
129 else_branch: MastNodeId,
130 ) -> Result<MastNodeId, MastForestError> {
131 let split = MastNode::new_split(if_branch, else_branch, self)?;
132 self.add_node(split)
133 }
134
135 pub fn add_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, MastForestError> {
137 let loop_node = MastNode::new_loop(body, self)?;
138 self.add_node(loop_node)
139 }
140
141 pub fn add_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
143 let call = MastNode::new_call(callee, self)?;
144 self.add_node(call)
145 }
146
147 pub fn add_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
149 let syscall = MastNode::new_syscall(callee, self)?;
150 self.add_node(syscall)
151 }
152
153 pub fn add_dyn(&mut self) -> Result<MastNodeId, MastForestError> {
155 self.add_node(MastNode::new_dyn())
156 }
157
158 pub fn add_dyncall(&mut self) -> Result<MastNodeId, MastForestError> {
160 self.add_node(MastNode::new_dyncall())
161 }
162
163 pub fn add_external(&mut self, mast_root: Word) -> Result<MastNodeId, MastForestError> {
165 self.add_node(MastNode::new_external(mast_root))
166 }
167
168 pub fn make_root(&mut self, new_root_id: MastNodeId) {
176 assert!((new_root_id.0 as usize) < self.nodes.len());
177
178 if !self.roots.contains(&new_root_id) {
179 self.roots.push(new_root_id);
180 }
181 }
182
183 pub fn remove_nodes(
191 &mut self,
192 nodes_to_remove: &BTreeSet<MastNodeId>,
193 ) -> BTreeMap<MastNodeId, MastNodeId> {
194 if nodes_to_remove.is_empty() {
195 return BTreeMap::new();
196 }
197
198 let old_nodes = mem::take(&mut self.nodes);
199 let old_root_ids = mem::take(&mut self.roots);
200 let (retained_nodes, id_remappings) = remove_nodes(old_nodes, nodes_to_remove);
201
202 self.remap_and_add_nodes(retained_nodes, &id_remappings);
203 self.remap_and_add_roots(old_root_ids, &id_remappings);
204 id_remappings
205 }
206
207 pub fn append_before_enter(&mut self, node_id: MastNodeId, decorator_ids: &[DecoratorId]) {
208 self[node_id].append_before_enter(decorator_ids)
209 }
210
211 pub fn append_after_exit(&mut self, node_id: MastNodeId, decorator_ids: &[DecoratorId]) {
212 self[node_id].append_after_exit(decorator_ids)
213 }
214
215 pub fn strip_decorators(&mut self) {
217 for node in self.nodes.iter_mut() {
218 node.remove_decorators();
219 }
220 self.decorators.truncate(0);
221 }
222
223 pub fn merge<'forest>(
273 forests: impl IntoIterator<Item = &'forest MastForest>,
274 ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
275 MastForestMerger::merge(forests)
276 }
277
278 #[cfg(test)]
283 pub fn add_block_with_raw_decorators(
284 &mut self,
285 operations: Vec<Operation>,
286 decorators: Vec<(usize, Decorator)>,
287 ) -> Result<MastNodeId, MastForestError> {
288 let block = MastNode::new_basic_block_with_raw_decorators(operations, decorators, self)?;
289 self.add_node(block)
290 }
291}
292
293impl MastForest {
295 fn remap_and_add_nodes(
301 &mut self,
302 nodes_to_add: Vec<MastNode>,
303 id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
304 ) {
305 assert!(self.nodes.is_empty());
306
307 for live_node in nodes_to_add {
310 match &live_node {
311 MastNode::Join(join_node) => {
312 let first_child =
313 id_remappings.get(&join_node.first()).copied().unwrap_or(join_node.first());
314 let second_child = id_remappings
315 .get(&join_node.second())
316 .copied()
317 .unwrap_or(join_node.second());
318
319 self.add_join(first_child, second_child).unwrap();
320 },
321 MastNode::Split(split_node) => {
322 let on_true_child = id_remappings
323 .get(&split_node.on_true())
324 .copied()
325 .unwrap_or(split_node.on_true());
326 let on_false_child = id_remappings
327 .get(&split_node.on_false())
328 .copied()
329 .unwrap_or(split_node.on_false());
330
331 self.add_split(on_true_child, on_false_child).unwrap();
332 },
333 MastNode::Loop(loop_node) => {
334 let body_id =
335 id_remappings.get(&loop_node.body()).copied().unwrap_or(loop_node.body());
336
337 self.add_loop(body_id).unwrap();
338 },
339 MastNode::Call(call_node) => {
340 let callee_id = id_remappings
341 .get(&call_node.callee())
342 .copied()
343 .unwrap_or(call_node.callee());
344
345 if call_node.is_syscall() {
346 self.add_syscall(callee_id).unwrap();
347 } else {
348 self.add_call(callee_id).unwrap();
349 }
350 },
351 MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => {
352 self.add_node(live_node).unwrap();
353 },
354 }
355 }
356 }
357
358 fn remap_and_add_roots(
363 &mut self,
364 old_root_ids: Vec<MastNodeId>,
365 id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
366 ) {
367 assert!(self.roots.is_empty());
368
369 for old_root_id in old_root_ids {
370 let new_root_id = id_remappings.get(&old_root_id).copied().unwrap_or(old_root_id);
371 self.make_root(new_root_id);
372 }
373 }
374}
375
376fn remove_nodes(
379 mast_nodes: Vec<MastNode>,
380 nodes_to_remove: &BTreeSet<MastNodeId>,
381) -> (Vec<MastNode>, BTreeMap<MastNodeId, MastNodeId>) {
382 assert!(mast_nodes.len() < u32::MAX as usize);
384
385 let mut retained_nodes = Vec::with_capacity(mast_nodes.len());
386 let mut id_remappings = BTreeMap::new();
387
388 for (old_node_index, old_node) in mast_nodes.into_iter().enumerate() {
389 let old_node_id: MastNodeId = MastNodeId(old_node_index as u32);
390
391 if !nodes_to_remove.contains(&old_node_id) {
392 let new_node_id: MastNodeId = MastNodeId(retained_nodes.len() as u32);
393 id_remappings.insert(old_node_id, new_node_id);
394
395 retained_nodes.push(old_node);
396 }
397 }
398
399 (retained_nodes, id_remappings)
400}
401
402impl MastForest {
406 #[inline(always)]
411 pub fn get_decorator_by_id(&self, decorator_id: DecoratorId) -> Option<&Decorator> {
412 let idx = decorator_id.0 as usize;
413
414 self.decorators.get(idx)
415 }
416
417 #[inline(always)]
422 pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
423 let idx = node_id.0 as usize;
424
425 self.nodes.get(idx)
426 }
427
428 #[inline(always)]
430 pub fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId> {
431 self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied()
432 }
433
434 pub fn is_procedure_root(&self, node_id: MastNodeId) -> bool {
436 self.roots.contains(&node_id)
437 }
438
439 pub fn procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
441 self.roots.iter().map(|&root_id| self[root_id].digest())
442 }
443
444 pub fn local_procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
448 self.roots.iter().filter_map(|&root_id| {
449 let node = &self[root_id];
450 if node.is_external() { None } else { Some(node.digest()) }
451 })
452 }
453
454 pub fn procedure_roots(&self) -> &[MastNodeId] {
456 &self.roots
457 }
458
459 pub fn num_procedures(&self) -> u32 {
461 self.roots
462 .len()
463 .try_into()
464 .expect("MAST forest contains more than 2^32 procedures.")
465 }
466
467 pub fn compute_nodes_commitment<'a>(
472 &self,
473 node_ids: impl IntoIterator<Item = &'a MastNodeId>,
474 ) -> Word {
475 let mut digests: Vec<Word> = node_ids.into_iter().map(|&id| self[id].digest()).collect();
476 digests.sort_unstable_by_key(|word| LexicographicWord::from(*word));
477 miden_crypto::hash::rpo::Rpo256::merge_many(&digests)
478 }
479
480 pub fn num_nodes(&self) -> u32 {
482 self.nodes.len() as u32
483 }
484
485 pub fn nodes(&self) -> &[MastNode] {
487 &self.nodes
488 }
489
490 pub fn decorators(&self) -> &[Decorator] {
491 &self.decorators
492 }
493
494 pub fn advice_map(&self) -> &AdviceMap {
495 &self.advice_map
496 }
497
498 pub fn advice_map_mut(&mut self) -> &mut AdviceMap {
499 &mut self.advice_map
500 }
501
502 pub fn register_error(&mut self, msg: Arc<str>) -> Felt {
505 let code: Felt = error_code_from_msg(&msg);
506 self.error_codes.insert(code.as_int(), msg);
508 code
509 }
510
511 pub fn resolve_error_message(&self, code: Felt) -> Option<Arc<str>> {
513 let key = u64::from(code);
514 self.error_codes.get(&key).cloned()
515 }
516}
517
518impl Index<MastNodeId> for MastForest {
519 type Output = MastNode;
520
521 #[inline(always)]
522 fn index(&self, node_id: MastNodeId) -> &Self::Output {
523 let idx = node_id.0 as usize;
524
525 &self.nodes[idx]
526 }
527}
528
529impl IndexMut<MastNodeId> for MastForest {
530 #[inline(always)]
531 fn index_mut(&mut self, node_id: MastNodeId) -> &mut Self::Output {
532 let idx = node_id.0 as usize;
533
534 &mut self.nodes[idx]
535 }
536}
537
538impl Index<DecoratorId> for MastForest {
539 type Output = Decorator;
540
541 #[inline(always)]
542 fn index(&self, decorator_id: DecoratorId) -> &Self::Output {
543 let idx = decorator_id.0 as usize;
544
545 &self.decorators[idx]
546 }
547}
548
549impl IndexMut<DecoratorId> for MastForest {
550 #[inline(always)]
551 fn index_mut(&mut self, decorator_id: DecoratorId) -> &mut Self::Output {
552 let idx = decorator_id.0 as usize;
553 &mut self.decorators[idx]
554 }
555}
556
557#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
567pub struct MastNodeId(u32);
568
569pub type Remapping = BTreeMap<MastNodeId, MastNodeId>;
571
572impl MastNodeId {
573 pub fn from_u32_safe(
578 value: u32,
579 mast_forest: &MastForest,
580 ) -> Result<Self, DeserializationError> {
581 Self::from_u32_with_node_count(value, mast_forest.nodes.len())
582 }
583
584 pub fn from_usize_safe(
588 node_id: usize,
589 mast_forest: &MastForest,
590 ) -> Result<Self, DeserializationError> {
591 let node_id: u32 = node_id.try_into().map_err(|_| {
592 DeserializationError::InvalidValue(format!(
593 "node id '{node_id}' does not fit into a u32"
594 ))
595 })?;
596 MastNodeId::from_u32_safe(node_id, mast_forest)
597 }
598
599 pub(crate) fn new_unchecked(value: u32) -> Self {
601 Self(value)
602 }
603
604 pub(super) fn from_u32_with_node_count(
618 id: u32,
619 node_count: usize,
620 ) -> Result<Self, DeserializationError> {
621 if (id as usize) < node_count {
622 Ok(Self(id))
623 } else {
624 Err(DeserializationError::InvalidValue(format!(
625 "Invalid deserialized MAST node ID '{id}', but {node_count} is the number of nodes in the forest",
626 )))
627 }
628 }
629
630 pub fn as_usize(&self) -> usize {
631 self.0 as usize
632 }
633
634 pub fn as_u32(&self) -> u32 {
635 self.0
636 }
637
638 pub fn remap(&self, remapping: &Remapping) -> Self {
640 *remapping.get(self).unwrap_or(self)
641 }
642}
643
644impl From<MastNodeId> for usize {
645 fn from(value: MastNodeId) -> Self {
646 value.0 as usize
647 }
648}
649
650impl From<MastNodeId> for u32 {
651 fn from(value: MastNodeId) -> Self {
652 value.0
653 }
654}
655
656impl From<&MastNodeId> for u32 {
657 fn from(value: &MastNodeId) -> Self {
658 value.0
659 }
660}
661
662impl fmt::Display for MastNodeId {
663 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
664 write!(f, "MastNodeId({})", self.0)
665 }
666}
667
668pub struct SubtreeIterator<'a> {
673 forest: &'a MastForest,
674 discovered: Vec<MastNodeId>,
675 unvisited: Vec<MastNodeId>,
676}
677impl<'a> SubtreeIterator<'a> {
678 pub fn new(root: &MastNodeId, forest: &'a MastForest) -> Self {
679 let discovered = vec![];
680 let unvisited = vec![*root];
681 SubtreeIterator { forest, discovered, unvisited }
682 }
683}
684impl Iterator for SubtreeIterator<'_> {
685 type Item = MastNodeId;
686 fn next(&mut self) -> Option<MastNodeId> {
687 while let Some(id) = self.unvisited.pop() {
688 let node = &self.forest[id];
689 if !node.has_children() {
690 return Some(id);
691 } else {
692 self.discovered.push(id);
693 node.append_children_to(&mut self.unvisited);
694 }
695 }
696 self.discovered.pop()
697 }
698}
699
700#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
706pub struct DecoratorId(u32);
707
708impl DecoratorId {
709 pub fn from_u32_safe(
714 value: u32,
715 mast_forest: &MastForest,
716 ) -> Result<Self, DeserializationError> {
717 if (value as usize) < mast_forest.decorators.len() {
718 Ok(Self(value))
719 } else {
720 Err(DeserializationError::InvalidValue(format!(
721 "Invalid deserialized MAST decorator id '{}', but only {} decorators in the forest",
722 value,
723 mast_forest.nodes.len(),
724 )))
725 }
726 }
727
728 pub(crate) fn new_unchecked(value: u32) -> Self {
730 Self(value)
731 }
732
733 pub fn as_usize(&self) -> usize {
734 self.0 as usize
735 }
736
737 pub fn as_u32(&self) -> u32 {
738 self.0
739 }
740}
741
742impl From<DecoratorId> for usize {
743 fn from(value: DecoratorId) -> Self {
744 value.0 as usize
745 }
746}
747
748impl From<DecoratorId> for u32 {
749 fn from(value: DecoratorId) -> Self {
750 value.0
751 }
752}
753
754impl From<&DecoratorId> for u32 {
755 fn from(value: &DecoratorId) -> Self {
756 value.0
757 }
758}
759
760impl fmt::Display for DecoratorId {
761 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
762 write!(f, "DecoratorId({})", self.0)
763 }
764}
765
766impl Serializable for DecoratorId {
767 fn write_into<W: ByteWriter>(&self, target: &mut W) {
768 self.0.write_into(target)
769 }
770}
771
772pub fn error_code_from_msg(msg: impl AsRef<str>) -> Felt {
775 let digest: Blake3Digest<32> = Blake3_256::hash(msg.as_ref().as_bytes());
776 let mut digest_bytes: [u8; 8] = [0; 8];
777 digest_bytes.copy_from_slice(&digest.as_bytes()[0..8]);
778 let code = u64::from_le_bytes(digest_bytes);
779 Felt::new(code)
780}
781
782#[derive(Debug, thiserror::Error, PartialEq)]
787pub enum MastForestError {
788 #[error("MAST forest decorator count exceeds the maximum of {} decorators", u32::MAX)]
789 TooManyDecorators,
790 #[error("MAST forest node count exceeds the maximum of {} nodes", MastForest::MAX_NODES)]
791 TooManyNodes,
792 #[error("node id {0} is greater than or equal to forest length {1}")]
793 NodeIdOverflow(MastNodeId, usize),
794 #[error("decorator id {0} is greater than or equal to decorator count {1}")]
795 DecoratorIdOverflow(DecoratorId, usize),
796 #[error("basic block cannot be created from an empty list of operations")]
797 EmptyBasicBlock,
798 #[error(
799 "decorator root of child with node id {0} is missing but is required for fingerprint computation"
800 )]
801 ChildFingerprintMissing(MastNodeId),
802 #[error("advice map key {0} already exists when merging forests")]
803 AdviceMapKeyCollisionOnMerge(Word),
804}