1use alloc::{
2 collections::{BTreeMap, BTreeSet},
3 vec::Vec,
4};
5use core::{
6 fmt, mem,
7 ops::{Index, IndexMut},
8};
9
10use miden_crypto::hash::rpo::RpoDigest;
11
12mod node;
13pub use node::{
14 BasicBlockNode, CallNode, DynNode, ExternalNode, JoinNode, LoopNode, MastNode, OP_BATCH_SIZE,
15 OP_GROUP_SIZE, OpBatch, OperationOrDecorator, SplitNode,
16};
17use winter_utils::{ByteWriter, DeserializationError, Serializable};
18
19use crate::{AdviceMap, Decorator, DecoratorList, Operation};
20
21mod serialization;
22
23mod merger;
24pub(crate) use merger::MastForestMerger;
25pub use merger::MastForestRootMap;
26
27mod multi_forest_node_iterator;
28pub(crate) use multi_forest_node_iterator::*;
29
30mod node_fingerprint;
31pub use node_fingerprint::{DecoratorFingerprint, MastNodeFingerprint};
32
33#[cfg(test)]
34mod tests;
35
36#[derive(Clone, Debug, Default, PartialEq, Eq)]
44pub struct MastForest {
45 nodes: Vec<MastNode>,
47
48 roots: Vec<MastNodeId>,
50
51 decorators: Vec<Decorator>,
53
54 advice_map: AdviceMap,
56}
57
58impl MastForest {
61 pub fn new() -> Self {
63 Self::default()
64 }
65}
66
67impl MastForest {
70 const MAX_NODES: usize = (1 << 30) - 1;
72
73 pub fn add_decorator(&mut self, decorator: Decorator) -> Result<DecoratorId, MastForestError> {
75 if self.decorators.len() >= u32::MAX as usize {
76 return Err(MastForestError::TooManyDecorators);
77 }
78
79 let new_decorator_id = DecoratorId(self.decorators.len() as u32);
80 self.decorators.push(decorator);
81
82 Ok(new_decorator_id)
83 }
84
85 pub fn add_node(&mut self, node: MastNode) -> Result<MastNodeId, MastForestError> {
89 if self.nodes.len() == Self::MAX_NODES {
90 return Err(MastForestError::TooManyNodes);
91 }
92
93 let new_node_id = MastNodeId(self.nodes.len() as u32);
94 self.nodes.push(node);
95
96 Ok(new_node_id)
97 }
98
99 pub fn add_block(
101 &mut self,
102 operations: Vec<Operation>,
103 decorators: Option<DecoratorList>,
104 ) -> Result<MastNodeId, MastForestError> {
105 let block = MastNode::new_basic_block(operations, decorators)?;
106 self.add_node(block)
107 }
108
109 pub fn add_join(
111 &mut self,
112 left_child: MastNodeId,
113 right_child: MastNodeId,
114 ) -> Result<MastNodeId, MastForestError> {
115 let join = MastNode::new_join(left_child, right_child, self)?;
116 self.add_node(join)
117 }
118
119 pub fn add_split(
121 &mut self,
122 if_branch: MastNodeId,
123 else_branch: MastNodeId,
124 ) -> Result<MastNodeId, MastForestError> {
125 let split = MastNode::new_split(if_branch, else_branch, self)?;
126 self.add_node(split)
127 }
128
129 pub fn add_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, MastForestError> {
131 let loop_node = MastNode::new_loop(body, self)?;
132 self.add_node(loop_node)
133 }
134
135 pub fn add_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
137 let call = MastNode::new_call(callee, self)?;
138 self.add_node(call)
139 }
140
141 pub fn add_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
143 let syscall = MastNode::new_syscall(callee, self)?;
144 self.add_node(syscall)
145 }
146
147 pub fn add_dyn(&mut self) -> Result<MastNodeId, MastForestError> {
149 self.add_node(MastNode::new_dyn())
150 }
151
152 pub fn add_dyncall(&mut self) -> Result<MastNodeId, MastForestError> {
154 self.add_node(MastNode::new_dyncall())
155 }
156
157 pub fn add_external(&mut self, mast_root: RpoDigest) -> Result<MastNodeId, MastForestError> {
159 self.add_node(MastNode::new_external(mast_root))
160 }
161
162 pub fn make_root(&mut self, new_root_id: MastNodeId) {
170 assert!((new_root_id.0 as usize) < self.nodes.len());
171
172 if !self.roots.contains(&new_root_id) {
173 self.roots.push(new_root_id);
174 }
175 }
176
177 pub fn remove_nodes(
185 &mut self,
186 nodes_to_remove: &BTreeSet<MastNodeId>,
187 ) -> BTreeMap<MastNodeId, MastNodeId> {
188 if nodes_to_remove.is_empty() {
189 return BTreeMap::new();
190 }
191
192 let old_nodes = mem::take(&mut self.nodes);
193 let old_root_ids = mem::take(&mut self.roots);
194 let (retained_nodes, id_remappings) = remove_nodes(old_nodes, nodes_to_remove);
195
196 self.remap_and_add_nodes(retained_nodes, &id_remappings);
197 self.remap_and_add_roots(old_root_ids, &id_remappings);
198 id_remappings
199 }
200
201 pub fn set_before_enter(&mut self, node_id: MastNodeId, decorator_ids: Vec<DecoratorId>) {
202 self[node_id].set_before_enter(decorator_ids)
203 }
204
205 pub fn set_after_exit(&mut self, node_id: MastNodeId, decorator_ids: Vec<DecoratorId>) {
206 self[node_id].set_after_exit(decorator_ids)
207 }
208
209 pub fn merge<'forest>(
259 forests: impl IntoIterator<Item = &'forest MastForest>,
260 ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
261 MastForestMerger::merge(forests)
262 }
263
264 #[cfg(test)]
269 pub fn add_block_with_raw_decorators(
270 &mut self,
271 operations: Vec<Operation>,
272 decorators: Vec<(usize, Decorator)>,
273 ) -> Result<MastNodeId, MastForestError> {
274 let block = MastNode::new_basic_block_with_raw_decorators(operations, decorators, self)?;
275 self.add_node(block)
276 }
277}
278
279impl MastForest {
281 fn remap_and_add_nodes(
287 &mut self,
288 nodes_to_add: Vec<MastNode>,
289 id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
290 ) {
291 assert!(self.nodes.is_empty());
292
293 for live_node in nodes_to_add {
296 match &live_node {
297 MastNode::Join(join_node) => {
298 let first_child =
299 id_remappings.get(&join_node.first()).copied().unwrap_or(join_node.first());
300 let second_child = id_remappings
301 .get(&join_node.second())
302 .copied()
303 .unwrap_or(join_node.second());
304
305 self.add_join(first_child, second_child).unwrap();
306 },
307 MastNode::Split(split_node) => {
308 let on_true_child = id_remappings
309 .get(&split_node.on_true())
310 .copied()
311 .unwrap_or(split_node.on_true());
312 let on_false_child = id_remappings
313 .get(&split_node.on_false())
314 .copied()
315 .unwrap_or(split_node.on_false());
316
317 self.add_split(on_true_child, on_false_child).unwrap();
318 },
319 MastNode::Loop(loop_node) => {
320 let body_id =
321 id_remappings.get(&loop_node.body()).copied().unwrap_or(loop_node.body());
322
323 self.add_loop(body_id).unwrap();
324 },
325 MastNode::Call(call_node) => {
326 let callee_id = id_remappings
327 .get(&call_node.callee())
328 .copied()
329 .unwrap_or(call_node.callee());
330
331 if call_node.is_syscall() {
332 self.add_syscall(callee_id).unwrap();
333 } else {
334 self.add_call(callee_id).unwrap();
335 }
336 },
337 MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => {
338 self.add_node(live_node).unwrap();
339 },
340 }
341 }
342 }
343
344 fn remap_and_add_roots(
349 &mut self,
350 old_root_ids: Vec<MastNodeId>,
351 id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
352 ) {
353 assert!(self.roots.is_empty());
354
355 for old_root_id in old_root_ids {
356 let new_root_id = id_remappings.get(&old_root_id).copied().unwrap_or(old_root_id);
357 self.make_root(new_root_id);
358 }
359 }
360}
361
362fn remove_nodes(
365 mast_nodes: Vec<MastNode>,
366 nodes_to_remove: &BTreeSet<MastNodeId>,
367) -> (Vec<MastNode>, BTreeMap<MastNodeId, MastNodeId>) {
368 assert!(mast_nodes.len() < u32::MAX as usize);
370
371 let mut retained_nodes = Vec::with_capacity(mast_nodes.len());
372 let mut id_remappings = BTreeMap::new();
373
374 for (old_node_index, old_node) in mast_nodes.into_iter().enumerate() {
375 let old_node_id: MastNodeId = MastNodeId(old_node_index as u32);
376
377 if !nodes_to_remove.contains(&old_node_id) {
378 let new_node_id: MastNodeId = MastNodeId(retained_nodes.len() as u32);
379 id_remappings.insert(old_node_id, new_node_id);
380
381 retained_nodes.push(old_node);
382 }
383 }
384
385 (retained_nodes, id_remappings)
386}
387
388impl MastForest {
392 #[inline(always)]
397 pub fn get_decorator_by_id(&self, decorator_id: DecoratorId) -> Option<&Decorator> {
398 let idx = decorator_id.0 as usize;
399
400 self.decorators.get(idx)
401 }
402
403 #[inline(always)]
408 pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
409 let idx = node_id.0 as usize;
410
411 self.nodes.get(idx)
412 }
413
414 #[inline(always)]
416 pub fn find_procedure_root(&self, digest: RpoDigest) -> Option<MastNodeId> {
417 self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied()
418 }
419
420 pub fn is_procedure_root(&self, node_id: MastNodeId) -> bool {
422 self.roots.contains(&node_id)
423 }
424
425 pub fn procedure_digests(&self) -> impl Iterator<Item = RpoDigest> + '_ {
427 self.roots.iter().map(|&root_id| self[root_id].digest())
428 }
429
430 pub fn local_procedure_digests(&self) -> impl Iterator<Item = RpoDigest> + '_ {
434 self.roots.iter().filter_map(|&root_id| {
435 let node = &self[root_id];
436 if node.is_external() { None } else { Some(node.digest()) }
437 })
438 }
439
440 pub fn procedure_roots(&self) -> &[MastNodeId] {
442 &self.roots
443 }
444
445 pub fn num_procedures(&self) -> u32 {
447 self.roots
448 .len()
449 .try_into()
450 .expect("MAST forest contains more than 2^32 procedures.")
451 }
452
453 pub fn num_nodes(&self) -> u32 {
455 self.nodes.len() as u32
456 }
457
458 pub fn nodes(&self) -> &[MastNode] {
460 &self.nodes
461 }
462
463 pub fn decorators(&self) -> &[Decorator] {
464 &self.decorators
465 }
466
467 pub fn advice_map(&self) -> &AdviceMap {
468 &self.advice_map
469 }
470
471 pub fn advice_map_mut(&mut self) -> &mut AdviceMap {
472 &mut self.advice_map
473 }
474}
475
476impl Index<MastNodeId> for MastForest {
477 type Output = MastNode;
478
479 #[inline(always)]
480 fn index(&self, node_id: MastNodeId) -> &Self::Output {
481 let idx = node_id.0 as usize;
482
483 &self.nodes[idx]
484 }
485}
486
487impl IndexMut<MastNodeId> for MastForest {
488 #[inline(always)]
489 fn index_mut(&mut self, node_id: MastNodeId) -> &mut Self::Output {
490 let idx = node_id.0 as usize;
491
492 &mut self.nodes[idx]
493 }
494}
495
496impl Index<DecoratorId> for MastForest {
497 type Output = Decorator;
498
499 #[inline(always)]
500 fn index(&self, decorator_id: DecoratorId) -> &Self::Output {
501 let idx = decorator_id.0 as usize;
502
503 &self.decorators[idx]
504 }
505}
506
507impl IndexMut<DecoratorId> for MastForest {
508 #[inline(always)]
509 fn index_mut(&mut self, decorator_id: DecoratorId) -> &mut Self::Output {
510 let idx = decorator_id.0 as usize;
511 &mut self.decorators[idx]
512 }
513}
514
515#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
525pub struct MastNodeId(u32);
526
527pub type Remapping = BTreeMap<MastNodeId, MastNodeId>;
529
530impl MastNodeId {
531 pub fn from_u32_safe(
536 value: u32,
537 mast_forest: &MastForest,
538 ) -> Result<Self, DeserializationError> {
539 Self::from_u32_with_node_count(value, mast_forest.nodes.len())
540 }
541
542 pub fn from_usize_safe(
546 node_id: usize,
547 mast_forest: &MastForest,
548 ) -> Result<Self, DeserializationError> {
549 let node_id: u32 = node_id.try_into().map_err(|_| {
550 DeserializationError::InvalidValue(format!(
551 "node id '{node_id}' does not fit into a u32"
552 ))
553 })?;
554 MastNodeId::from_u32_safe(node_id, mast_forest)
555 }
556
557 pub(crate) fn new_unchecked(value: u32) -> Self {
559 Self(value)
560 }
561
562 pub(super) fn from_u32_with_node_count(
576 id: u32,
577 node_count: usize,
578 ) -> Result<Self, DeserializationError> {
579 if (id as usize) < node_count {
580 Ok(Self(id))
581 } else {
582 Err(DeserializationError::InvalidValue(format!(
583 "Invalid deserialized MAST node ID '{}', but {} is the number of nodes in the forest",
584 id, node_count,
585 )))
586 }
587 }
588
589 pub fn as_usize(&self) -> usize {
590 self.0 as usize
591 }
592
593 pub fn as_u32(&self) -> u32 {
594 self.0
595 }
596
597 pub fn remap(&self, remapping: &Remapping) -> Self {
599 *remapping.get(self).unwrap_or(self)
600 }
601}
602
603impl From<MastNodeId> for usize {
604 fn from(value: MastNodeId) -> Self {
605 value.0 as usize
606 }
607}
608
609impl From<MastNodeId> for u32 {
610 fn from(value: MastNodeId) -> Self {
611 value.0
612 }
613}
614
615impl From<&MastNodeId> for u32 {
616 fn from(value: &MastNodeId) -> Self {
617 value.0
618 }
619}
620
621impl fmt::Display for MastNodeId {
622 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
623 write!(f, "MastNodeId({})", self.0)
624 }
625}
626
627pub struct SubtreeIterator<'a> {
632 forest: &'a MastForest,
633 discovered: Vec<MastNodeId>,
634 unvisited: Vec<MastNodeId>,
635}
636impl<'a> SubtreeIterator<'a> {
637 pub fn new(root: &MastNodeId, forest: &'a MastForest) -> Self {
638 let discovered = vec![];
639 let unvisited = vec![*root];
640 SubtreeIterator { forest, discovered, unvisited }
641 }
642}
643impl Iterator for SubtreeIterator<'_> {
644 type Item = MastNodeId;
645 fn next(&mut self) -> Option<MastNodeId> {
646 while let Some(id) = self.unvisited.pop() {
647 let node = &self.forest[id];
648 if !node.has_children() {
649 return Some(id);
650 } else {
651 self.discovered.push(id);
652 node.append_children_to(&mut self.unvisited);
653 }
654 }
655 self.discovered.pop()
656 }
657}
658
659#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
665pub struct DecoratorId(u32);
666
667impl DecoratorId {
668 pub fn from_u32_safe(
673 value: u32,
674 mast_forest: &MastForest,
675 ) -> Result<Self, DeserializationError> {
676 if (value as usize) < mast_forest.decorators.len() {
677 Ok(Self(value))
678 } else {
679 Err(DeserializationError::InvalidValue(format!(
680 "Invalid deserialized MAST decorator id '{}', but only {} decorators in the forest",
681 value,
682 mast_forest.nodes.len(),
683 )))
684 }
685 }
686
687 pub(crate) fn new_unchecked(value: u32) -> Self {
689 Self(value)
690 }
691
692 pub fn as_usize(&self) -> usize {
693 self.0 as usize
694 }
695
696 pub fn as_u32(&self) -> u32 {
697 self.0
698 }
699}
700
701impl From<DecoratorId> for usize {
702 fn from(value: DecoratorId) -> Self {
703 value.0 as usize
704 }
705}
706
707impl From<DecoratorId> for u32 {
708 fn from(value: DecoratorId) -> Self {
709 value.0
710 }
711}
712
713impl From<&DecoratorId> for u32 {
714 fn from(value: &DecoratorId) -> Self {
715 value.0
716 }
717}
718
719impl fmt::Display for DecoratorId {
720 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
721 write!(f, "DecoratorId({})", self.0)
722 }
723}
724
725impl Serializable for DecoratorId {
726 fn write_into<W: ByteWriter>(&self, target: &mut W) {
727 self.0.write_into(target)
728 }
729}
730
731#[derive(Debug, thiserror::Error, PartialEq)]
736pub enum MastForestError {
737 #[error("MAST forest decorator count exceeds the maximum of {} decorators", u32::MAX)]
738 TooManyDecorators,
739 #[error("MAST forest node count exceeds the maximum of {} nodes", MastForest::MAX_NODES)]
740 TooManyNodes,
741 #[error("node id {0} is greater than or equal to forest length {1}")]
742 NodeIdOverflow(MastNodeId, usize),
743 #[error("decorator id {0} is greater than or equal to decorator count {1}")]
744 DecoratorIdOverflow(DecoratorId, usize),
745 #[error("basic block cannot be created from an empty list of operations")]
746 EmptyBasicBlock,
747 #[error(
748 "decorator root of child with node id {0} is missing but is required for fingerprint computation"
749 )]
750 ChildFingerprintMissing(MastNodeId),
751 #[error("advice map key {0} already exists when merging forests")]
752 AdviceMapKeyCollisionOnMerge(RpoDigest),
753}