miden_core/mast/node/
mod.rs

1mod basic_block_node;
2use alloc::{boxed::Box, vec::Vec};
3use core::fmt;
4
5pub use basic_block_node::{
6    BATCH_SIZE as OP_BATCH_SIZE, BasicBlockNode, GROUP_SIZE as OP_GROUP_SIZE, OpBatch,
7    OperationOrDecorator,
8};
9
10mod call_node;
11pub use call_node::CallNode;
12
13mod dyn_node;
14pub use dyn_node::DynNode;
15
16mod external;
17pub use external::ExternalNode;
18
19mod join_node;
20pub use join_node::JoinNode;
21
22mod split_node;
23use miden_crypto::{Felt, Word};
24use miden_formatting::prettier::{Document, PrettyPrint};
25pub use split_node::SplitNode;
26
27mod loop_node;
28pub use loop_node::LoopNode;
29
30use super::{DecoratorId, MastForestError};
31use crate::{
32    AssemblyOp, Decorator, DecoratorList, Operation,
33    mast::{MastForest, MastNodeId, Remapping},
34};
35
36// MAST NODE
37// ================================================================================================
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum MastNode {
41    Block(BasicBlockNode),
42    Join(JoinNode),
43    Split(SplitNode),
44    Loop(LoopNode),
45    Call(CallNode),
46    Dyn(DynNode),
47    External(ExternalNode),
48}
49
50// ------------------------------------------------------------------------------------------------
51/// Constructors
52impl MastNode {
53    pub fn new_basic_block(
54        operations: Vec<Operation>,
55        decorators: Option<DecoratorList>,
56    ) -> Result<Self, MastForestError> {
57        let block = BasicBlockNode::new(operations, decorators)?;
58        Ok(Self::Block(block))
59    }
60
61    pub fn new_join(
62        left_child: MastNodeId,
63        right_child: MastNodeId,
64        mast_forest: &MastForest,
65    ) -> Result<Self, MastForestError> {
66        let join = JoinNode::new([left_child, right_child], mast_forest)?;
67        Ok(Self::Join(join))
68    }
69
70    pub fn new_split(
71        if_branch: MastNodeId,
72        else_branch: MastNodeId,
73        mast_forest: &MastForest,
74    ) -> Result<Self, MastForestError> {
75        let split = SplitNode::new([if_branch, else_branch], mast_forest)?;
76        Ok(Self::Split(split))
77    }
78
79    pub fn new_loop(body: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
80        let loop_node = LoopNode::new(body, mast_forest)?;
81        Ok(Self::Loop(loop_node))
82    }
83
84    pub fn new_call(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
85        let call = CallNode::new(callee, mast_forest)?;
86        Ok(Self::Call(call))
87    }
88
89    pub fn new_syscall(
90        callee: MastNodeId,
91        mast_forest: &MastForest,
92    ) -> Result<Self, MastForestError> {
93        let syscall = CallNode::new_syscall(callee, mast_forest)?;
94        Ok(Self::Call(syscall))
95    }
96
97    pub fn new_dyn() -> Self {
98        Self::Dyn(DynNode::new_dyn())
99    }
100    pub fn new_dyncall() -> Self {
101        Self::Dyn(DynNode::new_dyncall())
102    }
103
104    pub fn new_external(mast_root: Word) -> Self {
105        Self::External(ExternalNode::new(mast_root))
106    }
107
108    #[cfg(test)]
109    pub fn new_basic_block_with_raw_decorators(
110        operations: Vec<Operation>,
111        decorators: Vec<(usize, crate::Decorator)>,
112        mast_forest: &mut MastForest,
113    ) -> Result<Self, MastForestError> {
114        let block = BasicBlockNode::new_with_raw_decorators(operations, decorators, mast_forest)?;
115        Ok(Self::Block(block))
116    }
117}
118
119// ------------------------------------------------------------------------------------------------
120/// Public accessors
121impl MastNode {
122    /// Returns true if this node is an external node.
123    pub fn is_external(&self) -> bool {
124        matches!(self, MastNode::External(_))
125    }
126
127    /// Returns true if this node is a Dyn node.
128    pub fn is_dyn(&self) -> bool {
129        matches!(self, MastNode::Dyn(_))
130    }
131
132    /// Returns true if this node is a basic block.
133    pub fn is_basic_block(&self) -> bool {
134        matches!(self, Self::Block(_))
135    }
136
137    /// Returns the inner basic block node if the [`MastNode`] wraps a [`BasicBlockNode`]; `None`
138    /// otherwise.
139    pub fn get_basic_block(&self) -> Option<&BasicBlockNode> {
140        match self {
141            MastNode::Block(basic_block_node) => Some(basic_block_node),
142            _ => None,
143        }
144    }
145
146    /// Unwraps the inner basic block node if the [`MastNode`] wraps a [`BasicBlockNode`]; panics
147    /// otherwise.
148    ///
149    /// # Panics
150    /// Panics if the [`MastNode`] does not wrap a [`BasicBlockNode`].
151    pub fn unwrap_basic_block(&self) -> &BasicBlockNode {
152        match self {
153            Self::Block(basic_block_node) => basic_block_node,
154            other => unwrap_failed(other, "basic block"),
155        }
156    }
157
158    /// Unwraps the inner join node if the [`MastNode`] wraps a [`JoinNode`]; panics otherwise.
159    ///
160    /// # Panics
161    /// - if the [`MastNode`] does not wrap a [`JoinNode`].
162    pub fn unwrap_join(&self) -> &JoinNode {
163        match self {
164            Self::Join(join_node) => join_node,
165            other => unwrap_failed(other, "join"),
166        }
167    }
168
169    /// Unwraps the inner split node if the [`MastNode`] wraps a [`SplitNode`]; panics otherwise.
170    ///
171    /// # Panics
172    /// - if the [`MastNode`] does not wrap a [`SplitNode`].
173    pub fn unwrap_split(&self) -> &SplitNode {
174        match self {
175            Self::Split(split_node) => split_node,
176            other => unwrap_failed(other, "split"),
177        }
178    }
179
180    /// Unwraps the inner loop node if the [`MastNode`] wraps a [`LoopNode`]; panics otherwise.
181    ///
182    /// # Panics
183    /// - if the [`MastNode`] does not wrap a [`LoopNode`].
184    pub fn unwrap_loop(&self) -> &LoopNode {
185        match self {
186            Self::Loop(loop_node) => loop_node,
187            other => unwrap_failed(other, "loop"),
188        }
189    }
190
191    /// Unwraps the inner call node if the [`MastNode`] wraps a [`CallNode`]; panics otherwise.
192    ///
193    /// # Panics
194    /// - if the [`MastNode`] does not wrap a [`CallNode`].
195    pub fn unwrap_call(&self) -> &CallNode {
196        match self {
197            Self::Call(call_node) => call_node,
198            other => unwrap_failed(other, "call"),
199        }
200    }
201
202    /// Unwraps the inner dynamic node if the [`MastNode`] wraps a [`DynNode`]; panics otherwise.
203    ///
204    /// # Panics
205    /// - if the [`MastNode`] does not wrap a [`DynNode`].
206    pub fn unwrap_dyn(&self) -> &DynNode {
207        match self {
208            Self::Dyn(dyn_node) => dyn_node,
209            other => unwrap_failed(other, "dyn"),
210        }
211    }
212
213    /// Unwraps the inner external node if the [`MastNode`] wraps a [`ExternalNode`]; panics
214    /// otherwise.
215    ///
216    /// # Panics
217    /// - if the [`MastNode`] does not wrap a [`ExternalNode`].
218    pub fn unwrap_external(&self) -> &ExternalNode {
219        match self {
220            Self::External(external_node) => external_node,
221            other => unwrap_failed(other, "external"),
222        }
223    }
224
225    /// Remap the node children to their new positions indicated by the given [`Remapping`].
226    pub fn remap_children(&self, remapping: &Remapping) -> Self {
227        use MastNode::*;
228        match self {
229            Join(join_node) => Join(join_node.remap_children(remapping)),
230            Split(split_node) => Split(split_node.remap_children(remapping)),
231            Loop(loop_node) => Loop(loop_node.remap_children(remapping)),
232            Call(call_node) => Call(call_node.remap_children(remapping)),
233            Block(_) | Dyn(_) | External(_) => self.clone(),
234        }
235    }
236
237    /// Returns true if the this node has children.
238    pub fn has_children(&self) -> bool {
239        match &self {
240            MastNode::Join(_) | MastNode::Split(_) | MastNode::Loop(_) | MastNode::Call(_) => true,
241            MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => false,
242        }
243    }
244
245    /// Appends the NodeIds of the children of this node, if any, to the vector.
246    pub fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
247        match &self {
248            MastNode::Join(join_node) => {
249                target.push(join_node.first());
250                target.push(join_node.second())
251            },
252            MastNode::Split(split_node) => {
253                target.push(split_node.on_true());
254                target.push(split_node.on_false())
255            },
256            MastNode::Loop(loop_node) => target.push(loop_node.body()),
257            MastNode::Call(call_node) => target.push(call_node.callee()),
258            MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => (),
259        }
260    }
261
262    pub fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> impl PrettyPrint + 'a {
263        match self {
264            MastNode::Block(basic_block_node) => {
265                MastNodePrettyPrint::new(Box::new(basic_block_node.to_pretty_print(mast_forest)))
266            },
267            MastNode::Join(join_node) => {
268                MastNodePrettyPrint::new(Box::new(join_node.to_pretty_print(mast_forest)))
269            },
270            MastNode::Split(split_node) => {
271                MastNodePrettyPrint::new(Box::new(split_node.to_pretty_print(mast_forest)))
272            },
273            MastNode::Loop(loop_node) => {
274                MastNodePrettyPrint::new(Box::new(loop_node.to_pretty_print(mast_forest)))
275            },
276            MastNode::Call(call_node) => {
277                MastNodePrettyPrint::new(Box::new(call_node.to_pretty_print(mast_forest)))
278            },
279            MastNode::Dyn(dyn_node) => {
280                MastNodePrettyPrint::new(Box::new(dyn_node.to_pretty_print(mast_forest)))
281            },
282            MastNode::External(external_node) => {
283                MastNodePrettyPrint::new(Box::new(external_node.to_pretty_print(mast_forest)))
284            },
285        }
286    }
287
288    pub fn domain(&self) -> Felt {
289        match self {
290            MastNode::Block(_) => BasicBlockNode::DOMAIN,
291            MastNode::Join(_) => JoinNode::DOMAIN,
292            MastNode::Split(_) => SplitNode::DOMAIN,
293            MastNode::Loop(_) => LoopNode::DOMAIN,
294            MastNode::Call(call_node) => call_node.domain(),
295            MastNode::Dyn(dyn_node) => dyn_node.domain(),
296            MastNode::External(_) => panic!("Can't fetch domain for an `External` node."),
297        }
298    }
299
300    pub fn digest(&self) -> Word {
301        match self {
302            MastNode::Block(node) => node.digest(),
303            MastNode::Join(node) => node.digest(),
304            MastNode::Split(node) => node.digest(),
305            MastNode::Loop(node) => node.digest(),
306            MastNode::Call(node) => node.digest(),
307            MastNode::Dyn(node) => node.digest(),
308            MastNode::External(node) => node.digest(),
309        }
310    }
311
312    pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
313        match self {
314            MastNode::Block(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
315            MastNode::Join(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
316            MastNode::Split(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
317            MastNode::Loop(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
318            MastNode::Call(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
319            MastNode::Dyn(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
320            MastNode::External(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
321        }
322    }
323
324    /// Returns the decorators to be executed before this node is executed.
325    pub fn before_enter(&self) -> &[DecoratorId] {
326        use MastNode::*;
327        match self {
328            Block(_) => &[],
329            Join(node) => node.before_enter(),
330            Split(node) => node.before_enter(),
331            Loop(node) => node.before_enter(),
332            Call(node) => node.before_enter(),
333            Dyn(node) => node.before_enter(),
334            External(node) => node.before_enter(),
335        }
336    }
337
338    /// Returns the decorators to be executed after this node is executed.
339    pub fn after_exit(&self) -> &[DecoratorId] {
340        use MastNode::*;
341        match self {
342            Block(_) => &[],
343            Join(node) => node.after_exit(),
344            Split(node) => node.after_exit(),
345            Loop(node) => node.after_exit(),
346            Call(node) => node.after_exit(),
347            Dyn(node) => node.after_exit(),
348            External(node) => node.after_exit(),
349        }
350    }
351}
352
353//-------------------------------------------------------------------------------------------------
354/// Mutators
355impl MastNode {
356    /// Sets the list of decorators to be executed before this node.
357    pub fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
358        match self {
359            MastNode::Block(node) => node.prepend_decorators(decorator_ids),
360            MastNode::Join(node) => node.append_before_enter(decorator_ids),
361            MastNode::Split(node) => node.append_before_enter(decorator_ids),
362            MastNode::Loop(node) => node.append_before_enter(decorator_ids),
363            MastNode::Call(node) => node.append_before_enter(decorator_ids),
364            MastNode::Dyn(node) => node.append_before_enter(decorator_ids),
365            MastNode::External(node) => node.append_before_enter(decorator_ids),
366        }
367    }
368
369    /// Sets the list of decorators to be executed after this node.
370    pub fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
371        match self {
372            MastNode::Block(node) => node.append_decorators(decorator_ids),
373            MastNode::Join(node) => node.append_after_exit(decorator_ids),
374            MastNode::Split(node) => node.append_after_exit(decorator_ids),
375            MastNode::Loop(node) => node.append_after_exit(decorator_ids),
376            MastNode::Call(node) => node.append_after_exit(decorator_ids),
377            MastNode::Dyn(node) => node.append_after_exit(decorator_ids),
378            MastNode::External(node) => node.append_after_exit(decorator_ids),
379        }
380    }
381
382    pub fn remove_decorators(&mut self) {
383        match self {
384            MastNode::Block(node) => node.remove_decorators(),
385            MastNode::Join(node) => node.remove_decorators(),
386            MastNode::Split(node) => node.remove_decorators(),
387            MastNode::Loop(node) => node.remove_decorators(),
388            MastNode::Call(node) => node.remove_decorators(),
389            MastNode::Dyn(node) => node.remove_decorators(),
390            MastNode::External(node) => node.remove_decorators(),
391        }
392    }
393}
394
395// PRETTY PRINTING
396// ================================================================================================
397
398struct MastNodePrettyPrint<'a> {
399    node_pretty_print: Box<dyn PrettyPrint + 'a>,
400}
401
402impl<'a> MastNodePrettyPrint<'a> {
403    pub fn new(node_pretty_print: Box<dyn PrettyPrint + 'a>) -> Self {
404        Self { node_pretty_print }
405    }
406}
407
408impl PrettyPrint for MastNodePrettyPrint<'_> {
409    fn render(&self) -> Document {
410        self.node_pretty_print.render()
411    }
412}
413
414struct MastNodeDisplay<'a> {
415    node_display: Box<dyn fmt::Display + 'a>,
416}
417
418impl<'a> MastNodeDisplay<'a> {
419    pub fn new(node: impl fmt::Display + 'a) -> Self {
420        Self { node_display: Box::new(node) }
421    }
422}
423
424impl fmt::Display for MastNodeDisplay<'_> {
425    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426        self.node_display.fmt(f)
427    }
428}
429
430// MAST INNER NODE EXT
431// ===============================================================================================
432
433/// A trait for extending the functionality of all [`MastNode`]s.
434pub trait MastNodeExt: Send + Sync {
435    // REQUIRED METHODS
436    // -------------------------------------------------------------------------------------------
437
438    /// The list of decorators tied to this node, along with their associated index.
439    ///
440    /// The index is only meaningful for [`BasicBlockNode`]s, where it corresponds to the index of
441    /// the operation in the basic block to which the decorator is attached.
442    fn decorators(&self) -> impl Iterator<Item = (usize, DecoratorId)>;
443
444    // PROVIDED METHODS
445    // -------------------------------------------------------------------------------------------
446
447    /// Returns the [`AssemblyOp`] associated with this node and operation (if provided), if any.
448    ///
449    /// If the `target_op_idx` is provided, the method treats the wrapped node as a basic block will
450    /// return the assembly op associated with the operation at the corresponding index in the basic
451    /// block. If no `target_op_idx` is provided, the method will return the first assembly op found
452    /// (effectively assuming that the node has at most one associated [`AssemblyOp`]).
453    fn get_assembly_op<'m>(
454        &self,
455        mast_forest: &'m MastForest,
456        target_op_idx: Option<usize>,
457    ) -> Option<&'m AssemblyOp> {
458        match target_op_idx {
459            // If a target operation index is provided, return the assembly op associated with that
460            // operation.
461            Some(target_op_idx) => {
462                for (op_idx, decorator_id) in self.decorators() {
463                    if let Some(Decorator::AsmOp(assembly_op)) =
464                        mast_forest.get_decorator_by_id(decorator_id)
465                    {
466                        // when an instruction compiles down to multiple operations, only the first
467                        // operation is associated with the assembly op. We need to check if the
468                        // target operation index falls within the range of operations associated
469                        // with the assembly op.
470                        if target_op_idx >= op_idx
471                            && target_op_idx < op_idx + assembly_op.num_cycles() as usize
472                        {
473                            return Some(assembly_op);
474                        }
475                    }
476                }
477            },
478            // If no target operation index is provided, return the first assembly op found.
479            None => {
480                for (_, decorator_id) in self.decorators() {
481                    if let Some(Decorator::AsmOp(assembly_op)) =
482                        mast_forest.get_decorator_by_id(decorator_id)
483                    {
484                        return Some(assembly_op);
485                    }
486                }
487            },
488        }
489
490        None
491    }
492}
493
494// HELPERS
495// ===============================================================================================
496
497/// This function is analogous to the `unwrap_failed()` function used in the implementation of
498/// `core::result::Result` `unwrap_*()` methods.
499#[cold]
500#[inline(never)]
501#[track_caller]
502fn unwrap_failed(node: &MastNode, expected: &str) -> ! {
503    let actual = match node {
504        MastNode::Block(_) => "basic block",
505        MastNode::Join(_) => "join",
506        MastNode::Split(_) => "split",
507        MastNode::Loop(_) => "loop",
508        MastNode::Call(_) => "call",
509        MastNode::Dyn(_) => "dynamic",
510        MastNode::External(_) => "external",
511    };
512    panic!("tried to unwrap {expected} node, but got {actual}");
513}