miden_core/mast/node/
mod.rs

1mod basic_block_node;
2use alloc::{boxed::Box, vec::Vec};
3use core::fmt;
4
5pub use basic_block_node::{
6    BATCH_SIZE as OP_BATCH_SIZE, BasicBlockNode, GROUP_SIZE as OP_GROUP_SIZE, OpBatch,
7    OperationOrDecorator,
8};
9
10mod call_node;
11pub use call_node::CallNode;
12
13mod dyn_node;
14pub use dyn_node::DynNode;
15
16mod external;
17pub use external::ExternalNode;
18
19mod join_node;
20pub use join_node::JoinNode;
21
22mod split_node;
23use miden_crypto::{Felt, Word};
24use miden_formatting::prettier::{Document, PrettyPrint};
25pub use split_node::SplitNode;
26
27mod loop_node;
28pub use loop_node::LoopNode;
29
30use super::{DecoratorId, MastForestError};
31use crate::{
32    AssemblyOp, Decorator, DecoratorList, Operation,
33    mast::{MastForest, MastNodeId, Remapping},
34};
35
36// MAST NODE
37// ================================================================================================
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum MastNode {
41    Block(BasicBlockNode),
42    Join(JoinNode),
43    Split(SplitNode),
44    Loop(LoopNode),
45    Call(CallNode),
46    Dyn(DynNode),
47    External(ExternalNode),
48}
49
50// ------------------------------------------------------------------------------------------------
51/// Constructors
52impl MastNode {
53    pub fn new_basic_block(
54        operations: Vec<Operation>,
55        decorators: Option<DecoratorList>,
56    ) -> Result<Self, MastForestError> {
57        let block = BasicBlockNode::new(operations, decorators)?;
58        Ok(Self::Block(block))
59    }
60
61    pub fn new_join(
62        left_child: MastNodeId,
63        right_child: MastNodeId,
64        mast_forest: &MastForest,
65    ) -> Result<Self, MastForestError> {
66        let join = JoinNode::new([left_child, right_child], mast_forest)?;
67        Ok(Self::Join(join))
68    }
69
70    pub fn new_split(
71        if_branch: MastNodeId,
72        else_branch: MastNodeId,
73        mast_forest: &MastForest,
74    ) -> Result<Self, MastForestError> {
75        let split = SplitNode::new([if_branch, else_branch], mast_forest)?;
76        Ok(Self::Split(split))
77    }
78
79    pub fn new_loop(body: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
80        let loop_node = LoopNode::new(body, mast_forest)?;
81        Ok(Self::Loop(loop_node))
82    }
83
84    pub fn new_call(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
85        let call = CallNode::new(callee, mast_forest)?;
86        Ok(Self::Call(call))
87    }
88
89    pub fn new_syscall(
90        callee: MastNodeId,
91        mast_forest: &MastForest,
92    ) -> Result<Self, MastForestError> {
93        let syscall = CallNode::new_syscall(callee, mast_forest)?;
94        Ok(Self::Call(syscall))
95    }
96
97    pub fn new_dyn() -> Self {
98        Self::Dyn(DynNode::new_dyn())
99    }
100    pub fn new_dyncall() -> Self {
101        Self::Dyn(DynNode::new_dyncall())
102    }
103
104    pub fn new_external(mast_root: Word) -> Self {
105        Self::External(ExternalNode::new(mast_root))
106    }
107
108    #[cfg(test)]
109    pub fn new_basic_block_with_raw_decorators(
110        operations: Vec<Operation>,
111        decorators: Vec<(usize, crate::Decorator)>,
112        mast_forest: &mut MastForest,
113    ) -> Result<Self, MastForestError> {
114        let block = BasicBlockNode::new_with_raw_decorators(operations, decorators, mast_forest)?;
115        Ok(Self::Block(block))
116    }
117}
118
119// ------------------------------------------------------------------------------------------------
120/// Public accessors
121impl MastNode {
122    /// Returns true if this node is an external node.
123    pub fn is_external(&self) -> bool {
124        matches!(self, MastNode::External(_))
125    }
126
127    /// Returns true if this node is a Dyn node.
128    pub fn is_dyn(&self) -> bool {
129        matches!(self, MastNode::Dyn(_))
130    }
131
132    /// Returns true if this node is a basic block.
133    pub fn is_basic_block(&self) -> bool {
134        matches!(self, Self::Block(_))
135    }
136
137    /// Returns the inner basic block node if the [`MastNode`] wraps a [`BasicBlockNode`]; `None`
138    /// otherwise.
139    pub fn get_basic_block(&self) -> Option<&BasicBlockNode> {
140        match self {
141            MastNode::Block(basic_block_node) => Some(basic_block_node),
142            _ => None,
143        }
144    }
145
146    /// Unwraps the inner basic block node if the [`MastNode`] wraps a [`BasicBlockNode`]; panics
147    /// otherwise.
148    ///
149    /// # Panics
150    /// Panics if the [`MastNode`] does not wrap a [`BasicBlockNode`].
151    pub fn unwrap_basic_block(&self) -> &BasicBlockNode {
152        match self {
153            Self::Block(basic_block_node) => basic_block_node,
154            other => unwrap_failed(other, "basic block"),
155        }
156    }
157
158    /// Unwraps the inner join node if the [`MastNode`] wraps a [`JoinNode`]; panics otherwise.
159    ///
160    /// # Panics
161    /// - if the [`MastNode`] does not wrap a [`JoinNode`].
162    pub fn unwrap_join(&self) -> &JoinNode {
163        match self {
164            Self::Join(join_node) => join_node,
165            other => unwrap_failed(other, "join"),
166        }
167    }
168
169    /// Unwraps the inner split node if the [`MastNode`] wraps a [`SplitNode`]; panics otherwise.
170    ///
171    /// # Panics
172    /// - if the [`MastNode`] does not wrap a [`SplitNode`].
173    pub fn unwrap_split(&self) -> &SplitNode {
174        match self {
175            Self::Split(split_node) => split_node,
176            other => unwrap_failed(other, "split"),
177        }
178    }
179
180    /// Unwraps the inner loop node if the [`MastNode`] wraps a [`LoopNode`]; panics otherwise.
181    ///
182    /// # Panics
183    /// - if the [`MastNode`] does not wrap a [`LoopNode`].
184    pub fn unwrap_loop(&self) -> &LoopNode {
185        match self {
186            Self::Loop(loop_node) => loop_node,
187            other => unwrap_failed(other, "loop"),
188        }
189    }
190
191    /// Unwraps the inner call node if the [`MastNode`] wraps a [`CallNode`]; panics otherwise.
192    ///
193    /// # Panics
194    /// - if the [`MastNode`] does not wrap a [`CallNode`].
195    pub fn unwrap_call(&self) -> &CallNode {
196        match self {
197            Self::Call(call_node) => call_node,
198            other => unwrap_failed(other, "call"),
199        }
200    }
201
202    /// Unwraps the inner dynamic node if the [`MastNode`] wraps a [`DynNode`]; panics otherwise.
203    ///
204    /// # Panics
205    /// - if the [`MastNode`] does not wrap a [`DynNode`].
206    pub fn unwrap_dyn(&self) -> &DynNode {
207        match self {
208            Self::Dyn(dyn_node) => dyn_node,
209            other => unwrap_failed(other, "dyn"),
210        }
211    }
212
213    /// Unwraps the inner external node if the [`MastNode`] wraps a [`ExternalNode`]; panics
214    /// otherwise.
215    ///
216    /// # Panics
217    /// - if the [`MastNode`] does not wrap a [`ExternalNode`].
218    pub fn unwrap_external(&self) -> &ExternalNode {
219        match self {
220            Self::External(external_node) => external_node,
221            other => unwrap_failed(other, "external"),
222        }
223    }
224
225    /// Remap the node children to their new positions indicated by the given [`Remapping`].
226    pub fn remap_children(&self, remapping: &Remapping) -> Self {
227        use MastNode::*;
228        match self {
229            Join(join_node) => Join(join_node.remap_children(remapping)),
230            Split(split_node) => Split(split_node.remap_children(remapping)),
231            Loop(loop_node) => Loop(loop_node.remap_children(remapping)),
232            Call(call_node) => Call(call_node.remap_children(remapping)),
233            Block(_) | Dyn(_) | External(_) => self.clone(),
234        }
235    }
236
237    /// Returns true if the this node has children.
238    pub fn has_children(&self) -> bool {
239        match &self {
240            MastNode::Join(_) | MastNode::Split(_) | MastNode::Loop(_) | MastNode::Call(_) => true,
241            MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => false,
242        }
243    }
244
245    /// Appends the NodeIds of the children of this node, if any, to the vector.
246    pub fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
247        match &self {
248            MastNode::Join(join_node) => {
249                target.push(join_node.first());
250                target.push(join_node.second())
251            },
252            MastNode::Split(split_node) => {
253                target.push(split_node.on_true());
254                target.push(split_node.on_false())
255            },
256            MastNode::Loop(loop_node) => target.push(loop_node.body()),
257            MastNode::Call(call_node) => target.push(call_node.callee()),
258            MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => (),
259        }
260    }
261
262    pub fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> impl PrettyPrint + 'a {
263        match self {
264            MastNode::Block(basic_block_node) => {
265                MastNodePrettyPrint::new(Box::new(basic_block_node.to_pretty_print(mast_forest)))
266            },
267            MastNode::Join(join_node) => {
268                MastNodePrettyPrint::new(Box::new(join_node.to_pretty_print(mast_forest)))
269            },
270            MastNode::Split(split_node) => {
271                MastNodePrettyPrint::new(Box::new(split_node.to_pretty_print(mast_forest)))
272            },
273            MastNode::Loop(loop_node) => {
274                MastNodePrettyPrint::new(Box::new(loop_node.to_pretty_print(mast_forest)))
275            },
276            MastNode::Call(call_node) => {
277                MastNodePrettyPrint::new(Box::new(call_node.to_pretty_print(mast_forest)))
278            },
279            MastNode::Dyn(dyn_node) => {
280                MastNodePrettyPrint::new(Box::new(dyn_node.to_pretty_print(mast_forest)))
281            },
282            MastNode::External(external_node) => {
283                MastNodePrettyPrint::new(Box::new(external_node.to_pretty_print(mast_forest)))
284            },
285        }
286    }
287
288    pub fn domain(&self) -> Felt {
289        match self {
290            MastNode::Block(_) => BasicBlockNode::DOMAIN,
291            MastNode::Join(_) => JoinNode::DOMAIN,
292            MastNode::Split(_) => SplitNode::DOMAIN,
293            MastNode::Loop(_) => LoopNode::DOMAIN,
294            MastNode::Call(call_node) => call_node.domain(),
295            MastNode::Dyn(dyn_node) => dyn_node.domain(),
296            MastNode::External(_) => panic!("Can't fetch domain for an `External` node."),
297        }
298    }
299
300    pub fn digest(&self) -> Word {
301        match self {
302            MastNode::Block(node) => node.digest(),
303            MastNode::Join(node) => node.digest(),
304            MastNode::Split(node) => node.digest(),
305            MastNode::Loop(node) => node.digest(),
306            MastNode::Call(node) => node.digest(),
307            MastNode::Dyn(node) => node.digest(),
308            MastNode::External(node) => node.digest(),
309        }
310    }
311
312    pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
313        match self {
314            MastNode::Block(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
315            MastNode::Join(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
316            MastNode::Split(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
317            MastNode::Loop(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
318            MastNode::Call(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
319            MastNode::Dyn(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
320            MastNode::External(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
321        }
322    }
323
324    /// Returns the decorators to be executed before this node is executed.
325    pub fn before_enter(&self) -> &[DecoratorId] {
326        use MastNode::*;
327        match self {
328            Block(_) => &[],
329            Join(node) => node.before_enter(),
330            Split(node) => node.before_enter(),
331            Loop(node) => node.before_enter(),
332            Call(node) => node.before_enter(),
333            Dyn(node) => node.before_enter(),
334            External(node) => node.before_enter(),
335        }
336    }
337
338    /// Returns the decorators to be executed after this node is executed.
339    pub fn after_exit(&self) -> &[DecoratorId] {
340        use MastNode::*;
341        match self {
342            Block(_) => &[],
343            Join(node) => node.after_exit(),
344            Split(node) => node.after_exit(),
345            Loop(node) => node.after_exit(),
346            Call(node) => node.after_exit(),
347            Dyn(node) => node.after_exit(),
348            External(node) => node.after_exit(),
349        }
350    }
351}
352
353/// Mutators
354impl MastNode {
355    /// Sets the list of decorators to be executed before this node.
356    pub fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
357        match self {
358            MastNode::Block(node) => node.prepend_decorators(decorator_ids),
359            MastNode::Join(node) => node.append_before_enter(decorator_ids),
360            MastNode::Split(node) => node.append_before_enter(decorator_ids),
361            MastNode::Loop(node) => node.append_before_enter(decorator_ids),
362            MastNode::Call(node) => node.append_before_enter(decorator_ids),
363            MastNode::Dyn(node) => node.append_before_enter(decorator_ids),
364            MastNode::External(node) => node.append_before_enter(decorator_ids),
365        }
366    }
367
368    /// Sets the list of decorators to be executed after this node.
369    pub fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
370        match self {
371            MastNode::Block(node) => node.append_decorators(decorator_ids),
372            MastNode::Join(node) => node.append_after_exit(decorator_ids),
373            MastNode::Split(node) => node.append_after_exit(decorator_ids),
374            MastNode::Loop(node) => node.append_after_exit(decorator_ids),
375            MastNode::Call(node) => node.append_after_exit(decorator_ids),
376            MastNode::Dyn(node) => node.append_after_exit(decorator_ids),
377            MastNode::External(node) => node.append_after_exit(decorator_ids),
378        }
379    }
380}
381
382// PRETTY PRINTING
383// ================================================================================================
384
385struct MastNodePrettyPrint<'a> {
386    node_pretty_print: Box<dyn PrettyPrint + 'a>,
387}
388
389impl<'a> MastNodePrettyPrint<'a> {
390    pub fn new(node_pretty_print: Box<dyn PrettyPrint + 'a>) -> Self {
391        Self { node_pretty_print }
392    }
393}
394
395impl PrettyPrint for MastNodePrettyPrint<'_> {
396    fn render(&self) -> Document {
397        self.node_pretty_print.render()
398    }
399}
400
401struct MastNodeDisplay<'a> {
402    node_display: Box<dyn fmt::Display + 'a>,
403}
404
405impl<'a> MastNodeDisplay<'a> {
406    pub fn new(node: impl fmt::Display + 'a) -> Self {
407        Self { node_display: Box::new(node) }
408    }
409}
410
411impl fmt::Display for MastNodeDisplay<'_> {
412    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
413        self.node_display.fmt(f)
414    }
415}
416
417// MAST INNER NODE EXT
418// ===============================================================================================
419
420/// A trait for extending the functionality of all [`MastNode`]s.
421pub trait MastNodeExt: Send + Sync {
422    // REQUIRED METHODS
423    // -------------------------------------------------------------------------------------------
424
425    /// The list of decorators tied to this node, along with their associated index.
426    ///
427    /// The index is only meaningful for [`BasicBlockNode`]s, where it corresponds to the index of
428    /// the operation in the basic block to which the decorator is attached.
429    fn decorators(&self) -> impl Iterator<Item = (usize, DecoratorId)>;
430
431    // PROVIDED METHODS
432    // -------------------------------------------------------------------------------------------
433
434    /// Returns the [`AssemblyOp`] associated with this node and operation (if provided), if any.
435    ///
436    /// If the `target_op_idx` is provided, the method treats the wrapped node as a basic block will
437    /// return the assembly op associated with the operation at the corresponding index in the basic
438    /// block. If no `target_op_idx` is provided, the method will return the first assembly op found
439    /// (effectively assuming that the node has at most one associated [`AssemblyOp`]).
440    fn get_assembly_op<'m>(
441        &self,
442        mast_forest: &'m MastForest,
443        target_op_idx: Option<usize>,
444    ) -> Option<&'m AssemblyOp> {
445        match target_op_idx {
446            // If a target operation index is provided, return the assembly op associated with that
447            // operation.
448            Some(target_op_idx) => {
449                for (op_idx, decorator_id) in self.decorators() {
450                    if let Some(Decorator::AsmOp(assembly_op)) =
451                        mast_forest.get_decorator_by_id(decorator_id)
452                    {
453                        // when an instruction compiles down to multiple operations, only the first
454                        // operation is associated with the assembly op. We need to check if the
455                        // target operation index falls within the range of operations associated
456                        // with the assembly op.
457                        if target_op_idx >= op_idx
458                            && target_op_idx < op_idx + assembly_op.num_cycles() as usize
459                        {
460                            return Some(assembly_op);
461                        }
462                    }
463                }
464            },
465            // If no target operation index is provided, return the first assembly op found.
466            None => {
467                for (_, decorator_id) in self.decorators() {
468                    if let Some(Decorator::AsmOp(assembly_op)) =
469                        mast_forest.get_decorator_by_id(decorator_id)
470                    {
471                        return Some(assembly_op);
472                    }
473                }
474            },
475        }
476
477        None
478    }
479}
480
481// HELPERS
482// ===============================================================================================
483
484/// This function is analogous to the `unwrap_failed()` function used in the implementation of
485/// `core::result::Result` `unwrap_*()` methods.
486#[cold]
487#[inline(never)]
488#[track_caller]
489fn unwrap_failed(node: &MastNode, expected: &str) -> ! {
490    let actual = match node {
491        MastNode::Block(_) => "basic block",
492        MastNode::Join(_) => "join",
493        MastNode::Split(_) => "split",
494        MastNode::Loop(_) => "loop",
495        MastNode::Call(_) => "call",
496        MastNode::Dyn(_) => "dynamic",
497        MastNode::External(_) => "external",
498    };
499    panic!("tried to unwrap {expected} node, but got {actual}");
500}