miden_core/mast/node/
mod.rs

1mod basic_block_node;
2use alloc::{boxed::Box, vec::Vec};
3use core::fmt;
4
5pub use basic_block_node::{
6    BATCH_SIZE as OP_BATCH_SIZE, BasicBlockNode, BasicBlockNodeBuilder, DecoratorOpLinkIterator,
7    GROUP_SIZE as OP_GROUP_SIZE, OpBatch, OperationOrDecorator,
8};
9use derive_more::From;
10use miden_utils_core_derive::MastNodeExt;
11
12mod call_node;
13pub use call_node::{CallNode, CallNodeBuilder};
14
15mod dyn_node;
16pub use dyn_node::{DynNode, DynNodeBuilder};
17
18mod external;
19pub use external::{ExternalNode, ExternalNodeBuilder};
20
21mod join_node;
22pub use join_node::{JoinNode, JoinNodeBuilder};
23
24mod split_node;
25use miden_crypto::{Felt, Word};
26use miden_formatting::prettier::PrettyPrint;
27pub use split_node::{SplitNode, SplitNodeBuilder};
28
29mod loop_node;
30#[cfg(any(test, feature = "arbitrary"))]
31pub use basic_block_node::arbitrary;
32pub use loop_node::{LoopNode, LoopNodeBuilder};
33
34mod mast_forest_contributor;
35pub use mast_forest_contributor::{MastForestContributor, MastNodeBuilder};
36
37mod decorator_store;
38pub use decorator_store::DecoratorStore;
39
40use super::DecoratorId;
41use crate::{
42    AssemblyOp, Decorator,
43    mast::{MastForest, MastNodeId},
44};
45
46pub trait MastNodeExt {
47    /// Returns a commitment/hash of the node.
48    fn digest(&self) -> Word;
49
50    /// Returns the decorators to be executed before this node is executed.
51    fn before_enter<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId];
52
53    /// Returns the decorators to be executed after this node is executed.
54    fn after_exit<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId];
55
56    /// Returns a display formatter for this node.
57    fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a>;
58
59    /// Returns a pretty printer for this node.
60    fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a>;
61
62    /// Returns true if the this node has children.
63    fn has_children(&self) -> bool;
64
65    /// Appends the NodeIds of the children of this node, if any, to the vector.
66    fn append_children_to(&self, target: &mut Vec<MastNodeId>);
67
68    /// Executes the given closure for each child of this node.
69    fn for_each_child<F>(&self, f: F)
70    where
71        F: FnMut(MastNodeId);
72
73    /// Returns the domain of this node.
74    fn domain(&self) -> Felt;
75
76    /// Verifies that this node is stored at the ID in its decorators field in the forest.
77    #[cfg(debug_assertions)]
78    fn verify_node_in_forest(&self, forest: &MastForest);
79
80    /// Converts this node into its corresponding builder, reusing allocated data where possible.
81    type Builder: MastForestContributor;
82
83    fn to_builder(self, forest: &MastForest) -> Self::Builder;
84}
85
86// MAST NODE
87// ================================================================================================
88
89#[derive(Debug, Clone, PartialEq, Eq, From, MastNodeExt)]
90#[mast_node_ext(builder = "MastNodeBuilder")]
91pub enum MastNode {
92    Block(BasicBlockNode),
93    Join(JoinNode),
94    Split(SplitNode),
95    Loop(LoopNode),
96    Call(CallNode),
97    Dyn(DynNode),
98    External(ExternalNode),
99}
100
101// ------------------------------------------------------------------------------------------------
102/// Public accessors
103impl MastNode {
104    /// Returns true if this node is an external node.
105    pub fn is_external(&self) -> bool {
106        matches!(self, MastNode::External(_))
107    }
108
109    /// Returns true if this node is a Dyn node.
110    pub fn is_dyn(&self) -> bool {
111        matches!(self, MastNode::Dyn(_))
112    }
113
114    /// Returns true if this node is a basic block.
115    pub fn is_basic_block(&self) -> bool {
116        matches!(self, Self::Block(_))
117    }
118
119    /// Returns the inner basic block node if the [`MastNode`] wraps a [`BasicBlockNode`]; `None`
120    /// otherwise.
121    pub fn get_basic_block(&self) -> Option<&BasicBlockNode> {
122        match self {
123            MastNode::Block(basic_block_node) => Some(basic_block_node),
124            _ => None,
125        }
126    }
127
128    /// Unwraps the inner basic block node if the [`MastNode`] wraps a [`BasicBlockNode`]; panics
129    /// otherwise.
130    ///
131    /// # Panics
132    /// Panics if the [`MastNode`] does not wrap a [`BasicBlockNode`].
133    pub fn unwrap_basic_block(&self) -> &BasicBlockNode {
134        match self {
135            Self::Block(basic_block_node) => basic_block_node,
136            other => unwrap_failed(other, "basic block"),
137        }
138    }
139
140    /// Unwraps the inner join node if the [`MastNode`] wraps a [`JoinNode`]; panics otherwise.
141    ///
142    /// # Panics
143    /// - if the [`MastNode`] does not wrap a [`JoinNode`].
144    pub fn unwrap_join(&self) -> &JoinNode {
145        match self {
146            Self::Join(join_node) => join_node,
147            other => unwrap_failed(other, "join"),
148        }
149    }
150
151    /// Unwraps the inner split node if the [`MastNode`] wraps a [`SplitNode`]; panics otherwise.
152    ///
153    /// # Panics
154    /// - if the [`MastNode`] does not wrap a [`SplitNode`].
155    pub fn unwrap_split(&self) -> &SplitNode {
156        match self {
157            Self::Split(split_node) => split_node,
158            other => unwrap_failed(other, "split"),
159        }
160    }
161
162    /// Unwraps the inner loop node if the [`MastNode`] wraps a [`LoopNode`]; panics otherwise.
163    ///
164    /// # Panics
165    /// - if the [`MastNode`] does not wrap a [`LoopNode`].
166    pub fn unwrap_loop(&self) -> &LoopNode {
167        match self {
168            Self::Loop(loop_node) => loop_node,
169            other => unwrap_failed(other, "loop"),
170        }
171    }
172
173    /// Unwraps the inner call node if the [`MastNode`] wraps a [`CallNode`]; panics otherwise.
174    ///
175    /// # Panics
176    /// - if the [`MastNode`] does not wrap a [`CallNode`].
177    pub fn unwrap_call(&self) -> &CallNode {
178        match self {
179            Self::Call(call_node) => call_node,
180            other => unwrap_failed(other, "call"),
181        }
182    }
183
184    /// Unwraps the inner dynamic node if the [`MastNode`] wraps a [`DynNode`]; panics otherwise.
185    ///
186    /// # Panics
187    /// - if the [`MastNode`] does not wrap a [`DynNode`].
188    pub fn unwrap_dyn(&self) -> &DynNode {
189        match self {
190            Self::Dyn(dyn_node) => dyn_node,
191            other => unwrap_failed(other, "dyn"),
192        }
193    }
194
195    /// Unwraps the inner external node if the [`MastNode`] wraps a [`ExternalNode`]; panics
196    /// otherwise.
197    ///
198    /// # Panics
199    /// - if the [`MastNode`] does not wrap a [`ExternalNode`].
200    pub fn unwrap_external(&self) -> &ExternalNode {
201        match self {
202            Self::External(external_node) => external_node,
203            other => unwrap_failed(other, "external"),
204        }
205    }
206}
207
208// MAST INNER NODE EXT
209// ===============================================================================================
210
211/// A trait for extending the functionality of all [`MastNode`]s.
212pub trait MastNodeErrorContext: Send + Sync {
213    // REQUIRED METHODS
214    // -------------------------------------------------------------------------------------------
215
216    /// The list of decorators tied to this node, along with their associated index.
217    ///
218    /// The index is only meaningful for [`BasicBlockNode`]s, where it corresponds to the index of
219    /// the operation in the basic block to which the decorator is attached.
220    fn decorators<'a>(
221        &'a self,
222        forest: &'a MastForest,
223    ) -> impl Iterator<Item = DecoratedOpLink> + 'a;
224
225    // PROVIDED METHODS
226    // -------------------------------------------------------------------------------------------
227
228    /// Returns the [`AssemblyOp`] associated with this node and operation (if provided), if any.
229    ///
230    /// If the `target_op_idx` is provided, the method treats the wrapped node as a basic block will
231    /// return the assembly op associated with the operation at the corresponding index in the basic
232    /// block. If no `target_op_idx` is provided, the method will return the first assembly op found
233    /// (effectively assuming that the node has at most one associated [`AssemblyOp`]).
234    fn get_assembly_op<'m>(
235        &self,
236        mast_forest: &'m MastForest,
237        target_op_idx: Option<usize>,
238    ) -> Option<&'m AssemblyOp> {
239        match target_op_idx {
240            // If a target operation index is provided, return the assembly op associated with that
241            // operation.
242            Some(target_op_idx) => {
243                for (op_idx, decorator_id) in self.decorators(mast_forest) {
244                    if let Some(Decorator::AsmOp(assembly_op)) =
245                        mast_forest.decorator_by_id(decorator_id)
246                    {
247                        // when an instruction compiles down to multiple operations, only the first
248                        // operation is associated with the assembly op. We need to check if the
249                        // target operation index falls within the range of operations associated
250                        // with the assembly op.
251                        if target_op_idx >= op_idx
252                            && target_op_idx < op_idx + assembly_op.num_cycles() as usize
253                        {
254                            return Some(assembly_op);
255                        }
256                    }
257                }
258            },
259            // If no target operation index is provided, return the first assembly op found.
260            None => {
261                for (_, decorator_id) in self.decorators(mast_forest) {
262                    if let Some(Decorator::AsmOp(assembly_op)) =
263                        mast_forest.decorator_by_id(decorator_id)
264                    {
265                        return Some(assembly_op);
266                    }
267                }
268            },
269        }
270
271        None
272    }
273}
274
275// Links an operation index in a block to a decoratorid, to be executed right before this
276// operation's position
277pub type DecoratedOpLink = (usize, DecoratorId);
278
279// HELPERS
280// ===============================================================================================
281
282/// This function is analogous to the `unwrap_failed()` function used in the implementation of
283/// `core::result::Result` `unwrap_*()` methods.
284#[cold]
285#[inline(never)]
286#[track_caller]
287fn unwrap_failed(node: &MastNode, expected: &str) -> ! {
288    let actual = match node {
289        MastNode::Block(_) => "basic block",
290        MastNode::Join(_) => "join",
291        MastNode::Split(_) => "split",
292        MastNode::Loop(_) => "loop",
293        MastNode::Call(_) => "call",
294        MastNode::Dyn(_) => "dynamic",
295        MastNode::External(_) => "external",
296    };
297    panic!("tried to unwrap {expected} node, but got {actual}");
298}