Skip to main content

miden_core/mast/
mod.rs

1//! MAST forest: a collection of procedures represented as Merkle trees.
2//!
3//! # Deserializing from untrusted sources
4//!
5//! When loading a `MastForest` from bytes you don't fully trust (network, user upload, etc.),
6//! use [`UntrustedMastForest`] instead of calling `MastForest::read_from_bytes` directly:
7//!
8//! ```ignore
9//! use miden_core::mast::UntrustedMastForest;
10//!
11//! let forest = UntrustedMastForest::read_from_bytes(&bytes)?
12//!     .validate()?;
13//! ```
14//!
15//! [`UntrustedMastForest::read_from_bytes`] applies default parsing and validation budgets derived
16//! from the input size. Use [`UntrustedMastForest::read_from_bytes_with_options`] with
17//! [`UntrustedMastForestReadOptions`] to tune the wire byte budget. This limits allocations driven
18//! directly by wire counts while reading the payload. A separate validation helper budget is
19//! derived from it for later allocations needed to materialize and check hashless payloads.
20//!
21//! ```ignore
22//! use miden_core::mast::{UntrustedMastForest, UntrustedMastForestReadOptions};
23//!
24//! let options = UntrustedMastForestReadOptions::new()
25//!     .with_wire_byte_budget(bytes.len());
26//! let forest = UntrustedMastForest::read_from_bytes_with_options(&bytes, options)?
27//!     .validate()?;
28//! ```
29//!
30//! This recomputes all node hashes and checks structural invariants before returning a usable
31//! `MastForest`. Direct deserialization via `MastForest::read_from_bytes` trusts the serialized
32//! hashes and should only be used for data from trusted sources (e.g. compiled locally).
33//!
34//! In practice, the public entry points split into three policies:
35//! - [`MastForest::read_from_bytes`]: trusted full deserialization; rejects hashless payloads and
36//!   trusts serialized non-external digests.
37//! - [`MastForestWireView::new`]: trusted wire-backed cache access; scans only the layout needed
38//!   for random access and rejects hashless payloads.
39//! - [`UntrustedMastForest::read_from_bytes`] and
40//!   [`UntrustedMastForest::read_from_bytes_with_options`]: untrusted paths; parse with bounded
41//!   readers and require [`UntrustedMastForest::validate`] before use.
42
43#[cfg(test)]
44use alloc::collections::BTreeSet;
45use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec};
46use core::{fmt, ops::Index};
47
48#[cfg(any(test, feature = "arbitrary"))]
49use proptest::prelude::*;
50#[cfg(feature = "serde")]
51use serde::{Deserialize, Serialize};
52
53mod node;
54#[cfg(any(test, feature = "arbitrary"))]
55pub use node::arbitrary;
56pub(crate) use node::collect_immediate_placements;
57pub use node::{
58    BasicBlockNode, BasicBlockNodeBuilder, CallNode, CallNodeBuilder, DynNode, DynNodeBuilder,
59    ExternalNode, ExternalNodeBuilder, JoinNode, JoinNodeBuilder, LoopNode, LoopNodeBuilder,
60    MastForestContributor, MastNode, MastNodeBuilder, MastNodeExt, OP_BATCH_SIZE, OP_GROUP_SIZE,
61    OpBatch, SplitNode, SplitNodeBuilder,
62};
63
64#[cfg(feature = "serde")]
65use crate::serde::{Deserializable, Serializable, SliceReader};
66use crate::{
67    Felt, Word,
68    advice::AdviceMap,
69    serde::{ByteWriter, DeserializationError},
70    utils::{Idx, IndexVec, hash_string_to_word},
71};
72
73mod serialization;
74pub use serialization::{
75    AdviceMapView, AdviceValueView, MastForestReadMode, MastForestReadView, MastForestView,
76    MastForestWireView, MastNodeEntry, MastNodeInfo,
77};
78
79mod untrusted;
80pub use untrusted::{UntrustedMastForest, UntrustedMastForestReadOptions};
81
82mod merger;
83pub(crate) use merger::MastForestMerger;
84pub use merger::MastForestRootMap;
85
86mod multi_forest_node_iterator;
87pub(crate) use multi_forest_node_iterator::*;
88
89mod node_builder_utils;
90pub use node_builder_utils::build_node_with_remapped_ids;
91
92mod sparse;
93pub use sparse::{MastForestId, SparseMastForest, SparseMastForestBuilder, VisitKind};
94
95#[cfg(test)]
96mod tests;
97
98// MAST FOREST
99// ================================================================================================
100
101/// Represents one or more procedures, represented as a collection of [`MastNode`]s.
102///
103/// A [`MastForest`] does not have an entrypoint, and hence is not executable. A
104/// [`crate::program::Program`] can be built from a [`MastForest`] to specify an entrypoint.
105#[derive(Clone, Debug, Default)]
106#[cfg_attr(
107    all(feature = "arbitrary", test),
108    miden_test_serde_macros::serde_test(binary_serde(true))
109)]
110pub struct MastForest {
111    /// All of the nodes local to the trees comprising the MAST forest.
112    nodes: IndexVec<MastNodeId, MastNode>,
113
114    /// Roots of procedures defined within this MAST forest.
115    roots: Vec<MastNodeId>,
116
117    /// Advice map to be loaded into the VM prior to executing procedures from this MAST forest.
118    advice_map: AdviceMap,
119
120    /// Commitment to this MAST forest (commitment to all roots).
121    commitment: Word,
122}
123
124/// Complete parts needed to construct a finalized [`MastForest`].
125pub(crate) struct MastForestParts {
126    pub nodes: IndexVec<MastNodeId, MastNode>,
127    pub roots: Vec<MastNodeId>,
128    pub advice_map: AdviceMap,
129}
130
131// ------------------------------------------------------------------------------------------------
132/// Constructors
133impl MastForest {
134    /// Creates a new empty [`MastForest`].
135    pub fn new() -> Self {
136        Self {
137            nodes: IndexVec::new(),
138            roots: Vec::new(),
139            advice_map: AdviceMap::default(),
140            commitment: empty_mast_forest_commitment(),
141        }
142    }
143
144    /// Builds a [`MastForest`] from raw parts and validates local structure.
145    #[doc(hidden)]
146    pub fn from_raw_parts(
147        nodes: IndexVec<MastNodeId, MastNode>,
148        roots: Vec<MastNodeId>,
149        advice_map: AdviceMap,
150    ) -> Result<Self, MastForestError> {
151        Self::from_parts(MastForestParts { nodes, roots, advice_map })
152    }
153
154    /// Builds a [`MastForest`] from completed parts.
155    pub(crate) fn from_parts(parts: MastForestParts) -> Result<Self, MastForestError> {
156        if parts.nodes.len() > Self::MAX_NODES {
157            return Err(MastForestError::TooManyNodes);
158        }
159
160        let node_count = parts.nodes.len();
161        for &root_id in &parts.roots {
162            if root_id.to_usize() >= node_count {
163                return Err(MastForestError::NodeIdOverflow(root_id, node_count));
164            }
165        }
166
167        let forest = Self {
168            commitment: compute_nodes_commitment(&parts.nodes, &parts.roots),
169            nodes: parts.nodes,
170            roots: parts.roots,
171            advice_map: parts.advice_map,
172        };
173
174        forest.validate()?;
175        forest.validate_node_hashes()?;
176        Ok(forest)
177    }
178
179    pub(in crate::mast) fn from_trusted_deserialization_parts(
180        parts: MastForestParts,
181    ) -> Result<Self, MastForestError> {
182        if parts.nodes.len() > Self::MAX_NODES {
183            return Err(MastForestError::TooManyNodes);
184        }
185
186        let node_count = parts.nodes.len();
187        for &root_id in &parts.roots {
188            if root_id.to_usize() >= node_count {
189                return Err(MastForestError::NodeIdOverflow(root_id, node_count));
190            }
191        }
192        Ok(Self {
193            commitment: compute_nodes_commitment(&parts.nodes, &parts.roots),
194            nodes: parts.nodes,
195            roots: parts.roots,
196            advice_map: parts.advice_map,
197        })
198    }
199}
200
201// ------------------------------------------------------------------------------------------------
202/// Equality implementations
203impl PartialEq for MastForest {
204    fn eq(&self, other: &Self) -> bool {
205        self.nodes == other.nodes
206            && self.roots == other.roots
207            && self.advice_map == other.advice_map
208    }
209}
210
211impl Eq for MastForest {}
212
213// ------------------------------------------------------------------------------------------------
214/// State mutators
215impl MastForest {
216    /// The maximum number of nodes that can be stored in a single MAST forest.
217    const MAX_NODES: usize = (1 << 30) - 1;
218
219    // Kept private so callers cannot mutate roots arbitrarily, but shared with the merger so it
220    // can rebuild the root set while remapping nodes into the merged forest.
221    fn mark_root(&mut self, new_root_id: MastNodeId) {
222        assert!(new_root_id.to_usize() < self.nodes.len());
223
224        if !self.roots.contains(&new_root_id) {
225            self.roots.push(new_root_id);
226            self.commitment = self.compute_nodes_commitment(&self.roots);
227        }
228    }
229
230    /// Marks the given [`MastNodeId`] as being the root of a procedure.
231    ///
232    /// If the specified node is already marked as a root, this will have no effect.
233    ///
234    /// # Panics
235    /// - if `new_root_id`'s internal index is larger than the number of nodes in this forest (i.e.
236    ///   clearly doesn't belong to this MAST forest).
237    #[cfg(any(test, feature = "arbitrary"))]
238    pub fn make_root(&mut self, new_root_id: MastNodeId) {
239        self.mark_root(new_root_id);
240    }
241
242    /// Removes all nodes in the provided set from the MAST forest. The nodes MUST be orphaned (i.e.
243    /// have no parent). Otherwise, this parent's reference is considered "dangling" after the
244    /// removal (i.e. will point to an incorrect node after the removal), and this removal operation
245    /// would result in an invalid [`MastForest`].
246    ///
247    /// It also returns the map from old node IDs to new node IDs. Any [`MastNodeId`] used in
248    /// reference to the old [`MastForest`] should be remapped using this map.
249    #[cfg(test)]
250    pub fn remove_nodes(
251        &mut self,
252        nodes_to_remove: &BTreeSet<MastNodeId>,
253    ) -> BTreeMap<MastNodeId, MastNodeId> {
254        if nodes_to_remove.is_empty() {
255            return BTreeMap::new();
256        }
257
258        self.assert_nodes_to_remove_are_orphaned(nodes_to_remove);
259
260        let old_nodes = core::mem::replace(&mut self.nodes, IndexVec::new());
261        let old_root_ids = core::mem::take(&mut self.roots);
262        let (retained_nodes, id_remappings) = remove_nodes(old_nodes.into_inner(), nodes_to_remove);
263
264        self.remap_and_add_nodes(retained_nodes, &id_remappings);
265        self.remap_and_add_roots(old_root_ids, &id_remappings);
266
267        self.commitment = self.compute_nodes_commitment(&self.roots);
268
269        id_remappings
270    }
271
272    /// Compacts the forest by merging duplicate nodes.
273    ///
274    /// This operation performs node deduplication by merging the forest with itself.
275    /// This method consumes the forest and returns a new compacted forest.
276    ///
277    /// The process works by:
278    /// 1. Merging the forest with itself to deduplicate identical nodes
279    /// 2. Updating internal node references and remappings
280    /// 3. Returning the compacted forest and root map
281    ///
282    /// # Examples
283    ///
284    /// ```rust
285    /// use miden_core::mast::MastForest;
286    ///
287    /// let forest = MastForest::new();
288    /// // Add nodes to the forest
289    ///
290    /// // Compact the forest (consumes the original)
291    /// let (compacted_forest, root_map) = forest.compact();
292    ///
293    /// // compacted_forest is now compacted with duplicate nodes merged
294    /// ```
295    pub fn compact(self) -> (MastForest, MastForestRootMap) {
296        // Merge with itself to deduplicate nodes
297        // Note: This cannot fail for a self-merge under normal conditions.
298        // The only possible failure (TooManyNodes) would require the original forest to be at a
299        // capacity limit, at which point compaction wouldn't help.
300        MastForest::merge([&self])
301            .expect("Failed to compact MastForest: this should never happen during self-merge")
302    }
303
304    /// Merges all `forests` into a new [`MastForest`].
305    ///
306    /// Merging two forests means combining all their constituent parts, i.e. [`MastNode`]s and
307    /// roots. During this process, any duplicate or unreachable nodes are removed. Additionally,
308    /// [`MastNodeId`]s of nodes may change and references to them are remapped to their new
309    /// location.
310    ///
311    /// For example, consider this representation of a forest's nodes with all of these nodes being
312    /// roots:
313    ///
314    /// ```text
315    /// [Block(foo), Block(bar)]
316    /// ```
317    ///
318    /// If we merge another forest into it:
319    ///
320    /// ```text
321    /// [Block(bar), Call(0)]
322    /// ```
323    ///
324    /// then we would expect this forest:
325    ///
326    /// ```text
327    /// [Block(foo), Block(bar), Call(1)]
328    /// ```
329    ///
330    /// - The `Call` to the `bar` block was remapped to its new index (now 1, previously 0).
331    /// - The `Block(bar)` was deduplicated any only exists once in the merged forest.
332    ///
333    /// The function also returns a vector of [`MastForestRootMap`]s, whose length equals the number
334    /// of passed `forests`. The indices in the vector correspond to the ones in `forests`. The map
335    /// of a given forest contains the new locations of its roots in the merged forest. To
336    /// illustrate, the above example would return a vector of two maps:
337    ///
338    /// ```text
339    /// vec![{0 -> 0, 1 -> 1}
340    ///      {0 -> 1, 1 -> 2}]
341    /// ```
342    ///
343    /// - The root locations of the original forest are unchanged.
344    /// - For the second forest, the `bar` block has moved from index 0 to index 1 in the merged
345    ///   forest, and the `Call` has moved from index 1 to 2.
346    ///
347    /// If any forest being merged contains an `External(qux)` node and another forest contains a
348    /// node whose digest is `qux`, then the external node will be replaced with the `qux` node,
349    /// which is effectively deduplication.
350    pub fn merge<'forest>(
351        forests: impl IntoIterator<Item = &'forest MastForest>,
352    ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
353        MastForestMerger::merge(forests)
354    }
355}
356
357// ------------------------------------------------------------------------------------------------
358/// Helpers
359impl MastForest {
360    #[cfg(test)]
361    fn assert_nodes_to_remove_are_orphaned(&self, nodes_to_remove: &BTreeSet<MastNodeId>) {
362        for (node_idx, node) in self.nodes.iter().enumerate() {
363            let node_id = MastNodeId::new_unchecked(node_idx.try_into().expect("too many nodes"));
364            if nodes_to_remove.contains(&node_id) {
365                continue;
366            }
367
368            node.for_each_child(|child_id| {
369                assert!(
370                    !nodes_to_remove.contains(&child_id),
371                    "cannot remove node {child_id:?}; retained node {node_id:?} references it"
372                );
373            });
374        }
375    }
376
377    /// Adds all provided nodes to the internal set of nodes, remapping all [`MastNodeId`]
378    /// references in those nodes.
379    ///
380    /// # Panics
381    /// - Panics if the internal set of nodes is not empty.
382    #[cfg(test)]
383    fn remap_and_add_nodes(
384        &mut self,
385        nodes_to_add: Vec<MastNode>,
386        id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
387    ) {
388        assert!(self.nodes.is_empty());
389        let node_builders =
390            nodes_to_add.into_iter().map(|node| node.to_builder(self)).collect::<Vec<_>>();
391
392        // Add each node to the new MAST forest, making sure to rewrite any outdated internal
393        // `MastNodeId`s
394        for live_node_builder in node_builders {
395            live_node_builder.remap_children(id_remappings).add_to_forest(self).unwrap();
396        }
397    }
398
399    /// Remaps and adds all old root ids to the internal set of roots.
400    ///
401    /// # Panics
402    /// - Panics if the internal set of roots is not empty.
403    #[cfg(test)]
404    fn remap_and_add_roots(
405        &mut self,
406        old_root_ids: Vec<MastNodeId>,
407        id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
408    ) {
409        assert!(self.roots.is_empty());
410
411        for old_root_id in old_root_ids {
412            if let Some(new_root_id) = id_remappings.get(&old_root_id).copied() {
413                self.mark_root(new_root_id);
414            }
415        }
416    }
417}
418
419/// Returns the set of nodes that are live, as well as the mapping from "old ID" to "new ID" for all
420/// live nodes.
421#[cfg(test)]
422fn remove_nodes(
423    mast_nodes: Vec<MastNode>,
424    nodes_to_remove: &BTreeSet<MastNodeId>,
425) -> (Vec<MastNode>, BTreeMap<MastNodeId, MastNodeId>) {
426    // Note: this allows us to safely use `usize as u32`, guaranteeing that it won't wrap around.
427    assert!(mast_nodes.len() < u32::MAX as usize);
428
429    let mut retained_nodes = Vec::with_capacity(mast_nodes.len());
430    let mut id_remappings = BTreeMap::new();
431
432    for (old_node_index, old_node) in mast_nodes.into_iter().enumerate() {
433        let old_node_id: MastNodeId = MastNodeId(old_node_index as u32);
434
435        if !nodes_to_remove.contains(&old_node_id) {
436            let new_node_id: MastNodeId = MastNodeId(retained_nodes.len() as u32);
437            id_remappings.insert(old_node_id, new_node_id);
438
439            retained_nodes.push(old_node);
440        }
441    }
442
443    (retained_nodes, id_remappings)
444}
445
446fn empty_mast_forest_commitment() -> Word {
447    miden_crypto::hash::poseidon2::Poseidon2::merge_many(&[])
448}
449
450fn compute_nodes_commitment(
451    nodes: &IndexVec<MastNodeId, MastNode>,
452    node_ids: &[MastNodeId],
453) -> Word {
454    let mut digests: Vec<Word> = node_ids.iter().map(|&id| nodes[id].digest()).collect();
455    digests.sort_unstable();
456    miden_crypto::hash::poseidon2::Poseidon2::merge_many(&digests)
457}
458
459// ------------------------------------------------------------------------------------------------
460/// Public accessors
461impl MastForest {
462    /// Returns the [`MastNode`] associated with the provided [`MastNodeId`] if valid, or else
463    /// `None`.
464    ///
465    /// This is the fallible version of indexing (e.g. `mast_forest[node_id]`).
466    #[inline(always)]
467    pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
468        self.nodes.get(node_id)
469    }
470
471    /// Returns the [`MastNodeId`] of the procedure associated with a given digest, if any.
472    #[inline(always)]
473    pub fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId> {
474        self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied()
475    }
476
477    /// Returns true if a node with the specified ID is a root of a procedure in this MAST forest.
478    pub fn is_procedure_root(&self, node_id: MastNodeId) -> bool {
479        self.roots.contains(&node_id)
480    }
481
482    /// Returns true if a node with the specified ID is a root of a procedure in this MAST forest,
483    /// and the digest of that procedure is `digest`.
484    ///
485    /// This is primarily intended for use in confirming that procedure exports of a package,
486    /// which declare their MAST node and digest, actually exist in the MAST.
487    pub fn is_procedure_root_with_exact_digest(&self, node_id: MastNodeId, digest: Word) -> bool {
488        self.is_procedure_root(node_id) && self[node_id].digest() == digest
489    }
490
491    /// Returns an iterator over the digests of all procedures in this MAST forest.
492    pub fn procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
493        self.roots.iter().map(|&root_id| self[root_id].digest())
494    }
495
496    /// Returns an iterator over the digests of local procedures in this MAST forest.
497    ///
498    /// A local procedure is defined as a procedure which is not a single external node.
499    pub fn local_procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
500        self.roots.iter().filter_map(|&root_id| {
501            let node = &self[root_id];
502            if node.is_external() { None } else { Some(node.digest()) }
503        })
504    }
505
506    /// Returns an iterator over the IDs of the procedures in this MAST forest.
507    pub fn procedure_roots(&self) -> &[MastNodeId] {
508        &self.roots
509    }
510
511    /// Returns the number of procedures in this MAST forest.
512    pub fn num_procedures(&self) -> u32 {
513        self.roots
514            .len()
515            .try_into()
516            .expect("MAST forest contains more than 2^32 procedures.")
517    }
518
519    /// Returns the [Word] representing the content hash of a subset of [`MastNodeId`]s.
520    ///
521    /// # Panics
522    /// This function panics if any `node_ids` is not a node of this forest.
523    pub fn compute_nodes_commitment<'a>(
524        &self,
525        node_ids: impl IntoIterator<Item = &'a MastNodeId>,
526    ) -> Word {
527        let node_ids = node_ids.into_iter().copied().collect::<Vec<_>>();
528        compute_nodes_commitment(&self.nodes, &node_ids)
529    }
530
531    /// Returns the commitment to this MAST forest.
532    ///
533    /// The commitment is computed as the sequential hash of all procedure roots in the forest.
534    ///
535    /// The commitment uniquely identifies the forest's structure, as each root's digest
536    /// transitively includes all of its descendants. Therefore, a commitment to all roots
537    /// is a commitment to the entire forest.
538    pub fn commitment(&self) -> Word {
539        self.commitment
540    }
541
542    /// Returns the number of nodes in this MAST forest.
543    pub fn num_nodes(&self) -> u32 {
544        self.nodes.len() as u32
545    }
546
547    /// Returns the underlying nodes in this MAST forest.
548    pub fn nodes(&self) -> &[MastNode] {
549        self.nodes.as_slice()
550    }
551
552    pub fn advice_map(&self) -> &AdviceMap {
553        &self.advice_map
554    }
555
556    /// Returns this forest with `advice_map` entries added.
557    pub fn with_advice_map(mut self, advice_map: AdviceMap) -> Self {
558        self.advice_map.extend(advice_map);
559        self
560    }
561
562    #[cfg(test)]
563    pub(crate) fn advice_map_mut(&mut self) -> &mut AdviceMap {
564        &mut self.advice_map
565    }
566
567    // SERIALIZATION
568    // --------------------------------------------------------------------------------------------
569
570    /// Serializes this MastForest with the HASHLESS flag set.
571    ///
572    /// Hashless forest bytes omit rebuildable internal node hashes. External node digests stay on
573    /// the wire because they cannot be rebuilt from local structure. Trusted deserialization
574    /// rejects this flag.
575    ///
576    /// Use this when producing data for untrusted validation.
577    pub fn write_hashless<W: ByteWriter>(&self, target: &mut W) {
578        serialization::write_hashless_into(self, target);
579    }
580}
581
582/// Validation methods
583impl MastForest {
584    fn validate_basic_block_invariants(&self) -> Result<(), MastForestError> {
585        for (node_id_idx, node) in self.nodes.iter().enumerate() {
586            let node_id =
587                MastNodeId::new_unchecked(node_id_idx.try_into().expect("too many nodes"));
588            if let MastNode::Block(basic_block) = node {
589                basic_block.validate_batch_invariants().map_err(|error_msg| {
590                    MastForestError::InvalidBatchPadding(node_id, error_msg)
591                })?;
592            }
593        }
594
595        Ok(())
596    }
597
598    /// Validates that all BasicBlockNodes in this forest satisfy the core invariants:
599    /// 1. Power-of-two number of groups in each batch
600    /// 2. No operation group ends with an operation requiring an immediate value
601    /// 3. The last operation group in a batch cannot contain operations requiring immediate values
602    /// 4. OpBatch structural consistency (num_groups <= BATCH_SIZE, group size <= GROUP_SIZE,
603    ///    indptr integrity, bounds checking)
604    ///
605    /// This addresses the gap created by PR 2094, where padding NOOPs are now inserted
606    /// at assembly time rather than dynamically during execution, and adds comprehensive
607    /// structural validation to prevent deserialization-time panics.
608    pub fn validate(&self) -> Result<(), MastForestError> {
609        self.validate_basic_block_invariants()?;
610        Ok(())
611    }
612
613    /// Validates that stored node digests match the hashes implied by local structure.
614    ///
615    /// For `External` nodes the digest is accepted as-is because it is externally provided and
616    /// cannot be reconstructed from local structure alone.
617    fn validate_node_hashes(&self) -> Result<(), MastForestError> {
618        let computed_hashes = self.compute_node_hashes()?;
619        for (node_idx, (node, computed_digest)) in
620            self.nodes.iter().zip(computed_hashes).enumerate()
621        {
622            let expected_digest = node.digest();
623            if expected_digest != computed_digest {
624                return Err(MastForestError::HashMismatch {
625                    node_id: MastNodeId::new_unchecked(node_idx as u32),
626                    expected: expected_digest,
627                    computed: computed_digest,
628                });
629            }
630        }
631
632        Ok(())
633    }
634
635    /// Computes node hashes in topological order.
636    ///
637    /// The returned vector is aligned with node indices, so `digests[node_id as usize]` is the
638    /// digest of that node.
639    ///
640    /// For `External` nodes, the existing digest is returned unchanged.
641    ///
642    /// Returns [`MastForestError::ForwardReference`] if nodes are not in topological order.
643    fn compute_node_hashes(&self) -> Result<Vec<Word>, MastForestError> {
644        use crate::chiplets::hasher;
645
646        /// Checks that child_id references a node that appears before node_id in topological order.
647        fn check_no_forward_ref(
648            node_id: MastNodeId,
649            child_id: MastNodeId,
650        ) -> Result<(), MastForestError> {
651            if child_id.0 >= node_id.0 {
652                return Err(MastForestError::ForwardReference(node_id, child_id));
653            }
654            Ok(())
655        }
656
657        let mut computed_hashes = Vec::with_capacity(self.nodes.len());
658        for (node_idx, node) in self.nodes.iter().enumerate() {
659            let node_id = MastNodeId::new_unchecked(node_idx as u32);
660
661            // Check topological ordering and compute digest.
662            let computed_digest = match node {
663                MastNode::Block(block) => {
664                    let op_groups: Vec<Felt> =
665                        block.op_batches().iter().flat_map(|batch| *batch.groups()).collect();
666                    hasher::hash_elements(&op_groups)
667                },
668                MastNode::Join(join) => {
669                    let left_id = join.first();
670                    let right_id = join.second();
671                    check_no_forward_ref(node_id, left_id)?;
672                    check_no_forward_ref(node_id, right_id)?;
673
674                    let left_digest = computed_hashes[left_id.0 as usize];
675                    let right_digest = computed_hashes[right_id.0 as usize];
676                    hasher::merge_in_domain(&[left_digest, right_digest], JoinNode::DOMAIN)
677                },
678                MastNode::Split(split) => {
679                    let true_id = split.on_true();
680                    let false_id = split.on_false();
681                    check_no_forward_ref(node_id, true_id)?;
682                    check_no_forward_ref(node_id, false_id)?;
683
684                    let true_digest = computed_hashes[true_id.0 as usize];
685                    let false_digest = computed_hashes[false_id.0 as usize];
686                    hasher::merge_in_domain(&[true_digest, false_digest], SplitNode::DOMAIN)
687                },
688                MastNode::Loop(loop_node) => {
689                    let body_id = loop_node.body();
690                    check_no_forward_ref(node_id, body_id)?;
691
692                    let body_digest = computed_hashes[body_id.0 as usize];
693                    hasher::merge_in_domain(&[body_digest, Word::default()], LoopNode::DOMAIN)
694                },
695                MastNode::Call(call) => {
696                    let callee_id = call.callee();
697                    check_no_forward_ref(node_id, callee_id)?;
698
699                    let callee_digest = computed_hashes[callee_id.0 as usize];
700                    let domain = if call.is_syscall() {
701                        CallNode::SYSCALL_DOMAIN
702                    } else {
703                        CallNode::CALL_DOMAIN
704                    };
705                    hasher::merge_in_domain(&[callee_digest, Word::default()], domain)
706                },
707                MastNode::Dyn(dyn_node) => {
708                    if dyn_node.is_dyncall() {
709                        DynNode::DYNCALL_DEFAULT_DIGEST
710                    } else {
711                        DynNode::DYN_DEFAULT_DIGEST
712                    }
713                },
714                MastNode::External(_) => {
715                    // External nodes have externally-provided digests that cannot be recomputed.
716                    node.digest()
717                },
718            };
719
720            computed_hashes.push(computed_digest);
721        }
722
723        Ok(computed_hashes)
724    }
725}
726
727// MAST FOREST INDEXING
728// ------------------------------------------------------------------------------------------------
729
730impl Index<MastNodeId> for MastForest {
731    type Output = MastNode;
732
733    #[inline(always)]
734    fn index(&self, node_id: MastNodeId) -> &Self::Output {
735        &self.nodes[node_id]
736    }
737}
738
739// EXECUTABLE MAST FOREST
740// ================================================================================================
741
742/// A MAST forest that can be used as the source of nodes during program execution.
743///
744/// Implemented by both [`MastForest`] (a dense forest containing all nodes) and
745/// [`SparseMastForest`] (a sparse subset of a forest containing only the nodes visited during
746/// some prior execution). The latter preserves the original [`MastNodeId`]s of its source forest,
747/// which allows it to stand in for the dense forest during re-execution.
748pub trait ExecutableMastForest {
749    /// Returns the [`MastNode`] associated with the provided [`MastNodeId`] if present, or else
750    /// `None`.
751    fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode>;
752
753    /// Returns the digest of the node associated with the provided [`MastNodeId`] if present, or
754    /// else `None`.
755    ///
756    /// For dense forests this is equivalent to `get_node_by_id(id).map(|n| n.digest())`. For
757    /// [`SparseMastForest`], it additionally consults the digest-only entries — nodes that were
758    /// referenced (but not entered) during execution and which were therefore stored as digest
759    /// only. Use this method whenever only the digest of a referenced node is needed (e.g. when
760    /// populating the hasher state of a parent's trace row).
761    fn get_digest_by_id(&self, node_id: MastNodeId) -> Option<Word>;
762
763    /// Returns the [`MastNodeId`] of the procedure associated with a given digest, if any.
764    fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId>;
765
766    /// Returns the advice map associated with this forest.
767    fn advice_map(&self) -> &AdviceMap;
768}
769
770impl ExecutableMastForest for MastForest {
771    #[inline(always)]
772    fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
773        MastForest::get_node_by_id(self, node_id)
774    }
775
776    #[inline(always)]
777    fn get_digest_by_id(&self, node_id: MastNodeId) -> Option<Word> {
778        MastForest::get_node_by_id(self, node_id).map(MastNodeExt::digest)
779    }
780
781    #[inline(always)]
782    fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId> {
783        MastForest::find_procedure_root(self, digest)
784    }
785
786    #[inline(always)]
787    fn advice_map(&self) -> &AdviceMap {
788        MastForest::advice_map(self)
789    }
790}
791
792// Blanket impl: an `Arc<T>` is an `ExecutableMastForest` whenever the underlying `T` is, which
793// allows the executor and tracer plumbing to be generic over a forest type while the live
794// (`Arc<MastForest>`) and replay (`Arc<SparseMastForest>`) paths each pick a concrete instance.
795impl<T> Index<MastNodeId> for Arc<T>
796where
797    T: Index<MastNodeId, Output = MastNode> + ?Sized,
798{
799    type Output = MastNode;
800
801    #[inline(always)]
802    fn index(&self, node_id: MastNodeId) -> &Self::Output {
803        &(**self)[node_id]
804    }
805}
806
807impl<T: ExecutableMastForest + ?Sized> ExecutableMastForest for Arc<T> {
808    #[inline(always)]
809    fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
810        T::get_node_by_id(self, node_id)
811    }
812
813    #[inline(always)]
814    fn get_digest_by_id(&self, node_id: MastNodeId) -> Option<Word> {
815        T::get_digest_by_id(self, node_id)
816    }
817
818    #[inline(always)]
819    fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId> {
820        T::find_procedure_root(self, digest)
821    }
822
823    #[inline(always)]
824    fn advice_map(&self) -> &AdviceMap {
825        T::advice_map(self)
826    }
827}
828
829// MAST NODE ID
830// ================================================================================================
831
832/// An opaque handle to a [`MastNode`] in some [`MastForest`]. It is the responsibility of the user
833/// to use a given [`MastNodeId`] with the corresponding [`MastForest`].
834///
835/// Note that the [`MastForest`] does *not* ensure that equal [`MastNode`]s have equal
836/// [`MastNodeId`] handles. Hence, [`MastNodeId`] equality must not be used to test for equality of
837/// the underlying [`MastNode`].
838#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
839#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
840#[cfg_attr(feature = "serde", serde(transparent))]
841#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
842pub struct MastNodeId(u32);
843
844/// Operations that mutate a MAST often produce this mapping between old and new NodeIds.
845pub type Remapping = BTreeMap<MastNodeId, MastNodeId>;
846
847impl MastNodeId {
848    /// Returns a new `MastNodeId` with the provided inner value, or an error if the provided
849    /// `value` is greater than the number of nodes in the forest.
850    ///
851    /// For use in deserialization.
852    pub fn from_u32_safe(
853        value: u32,
854        mast_forest: &MastForest,
855    ) -> Result<Self, DeserializationError> {
856        Self::from_u32_with_node_count(value, mast_forest.nodes.len())
857    }
858
859    /// Returns a new [`MastNodeId`] from the given `value` without checking its validity.
860    pub fn new_unchecked(value: u32) -> Self {
861        Self(value)
862    }
863
864    /// Returns a new [`MastNodeId`] with the provided `id`, or an error if `id` is greater or equal
865    /// to `node_count`. The `node_count` is the total number of nodes in the [`MastForest`] for
866    /// which this ID is being constructed.
867    ///
868    /// This function can be used when deserializing an id whose corresponding node is not yet in
869    /// the forest and [`Self::from_u32_safe`] would fail. For instance, when deserializing the ids
870    /// referenced by the Join node in this forest:
871    ///
872    /// ```text
873    /// [Join(1, 2), Block(foo), Block(bar)]
874    /// ```
875    ///
876    /// Since it is less safe than [`Self::from_u32_safe`] and usually not needed it is not public.
877    pub(super) fn from_u32_with_node_count(
878        id: u32,
879        node_count: usize,
880    ) -> Result<Self, DeserializationError> {
881        if (id as usize) < node_count {
882            Ok(Self(id))
883        } else {
884            Err(DeserializationError::InvalidValue(format!(
885                "Invalid deserialized MAST node ID '{id}', but {node_count} is the number of nodes in the forest",
886            )))
887        }
888    }
889
890    /// Remap the NodeId to its new position using the given [`Remapping`].
891    pub fn remap(&self, remapping: &Remapping) -> Self {
892        *remapping.get(self).unwrap_or(self)
893    }
894}
895
896impl From<u32> for MastNodeId {
897    fn from(value: u32) -> Self {
898        MastNodeId::new_unchecked(value)
899    }
900}
901
902impl Idx for MastNodeId {}
903
904impl From<MastNodeId> for u32 {
905    fn from(value: MastNodeId) -> Self {
906        value.0
907    }
908}
909
910impl fmt::Display for MastNodeId {
911    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
912        write!(f, "MastNodeId({})", self.0)
913    }
914}
915
916#[cfg(any(test, feature = "arbitrary"))]
917impl Arbitrary for MastNodeId {
918    type Parameters = ();
919
920    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
921        use proptest::prelude::*;
922        any::<u32>().prop_map(MastNodeId).boxed()
923    }
924
925    type Strategy = BoxedStrategy<Self>;
926}
927
928// ITERATOR
929
930/// Iterates over all the nodes a root depends on, in pre-order. The iteration can include other
931/// roots in the same forest.
932pub struct SubtreeIterator<'a> {
933    forest: &'a MastForest,
934    discovered: Vec<MastNodeId>,
935    unvisited: Vec<MastNodeId>,
936}
937impl<'a> SubtreeIterator<'a> {
938    pub fn new(root: &MastNodeId, forest: &'a MastForest) -> Self {
939        let discovered = vec![];
940        let unvisited = vec![*root];
941        SubtreeIterator { forest, discovered, unvisited }
942    }
943}
944impl Iterator for SubtreeIterator<'_> {
945    type Item = MastNodeId;
946    fn next(&mut self) -> Option<MastNodeId> {
947        while let Some(id) = self.unvisited.pop() {
948            let node = &self.forest[id];
949            if !node.has_children() {
950                return Some(id);
951            } else {
952                self.discovered.push(id);
953                node.append_children_to(&mut self.unvisited);
954            }
955        }
956        self.discovered.pop()
957    }
958}
959
960/// Derives an error code from an error message by hashing the message and returning the 0th element
961/// of the resulting [`Word`].
962pub fn error_code_from_msg(msg: impl AsRef<str>) -> Felt {
963    // hash the message and return 0th felt of the resulting Word
964    hash_string_to_word(msg.as_ref())[0]
965}
966
967// MAST FOREST ERROR
968// ================================================================================================
969
970/// Represents the types of errors that can occur when dealing with MAST forest.
971#[derive(Debug, thiserror::Error, PartialEq, Eq)]
972pub enum MastForestError {
973    #[error("MAST forest node count exceeds the maximum of {} nodes", MastForest::MAX_NODES)]
974    TooManyNodes,
975    #[error("node id {0} is greater than or equal to forest length {1}")]
976    NodeIdOverflow(MastNodeId, usize),
977    #[error("basic block cannot be created from an empty list of operations")]
978    EmptyBasicBlock,
979    #[error("advice map key {0} already exists when merging forests")]
980    AdviceMapKeyCollisionOnMerge(Word),
981    #[error("digest is required for deserialization")]
982    DigestRequiredForDeserialization,
983    #[error("invalid batch in basic block node {0:?}: {1}")]
984    InvalidBatchPadding(MastNodeId, String),
985    #[error(
986        "node {0:?} references child {1:?} which comes after it in the forest (forward reference)"
987    )]
988    ForwardReference(MastNodeId, MastNodeId),
989    #[error("hash mismatch for node {node_id:?}: expected {expected:?}, computed {computed:?}")]
990    HashMismatch {
991        node_id: MastNodeId,
992        expected: Word,
993        computed: Word,
994    },
995    #[error("deserialization failed: {0}")]
996    Deserialization(DeserializationError),
997}
998
999// Custom serde implementation for MastForest delegates to the binary serialization format.
1000#[cfg(feature = "serde")]
1001impl Serialize for MastForest {
1002    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1003    where
1004        S: serde::Serializer,
1005    {
1006        let bytes = Serializable::to_bytes(self);
1007        serializer.serialize_bytes(&bytes)
1008    }
1009}
1010
1011#[cfg(feature = "serde")]
1012impl<'de> Deserialize<'de> for MastForest {
1013    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1014    where
1015        D: serde::Deserializer<'de>,
1016    {
1017        // Deserialize bytes, then use miden-crypto Deserializable
1018        let bytes = Vec::<u8>::deserialize(deserializer)?;
1019        let mut slice_reader = SliceReader::new(&bytes);
1020        Deserializable::read_from(&mut slice_reader).map_err(serde::de::Error::custom)
1021    }
1022}