miden_core/mast/node/
mod.rs

1mod basic_block_node;
2use alloc::{boxed::Box, vec::Vec};
3use core::fmt;
4
5pub use basic_block_node::{
6    BATCH_SIZE as OP_BATCH_SIZE, BasicBlockNode, GROUP_SIZE as OP_GROUP_SIZE, OpBatch,
7    OperationOrDecorator,
8};
9
10mod call_node;
11pub use call_node::CallNode;
12
13mod dyn_node;
14pub use dyn_node::DynNode;
15
16mod external;
17pub use external::ExternalNode;
18
19mod join_node;
20pub use join_node::JoinNode;
21
22mod split_node;
23use miden_crypto::{Felt, hash::rpo::RpoDigest};
24use miden_formatting::prettier::{Document, PrettyPrint};
25pub use split_node::SplitNode;
26
27mod loop_node;
28pub use loop_node::LoopNode;
29
30use super::{DecoratorId, MastForestError};
31use crate::{
32    DecoratorList, Operation,
33    mast::{MastForest, MastNodeId, Remapping},
34};
35
36// MAST NODE
37// ================================================================================================
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum MastNode {
41    Block(BasicBlockNode),
42    Join(JoinNode),
43    Split(SplitNode),
44    Loop(LoopNode),
45    Call(CallNode),
46    Dyn(DynNode),
47    External(ExternalNode),
48}
49
50// ------------------------------------------------------------------------------------------------
51/// Constructors
52impl MastNode {
53    pub fn new_basic_block(
54        operations: Vec<Operation>,
55        decorators: Option<DecoratorList>,
56    ) -> Result<Self, MastForestError> {
57        let block = BasicBlockNode::new(operations, decorators)?;
58        Ok(Self::Block(block))
59    }
60
61    pub fn new_join(
62        left_child: MastNodeId,
63        right_child: MastNodeId,
64        mast_forest: &MastForest,
65    ) -> Result<Self, MastForestError> {
66        let join = JoinNode::new([left_child, right_child], mast_forest)?;
67        Ok(Self::Join(join))
68    }
69
70    pub fn new_split(
71        if_branch: MastNodeId,
72        else_branch: MastNodeId,
73        mast_forest: &MastForest,
74    ) -> Result<Self, MastForestError> {
75        let split = SplitNode::new([if_branch, else_branch], mast_forest)?;
76        Ok(Self::Split(split))
77    }
78
79    pub fn new_loop(body: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
80        let loop_node = LoopNode::new(body, mast_forest)?;
81        Ok(Self::Loop(loop_node))
82    }
83
84    pub fn new_call(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
85        let call = CallNode::new(callee, mast_forest)?;
86        Ok(Self::Call(call))
87    }
88
89    pub fn new_syscall(
90        callee: MastNodeId,
91        mast_forest: &MastForest,
92    ) -> Result<Self, MastForestError> {
93        let syscall = CallNode::new_syscall(callee, mast_forest)?;
94        Ok(Self::Call(syscall))
95    }
96
97    pub fn new_dyn() -> Self {
98        Self::Dyn(DynNode::new_dyn())
99    }
100    pub fn new_dyncall() -> Self {
101        Self::Dyn(DynNode::new_dyncall())
102    }
103
104    pub fn new_external(mast_root: RpoDigest) -> Self {
105        Self::External(ExternalNode::new(mast_root))
106    }
107
108    #[cfg(test)]
109    pub fn new_basic_block_with_raw_decorators(
110        operations: Vec<Operation>,
111        decorators: Vec<(usize, crate::Decorator)>,
112        mast_forest: &mut MastForest,
113    ) -> Result<Self, MastForestError> {
114        let block = BasicBlockNode::new_with_raw_decorators(operations, decorators, mast_forest)?;
115        Ok(Self::Block(block))
116    }
117}
118
119// ------------------------------------------------------------------------------------------------
120/// Public accessors
121impl MastNode {
122    /// Returns true if this node is an external node.
123    pub fn is_external(&self) -> bool {
124        matches!(self, MastNode::External(_))
125    }
126
127    /// Returns true if this node is a Dyn node.
128    pub fn is_dyn(&self) -> bool {
129        matches!(self, MastNode::Dyn(_))
130    }
131
132    /// Returns true if this node is a basic block.
133    pub fn is_basic_block(&self) -> bool {
134        matches!(self, Self::Block(_))
135    }
136
137    /// Returns the inner basic block node if the [`MastNode`] wraps a [`BasicBlockNode`]; `None`
138    /// otherwise.
139    pub fn get_basic_block(&self) -> Option<&BasicBlockNode> {
140        match self {
141            MastNode::Block(basic_block_node) => Some(basic_block_node),
142            _ => None,
143        }
144    }
145
146    /// Remap the node children to their new positions indicated by the given [`Remapping`].
147    pub fn remap_children(&self, remapping: &Remapping) -> Self {
148        use MastNode::*;
149        match self {
150            Join(join_node) => Join(join_node.remap_children(remapping)),
151            Split(split_node) => Split(split_node.remap_children(remapping)),
152            Loop(loop_node) => Loop(loop_node.remap_children(remapping)),
153            Call(call_node) => Call(call_node.remap_children(remapping)),
154            Block(_) | Dyn(_) | External(_) => self.clone(),
155        }
156    }
157
158    /// Returns true if the this node has children.
159    pub fn has_children(&self) -> bool {
160        match &self {
161            MastNode::Join(_) | MastNode::Split(_) | MastNode::Loop(_) | MastNode::Call(_) => true,
162            MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => false,
163        }
164    }
165
166    /// Appends the NodeIds of the children of this node, if any, to the vector.
167    pub fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
168        match &self {
169            MastNode::Join(join_node) => {
170                target.push(join_node.first());
171                target.push(join_node.second())
172            },
173            MastNode::Split(split_node) => {
174                target.push(split_node.on_true());
175                target.push(split_node.on_false())
176            },
177            MastNode::Loop(loop_node) => target.push(loop_node.body()),
178            MastNode::Call(call_node) => target.push(call_node.callee()),
179            MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => (),
180        }
181    }
182
183    pub fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> impl PrettyPrint + 'a {
184        match self {
185            MastNode::Block(basic_block_node) => {
186                MastNodePrettyPrint::new(Box::new(basic_block_node.to_pretty_print(mast_forest)))
187            },
188            MastNode::Join(join_node) => {
189                MastNodePrettyPrint::new(Box::new(join_node.to_pretty_print(mast_forest)))
190            },
191            MastNode::Split(split_node) => {
192                MastNodePrettyPrint::new(Box::new(split_node.to_pretty_print(mast_forest)))
193            },
194            MastNode::Loop(loop_node) => {
195                MastNodePrettyPrint::new(Box::new(loop_node.to_pretty_print(mast_forest)))
196            },
197            MastNode::Call(call_node) => {
198                MastNodePrettyPrint::new(Box::new(call_node.to_pretty_print(mast_forest)))
199            },
200            MastNode::Dyn(dyn_node) => {
201                MastNodePrettyPrint::new(Box::new(dyn_node.to_pretty_print(mast_forest)))
202            },
203            MastNode::External(external_node) => {
204                MastNodePrettyPrint::new(Box::new(external_node.to_pretty_print(mast_forest)))
205            },
206        }
207    }
208
209    pub fn domain(&self) -> Felt {
210        match self {
211            MastNode::Block(_) => BasicBlockNode::DOMAIN,
212            MastNode::Join(_) => JoinNode::DOMAIN,
213            MastNode::Split(_) => SplitNode::DOMAIN,
214            MastNode::Loop(_) => LoopNode::DOMAIN,
215            MastNode::Call(call_node) => call_node.domain(),
216            MastNode::Dyn(dyn_node) => dyn_node.domain(),
217            MastNode::External(_) => panic!("Can't fetch domain for an `External` node."),
218        }
219    }
220
221    pub fn digest(&self) -> RpoDigest {
222        match self {
223            MastNode::Block(node) => node.digest(),
224            MastNode::Join(node) => node.digest(),
225            MastNode::Split(node) => node.digest(),
226            MastNode::Loop(node) => node.digest(),
227            MastNode::Call(node) => node.digest(),
228            MastNode::Dyn(node) => node.digest(),
229            MastNode::External(node) => node.digest(),
230        }
231    }
232
233    pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
234        match self {
235            MastNode::Block(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
236            MastNode::Join(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
237            MastNode::Split(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
238            MastNode::Loop(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
239            MastNode::Call(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
240            MastNode::Dyn(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
241            MastNode::External(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
242        }
243    }
244
245    /// Returns the decorators to be executed before this node is executed.
246    pub fn before_enter(&self) -> &[DecoratorId] {
247        use MastNode::*;
248        match self {
249            Block(_) => &[],
250            Join(node) => node.before_enter(),
251            Split(node) => node.before_enter(),
252            Loop(node) => node.before_enter(),
253            Call(node) => node.before_enter(),
254            Dyn(node) => node.before_enter(),
255            External(node) => node.before_enter(),
256        }
257    }
258
259    /// Returns the decorators to be executed after this node is executed.
260    pub fn after_exit(&self) -> &[DecoratorId] {
261        use MastNode::*;
262        match self {
263            Block(_) => &[],
264            Join(node) => node.after_exit(),
265            Split(node) => node.after_exit(),
266            Loop(node) => node.after_exit(),
267            Call(node) => node.after_exit(),
268            Dyn(node) => node.after_exit(),
269            External(node) => node.after_exit(),
270        }
271    }
272}
273
274/// Mutators
275impl MastNode {
276    /// Sets the list of decorators to be executed before this node.
277    pub fn set_before_enter(&mut self, decorator_ids: Vec<DecoratorId>) {
278        match self {
279            MastNode::Block(node) => node.prepend_decorators(decorator_ids),
280            MastNode::Join(node) => node.set_before_enter(decorator_ids),
281            MastNode::Split(node) => node.set_before_enter(decorator_ids),
282            MastNode::Loop(node) => node.set_before_enter(decorator_ids),
283            MastNode::Call(node) => node.set_before_enter(decorator_ids),
284            MastNode::Dyn(node) => node.set_before_enter(decorator_ids),
285            MastNode::External(node) => node.set_before_enter(decorator_ids),
286        }
287    }
288
289    /// Sets the list of decorators to be executed after this node.
290    pub fn set_after_exit(&mut self, decorator_ids: Vec<DecoratorId>) {
291        match self {
292            MastNode::Block(node) => node.append_decorators(decorator_ids),
293            MastNode::Join(node) => node.set_after_exit(decorator_ids),
294            MastNode::Split(node) => node.set_after_exit(decorator_ids),
295            MastNode::Loop(node) => node.set_after_exit(decorator_ids),
296            MastNode::Call(node) => node.set_after_exit(decorator_ids),
297            MastNode::Dyn(node) => node.set_after_exit(decorator_ids),
298            MastNode::External(node) => node.set_after_exit(decorator_ids),
299        }
300    }
301}
302
303// PRETTY PRINTING
304// ================================================================================================
305
306struct MastNodePrettyPrint<'a> {
307    node_pretty_print: Box<dyn PrettyPrint + 'a>,
308}
309
310impl<'a> MastNodePrettyPrint<'a> {
311    pub fn new(node_pretty_print: Box<dyn PrettyPrint + 'a>) -> Self {
312        Self { node_pretty_print }
313    }
314}
315
316impl PrettyPrint for MastNodePrettyPrint<'_> {
317    fn render(&self) -> Document {
318        self.node_pretty_print.render()
319    }
320}
321
322struct MastNodeDisplay<'a> {
323    node_display: Box<dyn fmt::Display + 'a>,
324}
325
326impl<'a> MastNodeDisplay<'a> {
327    pub fn new(node: impl fmt::Display + 'a) -> Self {
328        Self { node_display: Box::new(node) }
329    }
330}
331
332impl fmt::Display for MastNodeDisplay<'_> {
333    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334        self.node_display.fmt(f)
335    }
336}