1mod basic_block_node;
2use alloc::{boxed::Box, vec::Vec};
3use core::fmt;
4
5pub use basic_block_node::{
6 BATCH_SIZE as OP_BATCH_SIZE, BasicBlockNode, GROUP_SIZE as OP_GROUP_SIZE, OpBatch,
7 OperationOrDecorator,
8};
9
10mod call_node;
11pub use call_node::CallNode;
12
13mod dyn_node;
14pub use dyn_node::DynNode;
15
16mod external;
17pub use external::ExternalNode;
18
19mod join_node;
20pub use join_node::JoinNode;
21
22mod split_node;
23use miden_crypto::{Felt, hash::rpo::RpoDigest};
24use miden_formatting::prettier::{Document, PrettyPrint};
25pub use split_node::SplitNode;
26
27mod loop_node;
28pub use loop_node::LoopNode;
29
30use super::{DecoratorId, MastForestError};
31use crate::{
32 DecoratorList, Operation,
33 mast::{MastForest, MastNodeId, Remapping},
34};
35
36#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum MastNode {
41 Block(BasicBlockNode),
42 Join(JoinNode),
43 Split(SplitNode),
44 Loop(LoopNode),
45 Call(CallNode),
46 Dyn(DynNode),
47 External(ExternalNode),
48}
49
50impl MastNode {
53 pub fn new_basic_block(
54 operations: Vec<Operation>,
55 decorators: Option<DecoratorList>,
56 ) -> Result<Self, MastForestError> {
57 let block = BasicBlockNode::new(operations, decorators)?;
58 Ok(Self::Block(block))
59 }
60
61 pub fn new_join(
62 left_child: MastNodeId,
63 right_child: MastNodeId,
64 mast_forest: &MastForest,
65 ) -> Result<Self, MastForestError> {
66 let join = JoinNode::new([left_child, right_child], mast_forest)?;
67 Ok(Self::Join(join))
68 }
69
70 pub fn new_split(
71 if_branch: MastNodeId,
72 else_branch: MastNodeId,
73 mast_forest: &MastForest,
74 ) -> Result<Self, MastForestError> {
75 let split = SplitNode::new([if_branch, else_branch], mast_forest)?;
76 Ok(Self::Split(split))
77 }
78
79 pub fn new_loop(body: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
80 let loop_node = LoopNode::new(body, mast_forest)?;
81 Ok(Self::Loop(loop_node))
82 }
83
84 pub fn new_call(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
85 let call = CallNode::new(callee, mast_forest)?;
86 Ok(Self::Call(call))
87 }
88
89 pub fn new_syscall(
90 callee: MastNodeId,
91 mast_forest: &MastForest,
92 ) -> Result<Self, MastForestError> {
93 let syscall = CallNode::new_syscall(callee, mast_forest)?;
94 Ok(Self::Call(syscall))
95 }
96
97 pub fn new_dyn() -> Self {
98 Self::Dyn(DynNode::new_dyn())
99 }
100 pub fn new_dyncall() -> Self {
101 Self::Dyn(DynNode::new_dyncall())
102 }
103
104 pub fn new_external(mast_root: RpoDigest) -> Self {
105 Self::External(ExternalNode::new(mast_root))
106 }
107
108 #[cfg(test)]
109 pub fn new_basic_block_with_raw_decorators(
110 operations: Vec<Operation>,
111 decorators: Vec<(usize, crate::Decorator)>,
112 mast_forest: &mut MastForest,
113 ) -> Result<Self, MastForestError> {
114 let block = BasicBlockNode::new_with_raw_decorators(operations, decorators, mast_forest)?;
115 Ok(Self::Block(block))
116 }
117}
118
119impl MastNode {
122 pub fn is_external(&self) -> bool {
124 matches!(self, MastNode::External(_))
125 }
126
127 pub fn is_dyn(&self) -> bool {
129 matches!(self, MastNode::Dyn(_))
130 }
131
132 pub fn is_basic_block(&self) -> bool {
134 matches!(self, Self::Block(_))
135 }
136
137 pub fn get_basic_block(&self) -> Option<&BasicBlockNode> {
140 match self {
141 MastNode::Block(basic_block_node) => Some(basic_block_node),
142 _ => None,
143 }
144 }
145
146 pub fn remap_children(&self, remapping: &Remapping) -> Self {
148 use MastNode::*;
149 match self {
150 Join(join_node) => Join(join_node.remap_children(remapping)),
151 Split(split_node) => Split(split_node.remap_children(remapping)),
152 Loop(loop_node) => Loop(loop_node.remap_children(remapping)),
153 Call(call_node) => Call(call_node.remap_children(remapping)),
154 Block(_) | Dyn(_) | External(_) => self.clone(),
155 }
156 }
157
158 pub fn has_children(&self) -> bool {
160 match &self {
161 MastNode::Join(_) | MastNode::Split(_) | MastNode::Loop(_) | MastNode::Call(_) => true,
162 MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => false,
163 }
164 }
165
166 pub fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
168 match &self {
169 MastNode::Join(join_node) => {
170 target.push(join_node.first());
171 target.push(join_node.second())
172 },
173 MastNode::Split(split_node) => {
174 target.push(split_node.on_true());
175 target.push(split_node.on_false())
176 },
177 MastNode::Loop(loop_node) => target.push(loop_node.body()),
178 MastNode::Call(call_node) => target.push(call_node.callee()),
179 MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => (),
180 }
181 }
182
183 pub fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> impl PrettyPrint + 'a {
184 match self {
185 MastNode::Block(basic_block_node) => {
186 MastNodePrettyPrint::new(Box::new(basic_block_node.to_pretty_print(mast_forest)))
187 },
188 MastNode::Join(join_node) => {
189 MastNodePrettyPrint::new(Box::new(join_node.to_pretty_print(mast_forest)))
190 },
191 MastNode::Split(split_node) => {
192 MastNodePrettyPrint::new(Box::new(split_node.to_pretty_print(mast_forest)))
193 },
194 MastNode::Loop(loop_node) => {
195 MastNodePrettyPrint::new(Box::new(loop_node.to_pretty_print(mast_forest)))
196 },
197 MastNode::Call(call_node) => {
198 MastNodePrettyPrint::new(Box::new(call_node.to_pretty_print(mast_forest)))
199 },
200 MastNode::Dyn(dyn_node) => {
201 MastNodePrettyPrint::new(Box::new(dyn_node.to_pretty_print(mast_forest)))
202 },
203 MastNode::External(external_node) => {
204 MastNodePrettyPrint::new(Box::new(external_node.to_pretty_print(mast_forest)))
205 },
206 }
207 }
208
209 pub fn domain(&self) -> Felt {
210 match self {
211 MastNode::Block(_) => BasicBlockNode::DOMAIN,
212 MastNode::Join(_) => JoinNode::DOMAIN,
213 MastNode::Split(_) => SplitNode::DOMAIN,
214 MastNode::Loop(_) => LoopNode::DOMAIN,
215 MastNode::Call(call_node) => call_node.domain(),
216 MastNode::Dyn(dyn_node) => dyn_node.domain(),
217 MastNode::External(_) => panic!("Can't fetch domain for an `External` node."),
218 }
219 }
220
221 pub fn digest(&self) -> RpoDigest {
222 match self {
223 MastNode::Block(node) => node.digest(),
224 MastNode::Join(node) => node.digest(),
225 MastNode::Split(node) => node.digest(),
226 MastNode::Loop(node) => node.digest(),
227 MastNode::Call(node) => node.digest(),
228 MastNode::Dyn(node) => node.digest(),
229 MastNode::External(node) => node.digest(),
230 }
231 }
232
233 pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
234 match self {
235 MastNode::Block(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
236 MastNode::Join(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
237 MastNode::Split(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
238 MastNode::Loop(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
239 MastNode::Call(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
240 MastNode::Dyn(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
241 MastNode::External(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
242 }
243 }
244
245 pub fn before_enter(&self) -> &[DecoratorId] {
247 use MastNode::*;
248 match self {
249 Block(_) => &[],
250 Join(node) => node.before_enter(),
251 Split(node) => node.before_enter(),
252 Loop(node) => node.before_enter(),
253 Call(node) => node.before_enter(),
254 Dyn(node) => node.before_enter(),
255 External(node) => node.before_enter(),
256 }
257 }
258
259 pub fn after_exit(&self) -> &[DecoratorId] {
261 use MastNode::*;
262 match self {
263 Block(_) => &[],
264 Join(node) => node.after_exit(),
265 Split(node) => node.after_exit(),
266 Loop(node) => node.after_exit(),
267 Call(node) => node.after_exit(),
268 Dyn(node) => node.after_exit(),
269 External(node) => node.after_exit(),
270 }
271 }
272}
273
274impl MastNode {
276 pub fn set_before_enter(&mut self, decorator_ids: Vec<DecoratorId>) {
278 match self {
279 MastNode::Block(node) => node.prepend_decorators(decorator_ids),
280 MastNode::Join(node) => node.set_before_enter(decorator_ids),
281 MastNode::Split(node) => node.set_before_enter(decorator_ids),
282 MastNode::Loop(node) => node.set_before_enter(decorator_ids),
283 MastNode::Call(node) => node.set_before_enter(decorator_ids),
284 MastNode::Dyn(node) => node.set_before_enter(decorator_ids),
285 MastNode::External(node) => node.set_before_enter(decorator_ids),
286 }
287 }
288
289 pub fn set_after_exit(&mut self, decorator_ids: Vec<DecoratorId>) {
291 match self {
292 MastNode::Block(node) => node.append_decorators(decorator_ids),
293 MastNode::Join(node) => node.set_after_exit(decorator_ids),
294 MastNode::Split(node) => node.set_after_exit(decorator_ids),
295 MastNode::Loop(node) => node.set_after_exit(decorator_ids),
296 MastNode::Call(node) => node.set_after_exit(decorator_ids),
297 MastNode::Dyn(node) => node.set_after_exit(decorator_ids),
298 MastNode::External(node) => node.set_after_exit(decorator_ids),
299 }
300 }
301}
302
303struct MastNodePrettyPrint<'a> {
307 node_pretty_print: Box<dyn PrettyPrint + 'a>,
308}
309
310impl<'a> MastNodePrettyPrint<'a> {
311 pub fn new(node_pretty_print: Box<dyn PrettyPrint + 'a>) -> Self {
312 Self { node_pretty_print }
313 }
314}
315
316impl PrettyPrint for MastNodePrettyPrint<'_> {
317 fn render(&self) -> Document {
318 self.node_pretty_print.render()
319 }
320}
321
322struct MastNodeDisplay<'a> {
323 node_display: Box<dyn fmt::Display + 'a>,
324}
325
326impl<'a> MastNodeDisplay<'a> {
327 pub fn new(node: impl fmt::Display + 'a) -> Self {
328 Self { node_display: Box::new(node) }
329 }
330}
331
332impl fmt::Display for MastNodeDisplay<'_> {
333 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334 self.node_display.fmt(f)
335 }
336}