miden_core/mast/node/mod.rs
1mod basic_block_node;
2use alloc::{boxed::Box, vec::Vec};
3use core::fmt;
4
5pub use basic_block_node::{
6 BATCH_SIZE as OP_BATCH_SIZE, BasicBlockNode, GROUP_SIZE as OP_GROUP_SIZE, OpBatch,
7 OperationOrDecorator,
8};
9use enum_dispatch::enum_dispatch;
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12
13mod call_node;
14pub use call_node::CallNode;
15
16mod dyn_node;
17pub use dyn_node::DynNode;
18
19mod external;
20pub use external::ExternalNode;
21
22mod join_node;
23pub use join_node::JoinNode;
24
25mod split_node;
26use miden_crypto::{Felt, Word};
27use miden_formatting::prettier::PrettyPrint;
28pub use split_node::SplitNode;
29
30mod loop_node;
31pub use loop_node::LoopNode;
32
33use super::DecoratorId;
34use crate::{
35 AssemblyOp, Decorator,
36 mast::{MastForest, MastNodeId, Remapping},
37};
38
39#[enum_dispatch]
40pub trait MastNodeExt {
41 /// Returns a commitment/hash of the node.
42 fn digest(&self) -> Word;
43
44 /// Returns the decorators to be executed before this node is executed.
45 fn before_enter(&self) -> &[DecoratorId];
46
47 /// Returns the decorators to be executed after this node is executed.
48 fn after_exit(&self) -> &[DecoratorId];
49
50 /// Sets the list of decorators to be executed before this node.
51 fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]);
52
53 /// Sets the list of decorators to be executed after this node.
54 fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]);
55
56 /// Removes all decorators from this node.
57 fn remove_decorators(&mut self);
58
59 /// Returns a display formatter for this node.
60 fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a>;
61
62 /// Returns a pretty printer for this node.
63 fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a>;
64
65 /// Remap the node children to their new positions indicated by the given [`Remapping`].
66 fn remap_children(&self, remapping: &Remapping) -> Self;
67
68 /// Returns true if the this node has children.
69 fn has_children(&self) -> bool;
70
71 /// Appends the NodeIds of the children of this node, if any, to the vector.
72 fn append_children_to(&self, target: &mut Vec<MastNodeId>);
73
74 /// Returns the domain of this node.
75 fn domain(&self) -> Felt;
76}
77
78// MAST NODE
79// ================================================================================================
80
81#[enum_dispatch(MastNodeExt)]
82#[derive(Debug, Clone, PartialEq, Eq)]
83#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
84pub enum MastNode {
85 Block(BasicBlockNode),
86 Join(JoinNode),
87 Split(SplitNode),
88 Loop(LoopNode),
89 Call(CallNode),
90 Dyn(DynNode),
91 External(ExternalNode),
92}
93
94// ------------------------------------------------------------------------------------------------
95/// Public accessors
96impl MastNode {
97 /// Returns true if this node is an external node.
98 pub fn is_external(&self) -> bool {
99 matches!(self, MastNode::External(_))
100 }
101
102 /// Returns true if this node is a Dyn node.
103 pub fn is_dyn(&self) -> bool {
104 matches!(self, MastNode::Dyn(_))
105 }
106
107 /// Returns true if this node is a basic block.
108 pub fn is_basic_block(&self) -> bool {
109 matches!(self, Self::Block(_))
110 }
111
112 /// Returns the inner basic block node if the [`MastNode`] wraps a [`BasicBlockNode`]; `None`
113 /// otherwise.
114 pub fn get_basic_block(&self) -> Option<&BasicBlockNode> {
115 match self {
116 MastNode::Block(basic_block_node) => Some(basic_block_node),
117 _ => None,
118 }
119 }
120
121 /// Unwraps the inner basic block node if the [`MastNode`] wraps a [`BasicBlockNode`]; panics
122 /// otherwise.
123 ///
124 /// # Panics
125 /// Panics if the [`MastNode`] does not wrap a [`BasicBlockNode`].
126 pub fn unwrap_basic_block(&self) -> &BasicBlockNode {
127 match self {
128 Self::Block(basic_block_node) => basic_block_node,
129 other => unwrap_failed(other, "basic block"),
130 }
131 }
132
133 /// Unwraps the inner join node if the [`MastNode`] wraps a [`JoinNode`]; panics otherwise.
134 ///
135 /// # Panics
136 /// - if the [`MastNode`] does not wrap a [`JoinNode`].
137 pub fn unwrap_join(&self) -> &JoinNode {
138 match self {
139 Self::Join(join_node) => join_node,
140 other => unwrap_failed(other, "join"),
141 }
142 }
143
144 /// Unwraps the inner split node if the [`MastNode`] wraps a [`SplitNode`]; panics otherwise.
145 ///
146 /// # Panics
147 /// - if the [`MastNode`] does not wrap a [`SplitNode`].
148 pub fn unwrap_split(&self) -> &SplitNode {
149 match self {
150 Self::Split(split_node) => split_node,
151 other => unwrap_failed(other, "split"),
152 }
153 }
154
155 /// Unwraps the inner loop node if the [`MastNode`] wraps a [`LoopNode`]; panics otherwise.
156 ///
157 /// # Panics
158 /// - if the [`MastNode`] does not wrap a [`LoopNode`].
159 pub fn unwrap_loop(&self) -> &LoopNode {
160 match self {
161 Self::Loop(loop_node) => loop_node,
162 other => unwrap_failed(other, "loop"),
163 }
164 }
165
166 /// Unwraps the inner call node if the [`MastNode`] wraps a [`CallNode`]; panics otherwise.
167 ///
168 /// # Panics
169 /// - if the [`MastNode`] does not wrap a [`CallNode`].
170 pub fn unwrap_call(&self) -> &CallNode {
171 match self {
172 Self::Call(call_node) => call_node,
173 other => unwrap_failed(other, "call"),
174 }
175 }
176
177 /// Unwraps the inner dynamic node if the [`MastNode`] wraps a [`DynNode`]; panics otherwise.
178 ///
179 /// # Panics
180 /// - if the [`MastNode`] does not wrap a [`DynNode`].
181 pub fn unwrap_dyn(&self) -> &DynNode {
182 match self {
183 Self::Dyn(dyn_node) => dyn_node,
184 other => unwrap_failed(other, "dyn"),
185 }
186 }
187
188 /// Unwraps the inner external node if the [`MastNode`] wraps a [`ExternalNode`]; panics
189 /// otherwise.
190 ///
191 /// # Panics
192 /// - if the [`MastNode`] does not wrap a [`ExternalNode`].
193 pub fn unwrap_external(&self) -> &ExternalNode {
194 match self {
195 Self::External(external_node) => external_node,
196 other => unwrap_failed(other, "external"),
197 }
198 }
199}
200
201// MAST INNER NODE EXT
202// ===============================================================================================
203
204/// A trait for extending the functionality of all [`MastNode`]s.
205pub trait MastNodeErrorContext: Send + Sync {
206 // REQUIRED METHODS
207 // -------------------------------------------------------------------------------------------
208
209 /// The list of decorators tied to this node, along with their associated index.
210 ///
211 /// The index is only meaningful for [`BasicBlockNode`]s, where it corresponds to the index of
212 /// the operation in the basic block to which the decorator is attached.
213 fn decorators(&self) -> impl Iterator<Item = (usize, DecoratorId)>;
214
215 // PROVIDED METHODS
216 // -------------------------------------------------------------------------------------------
217
218 /// Returns the [`AssemblyOp`] associated with this node and operation (if provided), if any.
219 ///
220 /// If the `target_op_idx` is provided, the method treats the wrapped node as a basic block will
221 /// return the assembly op associated with the operation at the corresponding index in the basic
222 /// block. If no `target_op_idx` is provided, the method will return the first assembly op found
223 /// (effectively assuming that the node has at most one associated [`AssemblyOp`]).
224 fn get_assembly_op<'m>(
225 &self,
226 mast_forest: &'m MastForest,
227 target_op_idx: Option<usize>,
228 ) -> Option<&'m AssemblyOp> {
229 match target_op_idx {
230 // If a target operation index is provided, return the assembly op associated with that
231 // operation.
232 Some(target_op_idx) => {
233 for (op_idx, decorator_id) in self.decorators() {
234 if let Some(Decorator::AsmOp(assembly_op)) =
235 mast_forest.get_decorator_by_id(decorator_id)
236 {
237 // when an instruction compiles down to multiple operations, only the first
238 // operation is associated with the assembly op. We need to check if the
239 // target operation index falls within the range of operations associated
240 // with the assembly op.
241 if target_op_idx >= op_idx
242 && target_op_idx < op_idx + assembly_op.num_cycles() as usize
243 {
244 return Some(assembly_op);
245 }
246 }
247 }
248 },
249 // If no target operation index is provided, return the first assembly op found.
250 None => {
251 for (_, decorator_id) in self.decorators() {
252 if let Some(Decorator::AsmOp(assembly_op)) =
253 mast_forest.get_decorator_by_id(decorator_id)
254 {
255 return Some(assembly_op);
256 }
257 }
258 },
259 }
260
261 None
262 }
263}
264
265// HELPERS
266// ===============================================================================================
267
268/// This function is analogous to the `unwrap_failed()` function used in the implementation of
269/// `core::result::Result` `unwrap_*()` methods.
270#[cold]
271#[inline(never)]
272#[track_caller]
273fn unwrap_failed(node: &MastNode, expected: &str) -> ! {
274 let actual = match node {
275 MastNode::Block(_) => "basic block",
276 MastNode::Join(_) => "join",
277 MastNode::Split(_) => "split",
278 MastNode::Loop(_) => "loop",
279 MastNode::Call(_) => "call",
280 MastNode::Dyn(_) => "dynamic",
281 MastNode::External(_) => "external",
282 };
283 panic!("tried to unwrap {expected} node, but got {actual}");
284}