miden_core/mast/node/
mod.rs

1mod basic_block_node;
2use alloc::{boxed::Box, vec::Vec};
3use core::fmt;
4
5pub use basic_block_node::{
6    BasicBlockNode, OpBatch, OperationOrDecorator, BATCH_SIZE as OP_BATCH_SIZE,
7    GROUP_SIZE as OP_GROUP_SIZE,
8};
9
10mod call_node;
11pub use call_node::CallNode;
12
13mod dyn_node;
14pub use dyn_node::DynNode;
15
16mod external;
17pub use external::ExternalNode;
18
19mod join_node;
20pub use join_node::JoinNode;
21
22mod split_node;
23use miden_crypto::{hash::rpo::RpoDigest, Felt};
24use miden_formatting::prettier::{Document, PrettyPrint};
25pub use split_node::SplitNode;
26
27mod loop_node;
28pub use loop_node::LoopNode;
29
30use super::{DecoratorId, MastForestError};
31use crate::{
32    mast::{MastForest, MastNodeId},
33    DecoratorList, Operation,
34};
35
36// MAST NODE
37// ================================================================================================
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum MastNode {
41    Block(BasicBlockNode),
42    Join(JoinNode),
43    Split(SplitNode),
44    Loop(LoopNode),
45    Call(CallNode),
46    Dyn(DynNode),
47    External(ExternalNode),
48}
49
50// ------------------------------------------------------------------------------------------------
51/// Constructors
52impl MastNode {
53    pub fn new_basic_block(
54        operations: Vec<Operation>,
55        decorators: Option<DecoratorList>,
56    ) -> Result<Self, MastForestError> {
57        let block = BasicBlockNode::new(operations, decorators)?;
58        Ok(Self::Block(block))
59    }
60
61    pub fn new_join(
62        left_child: MastNodeId,
63        right_child: MastNodeId,
64        mast_forest: &MastForest,
65    ) -> Result<Self, MastForestError> {
66        let join = JoinNode::new([left_child, right_child], mast_forest)?;
67        Ok(Self::Join(join))
68    }
69
70    pub fn new_split(
71        if_branch: MastNodeId,
72        else_branch: MastNodeId,
73        mast_forest: &MastForest,
74    ) -> Result<Self, MastForestError> {
75        let split = SplitNode::new([if_branch, else_branch], mast_forest)?;
76        Ok(Self::Split(split))
77    }
78
79    pub fn new_loop(body: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
80        let loop_node = LoopNode::new(body, mast_forest)?;
81        Ok(Self::Loop(loop_node))
82    }
83
84    pub fn new_call(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
85        let call = CallNode::new(callee, mast_forest)?;
86        Ok(Self::Call(call))
87    }
88
89    pub fn new_syscall(
90        callee: MastNodeId,
91        mast_forest: &MastForest,
92    ) -> Result<Self, MastForestError> {
93        let syscall = CallNode::new_syscall(callee, mast_forest)?;
94        Ok(Self::Call(syscall))
95    }
96
97    pub fn new_dyn() -> Self {
98        Self::Dyn(DynNode::new_dyn())
99    }
100    pub fn new_dyncall() -> Self {
101        Self::Dyn(DynNode::new_dyncall())
102    }
103
104    pub fn new_external(mast_root: RpoDigest) -> Self {
105        Self::External(ExternalNode::new(mast_root))
106    }
107
108    #[cfg(test)]
109    pub fn new_basic_block_with_raw_decorators(
110        operations: Vec<Operation>,
111        decorators: Vec<(usize, crate::Decorator)>,
112        mast_forest: &mut MastForest,
113    ) -> Result<Self, MastForestError> {
114        let block = BasicBlockNode::new_with_raw_decorators(operations, decorators, mast_forest)?;
115        Ok(Self::Block(block))
116    }
117}
118
119// ------------------------------------------------------------------------------------------------
120/// Public accessors
121impl MastNode {
122    /// Returns true if this node is an external node.
123    pub fn is_external(&self) -> bool {
124        matches!(self, MastNode::External(_))
125    }
126
127    /// Returns true if this node is a Dyn node.
128    pub fn is_dyn(&self) -> bool {
129        matches!(self, MastNode::Dyn(_))
130    }
131
132    /// Returns true if this node is a basic block.
133    pub fn is_basic_block(&self) -> bool {
134        matches!(self, Self::Block(_))
135    }
136
137    /// Returns the inner basic block node if the [`MastNode`] wraps a [`BasicBlockNode`]; `None`
138    /// otherwise.
139    pub fn get_basic_block(&self) -> Option<&BasicBlockNode> {
140        match self {
141            MastNode::Block(basic_block_node) => Some(basic_block_node),
142            _ => None,
143        }
144    }
145
146    pub fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> impl PrettyPrint + 'a {
147        match self {
148            MastNode::Block(basic_block_node) => {
149                MastNodePrettyPrint::new(Box::new(basic_block_node.to_pretty_print(mast_forest)))
150            },
151            MastNode::Join(join_node) => {
152                MastNodePrettyPrint::new(Box::new(join_node.to_pretty_print(mast_forest)))
153            },
154            MastNode::Split(split_node) => {
155                MastNodePrettyPrint::new(Box::new(split_node.to_pretty_print(mast_forest)))
156            },
157            MastNode::Loop(loop_node) => {
158                MastNodePrettyPrint::new(Box::new(loop_node.to_pretty_print(mast_forest)))
159            },
160            MastNode::Call(call_node) => {
161                MastNodePrettyPrint::new(Box::new(call_node.to_pretty_print(mast_forest)))
162            },
163            MastNode::Dyn(dyn_node) => {
164                MastNodePrettyPrint::new(Box::new(dyn_node.to_pretty_print(mast_forest)))
165            },
166            MastNode::External(external_node) => {
167                MastNodePrettyPrint::new(Box::new(external_node.to_pretty_print(mast_forest)))
168            },
169        }
170    }
171
172    pub fn domain(&self) -> Felt {
173        match self {
174            MastNode::Block(_) => BasicBlockNode::DOMAIN,
175            MastNode::Join(_) => JoinNode::DOMAIN,
176            MastNode::Split(_) => SplitNode::DOMAIN,
177            MastNode::Loop(_) => LoopNode::DOMAIN,
178            MastNode::Call(call_node) => call_node.domain(),
179            MastNode::Dyn(dyn_node) => dyn_node.domain(),
180            MastNode::External(_) => panic!("Can't fetch domain for an `External` node."),
181        }
182    }
183
184    pub fn digest(&self) -> RpoDigest {
185        match self {
186            MastNode::Block(node) => node.digest(),
187            MastNode::Join(node) => node.digest(),
188            MastNode::Split(node) => node.digest(),
189            MastNode::Loop(node) => node.digest(),
190            MastNode::Call(node) => node.digest(),
191            MastNode::Dyn(node) => node.digest(),
192            MastNode::External(node) => node.digest(),
193        }
194    }
195
196    pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
197        match self {
198            MastNode::Block(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
199            MastNode::Join(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
200            MastNode::Split(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
201            MastNode::Loop(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
202            MastNode::Call(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
203            MastNode::Dyn(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
204            MastNode::External(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
205        }
206    }
207
208    /// Returns the decorators to be executed before this node is executed.
209    pub fn before_enter(&self) -> &[DecoratorId] {
210        use MastNode::*;
211        match self {
212            Block(_) => &[],
213            Join(node) => node.before_enter(),
214            Split(node) => node.before_enter(),
215            Loop(node) => node.before_enter(),
216            Call(node) => node.before_enter(),
217            Dyn(node) => node.before_enter(),
218            External(node) => node.before_enter(),
219        }
220    }
221
222    /// Returns the decorators to be executed after this node is executed.
223    pub fn after_exit(&self) -> &[DecoratorId] {
224        use MastNode::*;
225        match self {
226            Block(_) => &[],
227            Join(node) => node.after_exit(),
228            Split(node) => node.after_exit(),
229            Loop(node) => node.after_exit(),
230            Call(node) => node.after_exit(),
231            Dyn(node) => node.after_exit(),
232            External(node) => node.after_exit(),
233        }
234    }
235}
236
237/// Mutators
238impl MastNode {
239    /// Sets the list of decorators to be executed before this node.
240    pub fn set_before_enter(&mut self, decorator_ids: Vec<DecoratorId>) {
241        match self {
242            MastNode::Block(node) => node.prepend_decorators(decorator_ids),
243            MastNode::Join(node) => node.set_before_enter(decorator_ids),
244            MastNode::Split(node) => node.set_before_enter(decorator_ids),
245            MastNode::Loop(node) => node.set_before_enter(decorator_ids),
246            MastNode::Call(node) => node.set_before_enter(decorator_ids),
247            MastNode::Dyn(node) => node.set_before_enter(decorator_ids),
248            MastNode::External(node) => node.set_before_enter(decorator_ids),
249        }
250    }
251
252    /// Sets the list of decorators to be executed after this node.
253    pub fn set_after_exit(&mut self, decorator_ids: Vec<DecoratorId>) {
254        match self {
255            MastNode::Block(node) => node.append_decorators(decorator_ids),
256            MastNode::Join(node) => node.set_after_exit(decorator_ids),
257            MastNode::Split(node) => node.set_after_exit(decorator_ids),
258            MastNode::Loop(node) => node.set_after_exit(decorator_ids),
259            MastNode::Call(node) => node.set_after_exit(decorator_ids),
260            MastNode::Dyn(node) => node.set_after_exit(decorator_ids),
261            MastNode::External(node) => node.set_after_exit(decorator_ids),
262        }
263    }
264}
265
266// PRETTY PRINTING
267// ================================================================================================
268
269struct MastNodePrettyPrint<'a> {
270    node_pretty_print: Box<dyn PrettyPrint + 'a>,
271}
272
273impl<'a> MastNodePrettyPrint<'a> {
274    pub fn new(node_pretty_print: Box<dyn PrettyPrint + 'a>) -> Self {
275        Self { node_pretty_print }
276    }
277}
278
279impl PrettyPrint for MastNodePrettyPrint<'_> {
280    fn render(&self) -> Document {
281        self.node_pretty_print.render()
282    }
283}
284
285struct MastNodeDisplay<'a> {
286    node_display: Box<dyn fmt::Display + 'a>,
287}
288
289impl<'a> MastNodeDisplay<'a> {
290    pub fn new(node: impl fmt::Display + 'a) -> Self {
291        Self { node_display: Box::new(node) }
292    }
293}
294
295impl fmt::Display for MastNodeDisplay<'_> {
296    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
297        self.node_display.fmt(f)
298    }
299}