1use alloc::{
2 collections::{BTreeMap, BTreeSet},
3 vec::Vec,
4};
5use core::{
6 fmt, mem,
7 ops::{Index, IndexMut},
8};
9
10use miden_crypto::hash::rpo::RpoDigest;
11
12mod node;
13pub use node::{
14 BasicBlockNode, CallNode, DynNode, ExternalNode, JoinNode, LoopNode, MastNode, OpBatch,
15 OperationOrDecorator, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE,
16};
17use winter_utils::{ByteWriter, DeserializationError, Serializable};
18
19use crate::{AdviceMap, Decorator, DecoratorList, Operation};
20
21mod serialization;
22
23mod merger;
24pub(crate) use merger::MastForestMerger;
25pub use merger::MastForestRootMap;
26
27mod multi_forest_node_iterator;
28pub(crate) use multi_forest_node_iterator::*;
29
30mod node_fingerprint;
31pub use node_fingerprint::{DecoratorFingerprint, MastNodeFingerprint};
32
33#[cfg(test)]
34mod tests;
35
36#[derive(Clone, Debug, Default, PartialEq, Eq)]
44pub struct MastForest {
45 nodes: Vec<MastNode>,
47
48 roots: Vec<MastNodeId>,
50
51 decorators: Vec<Decorator>,
53
54 advice_map: AdviceMap,
56}
57
58impl MastForest {
61 pub fn new() -> Self {
63 Self::default()
64 }
65}
66
67impl MastForest {
70 const MAX_NODES: usize = (1 << 30) - 1;
72
73 pub fn add_decorator(&mut self, decorator: Decorator) -> Result<DecoratorId, MastForestError> {
75 if self.decorators.len() >= u32::MAX as usize {
76 return Err(MastForestError::TooManyDecorators);
77 }
78
79 let new_decorator_id = DecoratorId(self.decorators.len() as u32);
80 self.decorators.push(decorator);
81
82 Ok(new_decorator_id)
83 }
84
85 pub fn add_node(&mut self, node: MastNode) -> Result<MastNodeId, MastForestError> {
89 if self.nodes.len() == Self::MAX_NODES {
90 return Err(MastForestError::TooManyNodes);
91 }
92
93 let new_node_id = MastNodeId(self.nodes.len() as u32);
94 self.nodes.push(node);
95
96 Ok(new_node_id)
97 }
98
99 pub fn add_block(
101 &mut self,
102 operations: Vec<Operation>,
103 decorators: Option<DecoratorList>,
104 ) -> Result<MastNodeId, MastForestError> {
105 let block = MastNode::new_basic_block(operations, decorators)?;
106 self.add_node(block)
107 }
108
109 pub fn add_join(
111 &mut self,
112 left_child: MastNodeId,
113 right_child: MastNodeId,
114 ) -> Result<MastNodeId, MastForestError> {
115 let join = MastNode::new_join(left_child, right_child, self)?;
116 self.add_node(join)
117 }
118
119 pub fn add_split(
121 &mut self,
122 if_branch: MastNodeId,
123 else_branch: MastNodeId,
124 ) -> Result<MastNodeId, MastForestError> {
125 let split = MastNode::new_split(if_branch, else_branch, self)?;
126 self.add_node(split)
127 }
128
129 pub fn add_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, MastForestError> {
131 let loop_node = MastNode::new_loop(body, self)?;
132 self.add_node(loop_node)
133 }
134
135 pub fn add_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
137 let call = MastNode::new_call(callee, self)?;
138 self.add_node(call)
139 }
140
141 pub fn add_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
143 let syscall = MastNode::new_syscall(callee, self)?;
144 self.add_node(syscall)
145 }
146
147 pub fn add_dyn(&mut self) -> Result<MastNodeId, MastForestError> {
149 self.add_node(MastNode::new_dyn())
150 }
151
152 pub fn add_dyncall(&mut self) -> Result<MastNodeId, MastForestError> {
154 self.add_node(MastNode::new_dyncall())
155 }
156
157 pub fn add_external(&mut self, mast_root: RpoDigest) -> Result<MastNodeId, MastForestError> {
159 self.add_node(MastNode::new_external(mast_root))
160 }
161
162 pub fn make_root(&mut self, new_root_id: MastNodeId) {
170 assert!((new_root_id.0 as usize) < self.nodes.len());
171
172 if !self.roots.contains(&new_root_id) {
173 self.roots.push(new_root_id);
174 }
175 }
176
177 pub fn remove_nodes(
186 &mut self,
187 nodes_to_remove: &BTreeSet<MastNodeId>,
188 ) -> Option<BTreeMap<MastNodeId, MastNodeId>> {
189 if nodes_to_remove.is_empty() {
190 return None;
191 }
192
193 let old_nodes = mem::take(&mut self.nodes);
194 let old_root_ids = mem::take(&mut self.roots);
195 let (retained_nodes, id_remappings) = remove_nodes(old_nodes, nodes_to_remove);
196
197 self.remap_and_add_nodes(retained_nodes, &id_remappings);
198 self.remap_and_add_roots(old_root_ids, &id_remappings);
199 Some(id_remappings)
200 }
201
202 pub fn set_before_enter(&mut self, node_id: MastNodeId, decorator_ids: Vec<DecoratorId>) {
203 self[node_id].set_before_enter(decorator_ids)
204 }
205
206 pub fn set_after_exit(&mut self, node_id: MastNodeId, decorator_ids: Vec<DecoratorId>) {
207 self[node_id].set_after_exit(decorator_ids)
208 }
209
210 pub fn merge<'forest>(
260 forests: impl IntoIterator<Item = &'forest MastForest>,
261 ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
262 MastForestMerger::merge(forests)
263 }
264
265 #[cfg(test)]
270 pub fn add_block_with_raw_decorators(
271 &mut self,
272 operations: Vec<Operation>,
273 decorators: Vec<(usize, Decorator)>,
274 ) -> Result<MastNodeId, MastForestError> {
275 let block = MastNode::new_basic_block_with_raw_decorators(operations, decorators, self)?;
276 self.add_node(block)
277 }
278}
279
280impl MastForest {
282 fn remap_and_add_nodes(
288 &mut self,
289 nodes_to_add: Vec<MastNode>,
290 id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
291 ) {
292 assert!(self.nodes.is_empty());
293
294 for live_node in nodes_to_add {
297 match &live_node {
298 MastNode::Join(join_node) => {
299 let first_child =
300 id_remappings.get(&join_node.first()).copied().unwrap_or(join_node.first());
301 let second_child = id_remappings
302 .get(&join_node.second())
303 .copied()
304 .unwrap_or(join_node.second());
305
306 self.add_join(first_child, second_child).unwrap();
307 },
308 MastNode::Split(split_node) => {
309 let on_true_child = id_remappings
310 .get(&split_node.on_true())
311 .copied()
312 .unwrap_or(split_node.on_true());
313 let on_false_child = id_remappings
314 .get(&split_node.on_false())
315 .copied()
316 .unwrap_or(split_node.on_false());
317
318 self.add_split(on_true_child, on_false_child).unwrap();
319 },
320 MastNode::Loop(loop_node) => {
321 let body_id =
322 id_remappings.get(&loop_node.body()).copied().unwrap_or(loop_node.body());
323
324 self.add_loop(body_id).unwrap();
325 },
326 MastNode::Call(call_node) => {
327 let callee_id = id_remappings
328 .get(&call_node.callee())
329 .copied()
330 .unwrap_or(call_node.callee());
331
332 if call_node.is_syscall() {
333 self.add_syscall(callee_id).unwrap();
334 } else {
335 self.add_call(callee_id).unwrap();
336 }
337 },
338 MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => {
339 self.add_node(live_node).unwrap();
340 },
341 }
342 }
343 }
344
345 fn remap_and_add_roots(
350 &mut self,
351 old_root_ids: Vec<MastNodeId>,
352 id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
353 ) {
354 assert!(self.roots.is_empty());
355
356 for old_root_id in old_root_ids {
357 let new_root_id = id_remappings.get(&old_root_id).copied().unwrap_or(old_root_id);
358 self.make_root(new_root_id);
359 }
360 }
361}
362
363fn remove_nodes(
366 mast_nodes: Vec<MastNode>,
367 nodes_to_remove: &BTreeSet<MastNodeId>,
368) -> (Vec<MastNode>, BTreeMap<MastNodeId, MastNodeId>) {
369 assert!(mast_nodes.len() < u32::MAX as usize);
371
372 let mut retained_nodes = Vec::with_capacity(mast_nodes.len());
373 let mut id_remappings = BTreeMap::new();
374
375 for (old_node_index, old_node) in mast_nodes.into_iter().enumerate() {
376 let old_node_id: MastNodeId = MastNodeId(old_node_index as u32);
377
378 if !nodes_to_remove.contains(&old_node_id) {
379 let new_node_id: MastNodeId = MastNodeId(retained_nodes.len() as u32);
380 id_remappings.insert(old_node_id, new_node_id);
381
382 retained_nodes.push(old_node);
383 }
384 }
385
386 (retained_nodes, id_remappings)
387}
388
389impl MastForest {
393 #[inline(always)]
398 pub fn get_decorator_by_id(&self, decorator_id: DecoratorId) -> Option<&Decorator> {
399 let idx = decorator_id.0 as usize;
400
401 self.decorators.get(idx)
402 }
403
404 #[inline(always)]
409 pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
410 let idx = node_id.0 as usize;
411
412 self.nodes.get(idx)
413 }
414
415 #[inline(always)]
417 pub fn find_procedure_root(&self, digest: RpoDigest) -> Option<MastNodeId> {
418 self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied()
419 }
420
421 pub fn is_procedure_root(&self, node_id: MastNodeId) -> bool {
423 self.roots.contains(&node_id)
424 }
425
426 pub fn procedure_digests(&self) -> impl Iterator<Item = RpoDigest> + '_ {
428 self.roots.iter().map(|&root_id| self[root_id].digest())
429 }
430
431 pub fn local_procedure_digests(&self) -> impl Iterator<Item = RpoDigest> + '_ {
435 self.roots.iter().filter_map(|&root_id| {
436 let node = &self[root_id];
437 if node.is_external() {
438 None
439 } else {
440 Some(node.digest())
441 }
442 })
443 }
444
445 pub fn procedure_roots(&self) -> &[MastNodeId] {
447 &self.roots
448 }
449
450 pub fn num_procedures(&self) -> u32 {
452 self.roots
453 .len()
454 .try_into()
455 .expect("MAST forest contains more than 2^32 procedures.")
456 }
457
458 pub fn num_nodes(&self) -> u32 {
460 self.nodes.len() as u32
461 }
462
463 pub fn nodes(&self) -> &[MastNode] {
465 &self.nodes
466 }
467
468 pub fn decorators(&self) -> &[Decorator] {
469 &self.decorators
470 }
471
472 pub fn advice_map(&self) -> &AdviceMap {
473 &self.advice_map
474 }
475
476 pub fn advice_map_mut(&mut self) -> &mut AdviceMap {
477 &mut self.advice_map
478 }
479}
480
481impl Index<MastNodeId> for MastForest {
482 type Output = MastNode;
483
484 #[inline(always)]
485 fn index(&self, node_id: MastNodeId) -> &Self::Output {
486 let idx = node_id.0 as usize;
487
488 &self.nodes[idx]
489 }
490}
491
492impl IndexMut<MastNodeId> for MastForest {
493 #[inline(always)]
494 fn index_mut(&mut self, node_id: MastNodeId) -> &mut Self::Output {
495 let idx = node_id.0 as usize;
496
497 &mut self.nodes[idx]
498 }
499}
500
501impl Index<DecoratorId> for MastForest {
502 type Output = Decorator;
503
504 #[inline(always)]
505 fn index(&self, decorator_id: DecoratorId) -> &Self::Output {
506 let idx = decorator_id.0 as usize;
507
508 &self.decorators[idx]
509 }
510}
511
512impl IndexMut<DecoratorId> for MastForest {
513 #[inline(always)]
514 fn index_mut(&mut self, decorator_id: DecoratorId) -> &mut Self::Output {
515 let idx = decorator_id.0 as usize;
516 &mut self.decorators[idx]
517 }
518}
519
520#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
530pub struct MastNodeId(u32);
531
532impl MastNodeId {
533 pub fn from_u32_safe(
538 value: u32,
539 mast_forest: &MastForest,
540 ) -> Result<Self, DeserializationError> {
541 Self::from_u32_with_node_count(value, mast_forest.nodes.len())
542 }
543
544 pub fn from_usize_safe(
548 node_id: usize,
549 mast_forest: &MastForest,
550 ) -> Result<Self, DeserializationError> {
551 let node_id: u32 = node_id.try_into().map_err(|_| {
552 DeserializationError::InvalidValue(format!(
553 "node id '{node_id}' does not fit into a u32"
554 ))
555 })?;
556 MastNodeId::from_u32_safe(node_id, mast_forest)
557 }
558
559 pub(crate) fn new_unchecked(value: u32) -> Self {
561 Self(value)
562 }
563
564 pub(super) fn from_u32_with_node_count(
578 id: u32,
579 node_count: usize,
580 ) -> Result<Self, DeserializationError> {
581 if (id as usize) < node_count {
582 Ok(Self(id))
583 } else {
584 Err(DeserializationError::InvalidValue(format!(
585 "Invalid deserialized MAST node ID '{}', but {} is the number of nodes in the forest",
586 id, node_count,
587 )))
588 }
589 }
590
591 pub fn as_usize(&self) -> usize {
592 self.0 as usize
593 }
594
595 pub fn as_u32(&self) -> u32 {
596 self.0
597 }
598}
599
600impl From<MastNodeId> for usize {
601 fn from(value: MastNodeId) -> Self {
602 value.0 as usize
603 }
604}
605
606impl From<MastNodeId> for u32 {
607 fn from(value: MastNodeId) -> Self {
608 value.0
609 }
610}
611
612impl From<&MastNodeId> for u32 {
613 fn from(value: &MastNodeId) -> Self {
614 value.0
615 }
616}
617
618impl fmt::Display for MastNodeId {
619 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
620 write!(f, "MastNodeId({})", self.0)
621 }
622}
623
624#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
630pub struct DecoratorId(u32);
631
632impl DecoratorId {
633 pub fn from_u32_safe(
638 value: u32,
639 mast_forest: &MastForest,
640 ) -> Result<Self, DeserializationError> {
641 if (value as usize) < mast_forest.decorators.len() {
642 Ok(Self(value))
643 } else {
644 Err(DeserializationError::InvalidValue(format!(
645 "Invalid deserialized MAST decorator id '{}', but only {} decorators in the forest",
646 value,
647 mast_forest.nodes.len(),
648 )))
649 }
650 }
651
652 pub(crate) fn new_unchecked(value: u32) -> Self {
654 Self(value)
655 }
656
657 pub fn as_usize(&self) -> usize {
658 self.0 as usize
659 }
660
661 pub fn as_u32(&self) -> u32 {
662 self.0
663 }
664}
665
666impl From<DecoratorId> for usize {
667 fn from(value: DecoratorId) -> Self {
668 value.0 as usize
669 }
670}
671
672impl From<DecoratorId> for u32 {
673 fn from(value: DecoratorId) -> Self {
674 value.0
675 }
676}
677
678impl From<&DecoratorId> for u32 {
679 fn from(value: &DecoratorId) -> Self {
680 value.0
681 }
682}
683
684impl fmt::Display for DecoratorId {
685 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
686 write!(f, "DecoratorId({})", self.0)
687 }
688}
689
690impl Serializable for DecoratorId {
691 fn write_into<W: ByteWriter>(&self, target: &mut W) {
692 self.0.write_into(target)
693 }
694}
695
696#[derive(Debug, thiserror::Error, PartialEq)]
701pub enum MastForestError {
702 #[error("MAST forest decorator count exceeds the maximum of {} decorators", u32::MAX)]
703 TooManyDecorators,
704 #[error("MAST forest node count exceeds the maximum of {} nodes", MastForest::MAX_NODES)]
705 TooManyNodes,
706 #[error("node id {0} is greater than or equal to forest length {1}")]
707 NodeIdOverflow(MastNodeId, usize),
708 #[error("decorator id {0} is greater than or equal to decorator count {1}")]
709 DecoratorIdOverflow(DecoratorId, usize),
710 #[error("basic block cannot be created from an empty list of operations")]
711 EmptyBasicBlock,
712 #[error("decorator root of child with node id {0} is missing but is required for fingerprint computation")]
713 ChildFingerprintMissing(MastNodeId),
714 #[error("advice map key {0} already exists when merging forests")]
715 AdviceMapKeyCollisionOnMerge(RpoDigest),
716}