1use alloc::{
2    collections::{BTreeMap, BTreeSet},
3    sync::Arc,
4    vec::Vec,
5};
6use core::{
7    fmt, mem,
8    ops::{Index, IndexMut},
9};
10mod node;
11pub use node::{
12    BasicBlockNode, CallNode, DynNode, ExternalNode, JoinNode, LoopNode, MastNode, MastNodeExt,
13    OP_BATCH_SIZE, OP_GROUP_SIZE, OpBatch, OperationOrDecorator, SplitNode,
14};
15
16use crate::{
17    AdviceMap, Decorator, DecoratorList, Felt, LexicographicWord, Operation, Word,
18    crypto::hash::{Blake3_256, Blake3Digest, Digest, Hasher},
19    utils::{ByteWriter, DeserializationError, Serializable},
20};
21
22mod serialization;
23
24mod merger;
25pub(crate) use merger::MastForestMerger;
26pub use merger::MastForestRootMap;
27
28mod multi_forest_node_iterator;
29pub(crate) use multi_forest_node_iterator::*;
30
31mod node_fingerprint;
32pub use node_fingerprint::{DecoratorFingerprint, MastNodeFingerprint};
33
34#[cfg(test)]
35mod tests;
36
37#[derive(Clone, Debug, Default, PartialEq, Eq)]
45pub struct MastForest {
46    nodes: Vec<MastNode>,
48
49    roots: Vec<MastNodeId>,
51
52    decorators: Vec<Decorator>,
54
55    advice_map: AdviceMap,
57
58    error_codes: BTreeMap<u64, Arc<str>>,
62}
63
64impl MastForest {
67    pub fn new() -> Self {
69        Self::default()
70    }
71}
72
73impl MastForest {
76    const MAX_NODES: usize = (1 << 30) - 1;
78
79    pub fn add_decorator(&mut self, decorator: Decorator) -> Result<DecoratorId, MastForestError> {
81        if self.decorators.len() >= u32::MAX as usize {
82            return Err(MastForestError::TooManyDecorators);
83        }
84
85        let new_decorator_id = DecoratorId(self.decorators.len() as u32);
86        self.decorators.push(decorator);
87
88        Ok(new_decorator_id)
89    }
90
91    pub fn add_node(&mut self, node: MastNode) -> Result<MastNodeId, MastForestError> {
95        if self.nodes.len() == Self::MAX_NODES {
96            return Err(MastForestError::TooManyNodes);
97        }
98
99        let new_node_id = MastNodeId(self.nodes.len() as u32);
100        self.nodes.push(node);
101
102        Ok(new_node_id)
103    }
104
105    pub fn add_block(
107        &mut self,
108        operations: Vec<Operation>,
109        decorators: Option<DecoratorList>,
110    ) -> Result<MastNodeId, MastForestError> {
111        let block = MastNode::new_basic_block(operations, decorators)?;
112        self.add_node(block)
113    }
114
115    pub fn add_join(
117        &mut self,
118        left_child: MastNodeId,
119        right_child: MastNodeId,
120    ) -> Result<MastNodeId, MastForestError> {
121        let join = MastNode::new_join(left_child, right_child, self)?;
122        self.add_node(join)
123    }
124
125    pub fn add_split(
127        &mut self,
128        if_branch: MastNodeId,
129        else_branch: MastNodeId,
130    ) -> Result<MastNodeId, MastForestError> {
131        let split = MastNode::new_split(if_branch, else_branch, self)?;
132        self.add_node(split)
133    }
134
135    pub fn add_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, MastForestError> {
137        let loop_node = MastNode::new_loop(body, self)?;
138        self.add_node(loop_node)
139    }
140
141    pub fn add_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
143        let call = MastNode::new_call(callee, self)?;
144        self.add_node(call)
145    }
146
147    pub fn add_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
149        let syscall = MastNode::new_syscall(callee, self)?;
150        self.add_node(syscall)
151    }
152
153    pub fn add_dyn(&mut self) -> Result<MastNodeId, MastForestError> {
155        self.add_node(MastNode::new_dyn())
156    }
157
158    pub fn add_dyncall(&mut self) -> Result<MastNodeId, MastForestError> {
160        self.add_node(MastNode::new_dyncall())
161    }
162
163    pub fn add_external(&mut self, mast_root: Word) -> Result<MastNodeId, MastForestError> {
165        self.add_node(MastNode::new_external(mast_root))
166    }
167
168    pub fn make_root(&mut self, new_root_id: MastNodeId) {
176        assert!((new_root_id.0 as usize) < self.nodes.len());
177
178        if !self.roots.contains(&new_root_id) {
179            self.roots.push(new_root_id);
180        }
181    }
182
183    pub fn remove_nodes(
191        &mut self,
192        nodes_to_remove: &BTreeSet<MastNodeId>,
193    ) -> BTreeMap<MastNodeId, MastNodeId> {
194        if nodes_to_remove.is_empty() {
195            return BTreeMap::new();
196        }
197
198        let old_nodes = mem::take(&mut self.nodes);
199        let old_root_ids = mem::take(&mut self.roots);
200        let (retained_nodes, id_remappings) = remove_nodes(old_nodes, nodes_to_remove);
201
202        self.remap_and_add_nodes(retained_nodes, &id_remappings);
203        self.remap_and_add_roots(old_root_ids, &id_remappings);
204        id_remappings
205    }
206
207    pub fn append_before_enter(&mut self, node_id: MastNodeId, decorator_ids: &[DecoratorId]) {
208        self[node_id].append_before_enter(decorator_ids)
209    }
210
211    pub fn append_after_exit(&mut self, node_id: MastNodeId, decorator_ids: &[DecoratorId]) {
212        self[node_id].append_after_exit(decorator_ids)
213    }
214
215    pub fn strip_decorators(&mut self) {
217        for node in self.nodes.iter_mut() {
218            node.remove_decorators();
219        }
220        self.decorators.truncate(0);
221    }
222
223    pub fn merge<'forest>(
273        forests: impl IntoIterator<Item = &'forest MastForest>,
274    ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
275        MastForestMerger::merge(forests)
276    }
277
278    #[cfg(test)]
283    pub fn add_block_with_raw_decorators(
284        &mut self,
285        operations: Vec<Operation>,
286        decorators: Vec<(usize, Decorator)>,
287    ) -> Result<MastNodeId, MastForestError> {
288        let block = MastNode::new_basic_block_with_raw_decorators(operations, decorators, self)?;
289        self.add_node(block)
290    }
291}
292
293impl MastForest {
295    fn remap_and_add_nodes(
301        &mut self,
302        nodes_to_add: Vec<MastNode>,
303        id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
304    ) {
305        assert!(self.nodes.is_empty());
306
307        for live_node in nodes_to_add {
310            match &live_node {
311                MastNode::Join(join_node) => {
312                    let first_child =
313                        id_remappings.get(&join_node.first()).copied().unwrap_or(join_node.first());
314                    let second_child = id_remappings
315                        .get(&join_node.second())
316                        .copied()
317                        .unwrap_or(join_node.second());
318
319                    self.add_join(first_child, second_child).unwrap();
320                },
321                MastNode::Split(split_node) => {
322                    let on_true_child = id_remappings
323                        .get(&split_node.on_true())
324                        .copied()
325                        .unwrap_or(split_node.on_true());
326                    let on_false_child = id_remappings
327                        .get(&split_node.on_false())
328                        .copied()
329                        .unwrap_or(split_node.on_false());
330
331                    self.add_split(on_true_child, on_false_child).unwrap();
332                },
333                MastNode::Loop(loop_node) => {
334                    let body_id =
335                        id_remappings.get(&loop_node.body()).copied().unwrap_or(loop_node.body());
336
337                    self.add_loop(body_id).unwrap();
338                },
339                MastNode::Call(call_node) => {
340                    let callee_id = id_remappings
341                        .get(&call_node.callee())
342                        .copied()
343                        .unwrap_or(call_node.callee());
344
345                    if call_node.is_syscall() {
346                        self.add_syscall(callee_id).unwrap();
347                    } else {
348                        self.add_call(callee_id).unwrap();
349                    }
350                },
351                MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => {
352                    self.add_node(live_node).unwrap();
353                },
354            }
355        }
356    }
357
358    fn remap_and_add_roots(
363        &mut self,
364        old_root_ids: Vec<MastNodeId>,
365        id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
366    ) {
367        assert!(self.roots.is_empty());
368
369        for old_root_id in old_root_ids {
370            let new_root_id = id_remappings.get(&old_root_id).copied().unwrap_or(old_root_id);
371            self.make_root(new_root_id);
372        }
373    }
374}
375
376fn remove_nodes(
379    mast_nodes: Vec<MastNode>,
380    nodes_to_remove: &BTreeSet<MastNodeId>,
381) -> (Vec<MastNode>, BTreeMap<MastNodeId, MastNodeId>) {
382    assert!(mast_nodes.len() < u32::MAX as usize);
384
385    let mut retained_nodes = Vec::with_capacity(mast_nodes.len());
386    let mut id_remappings = BTreeMap::new();
387
388    for (old_node_index, old_node) in mast_nodes.into_iter().enumerate() {
389        let old_node_id: MastNodeId = MastNodeId(old_node_index as u32);
390
391        if !nodes_to_remove.contains(&old_node_id) {
392            let new_node_id: MastNodeId = MastNodeId(retained_nodes.len() as u32);
393            id_remappings.insert(old_node_id, new_node_id);
394
395            retained_nodes.push(old_node);
396        }
397    }
398
399    (retained_nodes, id_remappings)
400}
401
402impl MastForest {
406    #[inline(always)]
411    pub fn get_decorator_by_id(&self, decorator_id: DecoratorId) -> Option<&Decorator> {
412        let idx = decorator_id.0 as usize;
413
414        self.decorators.get(idx)
415    }
416
417    #[inline(always)]
422    pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
423        let idx = node_id.0 as usize;
424
425        self.nodes.get(idx)
426    }
427
428    #[inline(always)]
430    pub fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId> {
431        self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied()
432    }
433
434    pub fn is_procedure_root(&self, node_id: MastNodeId) -> bool {
436        self.roots.contains(&node_id)
437    }
438
439    pub fn procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
441        self.roots.iter().map(|&root_id| self[root_id].digest())
442    }
443
444    pub fn local_procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
448        self.roots.iter().filter_map(|&root_id| {
449            let node = &self[root_id];
450            if node.is_external() { None } else { Some(node.digest()) }
451        })
452    }
453
454    pub fn procedure_roots(&self) -> &[MastNodeId] {
456        &self.roots
457    }
458
459    pub fn num_procedures(&self) -> u32 {
461        self.roots
462            .len()
463            .try_into()
464            .expect("MAST forest contains more than 2^32 procedures.")
465    }
466
467    pub fn compute_nodes_commitment<'a>(
472        &self,
473        node_ids: impl IntoIterator<Item = &'a MastNodeId>,
474    ) -> Word {
475        let mut digests: Vec<Word> = node_ids.into_iter().map(|&id| self[id].digest()).collect();
476        digests.sort_unstable_by_key(|word| LexicographicWord::from(*word));
477        miden_crypto::hash::rpo::Rpo256::merge_many(&digests)
478    }
479
480    pub fn num_nodes(&self) -> u32 {
482        self.nodes.len() as u32
483    }
484
485    pub fn nodes(&self) -> &[MastNode] {
487        &self.nodes
488    }
489
490    pub fn decorators(&self) -> &[Decorator] {
491        &self.decorators
492    }
493
494    pub fn advice_map(&self) -> &AdviceMap {
495        &self.advice_map
496    }
497
498    pub fn advice_map_mut(&mut self) -> &mut AdviceMap {
499        &mut self.advice_map
500    }
501
502    pub fn register_error(&mut self, msg: Arc<str>) -> Felt {
505        let code: Felt = error_code_from_msg(&msg);
506        self.error_codes.insert(code.as_int(), msg);
508        code
509    }
510
511    pub fn resolve_error_message(&self, code: Felt) -> Option<Arc<str>> {
513        let key = u64::from(code);
514        self.error_codes.get(&key).cloned()
515    }
516}
517
518impl Index<MastNodeId> for MastForest {
519    type Output = MastNode;
520
521    #[inline(always)]
522    fn index(&self, node_id: MastNodeId) -> &Self::Output {
523        let idx = node_id.0 as usize;
524
525        &self.nodes[idx]
526    }
527}
528
529impl IndexMut<MastNodeId> for MastForest {
530    #[inline(always)]
531    fn index_mut(&mut self, node_id: MastNodeId) -> &mut Self::Output {
532        let idx = node_id.0 as usize;
533
534        &mut self.nodes[idx]
535    }
536}
537
538impl Index<DecoratorId> for MastForest {
539    type Output = Decorator;
540
541    #[inline(always)]
542    fn index(&self, decorator_id: DecoratorId) -> &Self::Output {
543        let idx = decorator_id.0 as usize;
544
545        &self.decorators[idx]
546    }
547}
548
549impl IndexMut<DecoratorId> for MastForest {
550    #[inline(always)]
551    fn index_mut(&mut self, decorator_id: DecoratorId) -> &mut Self::Output {
552        let idx = decorator_id.0 as usize;
553        &mut self.decorators[idx]
554    }
555}
556
557#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
567pub struct MastNodeId(u32);
568
569pub type Remapping = BTreeMap<MastNodeId, MastNodeId>;
571
572impl MastNodeId {
573    pub fn from_u32_safe(
578        value: u32,
579        mast_forest: &MastForest,
580    ) -> Result<Self, DeserializationError> {
581        Self::from_u32_with_node_count(value, mast_forest.nodes.len())
582    }
583
584    pub fn from_usize_safe(
588        node_id: usize,
589        mast_forest: &MastForest,
590    ) -> Result<Self, DeserializationError> {
591        let node_id: u32 = node_id.try_into().map_err(|_| {
592            DeserializationError::InvalidValue(format!(
593                "node id '{node_id}' does not fit into a u32"
594            ))
595        })?;
596        MastNodeId::from_u32_safe(node_id, mast_forest)
597    }
598
599    pub(crate) fn new_unchecked(value: u32) -> Self {
601        Self(value)
602    }
603
604    pub(super) fn from_u32_with_node_count(
618        id: u32,
619        node_count: usize,
620    ) -> Result<Self, DeserializationError> {
621        if (id as usize) < node_count {
622            Ok(Self(id))
623        } else {
624            Err(DeserializationError::InvalidValue(format!(
625                "Invalid deserialized MAST node ID '{id}', but {node_count} is the number of nodes in the forest",
626            )))
627        }
628    }
629
630    pub fn as_usize(&self) -> usize {
631        self.0 as usize
632    }
633
634    pub fn as_u32(&self) -> u32 {
635        self.0
636    }
637
638    pub fn remap(&self, remapping: &Remapping) -> Self {
640        *remapping.get(self).unwrap_or(self)
641    }
642}
643
644impl From<MastNodeId> for usize {
645    fn from(value: MastNodeId) -> Self {
646        value.0 as usize
647    }
648}
649
650impl From<MastNodeId> for u32 {
651    fn from(value: MastNodeId) -> Self {
652        value.0
653    }
654}
655
656impl From<&MastNodeId> for u32 {
657    fn from(value: &MastNodeId) -> Self {
658        value.0
659    }
660}
661
662impl fmt::Display for MastNodeId {
663    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
664        write!(f, "MastNodeId({})", self.0)
665    }
666}
667
668pub struct SubtreeIterator<'a> {
673    forest: &'a MastForest,
674    discovered: Vec<MastNodeId>,
675    unvisited: Vec<MastNodeId>,
676}
677impl<'a> SubtreeIterator<'a> {
678    pub fn new(root: &MastNodeId, forest: &'a MastForest) -> Self {
679        let discovered = vec![];
680        let unvisited = vec![*root];
681        SubtreeIterator { forest, discovered, unvisited }
682    }
683}
684impl Iterator for SubtreeIterator<'_> {
685    type Item = MastNodeId;
686    fn next(&mut self) -> Option<MastNodeId> {
687        while let Some(id) = self.unvisited.pop() {
688            let node = &self.forest[id];
689            if !node.has_children() {
690                return Some(id);
691            } else {
692                self.discovered.push(id);
693                node.append_children_to(&mut self.unvisited);
694            }
695        }
696        self.discovered.pop()
697    }
698}
699
700#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
706pub struct DecoratorId(u32);
707
708impl DecoratorId {
709    pub fn from_u32_safe(
714        value: u32,
715        mast_forest: &MastForest,
716    ) -> Result<Self, DeserializationError> {
717        if (value as usize) < mast_forest.decorators.len() {
718            Ok(Self(value))
719        } else {
720            Err(DeserializationError::InvalidValue(format!(
721                "Invalid deserialized MAST decorator id '{}', but only {} decorators in the forest",
722                value,
723                mast_forest.nodes.len(),
724            )))
725        }
726    }
727
728    pub(crate) fn new_unchecked(value: u32) -> Self {
730        Self(value)
731    }
732
733    pub fn as_usize(&self) -> usize {
734        self.0 as usize
735    }
736
737    pub fn as_u32(&self) -> u32 {
738        self.0
739    }
740}
741
742impl From<DecoratorId> for usize {
743    fn from(value: DecoratorId) -> Self {
744        value.0 as usize
745    }
746}
747
748impl From<DecoratorId> for u32 {
749    fn from(value: DecoratorId) -> Self {
750        value.0
751    }
752}
753
754impl From<&DecoratorId> for u32 {
755    fn from(value: &DecoratorId) -> Self {
756        value.0
757    }
758}
759
760impl fmt::Display for DecoratorId {
761    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
762        write!(f, "DecoratorId({})", self.0)
763    }
764}
765
766impl Serializable for DecoratorId {
767    fn write_into<W: ByteWriter>(&self, target: &mut W) {
768        self.0.write_into(target)
769    }
770}
771
772pub fn error_code_from_msg(msg: impl AsRef<str>) -> Felt {
775    let digest: Blake3Digest<32> = Blake3_256::hash(msg.as_ref().as_bytes());
776    let mut digest_bytes: [u8; 8] = [0; 8];
777    digest_bytes.copy_from_slice(&digest.as_bytes()[0..8]);
778    let code = u64::from_le_bytes(digest_bytes);
779    Felt::new(code)
780}
781
782#[derive(Debug, thiserror::Error, PartialEq)]
787pub enum MastForestError {
788    #[error("MAST forest decorator count exceeds the maximum of {} decorators", u32::MAX)]
789    TooManyDecorators,
790    #[error("MAST forest node count exceeds the maximum of {} nodes", MastForest::MAX_NODES)]
791    TooManyNodes,
792    #[error("node id {0} is greater than or equal to forest length {1}")]
793    NodeIdOverflow(MastNodeId, usize),
794    #[error("decorator id {0} is greater than or equal to decorator count {1}")]
795    DecoratorIdOverflow(DecoratorId, usize),
796    #[error("basic block cannot be created from an empty list of operations")]
797    EmptyBasicBlock,
798    #[error(
799        "decorator root of child with node id {0} is missing but is required for fingerprint computation"
800    )]
801    ChildFingerprintMissing(MastNodeId),
802    #[error("advice map key {0} already exists when merging forests")]
803    AdviceMapKeyCollisionOnMerge(Word),
804}