Skip to main content

miden_core/mast/
mod.rs

1//! MAST forest: a collection of procedures represented as Merkle trees.
2//!
3//! # Deserializing from untrusted sources
4//!
5//! When loading a `MastForest` from bytes you don't fully trust (network, user upload, etc.),
6//! use [`UntrustedMastForest`] instead of calling `MastForest::read_from_bytes` directly:
7//!
8//! ```ignore
9//! use miden_core::mast::UntrustedMastForest;
10//!
11//! let forest = UntrustedMastForest::read_from_bytes(&bytes)?
12//!     .validate()?;
13//! ```
14//!
15//! For maximum protection against denial-of-service attacks from malicious input, use
16//! [`UntrustedMastForest::read_from_bytes_with_budget`] which limits memory consumption:
17//!
18//! ```ignore
19//! use miden_core::mast::UntrustedMastForest;
20//!
21//! // Budget limits pre-allocation sizes and total bytes consumed
22//! let forest = UntrustedMastForest::read_from_bytes_with_budget(&bytes, bytes.len())?
23//!     .validate()?;
24//! ```
25//!
26//! This recomputes all node hashes and checks structural invariants before returning a usable
27//! `MastForest`. Direct deserialization via `MastForest::read_from_bytes` trusts the serialized
28//! hashes and should only be used for data from trusted sources (e.g. compiled locally).
29
30use alloc::{
31    collections::{BTreeMap, BTreeSet},
32    string::String,
33    sync::Arc,
34    vec::Vec,
35};
36use core::{
37    fmt,
38    ops::{Index, IndexMut},
39};
40
41use miden_utils_sync::OnceLockCompat;
42#[cfg(feature = "serde")]
43use serde::{Deserialize, Serialize};
44
45mod node;
46#[cfg(any(test, feature = "arbitrary"))]
47pub use node::arbitrary;
48pub use node::{
49    BasicBlockNode, BasicBlockNodeBuilder, CallNode, CallNodeBuilder, DecoratedOpLink,
50    DecoratorOpLinkIterator, DecoratorStore, DynNode, DynNodeBuilder, ExternalNode,
51    ExternalNodeBuilder, JoinNode, JoinNodeBuilder, LoopNode, LoopNodeBuilder,
52    MastForestContributor, MastNode, MastNodeBuilder, MastNodeExt, OP_BATCH_SIZE, OP_GROUP_SIZE,
53    OpBatch, OperationOrDecorator, SplitNode, SplitNodeBuilder,
54};
55
56use crate::{
57    Felt, LexicographicWord, Word,
58    advice::AdviceMap,
59    operations::{AssemblyOp, DebugVarInfo, Decorator},
60    serde::{
61        BudgetedReader, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
62        SliceReader,
63    },
64    utils::{Idx, IndexVec, hash_string_to_word},
65};
66
67mod debuginfo;
68pub use debuginfo::{
69    AsmOpIndexError, DebugInfo, DebugVarId, DecoratedLinks, DecoratedLinksIter,
70    DecoratorIndexError, NodeToDecoratorIds, OpToAsmOpId, OpToDebugVarIds, OpToDecoratorIds,
71};
72
73mod serialization;
74
75mod merger;
76pub(crate) use merger::MastForestMerger;
77pub use merger::MastForestRootMap;
78
79mod multi_forest_node_iterator;
80pub(crate) use multi_forest_node_iterator::*;
81
82mod node_fingerprint;
83pub use node_fingerprint::{DecoratorFingerprint, MastNodeFingerprint};
84
85mod node_builder_utils;
86pub use node_builder_utils::build_node_with_remapped_ids;
87
88#[cfg(test)]
89mod tests;
90
91// MAST FOREST
92// ================================================================================================
93
94/// Represents one or more procedures, represented as a collection of [`MastNode`]s.
95///
96/// A [`MastForest`] does not have an entrypoint, and hence is not executable. A
97/// [`crate::program::Program`] can be built from a [`MastForest`] to specify an entrypoint.
98#[derive(Clone, Debug, Default)]
99pub struct MastForest {
100    /// All of the nodes local to the trees comprising the MAST forest.
101    nodes: IndexVec<MastNodeId, MastNode>,
102
103    /// Roots of procedures defined within this MAST forest.
104    roots: Vec<MastNodeId>,
105
106    /// Advice map to be loaded into the VM prior to executing procedures from this MAST forest.
107    advice_map: AdviceMap,
108
109    /// Debug information including decorators and error codes.
110    /// Always present (as per issue #1821), but can be empty for stripped builds.
111    debug_info: DebugInfo,
112
113    /// Cached commitment to this MAST forest (commitment to all roots).
114    /// This is computed lazily on first access and invalidated on any mutation.
115    commitment_cache: OnceLockCompat<Word>,
116}
117
118// ------------------------------------------------------------------------------------------------
119/// Constructors
120impl MastForest {
121    /// Creates a new empty [`MastForest`].
122    pub fn new() -> Self {
123        Self {
124            nodes: IndexVec::new(),
125            roots: Vec::new(),
126            advice_map: AdviceMap::default(),
127            debug_info: DebugInfo::new(),
128            commitment_cache: OnceLockCompat::new(),
129        }
130    }
131}
132
133// ------------------------------------------------------------------------------------------------
134/// Equality implementations
135impl PartialEq for MastForest {
136    fn eq(&self, other: &Self) -> bool {
137        // Compare all fields except commitment_cache, which is derived data
138        self.nodes == other.nodes
139            && self.roots == other.roots
140            && self.advice_map == other.advice_map
141            && self.debug_info == other.debug_info
142    }
143}
144
145impl Eq for MastForest {}
146
147// ------------------------------------------------------------------------------------------------
148/// State mutators
149impl MastForest {
150    /// The maximum number of nodes that can be stored in a single MAST forest.
151    const MAX_NODES: usize = (1 << 30) - 1;
152
153    /// Marks the given [`MastNodeId`] as being the root of a procedure.
154    ///
155    /// If the specified node is already marked as a root, this will have no effect.
156    ///
157    /// # Panics
158    /// - if `new_root_id`'s internal index is larger than the number of nodes in this forest (i.e.
159    ///   clearly doesn't belong to this MAST forest).
160    pub fn make_root(&mut self, new_root_id: MastNodeId) {
161        assert!(new_root_id.to_usize() < self.nodes.len());
162
163        if !self.roots.contains(&new_root_id) {
164            self.roots.push(new_root_id);
165            // Invalidate the cached commitment since we modified the roots
166            self.commitment_cache.take();
167        }
168    }
169
170    /// Removes all nodes in the provided set from the MAST forest. The nodes MUST be orphaned (i.e.
171    /// have no parent). Otherwise, this parent's reference is considered "dangling" after the
172    /// removal (i.e. will point to an incorrect node after the removal), and this removal operation
173    /// would result in an invalid [`MastForest`].
174    ///
175    /// It also returns the map from old node IDs to new node IDs. Any [`MastNodeId`] used in
176    /// reference to the old [`MastForest`] should be remapped using this map.
177    pub fn remove_nodes(
178        &mut self,
179        nodes_to_remove: &BTreeSet<MastNodeId>,
180    ) -> BTreeMap<MastNodeId, MastNodeId> {
181        if nodes_to_remove.is_empty() {
182            return BTreeMap::new();
183        }
184
185        let old_nodes = core::mem::replace(&mut self.nodes, IndexVec::new());
186        let old_root_ids = core::mem::take(&mut self.roots);
187        let (retained_nodes, id_remappings) = remove_nodes(old_nodes.into_inner(), nodes_to_remove);
188
189        self.remap_and_add_nodes(retained_nodes, &id_remappings);
190        self.remap_and_add_roots(old_root_ids, &id_remappings);
191
192        // Remap the asm_op_storage to use the new node IDs
193        self.debug_info.remap_asm_op_storage(&id_remappings);
194
195        // Invalidate the cached commitment since we modified the forest structure
196        self.commitment_cache.take();
197
198        id_remappings
199    }
200
201    /// Clears all [`DebugInfo`] from this forest: decorators, error codes, and procedure names.
202    ///
203    /// ```
204    /// # use miden_core::mast::MastForest;
205    /// let mut forest = MastForest::new();
206    /// forest.clear_debug_info();
207    /// assert!(forest.decorators().is_empty());
208    /// ```
209    pub fn clear_debug_info(&mut self) {
210        self.debug_info = DebugInfo::empty_for_nodes(self.nodes.len());
211    }
212
213    /// Compacts the forest by merging duplicate nodes.
214    ///
215    /// This operation performs node deduplication by merging the forest with itself.
216    /// The method assumes that debug info has already been cleared if that is desired.
217    /// This method consumes the forest and returns a new compacted forest.
218    ///
219    /// The process works by:
220    /// 1. Merging the forest with itself to deduplicate identical nodes
221    /// 2. Updating internal node references and remappings
222    /// 3. Returning the compacted forest and root map
223    ///
224    /// # Examples
225    ///
226    /// ```rust
227    /// use miden_core::mast::MastForest;
228    ///
229    /// let mut forest = MastForest::new();
230    /// // Add nodes to the forest
231    ///
232    /// // First clear debug info if needed
233    /// forest.clear_debug_info();
234    ///
235    /// // Then compact the forest (consumes the original)
236    /// let (compacted_forest, root_map) = forest.compact();
237    ///
238    /// // compacted_forest is now compacted with duplicate nodes merged
239    /// ```
240    pub fn compact(self) -> (MastForest, MastForestRootMap) {
241        // Merge with itself to deduplicate nodes
242        // Note: This cannot fail for a self-merge under normal conditions.
243        // The only possible failures (TooManyNodes, TooManyDecorators) would require the
244        // original forest to be at capacity limits, at which point compaction wouldn't help.
245        MastForest::merge([&self])
246            .expect("Failed to compact MastForest: this should never happen during self-merge")
247    }
248
249    /// Merges all `forests` into a new [`MastForest`].
250    ///
251    /// Merging two forests means combining all their constituent parts, i.e. [`MastNode`]s,
252    /// [`Decorator`]s and roots. During this process, any duplicate or
253    /// unreachable nodes are removed. Additionally, [`MastNodeId`]s of nodes as well as
254    /// [`DecoratorId`]s of decorators may change and references to them are remapped to their new
255    /// location.
256    ///
257    /// For example, consider this representation of a forest's nodes with all of these nodes being
258    /// roots:
259    ///
260    /// ```text
261    /// [Block(foo), Block(bar)]
262    /// ```
263    ///
264    /// If we merge another forest into it:
265    ///
266    /// ```text
267    /// [Block(bar), Call(0)]
268    /// ```
269    ///
270    /// then we would expect this forest:
271    ///
272    /// ```text
273    /// [Block(foo), Block(bar), Call(1)]
274    /// ```
275    ///
276    /// - The `Call` to the `bar` block was remapped to its new index (now 1, previously 0).
277    /// - The `Block(bar)` was deduplicated any only exists once in the merged forest.
278    ///
279    /// The function also returns a vector of [`MastForestRootMap`]s, whose length equals the number
280    /// of passed `forests`. The indices in the vector correspond to the ones in `forests`. The map
281    /// of a given forest contains the new locations of its roots in the merged forest. To
282    /// illustrate, the above example would return a vector of two maps:
283    ///
284    /// ```text
285    /// vec![{0 -> 0, 1 -> 1}
286    ///      {0 -> 1, 1 -> 2}]
287    /// ```
288    ///
289    /// - The root locations of the original forest are unchanged.
290    /// - For the second forest, the `bar` block has moved from index 0 to index 1 in the merged
291    ///   forest, and the `Call` has moved from index 1 to 2.
292    ///
293    /// If any forest being merged contains an `External(qux)` node and another forest contains a
294    /// node whose digest is `qux`, then the external node will be replaced with the `qux` node,
295    /// which is effectively deduplication. Decorators are ignored when it comes to merging
296    /// External nodes. This means that an External node with decorators may be replaced by a node
297    /// without decorators or vice versa.
298    pub fn merge<'forest>(
299        forests: impl IntoIterator<Item = &'forest MastForest>,
300    ) -> Result<(MastForest, MastForestRootMap), MastForestError> {
301        MastForestMerger::merge(forests)
302    }
303}
304
305// ------------------------------------------------------------------------------------------------
306/// Helpers
307impl MastForest {
308    /// Adds all provided nodes to the internal set of nodes, remapping all [`MastNodeId`]
309    /// references in those nodes.
310    ///
311    /// # Panics
312    /// - Panics if the internal set of nodes is not empty.
313    fn remap_and_add_nodes(
314        &mut self,
315        nodes_to_add: Vec<MastNode>,
316        id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
317    ) {
318        assert!(self.nodes.is_empty());
319        // extract decorator information from the nodes by converting them into builders
320        let node_builders =
321            nodes_to_add.into_iter().map(|node| node.to_builder(self)).collect::<Vec<_>>();
322
323        // Clear decorator storage after extracting builders (builders contain decorator data)
324        self.debug_info.clear_mappings();
325
326        // Add each node to the new MAST forest, making sure to rewrite any outdated internal
327        // `MastNodeId`s
328        for live_node_builder in node_builders {
329            live_node_builder.remap_children(id_remappings).add_to_forest(self).unwrap();
330        }
331    }
332
333    /// Remaps and adds all old root ids to the internal set of roots.
334    ///
335    /// # Panics
336    /// - Panics if the internal set of roots is not empty.
337    fn remap_and_add_roots(
338        &mut self,
339        old_root_ids: Vec<MastNodeId>,
340        id_remappings: &BTreeMap<MastNodeId, MastNodeId>,
341    ) {
342        assert!(self.roots.is_empty());
343
344        for old_root_id in old_root_ids {
345            let new_root_id = id_remappings.get(&old_root_id).copied().unwrap_or(old_root_id);
346            self.make_root(new_root_id);
347        }
348    }
349}
350
351/// Returns the set of nodes that are live, as well as the mapping from "old ID" to "new ID" for all
352/// live nodes.
353fn remove_nodes(
354    mast_nodes: Vec<MastNode>,
355    nodes_to_remove: &BTreeSet<MastNodeId>,
356) -> (Vec<MastNode>, BTreeMap<MastNodeId, MastNodeId>) {
357    // Note: this allows us to safely use `usize as u32`, guaranteeing that it won't wrap around.
358    assert!(mast_nodes.len() < u32::MAX as usize);
359
360    let mut retained_nodes = Vec::with_capacity(mast_nodes.len());
361    let mut id_remappings = BTreeMap::new();
362
363    for (old_node_index, old_node) in mast_nodes.into_iter().enumerate() {
364        let old_node_id: MastNodeId = MastNodeId(old_node_index as u32);
365
366        if !nodes_to_remove.contains(&old_node_id) {
367            let new_node_id: MastNodeId = MastNodeId(retained_nodes.len() as u32);
368            id_remappings.insert(old_node_id, new_node_id);
369
370            retained_nodes.push(old_node);
371        }
372    }
373
374    (retained_nodes, id_remappings)
375}
376
377// ------------------------------------------------------------------------------------------------
378/// Public accessors
379impl MastForest {
380    /// Returns the [`MastNode`] associated with the provided [`MastNodeId`] if valid, or else
381    /// `None`.
382    ///
383    /// This is the fallible version of indexing (e.g. `mast_forest[node_id]`).
384    #[inline(always)]
385    pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> {
386        self.nodes.get(node_id)
387    }
388
389    /// Returns the [`MastNodeId`] of the procedure associated with a given digest, if any.
390    #[inline(always)]
391    pub fn find_procedure_root(&self, digest: Word) -> Option<MastNodeId> {
392        self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied()
393    }
394
395    /// Returns true if a node with the specified ID is a root of a procedure in this MAST forest.
396    pub fn is_procedure_root(&self, node_id: MastNodeId) -> bool {
397        self.roots.contains(&node_id)
398    }
399
400    /// Returns an iterator over the digests of all procedures in this MAST forest.
401    pub fn procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
402        self.roots.iter().map(|&root_id| self[root_id].digest())
403    }
404
405    /// Returns an iterator over the digests of local procedures in this MAST forest.
406    ///
407    /// A local procedure is defined as a procedure which is not a single external node.
408    pub fn local_procedure_digests(&self) -> impl Iterator<Item = Word> + '_ {
409        self.roots.iter().filter_map(|&root_id| {
410            let node = &self[root_id];
411            if node.is_external() { None } else { Some(node.digest()) }
412        })
413    }
414
415    /// Returns an iterator over the IDs of the procedures in this MAST forest.
416    pub fn procedure_roots(&self) -> &[MastNodeId] {
417        &self.roots
418    }
419
420    /// Returns the number of procedures in this MAST forest.
421    pub fn num_procedures(&self) -> u32 {
422        self.roots
423            .len()
424            .try_into()
425            .expect("MAST forest contains more than 2^32 procedures.")
426    }
427
428    /// Returns the [Word] representing the content hash of a subset of [`MastNodeId`]s.
429    ///
430    /// # Panics
431    /// This function panics if any `node_ids` is not a node of this forest.
432    pub fn compute_nodes_commitment<'a>(
433        &self,
434        node_ids: impl IntoIterator<Item = &'a MastNodeId>,
435    ) -> Word {
436        let mut digests: Vec<Word> = node_ids.into_iter().map(|&id| self[id].digest()).collect();
437        digests.sort_unstable_by_key(|word| LexicographicWord::from(*word));
438        miden_crypto::hash::poseidon2::Poseidon2::merge_many(&digests)
439    }
440
441    /// Returns the commitment to this MAST forest.
442    ///
443    /// The commitment is computed as the sequential hash of all procedure roots in the forest.
444    /// This value is cached after the first computation and reused for subsequent calls,
445    /// unless the forest is mutated (in which case the cache is invalidated).
446    ///
447    /// The commitment uniquely identifies the forest's structure, as each root's digest
448    /// transitively includes all of its descendants. Therefore, a commitment to all roots
449    /// is a commitment to the entire forest.
450    pub fn commitment(&self) -> Word {
451        *self.commitment_cache.get_or_init(|| self.compute_nodes_commitment(&self.roots))
452    }
453
454    /// Returns the number of nodes in this MAST forest.
455    pub fn num_nodes(&self) -> u32 {
456        self.nodes.len() as u32
457    }
458
459    /// Returns the underlying nodes in this MAST forest.
460    pub fn nodes(&self) -> &[MastNode] {
461        self.nodes.as_slice()
462    }
463
464    pub fn advice_map(&self) -> &AdviceMap {
465        &self.advice_map
466    }
467
468    pub fn advice_map_mut(&mut self) -> &mut AdviceMap {
469        &mut self.advice_map
470    }
471
472    // SERIALIZATION
473    // --------------------------------------------------------------------------------------------
474
475    /// Serializes this MastForest without debug information.
476    ///
477    /// This produces a smaller output by omitting decorators, error codes, and procedure names.
478    /// The resulting bytes can be deserialized with the standard [`Deserializable`] impl,
479    /// which auto-detects the format and creates an empty [`DebugInfo`].
480    ///
481    /// Use this for production builds where debug info is not needed.
482    ///
483    /// # Example
484    ///
485    /// ```
486    /// use miden_core::{mast::MastForest, serde::Serializable};
487    ///
488    /// let forest = MastForest::new();
489    ///
490    /// // Full serialization (with debug info)
491    /// let full_bytes = forest.to_bytes();
492    ///
493    /// // Stripped serialization (without debug info)
494    /// let mut stripped_bytes = Vec::new();
495    /// forest.write_stripped(&mut stripped_bytes);
496    ///
497    /// // Both can be deserialized the same way
498    /// // let restored = MastForest::read_from_bytes(&stripped_bytes).unwrap();
499    /// ```
500    pub fn write_stripped<W: ByteWriter>(&self, target: &mut W) {
501        use serialization::StrippedMastForest;
502        StrippedMastForest(self).write_into(target);
503    }
504}
505
506// ------------------------------------------------------------------------------------------------
507/// Decorator methods
508impl MastForest {
509    /// Returns a list of all decorators contained in this [MastForest].
510    pub fn decorators(&self) -> &[Decorator] {
511        self.debug_info.decorators()
512    }
513
514    /// Returns the [`Decorator`] associated with the provided [`DecoratorId`] if valid, or else
515    /// `None`.
516    ///
517    /// This is the fallible version of indexing (e.g. `mast_forest[decorator_id]`).
518    #[inline]
519    pub fn decorator_by_id(&self, decorator_id: DecoratorId) -> Option<&Decorator> {
520        self.debug_info.decorator(decorator_id)
521    }
522
523    /// Returns decorator indices for a specific operation within a node.
524    ///
525    /// This is the primary accessor for reading decorators from the centralized storage.
526    /// Returns a slice of decorator IDs for the given operation.
527    #[inline]
528    pub(crate) fn decorator_indices_for_op(
529        &self,
530        node_id: MastNodeId,
531        local_op_idx: usize,
532    ) -> &[DecoratorId] {
533        self.debug_info.decorators_for_operation(node_id, local_op_idx)
534    }
535
536    /// Returns an iterator over decorator references for a specific operation within a node.
537    ///
538    /// This is the preferred method for accessing decorators, as it provides direct
539    /// references to the decorator objects.
540    #[inline]
541    pub fn decorators_for_op<'a>(
542        &'a self,
543        node_id: MastNodeId,
544        local_op_idx: usize,
545    ) -> impl Iterator<Item = &'a Decorator> + 'a {
546        self.decorator_indices_for_op(node_id, local_op_idx)
547            .iter()
548            .map(move |&decorator_id| &self[decorator_id])
549    }
550
551    /// Returns the decorators to be executed before this node is executed.
552    #[inline]
553    pub fn before_enter_decorators(&self, node_id: MastNodeId) -> &[DecoratorId] {
554        self.debug_info.before_enter_decorators(node_id)
555    }
556
557    /// Returns the decorators to be executed after this node is executed.
558    #[inline]
559    pub fn after_exit_decorators(&self, node_id: MastNodeId) -> &[DecoratorId] {
560        self.debug_info.after_exit_decorators(node_id)
561    }
562
563    /// Returns decorator links for a node, including operation indices.
564    ///
565    /// This provides a flattened view of all decorators for a node with their operation indices.
566    #[inline]
567    pub(crate) fn decorator_links_for_node<'a>(
568        &'a self,
569        node_id: MastNodeId,
570    ) -> Result<DecoratedLinks<'a>, DecoratorIndexError> {
571        self.debug_info.decorator_links_for_node(node_id)
572    }
573
574    /// Adds a decorator to the forest, and returns the associated [`DecoratorId`].
575    pub fn add_decorator(&mut self, decorator: Decorator) -> Result<DecoratorId, MastForestError> {
576        self.debug_info.add_decorator(decorator)
577    }
578
579    /// Adds a debug variable to the forest, and returns the associated [`DebugVarId`].
580    pub fn add_debug_var(
581        &mut self,
582        debug_var: DebugVarInfo,
583    ) -> Result<DebugVarId, MastForestError> {
584        self.debug_info.add_debug_var(debug_var)
585    }
586
587    /// Returns debug variable IDs for a specific operation within a node.
588    pub fn debug_vars_for_operation(
589        &self,
590        node_id: MastNodeId,
591        local_op_idx: usize,
592    ) -> &[DebugVarId] {
593        self.debug_info.debug_vars_for_operation(node_id, local_op_idx)
594    }
595
596    /// Returns the debug variable with the given ID, if it exists.
597    pub fn debug_var(&self, debug_var_id: DebugVarId) -> Option<&DebugVarInfo> {
598        self.debug_info.debug_var(debug_var_id)
599    }
600
601    /// Adds decorator IDs for a node to the storage.
602    ///
603    /// Used when building nodes for efficient decorator access during execution.
604    ///
605    /// # Note
606    /// This method does not validate decorator IDs immediately. Validation occurs during
607    /// operations that need to access the actual decorator data (e.g., merging, serialization).
608    #[inline]
609    pub(crate) fn register_node_decorators(
610        &mut self,
611        node_id: MastNodeId,
612        before_enter: &[DecoratorId],
613        after_exit: &[DecoratorId],
614    ) {
615        self.debug_info.register_node_decorators(node_id, before_enter, after_exit);
616    }
617
618    /// Returns the [`AssemblyOp`] associated with a node.
619    ///
620    /// For basic block nodes with a `target_op_idx`, returns the AssemblyOp for that operation.
621    /// For other nodes or when no `target_op_idx` is provided, returns the first AssemblyOp.
622    pub fn get_assembly_op(
623        &self,
624        node_id: MastNodeId,
625        target_op_idx: Option<usize>,
626    ) -> Option<&AssemblyOp> {
627        match target_op_idx {
628            Some(op_idx) => self.debug_info.asm_op_for_operation(node_id, op_idx),
629            None => self.debug_info.first_asm_op_for_node(node_id),
630        }
631    }
632}
633
634// ------------------------------------------------------------------------------------------------
635/// Validation methods
636impl MastForest {
637    /// Validates that all BasicBlockNodes in this forest satisfy the core invariants:
638    /// 1. Power-of-two number of groups in each batch
639    /// 2. No operation group ends with an operation requiring an immediate value
640    /// 3. The last operation group in a batch cannot contain operations requiring immediate values
641    /// 4. OpBatch structural consistency (num_groups <= BATCH_SIZE, group size <= GROUP_SIZE,
642    ///    indptr integrity, bounds checking)
643    ///
644    /// This addresses the gap created by PR 2094, where padding NOOPs are now inserted
645    /// at assembly time rather than dynamically during execution, and adds comprehensive
646    /// structural validation to prevent deserialization-time panics.
647    pub fn validate(&self) -> Result<(), MastForestError> {
648        // Validate basic block batch invariants
649        for (node_id_idx, node) in self.nodes.iter().enumerate() {
650            let node_id =
651                MastNodeId::new_unchecked(node_id_idx.try_into().expect("too many nodes"));
652            if let MastNode::Block(basic_block) = node {
653                basic_block.validate_batch_invariants().map_err(|error_msg| {
654                    MastForestError::InvalidBatchPadding(node_id, error_msg)
655                })?;
656            }
657        }
658
659        // Validate that all procedure name digests correspond to procedure roots in the forest
660        for (digest, _) in self.debug_info.procedure_names() {
661            if self.find_procedure_root(digest).is_none() {
662                return Err(MastForestError::InvalidProcedureNameDigest(digest));
663            }
664        }
665
666        Ok(())
667    }
668
669    /// Validates topological ordering of nodes and recomputes all node hashes.
670    ///
671    /// This method iterates through all nodes in index order, verifying:
672    /// 1. All child references point to nodes with smaller indices (topological order)
673    /// 2. Each node's recomputed digest matches its stored digest
674    ///
675    /// # Errors
676    ///
677    /// Returns `MastForestError::ForwardReference` if any node references a child that
678    /// appears later in the forest.
679    ///
680    /// Returns `MastForestError::HashMismatch` if any node's recomputed digest doesn't
681    /// match its stored digest.
682    fn validate_node_hashes(&self) -> Result<(), MastForestError> {
683        use crate::chiplets::hasher;
684
685        /// Checks that child_id references a node that appears before node_id in topological order.
686        fn check_no_forward_ref(
687            node_id: MastNodeId,
688            child_id: MastNodeId,
689        ) -> Result<(), MastForestError> {
690            if child_id.0 >= node_id.0 {
691                return Err(MastForestError::ForwardReference(node_id, child_id));
692            }
693            Ok(())
694        }
695
696        for (node_idx, node) in self.nodes.iter().enumerate() {
697            let node_id = MastNodeId::new_unchecked(node_idx as u32);
698
699            // Check topological ordering and compute expected digest
700            let computed_digest = match node {
701                MastNode::Block(block) => {
702                    let op_groups: Vec<Felt> =
703                        block.op_batches().iter().flat_map(|batch| *batch.groups()).collect();
704                    hasher::hash_elements(&op_groups)
705                },
706                MastNode::Join(join) => {
707                    let left_id = join.first();
708                    let right_id = join.second();
709                    check_no_forward_ref(node_id, left_id)?;
710                    check_no_forward_ref(node_id, right_id)?;
711
712                    let left_digest = self.nodes[left_id].digest();
713                    let right_digest = self.nodes[right_id].digest();
714                    hasher::merge_in_domain(&[left_digest, right_digest], JoinNode::DOMAIN)
715                },
716                MastNode::Split(split) => {
717                    let true_id = split.on_true();
718                    let false_id = split.on_false();
719                    check_no_forward_ref(node_id, true_id)?;
720                    check_no_forward_ref(node_id, false_id)?;
721
722                    let true_digest = self.nodes[true_id].digest();
723                    let false_digest = self.nodes[false_id].digest();
724                    hasher::merge_in_domain(&[true_digest, false_digest], SplitNode::DOMAIN)
725                },
726                MastNode::Loop(loop_node) => {
727                    let body_id = loop_node.body();
728                    check_no_forward_ref(node_id, body_id)?;
729
730                    let body_digest = self.nodes[body_id].digest();
731                    hasher::merge_in_domain(&[body_digest, Word::default()], LoopNode::DOMAIN)
732                },
733                MastNode::Call(call) => {
734                    let callee_id = call.callee();
735                    check_no_forward_ref(node_id, callee_id)?;
736
737                    let callee_digest = self.nodes[callee_id].digest();
738                    let domain = if call.is_syscall() {
739                        CallNode::SYSCALL_DOMAIN
740                    } else {
741                        CallNode::CALL_DOMAIN
742                    };
743                    hasher::merge_in_domain(&[callee_digest, Word::default()], domain)
744                },
745                MastNode::Dyn(dyn_node) => {
746                    if dyn_node.is_dyncall() {
747                        DynNode::DYNCALL_DEFAULT_DIGEST
748                    } else {
749                        DynNode::DYN_DEFAULT_DIGEST
750                    }
751                },
752                MastNode::External(_) => {
753                    // External nodes have externally-provided digests that cannot be recomputed
754                    continue;
755                },
756            };
757
758            let stored_digest = node.digest();
759            if computed_digest != stored_digest {
760                return Err(MastForestError::HashMismatch {
761                    node_id,
762                    expected: stored_digest,
763                    computed: computed_digest,
764                });
765            }
766        }
767
768        Ok(())
769    }
770}
771
772// ------------------------------------------------------------------------------------------------
773/// Error message methods
774impl MastForest {
775    /// Given an error code as a Felt, resolves it to its corresponding error message.
776    pub fn resolve_error_message(&self, code: Felt) -> Option<Arc<str>> {
777        let key = code.as_canonical_u64();
778        self.debug_info.error_message(key)
779    }
780
781    /// Registers an error message in the MAST Forest and returns the corresponding error code as a
782    /// Felt.
783    pub fn register_error(&mut self, msg: Arc<str>) -> Felt {
784        let code: Felt = error_code_from_msg(&msg);
785        // we use u64 as keys for the map
786        self.debug_info.insert_error_code(code.as_canonical_u64(), msg);
787        code
788    }
789}
790
791// ------------------------------------------------------------------------------------------------
792/// Procedure name methods
793impl MastForest {
794    /// Returns the procedure name for the given MAST root digest, if present.
795    pub fn procedure_name(&self, digest: &Word) -> Option<&str> {
796        self.debug_info.procedure_name(digest)
797    }
798
799    /// Returns an iterator over all (digest, name) pairs of procedure names.
800    pub fn procedure_names(&self) -> impl Iterator<Item = (Word, &Arc<str>)> {
801        self.debug_info.procedure_names()
802    }
803
804    /// Inserts a procedure name for the given MAST root digest.
805    pub fn insert_procedure_name(&mut self, digest: Word, name: Arc<str>) {
806        assert!(
807            self.find_procedure_root(digest).is_some(),
808            "attempted to insert procedure name for digest that is not a procedure root"
809        );
810        self.debug_info.insert_procedure_name(digest, name);
811    }
812
813    /// Returns a reference to the debug info for this forest.
814    pub fn debug_info(&self) -> &DebugInfo {
815        &self.debug_info
816    }
817
818    /// Returns a mutable reference to the debug info.
819    ///
820    /// This is intended for use by the assembler to register AssemblyOps and other debug
821    /// information during compilation.
822    pub fn debug_info_mut(&mut self) -> &mut DebugInfo {
823        &mut self.debug_info
824    }
825}
826
827// TEST HELPERS
828// ================================================================================================
829
830#[cfg(test)]
831impl MastForest {
832    /// Returns all decorators for a given node as a vector of (position, DecoratorId) tuples.
833    ///
834    /// This helper method combines before_enter, operation-indexed, and after_exit decorators
835    /// into a single collection, which is useful for testing decorator positions and ordering.
836    ///
837    /// **Performance Warning**: This method performs multiple allocations through collect() calls
838    /// and should not be relied upon for performance-critical code. It is intended for testing
839    /// only.
840    pub fn all_decorators(&self, node_id: MastNodeId) -> Vec<(usize, DecoratorId)> {
841        let node = &self[node_id];
842
843        // For non-basic blocks, just get before_enter and after_exit decorators at position 0
844        if !node.is_basic_block() {
845            let before_enter_decorators: Vec<_> = self
846                .before_enter_decorators(node_id)
847                .iter()
848                .map(|&deco_id| (0, deco_id))
849                .collect();
850
851            let after_exit_decorators: Vec<_> = self
852                .after_exit_decorators(node_id)
853                .iter()
854                .map(|&deco_id| (1, deco_id))
855                .collect();
856
857            return [before_enter_decorators, after_exit_decorators].concat();
858        }
859
860        // For basic blocks, we need to handle operation-indexed decorators with proper positioning
861        let block = node.unwrap_basic_block();
862
863        // Before-enter decorators are at position 0
864        let before_enter_decorators: Vec<_> = self
865            .before_enter_decorators(node_id)
866            .iter()
867            .map(|&deco_id| (0, deco_id))
868            .collect();
869
870        // Operation-indexed decorators with their actual positions
871        let op_indexed_decorators: Vec<_> =
872            self.decorator_links_for_node(node_id).unwrap().into_iter().collect();
873
874        // After-exit decorators are positioned after all operations
875        let after_exit_decorators: Vec<_> = self
876            .after_exit_decorators(node_id)
877            .iter()
878            .map(|&deco_id| (block.num_operations() as usize, deco_id))
879            .collect();
880
881        [before_enter_decorators, op_indexed_decorators, after_exit_decorators].concat()
882    }
883}
884
885// MAST FOREST INDEXING
886// ------------------------------------------------------------------------------------------------
887
888impl Index<MastNodeId> for MastForest {
889    type Output = MastNode;
890
891    #[inline(always)]
892    fn index(&self, node_id: MastNodeId) -> &Self::Output {
893        &self.nodes[node_id]
894    }
895}
896
897impl IndexMut<MastNodeId> for MastForest {
898    #[inline(always)]
899    fn index_mut(&mut self, node_id: MastNodeId) -> &mut Self::Output {
900        &mut self.nodes[node_id]
901    }
902}
903
904impl Index<DecoratorId> for MastForest {
905    type Output = Decorator;
906
907    #[inline(always)]
908    fn index(&self, decorator_id: DecoratorId) -> &Self::Output {
909        self.debug_info.decorator(decorator_id).expect("DecoratorId out of bounds")
910    }
911}
912
913impl IndexMut<DecoratorId> for MastForest {
914    #[inline(always)]
915    fn index_mut(&mut self, decorator_id: DecoratorId) -> &mut Self::Output {
916        self.debug_info.decorator_mut(decorator_id).expect("DecoratorId out of bounds")
917    }
918}
919
920// MAST NODE ID
921// ================================================================================================
922
923/// An opaque handle to a [`MastNode`] in some [`MastForest`]. It is the responsibility of the user
924/// to use a given [`MastNodeId`] with the corresponding [`MastForest`].
925///
926/// Note that the [`MastForest`] does *not* ensure that equal [`MastNode`]s have equal
927/// [`MastNodeId`] handles. Hence, [`MastNodeId`] equality must not be used to test for equality of
928/// the underlying [`MastNode`].
929#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
930#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
931#[cfg_attr(feature = "serde", serde(transparent))]
932#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
933pub struct MastNodeId(u32);
934
935/// Operations that mutate a MAST often produce this mapping between old and new NodeIds.
936pub type Remapping = BTreeMap<MastNodeId, MastNodeId>;
937
938impl MastNodeId {
939    /// Returns a new `MastNodeId` with the provided inner value, or an error if the provided
940    /// `value` is greater than the number of nodes in the forest.
941    ///
942    /// For use in deserialization.
943    pub fn from_u32_safe(
944        value: u32,
945        mast_forest: &MastForest,
946    ) -> Result<Self, DeserializationError> {
947        Self::from_u32_with_node_count(value, mast_forest.nodes.len())
948    }
949
950    /// Returns a new [`MastNodeId`] with the provided `node_id`, or an error if `node_id` is
951    /// greater than the number of nodes in the [`MastForest`] for which this ID is being
952    /// constructed.
953    pub fn from_usize_safe(
954        node_id: usize,
955        mast_forest: &MastForest,
956    ) -> Result<Self, DeserializationError> {
957        let node_id: u32 = node_id.try_into().map_err(|_| {
958            DeserializationError::InvalidValue(format!(
959                "node id '{node_id}' does not fit into a u32"
960            ))
961        })?;
962        MastNodeId::from_u32_safe(node_id, mast_forest)
963    }
964
965    /// Returns a new [`MastNodeId`] from the given `value` without checking its validity.
966    pub fn new_unchecked(value: u32) -> Self {
967        Self(value)
968    }
969
970    /// Returns a new [`MastNodeId`] with the provided `id`, or an error if `id` is greater or equal
971    /// to `node_count`. The `node_count` is the total number of nodes in the [`MastForest`] for
972    /// which this ID is being constructed.
973    ///
974    /// This function can be used when deserializing an id whose corresponding node is not yet in
975    /// the forest and [`Self::from_u32_safe`] would fail. For instance, when deserializing the ids
976    /// referenced by the Join node in this forest:
977    ///
978    /// ```text
979    /// [Join(1, 2), Block(foo), Block(bar)]
980    /// ```
981    ///
982    /// Since it is less safe than [`Self::from_u32_safe`] and usually not needed it is not public.
983    pub(super) fn from_u32_with_node_count(
984        id: u32,
985        node_count: usize,
986    ) -> Result<Self, DeserializationError> {
987        if (id as usize) < node_count {
988            Ok(Self(id))
989        } else {
990            Err(DeserializationError::InvalidValue(format!(
991                "Invalid deserialized MAST node ID '{id}', but {node_count} is the number of nodes in the forest",
992            )))
993        }
994    }
995
996    /// Remap the NodeId to its new position using the given [`Remapping`].
997    pub fn remap(&self, remapping: &Remapping) -> Self {
998        *remapping.get(self).unwrap_or(self)
999    }
1000}
1001
1002impl From<u32> for MastNodeId {
1003    fn from(value: u32) -> Self {
1004        MastNodeId::new_unchecked(value)
1005    }
1006}
1007
1008impl Idx for MastNodeId {}
1009
1010impl From<MastNodeId> for u32 {
1011    fn from(value: MastNodeId) -> Self {
1012        value.0
1013    }
1014}
1015
1016impl fmt::Display for MastNodeId {
1017    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1018        write!(f, "MastNodeId({})", self.0)
1019    }
1020}
1021
1022#[cfg(any(test, feature = "arbitrary"))]
1023impl proptest::prelude::Arbitrary for MastNodeId {
1024    type Parameters = ();
1025
1026    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
1027        use proptest::prelude::*;
1028        any::<u32>().prop_map(MastNodeId).boxed()
1029    }
1030
1031    type Strategy = proptest::prelude::BoxedStrategy<Self>;
1032}
1033
1034// ITERATOR
1035
1036/// Iterates over all the nodes a root depends on, in pre-order. The iteration can include other
1037/// roots in the same forest.
1038pub struct SubtreeIterator<'a> {
1039    forest: &'a MastForest,
1040    discovered: Vec<MastNodeId>,
1041    unvisited: Vec<MastNodeId>,
1042}
1043impl<'a> SubtreeIterator<'a> {
1044    pub fn new(root: &MastNodeId, forest: &'a MastForest) -> Self {
1045        let discovered = vec![];
1046        let unvisited = vec![*root];
1047        SubtreeIterator { forest, discovered, unvisited }
1048    }
1049}
1050impl Iterator for SubtreeIterator<'_> {
1051    type Item = MastNodeId;
1052    fn next(&mut self) -> Option<MastNodeId> {
1053        while let Some(id) = self.unvisited.pop() {
1054            let node = &self.forest[id];
1055            if !node.has_children() {
1056                return Some(id);
1057            } else {
1058                self.discovered.push(id);
1059                node.append_children_to(&mut self.unvisited);
1060            }
1061        }
1062        self.discovered.pop()
1063    }
1064}
1065
1066// DECORATOR ID
1067// ================================================================================================
1068
1069/// An opaque handle to a [`Decorator`] in some [`MastForest`]. It is the responsibility of the user
1070/// to use a given [`DecoratorId`] with the corresponding [`MastForest`].
1071#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
1072#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1073#[cfg_attr(feature = "serde", serde(transparent))]
1074pub struct DecoratorId(u32);
1075
1076impl DecoratorId {
1077    /// Returns a new `DecoratorId` with the provided inner value, or an error if the provided
1078    /// `value` is greater than the number of nodes in the forest.
1079    ///
1080    /// For use in deserialization.
1081    pub fn from_u32_safe(
1082        value: u32,
1083        mast_forest: &MastForest,
1084    ) -> Result<Self, DeserializationError> {
1085        Self::from_u32_bounded(value, mast_forest.debug_info.num_decorators())
1086    }
1087
1088    /// Returns a new `DecoratorId` with the provided inner value, or an error if the provided
1089    /// `value` is greater than or equal to `bound`.
1090    ///
1091    /// For use in deserialization when the bound is known without needing the full MastForest.
1092    pub fn from_u32_bounded(value: u32, bound: usize) -> Result<Self, DeserializationError> {
1093        if (value as usize) < bound {
1094            Ok(Self(value))
1095        } else {
1096            Err(DeserializationError::InvalidValue(format!(
1097                "Invalid deserialized MAST decorator id '{}', but allows only {} decorators",
1098                value, bound,
1099            )))
1100        }
1101    }
1102
1103    /// Creates a new [`DecoratorId`] without checking its validity.
1104    pub(crate) fn new_unchecked(value: u32) -> Self {
1105        Self(value)
1106    }
1107}
1108
1109impl From<u32> for DecoratorId {
1110    fn from(value: u32) -> Self {
1111        DecoratorId::new_unchecked(value)
1112    }
1113}
1114
1115impl Idx for DecoratorId {}
1116
1117impl From<DecoratorId> for u32 {
1118    fn from(value: DecoratorId) -> Self {
1119        value.0
1120    }
1121}
1122
1123impl fmt::Display for DecoratorId {
1124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1125        write!(f, "DecoratorId({})", self.0)
1126    }
1127}
1128
1129impl Serializable for DecoratorId {
1130    fn write_into<W: ByteWriter>(&self, target: &mut W) {
1131        self.0.write_into(target)
1132    }
1133}
1134
1135impl Deserializable for DecoratorId {
1136    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
1137        let value = u32::read_from(source)?;
1138        Ok(Self(value))
1139    }
1140}
1141
1142// ASM OP ID
1143// ================================================================================================
1144
1145/// Unique identifier for an [`AssemblyOp`] within a [`MastForest`].
1146///
1147/// Unlike decorators (which are executed at runtime), AssemblyOps are metadata
1148/// used only for error context and debugging tools.
1149#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
1150#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
1151#[cfg_attr(feature = "serde", serde(transparent))]
1152pub struct AsmOpId(u32);
1153
1154impl AsmOpId {
1155    /// Creates a new [`AsmOpId`] with the provided inner value.
1156    pub const fn new(value: u32) -> Self {
1157        Self(value)
1158    }
1159}
1160
1161impl From<u32> for AsmOpId {
1162    fn from(value: u32) -> Self {
1163        AsmOpId::new(value)
1164    }
1165}
1166
1167impl Idx for AsmOpId {}
1168
1169impl From<AsmOpId> for u32 {
1170    fn from(id: AsmOpId) -> Self {
1171        id.0
1172    }
1173}
1174
1175impl fmt::Display for AsmOpId {
1176    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1177        write!(f, "AsmOpId({})", self.0)
1178    }
1179}
1180
1181impl Serializable for AsmOpId {
1182    fn write_into<W: ByteWriter>(&self, target: &mut W) {
1183        self.0.write_into(target)
1184    }
1185}
1186
1187impl Deserializable for AsmOpId {
1188    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
1189        let value = u32::read_from(source)?;
1190        Ok(Self(value))
1191    }
1192}
1193
1194/// Derives an error code from an error message by hashing the message and returning the 0th element
1195/// of the resulting [`Word`].
1196pub fn error_code_from_msg(msg: impl AsRef<str>) -> Felt {
1197    // hash the message and return 0th felt of the resulting Word
1198    hash_string_to_word(msg.as_ref())[0]
1199}
1200
1201// MAST FOREST ERROR
1202// ================================================================================================
1203
1204/// Represents the types of errors that can occur when dealing with MAST forest.
1205#[derive(Debug, thiserror::Error, PartialEq)]
1206pub enum MastForestError {
1207    #[error("MAST forest decorator count exceeds the maximum of {} decorators", u32::MAX)]
1208    TooManyDecorators,
1209    #[error("MAST forest node count exceeds the maximum of {} nodes", MastForest::MAX_NODES)]
1210    TooManyNodes,
1211    #[error("node id {0} is greater than or equal to forest length {1}")]
1212    NodeIdOverflow(MastNodeId, usize),
1213    #[error("decorator id {0} is greater than or equal to decorator count {1}")]
1214    DecoratorIdOverflow(DecoratorId, usize),
1215    #[error("basic block cannot be created from an empty list of operations")]
1216    EmptyBasicBlock,
1217    #[error(
1218        "decorator root of child with node id {0} is missing but is required for fingerprint computation"
1219    )]
1220    ChildFingerprintMissing(MastNodeId),
1221    #[error("advice map key {0} already exists when merging forests")]
1222    AdviceMapKeyCollisionOnMerge(Word),
1223    #[error("decorator storage error: {0}")]
1224    DecoratorError(DecoratorIndexError),
1225    #[error("digest is required for deserialization")]
1226    DigestRequiredForDeserialization,
1227    #[error("invalid batch in basic block node {0:?}: {1}")]
1228    InvalidBatchPadding(MastNodeId, String),
1229    #[error("procedure name references digest that is not a procedure root: {0:?}")]
1230    InvalidProcedureNameDigest(Word),
1231    #[error(
1232        "node {0:?} references child {1:?} which comes after it in the forest (forward reference)"
1233    )]
1234    ForwardReference(MastNodeId, MastNodeId),
1235    #[error("hash mismatch for node {node_id:?}: expected {expected:?}, computed {computed:?}")]
1236    HashMismatch {
1237        node_id: MastNodeId,
1238        expected: Word,
1239        computed: Word,
1240    },
1241}
1242
1243// Custom serde implementations for MastForest that handle linked decorators properly
1244// by delegating to the existing miden-crypto serialization which already handles
1245// the conversion between linked and owned decorator formats.
1246#[cfg(feature = "serde")]
1247impl serde::Serialize for MastForest {
1248    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1249    where
1250        S: serde::Serializer,
1251    {
1252        // Use the existing miden-crypto serialization which already handles linked decorators
1253        let bytes = Serializable::to_bytes(self);
1254        serializer.serialize_bytes(&bytes)
1255    }
1256}
1257
1258#[cfg(feature = "serde")]
1259impl<'de> serde::Deserialize<'de> for MastForest {
1260    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1261    where
1262        D: serde::Deserializer<'de>,
1263    {
1264        // Deserialize bytes, then use miden-crypto Deserializable
1265        let bytes = Vec::<u8>::deserialize(deserializer)?;
1266        let mut slice_reader = SliceReader::new(&bytes);
1267        Deserializable::read_from(&mut slice_reader).map_err(serde::de::Error::custom)
1268    }
1269}
1270
1271// UNTRUSTED MAST FOREST
1272// ================================================================================================
1273
1274/// A [`MastForest`] deserialized from untrusted input that has not yet been validated.
1275///
1276/// This type wraps a `MastForest` that was deserialized from bytes but has not had its
1277/// node hashes verified. Before using the forest, callers must call [`validate()`](Self::validate)
1278/// to verify structural integrity and recompute all node hashes.
1279///
1280/// # Usage
1281///
1282/// ```ignore
1283/// // Deserialize from untrusted bytes
1284/// let untrusted = UntrustedMastForest::read_from_bytes(&bytes)?;
1285///
1286/// // Validate structure and hashes
1287/// let forest = untrusted.validate()?;
1288///
1289/// // Now safe to use
1290/// let root = forest.procedure_roots()[0];
1291/// ```
1292///
1293/// # Security
1294///
1295/// This type exists to provide type-level safety for untrusted deserialization. The validation
1296/// performed by [`validate()`](Self::validate) includes:
1297///
1298/// 1. **Structural validation**: Checks that basic block batch invariants are satisfied and
1299///    procedure names reference valid roots.
1300/// 2. **Topological ordering**: Verifies that all node references point to nodes that appear
1301///    earlier in the forest (no forward references).
1302/// 3. **Hash recomputation**: Recomputes the digest for every node and verifies it matches the
1303///    stored digest.
1304#[derive(Debug, Clone)]
1305pub struct UntrustedMastForest(MastForest);
1306
1307impl UntrustedMastForest {
1308    /// Validates the forest by checking structural invariants and recomputing all node hashes.
1309    ///
1310    /// This method performs a complete validation of the deserialized forest:
1311    ///
1312    /// 1. Validates structural invariants (batch padding, procedure names)
1313    /// 2. Validates topological ordering (no forward references)
1314    /// 3. Recomputes all node hashes and compares against stored digests
1315    ///
1316    /// # Returns
1317    ///
1318    /// - `Ok(MastForest)` if validation succeeds
1319    /// - `Err(MastForestError)` with details about the first validation failure
1320    ///
1321    /// # Errors
1322    ///
1323    /// Returns an error if:
1324    /// - Any basic block has invalid batch structure ([`MastForestError::InvalidBatchPadding`])
1325    /// - Any procedure name references a non-root digest
1326    ///   ([`MastForestError::InvalidProcedureNameDigest`])
1327    /// - Any node references a child that appears later in the forest
1328    ///   ([`MastForestError::ForwardReference`])
1329    /// - Any node's recomputed hash doesn't match its stored digest
1330    ///   ([`MastForestError::HashMismatch`])
1331    pub fn validate(self) -> Result<MastForest, MastForestError> {
1332        let forest = self.0;
1333
1334        // Step 1: Validate structural invariants (existing validate() checks)
1335        forest.validate()?;
1336
1337        // Step 2: Validate topological ordering and recompute hashes
1338        forest.validate_node_hashes()?;
1339
1340        Ok(forest)
1341    }
1342
1343    /// Deserializes an [`UntrustedMastForest`] from bytes.
1344    ///
1345    /// This method uses a [`BudgetedReader`] with a budget equal to the input size to protect
1346    /// against denial-of-service attacks from malicious input.
1347    ///
1348    /// For stricter limits, use
1349    /// [`read_from_bytes_with_budget`](Self::read_from_bytes_with_budget) with a custom budget.
1350    ///
1351    /// # Example
1352    ///
1353    /// ```ignore
1354    /// // Read from untrusted source
1355    /// let untrusted = UntrustedMastForest::read_from_bytes(&bytes)?;
1356    ///
1357    /// // Validate before use
1358    /// let forest = untrusted.validate()?;
1359    /// ```
1360    pub fn read_from_bytes(bytes: &[u8]) -> Result<Self, DeserializationError> {
1361        Self::read_from_bytes_with_budget(bytes, bytes.len())
1362    }
1363
1364    /// Deserializes an [`UntrustedMastForest`] from bytes with a byte budget.
1365    ///
1366    /// This method uses a [`BudgetedReader`] to limit memory consumption during deserialization,
1367    /// protecting against denial-of-service attacks from malicious input that claims to contain
1368    /// an excessive number of elements.
1369    ///
1370    /// # Arguments
1371    ///
1372    /// * `bytes` - The serialized forest bytes
1373    /// * `budget` - Maximum bytes to consume during deserialization. Set this to `bytes.len()` for
1374    ///   typical use cases, or lower to enforce stricter limits.
1375    ///
1376    /// # Example
1377    ///
1378    /// ```ignore
1379    /// // Read from untrusted source with budget equal to input size
1380    /// let untrusted = UntrustedMastForest::read_from_bytes_with_budget(&bytes, bytes.len())?;
1381    ///
1382    /// // Validate before use
1383    /// let forest = untrusted.validate()?;
1384    /// ```
1385    ///
1386    /// # Security
1387    ///
1388    /// The budget limits:
1389    /// - Pre-allocation sizes when deserializing collections (via `max_alloc`)
1390    /// - Total bytes consumed during deserialization
1391    ///
1392    /// This prevents attacks where malicious input claims an unrealistic number of elements
1393    /// (e.g., `len = 2^60`), causing excessive memory allocation before any data is read.
1394    pub fn read_from_bytes_with_budget(
1395        bytes: &[u8],
1396        budget: usize,
1397    ) -> Result<Self, DeserializationError> {
1398        let mut reader = BudgetedReader::new(SliceReader::new(bytes), budget);
1399        Self::read_from(&mut reader)
1400    }
1401}