1use alloc::{
2 collections::{BTreeMap, BTreeSet},
3 sync::Arc,
4 vec::Vec,
5};
6use core::{
7 fmt, mem,
8 ops::{Index, IndexMut},
9};
10
11use miden_crypto::hash::rpo::RpoDigest;
12
13use crate::crypto::hash::{Blake3_256, Blake3Digest, Digest};
14
15mod node;
16pub use node::{
17 BasicBlockNode, CallNode, DynNode, ExternalNode, JoinNode, LoopNode, MastNode, MastNodeExt,
18 OP_BATCH_SIZE, OP_GROUP_SIZE, OpBatch, OperationOrDecorator, SplitNode,
19};
20use winter_utils::{ByteWriter, DeserializationError, Serializable};
21
22use crate::{AdviceMap, Decorator, DecoratorList, Felt, Operation};
23
24mod serialization;
25
26mod merger;
27pub(crate) use merger::MastForestMerger;
28pub use merger::MastForestRootMap;
29
30mod multi_forest_node_iterator;
31pub(crate) use multi_forest_node_iterator::*;
32
33mod node_fingerprint;
34pub use node_fingerprint::{DecoratorFingerprint, MastNodeFingerprint};
35
36#[cfg(test)]
37mod tests;
38
39#[derive(Clone, Debug, Default, PartialEq, Eq)]
47pub struct MastForest {
48 nodes: Vec<MastNode>,
50
51 roots: Vec<MastNodeId>,
53
54 decorators: Vec<Decorator>,
56
57 advice_map: AdviceMap,
59
60 error_codes: BTreeMap<u64, Arc<str>>,
64}
65
66impl MastForest {
69 pub fn new() -> Self {
71 Self::default()
72 }
73}
74
75impl MastForest {
78 const MAX_NODES: usize = (1 << 30) - 1;
80
81 pub fn add_decorator(&mut self, decorator: Decorator) -> Result<DecoratorId, MastForestError> {
83 if self.decorators.len() >= u32::MAX as usize {
84 return Err(MastForestError::TooManyDecorators);
85 }
86
87 let new_decorator_id = DecoratorId(self.decorators.len() as u32);
88 self.decorators.push(decorator);
89
90 Ok(new_decorator_id)
91 }
92
93 pub fn add_node(&mut self, node: MastNode) -> Result<MastNodeId, MastForestError> {
97 if self.nodes.len() == Self::MAX_NODES {
98 return Err(MastForestError::TooManyNodes);
99 }
100
101 let new_node_id = MastNodeId(self.nodes.len() as u32);
102 self.nodes.push(node);
103
104 Ok(new_node_id)
105 }
106
107 pub fn add_block(
109 &mut self,
110 operations: Vec<Operation>,
111 decorators: Option<DecoratorList>,
112 ) -> Result<MastNodeId, MastForestError> {
113 let block = MastNode::new_basic_block(operations, decorators)?;
114 self.add_node(block)
115 }
116
117 pub fn add_join(
119 &mut self,
120 left_child: MastNodeId,
121 right_child: MastNodeId,
122 ) -> Result<MastNodeId, MastForestError> {
123 let join = MastNode::new_join(left_child, right_child, self)?;
124 self.add_node(join)
125 }
126
127 pub fn add_split(
129 &mut self,
130 if_branch: MastNodeId,
131 else_branch: MastNodeId,
132 ) -> Result<MastNodeId, MastForestError> {
133 let split = MastNode::new_split(if_branch, else_branch, self)?;
134 self.add_node(split)
135 }
136
137 pub fn add_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, MastForestError> {
139 let loop_node = MastNode::new_loop(body, self)?;
140 self.add_node(loop_node)
141 }
142
143 pub fn add_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
145 let call = MastNode::new_call(callee, self)?;
146 self.add_node(call)
147 }
148
149 pub fn add_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
151 let syscall = MastNode::new_syscall(callee, self)?;
152 self.add_node(syscall)
153 }
154
155 pub fn add_dyn(&mut self) -> Result<MastNodeId, MastForestError> {
157 self.add_node(MastNode::new_dyn())
158 }
159
160 pub fn add_dyncall(&mut self) -> Result<MastNodeId, MastForestError> {
162 self.add_node(MastNode::new_dyncall())
163 }
164
165 pub fn add_external(&mut self, mast_root: RpoDigest) -> Result<MastNodeId, MastForestError> {
167 self.add_node(MastNode::new_external(mast_root))
168 }
169
170 pub fn make_root(&mut self, new_root_id: MastNodeId) {
178 assert!((new_root_id.0 as usize) < self.nodes.len());
179
180 if !self.roots.contains(&new_root_id) {
181 self.roots.push(new_root_id);
182 }
183 }
184
185 pub fn remove_nodes(
193 &mut self,
194 nodes_to_remove: &BTreeSet<MastNodeId>,
195 ) -> BTreeMap<MastNodeId, MastNodeId> {
196 if nodes_to_remove.is_empty() {
197 return BTreeMap::new();
198 }
199
200 let old_nodes = mem::take(&mut self.nodes);
201 let old_root_ids = mem::take(&mut self.roots);
202 let (retained_nodes, id_remappings) = remove_nodes(old_nodes, nodes_to_remove);
203
204 self.remap_and_add_nodes(retained_nodes, &id_remappings);
205 self.remap_and_add_roots(old_root_ids, &id_remappings);
206 id_remappings
207 }
208
209 pub fn append_before_enter(&mut self, node_id: MastNodeId, decorator_ids: &[DecoratorId]) {
210 self[node_id].append_before_enter(decorator_ids)
211 }
212
213 pub fn append_after_exit(&mut self, node_id: MastNodeId, decorator_ids: &[DecoratorId]) {
214 self[node_id].append_after_exit(decorator_ids)
215 }
216
217 pub fn merge<'forest>(
267 forests: impl IntoIterator<Item = &'forest MastForest>,
268 ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
269 MastForestMerger::merge(forests)
270 }
271
272 #[cfg(test)]
277 pub fn add_block_with_raw_decorators(
278 &mut self,
279 operations: Vec<Operation>,
280 decorators: Vec<(usize, Decorator)>,
281 ) -> Result<MastNodeId, MastForestError> {
282 let block = MastNode::new_basic_block_with_raw_decorators(operations, decorators, self)?;
283 self.add_node(block)
284 }
285}
286
287impl MastForest {
289 fn remap_and_add_nodes(
295 &mut self,
296 nodes_to_add: Vec<MastNode>,
297 id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
298 ) {
299 assert!(self.nodes.is_empty());
300
301 for live_node in nodes_to_add {
304 match &live_node {
305 MastNode::Join(join_node) => {
306 let first_child =
307 id_remappings.get(&join_node.first()).copied().unwrap_or(join_node.first());
308 let second_child = id_remappings
309 .get(&join_node.second())
310 .copied()
311 .unwrap_or(join_node.second());
312
313 self.add_join(first_child, second_child).unwrap();
314 },
315 MastNode::Split(split_node) => {
316 let on_true_child = id_remappings
317 .get(&split_node.on_true())
318 .copied()
319 .unwrap_or(split_node.on_true());
320 let on_false_child = id_remappings
321 .get(&split_node.on_false())
322 .copied()
323 .unwrap_or(split_node.on_false());
324
325 self.add_split(on_true_child, on_false_child).unwrap();
326 },
327 MastNode::Loop(loop_node) => {
328 let body_id =
329 id_remappings.get(&loop_node.body()).copied().unwrap_or(loop_node.body());
330
331 self.add_loop(body_id).unwrap();
332 },
333 MastNode::Call(call_node) => {
334 let callee_id = id_remappings
335 .get(&call_node.callee())
336 .copied()
337 .unwrap_or(call_node.callee());
338
339 if call_node.is_syscall() {
340 self.add_syscall(callee_id).unwrap();
341 } else {
342 self.add_call(callee_id).unwrap();
343 }
344 },
345 MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => {
346 self.add_node(live_node).unwrap();
347 },
348 }
349 }
350 }
351
352 fn remap_and_add_roots(
357 &mut self,
358 old_root_ids: Vec<MastNodeId>,
359 id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
360 ) {
361 assert!(self.roots.is_empty());
362
363 for old_root_id in old_root_ids {
364 let new_root_id = id_remappings.get(&old_root_id).copied().unwrap_or(old_root_id);
365 self.make_root(new_root_id);
366 }
367 }
368}
369
370fn remove_nodes(
373 mast_nodes: Vec<MastNode>,
374 nodes_to_remove: &BTreeSet<MastNodeId>,
375) -> (Vec<MastNode>, BTreeMap<MastNodeId, MastNodeId>) {
376 assert!(mast_nodes.len() < u32::MAX as usize);
378
379 let mut retained_nodes = Vec::with_capacity(mast_nodes.len());
380 let mut id_remappings = BTreeMap::new();
381
382 for (old_node_index, old_node) in mast_nodes.into_iter().enumerate() {
383 let old_node_id: MastNodeId = MastNodeId(old_node_index as u32);
384
385 if !nodes_to_remove.contains(&old_node_id) {
386 let new_node_id: MastNodeId = MastNodeId(retained_nodes.len() as u32);
387 id_remappings.insert(old_node_id, new_node_id);
388
389 retained_nodes.push(old_node);
390 }
391 }
392
393 (retained_nodes, id_remappings)
394}
395
396impl MastForest {
400 #[inline(always)]
405 pub fn get_decorator_by_id(&self, decorator_id: DecoratorId) -> Option<&Decorator> {
406 let idx = decorator_id.0 as usize;
407
408 self.decorators.get(idx)
409 }
410
411 #[inline(always)]
416 pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
417 let idx = node_id.0 as usize;
418
419 self.nodes.get(idx)
420 }
421
422 #[inline(always)]
424 pub fn find_procedure_root(&self, digest: RpoDigest) -> Option<MastNodeId> {
425 self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied()
426 }
427
428 pub fn is_procedure_root(&self, node_id: MastNodeId) -> bool {
430 self.roots.contains(&node_id)
431 }
432
433 pub fn procedure_digests(&self) -> impl Iterator<Item = RpoDigest> + '_ {
435 self.roots.iter().map(|&root_id| self[root_id].digest())
436 }
437
438 pub fn local_procedure_digests(&self) -> impl Iterator<Item = RpoDigest> + '_ {
442 self.roots.iter().filter_map(|&root_id| {
443 let node = &self[root_id];
444 if node.is_external() { None } else { Some(node.digest()) }
445 })
446 }
447
448 pub fn procedure_roots(&self) -> &[MastNodeId] {
450 &self.roots
451 }
452
453 pub fn num_procedures(&self) -> u32 {
455 self.roots
456 .len()
457 .try_into()
458 .expect("MAST forest contains more than 2^32 procedures.")
459 }
460
461 pub fn num_nodes(&self) -> u32 {
463 self.nodes.len() as u32
464 }
465
466 pub fn nodes(&self) -> &[MastNode] {
468 &self.nodes
469 }
470
471 pub fn decorators(&self) -> &[Decorator] {
472 &self.decorators
473 }
474
475 pub fn advice_map(&self) -> &AdviceMap {
476 &self.advice_map
477 }
478
479 pub fn advice_map_mut(&mut self) -> &mut AdviceMap {
480 &mut self.advice_map
481 }
482
483 pub fn register_error(&mut self, msg: Arc<str>) -> Felt {
486 let code: Felt = error_code_from_msg(&msg);
487 let code_key = u64::from(code);
488 self.error_codes.insert(code_key, msg);
490 code
491 }
492
493 pub fn resolve_error_message(&self, code: Felt) -> Option<Arc<str>> {
496 let key = u64::from(code);
497 self.error_codes.get(&key).cloned()
498 }
499}
500
501impl Index<MastNodeId> for MastForest {
502 type Output = MastNode;
503
504 #[inline(always)]
505 fn index(&self, node_id: MastNodeId) -> &Self::Output {
506 let idx = node_id.0 as usize;
507
508 &self.nodes[idx]
509 }
510}
511
512impl IndexMut<MastNodeId> for MastForest {
513 #[inline(always)]
514 fn index_mut(&mut self, node_id: MastNodeId) -> &mut Self::Output {
515 let idx = node_id.0 as usize;
516
517 &mut self.nodes[idx]
518 }
519}
520
521impl Index<DecoratorId> for MastForest {
522 type Output = Decorator;
523
524 #[inline(always)]
525 fn index(&self, decorator_id: DecoratorId) -> &Self::Output {
526 let idx = decorator_id.0 as usize;
527
528 &self.decorators[idx]
529 }
530}
531
532impl IndexMut<DecoratorId> for MastForest {
533 #[inline(always)]
534 fn index_mut(&mut self, decorator_id: DecoratorId) -> &mut Self::Output {
535 let idx = decorator_id.0 as usize;
536 &mut self.decorators[idx]
537 }
538}
539
540#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
550pub struct MastNodeId(u32);
551
552pub type Remapping = BTreeMap<MastNodeId, MastNodeId>;
554
555impl MastNodeId {
556 pub fn from_u32_safe(
561 value: u32,
562 mast_forest: &MastForest,
563 ) -> Result<Self, DeserializationError> {
564 Self::from_u32_with_node_count(value, mast_forest.nodes.len())
565 }
566
567 pub fn from_usize_safe(
571 node_id: usize,
572 mast_forest: &MastForest,
573 ) -> Result<Self, DeserializationError> {
574 let node_id: u32 = node_id.try_into().map_err(|_| {
575 DeserializationError::InvalidValue(format!(
576 "node id '{node_id}' does not fit into a u32"
577 ))
578 })?;
579 MastNodeId::from_u32_safe(node_id, mast_forest)
580 }
581
582 pub(crate) fn new_unchecked(value: u32) -> Self {
584 Self(value)
585 }
586
587 pub(super) fn from_u32_with_node_count(
601 id: u32,
602 node_count: usize,
603 ) -> Result<Self, DeserializationError> {
604 if (id as usize) < node_count {
605 Ok(Self(id))
606 } else {
607 Err(DeserializationError::InvalidValue(format!(
608 "Invalid deserialized MAST node ID '{id}', but {node_count} is the number of nodes in the forest",
609 )))
610 }
611 }
612
613 pub fn as_usize(&self) -> usize {
614 self.0 as usize
615 }
616
617 pub fn as_u32(&self) -> u32 {
618 self.0
619 }
620
621 pub fn remap(&self, remapping: &Remapping) -> Self {
623 *remapping.get(self).unwrap_or(self)
624 }
625}
626
627impl From<MastNodeId> for usize {
628 fn from(value: MastNodeId) -> Self {
629 value.0 as usize
630 }
631}
632
633impl From<MastNodeId> for u32 {
634 fn from(value: MastNodeId) -> Self {
635 value.0
636 }
637}
638
639impl From<&MastNodeId> for u32 {
640 fn from(value: &MastNodeId) -> Self {
641 value.0
642 }
643}
644
645impl fmt::Display for MastNodeId {
646 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
647 write!(f, "MastNodeId({})", self.0)
648 }
649}
650
651pub struct SubtreeIterator<'a> {
656 forest: &'a MastForest,
657 discovered: Vec<MastNodeId>,
658 unvisited: Vec<MastNodeId>,
659}
660impl<'a> SubtreeIterator<'a> {
661 pub fn new(root: &MastNodeId, forest: &'a MastForest) -> Self {
662 let discovered = vec![];
663 let unvisited = vec![*root];
664 SubtreeIterator { forest, discovered, unvisited }
665 }
666}
667impl Iterator for SubtreeIterator<'_> {
668 type Item = MastNodeId;
669 fn next(&mut self) -> Option<MastNodeId> {
670 while let Some(id) = self.unvisited.pop() {
671 let node = &self.forest[id];
672 if !node.has_children() {
673 return Some(id);
674 } else {
675 self.discovered.push(id);
676 node.append_children_to(&mut self.unvisited);
677 }
678 }
679 self.discovered.pop()
680 }
681}
682
683#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
689pub struct DecoratorId(u32);
690
691impl DecoratorId {
692 pub fn from_u32_safe(
697 value: u32,
698 mast_forest: &MastForest,
699 ) -> Result<Self, DeserializationError> {
700 if (value as usize) < mast_forest.decorators.len() {
701 Ok(Self(value))
702 } else {
703 Err(DeserializationError::InvalidValue(format!(
704 "Invalid deserialized MAST decorator id '{}', but only {} decorators in the forest",
705 value,
706 mast_forest.nodes.len(),
707 )))
708 }
709 }
710
711 pub(crate) fn new_unchecked(value: u32) -> Self {
713 Self(value)
714 }
715
716 pub fn as_usize(&self) -> usize {
717 self.0 as usize
718 }
719
720 pub fn as_u32(&self) -> u32 {
721 self.0
722 }
723}
724
725impl From<DecoratorId> for usize {
726 fn from(value: DecoratorId) -> Self {
727 value.0 as usize
728 }
729}
730
731impl From<DecoratorId> for u32 {
732 fn from(value: DecoratorId) -> Self {
733 value.0
734 }
735}
736
737impl From<&DecoratorId> for u32 {
738 fn from(value: &DecoratorId) -> Self {
739 value.0
740 }
741}
742
743impl fmt::Display for DecoratorId {
744 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
745 write!(f, "DecoratorId({})", self.0)
746 }
747}
748
749impl Serializable for DecoratorId {
750 fn write_into<W: ByteWriter>(&self, target: &mut W) {
751 self.0.write_into(target)
752 }
753}
754
755pub fn error_code_from_msg(msg: impl AsRef<str>) -> Felt {
758 let digest: Blake3Digest<32> = Blake3_256::hash(msg.as_ref().as_bytes());
759 let mut digest_bytes: [u8; 8] = [0; 8];
760 digest_bytes.copy_from_slice(&digest.as_bytes()[0..8]);
761 let code = u64::from_le_bytes(digest_bytes);
762 Felt::new(code)
763}
764
765#[derive(Debug, thiserror::Error, PartialEq)]
770pub enum MastForestError {
771 #[error("MAST forest decorator count exceeds the maximum of {} decorators", u32::MAX)]
772 TooManyDecorators,
773 #[error("MAST forest node count exceeds the maximum of {} nodes", MastForest::MAX_NODES)]
774 TooManyNodes,
775 #[error("node id {0} is greater than or equal to forest length {1}")]
776 NodeIdOverflow(MastNodeId, usize),
777 #[error("decorator id {0} is greater than or equal to decorator count {1}")]
778 DecoratorIdOverflow(DecoratorId, usize),
779 #[error("basic block cannot be created from an empty list of operations")]
780 EmptyBasicBlock,
781 #[error(
782 "decorator root of child with node id {0} is missing but is required for fingerprint computation"
783 )]
784 ChildFingerprintMissing(MastNodeId),
785 #[error("advice map key {0} already exists when merging forests")]
786 AdviceMapKeyCollisionOnMerge(RpoDigest),
787}