miden_core/mast/node/
mod.rs

1mod basic_block_node;
2use alloc::{boxed::Box, vec::Vec};
3use core::fmt;
4
5pub use basic_block_node::{
6    BATCH_SIZE as OP_BATCH_SIZE, BasicBlockNode, GROUP_SIZE as OP_GROUP_SIZE, OpBatch,
7    OperationOrDecorator,
8};
9use enum_dispatch::enum_dispatch;
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12
13mod call_node;
14pub use call_node::CallNode;
15
16mod dyn_node;
17pub use dyn_node::DynNode;
18
19mod external;
20pub use external::ExternalNode;
21
22mod join_node;
23pub use join_node::JoinNode;
24
25mod split_node;
26use miden_crypto::{Felt, Word};
27use miden_formatting::prettier::PrettyPrint;
28pub use split_node::SplitNode;
29
30mod loop_node;
31pub use loop_node::LoopNode;
32
33use super::DecoratorId;
34use crate::{
35    AssemblyOp, Decorator,
36    mast::{MastForest, MastNodeId, Remapping},
37};
38
39#[enum_dispatch]
40pub trait MastNodeExt {
41    /// Returns a commitment/hash of the node.
42    fn digest(&self) -> Word;
43
44    /// Returns the decorators to be executed before this node is executed.
45    fn before_enter(&self) -> &[DecoratorId];
46
47    /// Returns the decorators to be executed after this node is executed.
48    fn after_exit(&self) -> &[DecoratorId];
49
50    /// Sets the list of decorators to be executed before this node.
51    fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]);
52
53    /// Sets the list of decorators to be executed after this node.
54    fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]);
55
56    /// Removes all decorators from this node.
57    fn remove_decorators(&mut self);
58
59    /// Returns a display formatter for this node.
60    fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a>;
61
62    /// Returns a pretty printer for this node.
63    fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a>;
64
65    /// Remap the node children to their new positions indicated by the given [`Remapping`].
66    fn remap_children(&self, remapping: &Remapping) -> Self;
67
68    /// Returns true if the this node has children.
69    fn has_children(&self) -> bool;
70
71    /// Appends the NodeIds of the children of this node, if any, to the vector.
72    fn append_children_to(&self, target: &mut Vec<MastNodeId>);
73
74    /// Returns the domain of this node.
75    fn domain(&self) -> Felt;
76}
77
78// MAST NODE
79// ================================================================================================
80
81#[enum_dispatch(MastNodeExt)]
82#[derive(Debug, Clone, PartialEq, Eq)]
83#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
84pub enum MastNode {
85    Block(BasicBlockNode),
86    Join(JoinNode),
87    Split(SplitNode),
88    Loop(LoopNode),
89    Call(CallNode),
90    Dyn(DynNode),
91    External(ExternalNode),
92}
93
94// ------------------------------------------------------------------------------------------------
95/// Public accessors
96impl MastNode {
97    /// Returns true if this node is an external node.
98    pub fn is_external(&self) -> bool {
99        matches!(self, MastNode::External(_))
100    }
101
102    /// Returns true if this node is a Dyn node.
103    pub fn is_dyn(&self) -> bool {
104        matches!(self, MastNode::Dyn(_))
105    }
106
107    /// Returns true if this node is a basic block.
108    pub fn is_basic_block(&self) -> bool {
109        matches!(self, Self::Block(_))
110    }
111
112    /// Returns the inner basic block node if the [`MastNode`] wraps a [`BasicBlockNode`]; `None`
113    /// otherwise.
114    pub fn get_basic_block(&self) -> Option<&BasicBlockNode> {
115        match self {
116            MastNode::Block(basic_block_node) => Some(basic_block_node),
117            _ => None,
118        }
119    }
120
121    /// Unwraps the inner basic block node if the [`MastNode`] wraps a [`BasicBlockNode`]; panics
122    /// otherwise.
123    ///
124    /// # Panics
125    /// Panics if the [`MastNode`] does not wrap a [`BasicBlockNode`].
126    pub fn unwrap_basic_block(&self) -> &BasicBlockNode {
127        match self {
128            Self::Block(basic_block_node) => basic_block_node,
129            other => unwrap_failed(other, "basic block"),
130        }
131    }
132
133    /// Unwraps the inner join node if the [`MastNode`] wraps a [`JoinNode`]; panics otherwise.
134    ///
135    /// # Panics
136    /// - if the [`MastNode`] does not wrap a [`JoinNode`].
137    pub fn unwrap_join(&self) -> &JoinNode {
138        match self {
139            Self::Join(join_node) => join_node,
140            other => unwrap_failed(other, "join"),
141        }
142    }
143
144    /// Unwraps the inner split node if the [`MastNode`] wraps a [`SplitNode`]; panics otherwise.
145    ///
146    /// # Panics
147    /// - if the [`MastNode`] does not wrap a [`SplitNode`].
148    pub fn unwrap_split(&self) -> &SplitNode {
149        match self {
150            Self::Split(split_node) => split_node,
151            other => unwrap_failed(other, "split"),
152        }
153    }
154
155    /// Unwraps the inner loop node if the [`MastNode`] wraps a [`LoopNode`]; panics otherwise.
156    ///
157    /// # Panics
158    /// - if the [`MastNode`] does not wrap a [`LoopNode`].
159    pub fn unwrap_loop(&self) -> &LoopNode {
160        match self {
161            Self::Loop(loop_node) => loop_node,
162            other => unwrap_failed(other, "loop"),
163        }
164    }
165
166    /// Unwraps the inner call node if the [`MastNode`] wraps a [`CallNode`]; panics otherwise.
167    ///
168    /// # Panics
169    /// - if the [`MastNode`] does not wrap a [`CallNode`].
170    pub fn unwrap_call(&self) -> &CallNode {
171        match self {
172            Self::Call(call_node) => call_node,
173            other => unwrap_failed(other, "call"),
174        }
175    }
176
177    /// Unwraps the inner dynamic node if the [`MastNode`] wraps a [`DynNode`]; panics otherwise.
178    ///
179    /// # Panics
180    /// - if the [`MastNode`] does not wrap a [`DynNode`].
181    pub fn unwrap_dyn(&self) -> &DynNode {
182        match self {
183            Self::Dyn(dyn_node) => dyn_node,
184            other => unwrap_failed(other, "dyn"),
185        }
186    }
187
188    /// Unwraps the inner external node if the [`MastNode`] wraps a [`ExternalNode`]; panics
189    /// otherwise.
190    ///
191    /// # Panics
192    /// - if the [`MastNode`] does not wrap a [`ExternalNode`].
193    pub fn unwrap_external(&self) -> &ExternalNode {
194        match self {
195            Self::External(external_node) => external_node,
196            other => unwrap_failed(other, "external"),
197        }
198    }
199}
200
201// MAST INNER NODE EXT
202// ===============================================================================================
203
204/// A trait for extending the functionality of all [`MastNode`]s.
205pub trait MastNodeErrorContext: Send + Sync {
206    // REQUIRED METHODS
207    // -------------------------------------------------------------------------------------------
208
209    /// The list of decorators tied to this node, along with their associated index.
210    ///
211    /// The index is only meaningful for [`BasicBlockNode`]s, where it corresponds to the index of
212    /// the operation in the basic block to which the decorator is attached.
213    fn decorators(&self) -> impl Iterator<Item = (usize, DecoratorId)>;
214
215    // PROVIDED METHODS
216    // -------------------------------------------------------------------------------------------
217
218    /// Returns the [`AssemblyOp`] associated with this node and operation (if provided), if any.
219    ///
220    /// If the `target_op_idx` is provided, the method treats the wrapped node as a basic block will
221    /// return the assembly op associated with the operation at the corresponding index in the basic
222    /// block. If no `target_op_idx` is provided, the method will return the first assembly op found
223    /// (effectively assuming that the node has at most one associated [`AssemblyOp`]).
224    fn get_assembly_op<'m>(
225        &self,
226        mast_forest: &'m MastForest,
227        target_op_idx: Option<usize>,
228    ) -> Option<&'m AssemblyOp> {
229        match target_op_idx {
230            // If a target operation index is provided, return the assembly op associated with that
231            // operation.
232            Some(target_op_idx) => {
233                for (op_idx, decorator_id) in self.decorators() {
234                    if let Some(Decorator::AsmOp(assembly_op)) =
235                        mast_forest.get_decorator_by_id(decorator_id)
236                    {
237                        // when an instruction compiles down to multiple operations, only the first
238                        // operation is associated with the assembly op. We need to check if the
239                        // target operation index falls within the range of operations associated
240                        // with the assembly op.
241                        if target_op_idx >= op_idx
242                            && target_op_idx < op_idx + assembly_op.num_cycles() as usize
243                        {
244                            return Some(assembly_op);
245                        }
246                    }
247                }
248            },
249            // If no target operation index is provided, return the first assembly op found.
250            None => {
251                for (_, decorator_id) in self.decorators() {
252                    if let Some(Decorator::AsmOp(assembly_op)) =
253                        mast_forest.get_decorator_by_id(decorator_id)
254                    {
255                        return Some(assembly_op);
256                    }
257                }
258            },
259        }
260
261        None
262    }
263}
264
265// HELPERS
266// ===============================================================================================
267
268/// This function is analogous to the `unwrap_failed()` function used in the implementation of
269/// `core::result::Result` `unwrap_*()` methods.
270#[cold]
271#[inline(never)]
272#[track_caller]
273fn unwrap_failed(node: &MastNode, expected: &str) -> ! {
274    let actual = match node {
275        MastNode::Block(_) => "basic block",
276        MastNode::Join(_) => "join",
277        MastNode::Split(_) => "split",
278        MastNode::Loop(_) => "loop",
279        MastNode::Call(_) => "call",
280        MastNode::Dyn(_) => "dynamic",
281        MastNode::External(_) => "external",
282    };
283    panic!("tried to unwrap {expected} node, but got {actual}");
284}