miden_core/mast/node/mod.rs
1mod basic_block_node;
2use alloc::{boxed::Box, vec::Vec};
3use core::fmt;
4
5pub use basic_block_node::{
6 BATCH_SIZE as OP_BATCH_SIZE, BasicBlockNode, DecoratorOpLinkIterator,
7 GROUP_SIZE as OP_GROUP_SIZE, OpBatch, OperationOrDecorator,
8};
9use enum_dispatch::enum_dispatch;
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12
13mod call_node;
14pub use call_node::CallNode;
15
16mod dyn_node;
17pub use dyn_node::DynNode;
18
19mod external;
20pub use external::ExternalNode;
21
22mod join_node;
23pub use join_node::JoinNode;
24
25mod split_node;
26use miden_crypto::{Felt, Word};
27use miden_formatting::prettier::PrettyPrint;
28pub use split_node::SplitNode;
29
30mod loop_node;
31#[cfg(any(test, feature = "arbitrary"))]
32pub use basic_block_node::arbitrary;
33pub use loop_node::LoopNode;
34
35use super::DecoratorId;
36use crate::{
37 AssemblyOp, Decorator,
38 mast::{MastForest, MastNodeId, Remapping},
39};
40
41#[enum_dispatch]
42pub trait MastNodeExt {
43 /// Returns a commitment/hash of the node.
44 fn digest(&self) -> Word;
45
46 /// Returns the decorators to be executed before this node is executed.
47 fn before_enter(&self) -> &[DecoratorId];
48
49 /// Returns the decorators to be executed after this node is executed.
50 fn after_exit(&self) -> &[DecoratorId];
51
52 /// Sets the list of decorators to be executed before this node.
53 fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]);
54
55 /// Sets the list of decorators to be executed after this node.
56 fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]);
57
58 /// Removes all decorators from this node.
59 fn remove_decorators(&mut self);
60
61 /// Returns a display formatter for this node.
62 fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a>;
63
64 /// Returns a pretty printer for this node.
65 fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a>;
66
67 /// Remap the node children to their new positions indicated by the given [`Remapping`].
68 fn remap_children(&self, remapping: &Remapping) -> Self;
69
70 /// Returns true if the this node has children.
71 fn has_children(&self) -> bool;
72
73 /// Appends the NodeIds of the children of this node, if any, to the vector.
74 fn append_children_to(&self, target: &mut Vec<MastNodeId>);
75
76 /// Executes the given closure for each child of this node.
77 fn for_each_child<F>(&self, f: F)
78 where
79 F: FnMut(MastNodeId);
80
81 /// Returns the domain of this node.
82 fn domain(&self) -> Felt;
83}
84
85// MAST NODE
86// ================================================================================================
87
88#[enum_dispatch(MastNodeExt)]
89#[derive(Debug, Clone, PartialEq, Eq)]
90#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
91pub enum MastNode {
92 Block(BasicBlockNode),
93 Join(JoinNode),
94 Split(SplitNode),
95 Loop(LoopNode),
96 Call(CallNode),
97 Dyn(DynNode),
98 External(ExternalNode),
99}
100
101// ------------------------------------------------------------------------------------------------
102/// Public accessors
103impl MastNode {
104 /// Returns true if this node is an external node.
105 pub fn is_external(&self) -> bool {
106 matches!(self, MastNode::External(_))
107 }
108
109 /// Returns true if this node is a Dyn node.
110 pub fn is_dyn(&self) -> bool {
111 matches!(self, MastNode::Dyn(_))
112 }
113
114 /// Returns true if this node is a basic block.
115 pub fn is_basic_block(&self) -> bool {
116 matches!(self, Self::Block(_))
117 }
118
119 /// Returns the inner basic block node if the [`MastNode`] wraps a [`BasicBlockNode`]; `None`
120 /// otherwise.
121 pub fn get_basic_block(&self) -> Option<&BasicBlockNode> {
122 match self {
123 MastNode::Block(basic_block_node) => Some(basic_block_node),
124 _ => None,
125 }
126 }
127
128 /// Unwraps the inner basic block node if the [`MastNode`] wraps a [`BasicBlockNode`]; panics
129 /// otherwise.
130 ///
131 /// # Panics
132 /// Panics if the [`MastNode`] does not wrap a [`BasicBlockNode`].
133 pub fn unwrap_basic_block(&self) -> &BasicBlockNode {
134 match self {
135 Self::Block(basic_block_node) => basic_block_node,
136 other => unwrap_failed(other, "basic block"),
137 }
138 }
139
140 /// Unwraps the inner join node if the [`MastNode`] wraps a [`JoinNode`]; panics otherwise.
141 ///
142 /// # Panics
143 /// - if the [`MastNode`] does not wrap a [`JoinNode`].
144 pub fn unwrap_join(&self) -> &JoinNode {
145 match self {
146 Self::Join(join_node) => join_node,
147 other => unwrap_failed(other, "join"),
148 }
149 }
150
151 /// Unwraps the inner split node if the [`MastNode`] wraps a [`SplitNode`]; panics otherwise.
152 ///
153 /// # Panics
154 /// - if the [`MastNode`] does not wrap a [`SplitNode`].
155 pub fn unwrap_split(&self) -> &SplitNode {
156 match self {
157 Self::Split(split_node) => split_node,
158 other => unwrap_failed(other, "split"),
159 }
160 }
161
162 /// Unwraps the inner loop node if the [`MastNode`] wraps a [`LoopNode`]; panics otherwise.
163 ///
164 /// # Panics
165 /// - if the [`MastNode`] does not wrap a [`LoopNode`].
166 pub fn unwrap_loop(&self) -> &LoopNode {
167 match self {
168 Self::Loop(loop_node) => loop_node,
169 other => unwrap_failed(other, "loop"),
170 }
171 }
172
173 /// Unwraps the inner call node if the [`MastNode`] wraps a [`CallNode`]; panics otherwise.
174 ///
175 /// # Panics
176 /// - if the [`MastNode`] does not wrap a [`CallNode`].
177 pub fn unwrap_call(&self) -> &CallNode {
178 match self {
179 Self::Call(call_node) => call_node,
180 other => unwrap_failed(other, "call"),
181 }
182 }
183
184 /// Unwraps the inner dynamic node if the [`MastNode`] wraps a [`DynNode`]; panics otherwise.
185 ///
186 /// # Panics
187 /// - if the [`MastNode`] does not wrap a [`DynNode`].
188 pub fn unwrap_dyn(&self) -> &DynNode {
189 match self {
190 Self::Dyn(dyn_node) => dyn_node,
191 other => unwrap_failed(other, "dyn"),
192 }
193 }
194
195 /// Unwraps the inner external node if the [`MastNode`] wraps a [`ExternalNode`]; panics
196 /// otherwise.
197 ///
198 /// # Panics
199 /// - if the [`MastNode`] does not wrap a [`ExternalNode`].
200 pub fn unwrap_external(&self) -> &ExternalNode {
201 match self {
202 Self::External(external_node) => external_node,
203 other => unwrap_failed(other, "external"),
204 }
205 }
206}
207
208// MAST INNER NODE EXT
209// ===============================================================================================
210
211/// A trait for extending the functionality of all [`MastNode`]s.
212pub trait MastNodeErrorContext: Send + Sync {
213 // REQUIRED METHODS
214 // -------------------------------------------------------------------------------------------
215
216 /// The list of decorators tied to this node, along with their associated index.
217 ///
218 /// The index is only meaningful for [`BasicBlockNode`]s, where it corresponds to the index of
219 /// the operation in the basic block to which the decorator is attached.
220 fn decorators(&self) -> impl Iterator<Item = DecoratedOpLink>;
221
222 // PROVIDED METHODS
223 // -------------------------------------------------------------------------------------------
224
225 /// Returns the [`AssemblyOp`] associated with this node and operation (if provided), if any.
226 ///
227 /// If the `target_op_idx` is provided, the method treats the wrapped node as a basic block will
228 /// return the assembly op associated with the operation at the corresponding index in the basic
229 /// block. If no `target_op_idx` is provided, the method will return the first assembly op found
230 /// (effectively assuming that the node has at most one associated [`AssemblyOp`]).
231 fn get_assembly_op<'m>(
232 &self,
233 mast_forest: &'m MastForest,
234 target_op_idx: Option<usize>,
235 ) -> Option<&'m AssemblyOp> {
236 match target_op_idx {
237 // If a target operation index is provided, return the assembly op associated with that
238 // operation.
239 Some(target_op_idx) => {
240 for (op_idx, decorator_id) in self.decorators() {
241 if let Some(Decorator::AsmOp(assembly_op)) =
242 mast_forest.get_decorator_by_id(decorator_id)
243 {
244 // when an instruction compiles down to multiple operations, only the first
245 // operation is associated with the assembly op. We need to check if the
246 // target operation index falls within the range of operations associated
247 // with the assembly op.
248 if target_op_idx >= op_idx
249 && target_op_idx < op_idx + assembly_op.num_cycles() as usize
250 {
251 return Some(assembly_op);
252 }
253 }
254 }
255 },
256 // If no target operation index is provided, return the first assembly op found.
257 None => {
258 for (_, decorator_id) in self.decorators() {
259 if let Some(Decorator::AsmOp(assembly_op)) =
260 mast_forest.get_decorator_by_id(decorator_id)
261 {
262 return Some(assembly_op);
263 }
264 }
265 },
266 }
267
268 None
269 }
270}
271
272// Links an operation index in a block to a decoratorid, to be executed right before this
273// operation's position
274pub type DecoratedOpLink = (usize, DecoratorId);
275
276// HELPERS
277// ===============================================================================================
278
279/// This function is analogous to the `unwrap_failed()` function used in the implementation of
280/// `core::result::Result` `unwrap_*()` methods.
281#[cold]
282#[inline(never)]
283#[track_caller]
284fn unwrap_failed(node: &MastNode, expected: &str) -> ! {
285 let actual = match node {
286 MastNode::Block(_) => "basic block",
287 MastNode::Join(_) => "join",
288 MastNode::Split(_) => "split",
289 MastNode::Loop(_) => "loop",
290 MastNode::Call(_) => "call",
291 MastNode::Dyn(_) => "dynamic",
292 MastNode::External(_) => "external",
293 };
294 panic!("tried to unwrap {expected} node, but got {actual}");
295}