1mod basic_block_node;
2use alloc::{boxed::Box, vec::Vec};
3use core::fmt;
4
5pub use basic_block_node::{
6 BATCH_SIZE as OP_BATCH_SIZE, BasicBlockNode, GROUP_SIZE as OP_GROUP_SIZE, OpBatch,
7 OperationOrDecorator,
8};
9
10mod call_node;
11pub use call_node::CallNode;
12
13mod dyn_node;
14pub use dyn_node::DynNode;
15
16mod external;
17pub use external::ExternalNode;
18
19mod join_node;
20pub use join_node::JoinNode;
21
22mod split_node;
23use miden_crypto::{Felt, Word};
24use miden_formatting::prettier::{Document, PrettyPrint};
25pub use split_node::SplitNode;
26
27mod loop_node;
28pub use loop_node::LoopNode;
29
30use super::{DecoratorId, MastForestError};
31use crate::{
32 AssemblyOp, Decorator, DecoratorList, Operation,
33 mast::{MastForest, MastNodeId, Remapping},
34};
35
36#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum MastNode {
41 Block(BasicBlockNode),
42 Join(JoinNode),
43 Split(SplitNode),
44 Loop(LoopNode),
45 Call(CallNode),
46 Dyn(DynNode),
47 External(ExternalNode),
48}
49
50impl MastNode {
53 pub fn new_basic_block(
54 operations: Vec<Operation>,
55 decorators: Option<DecoratorList>,
56 ) -> Result<Self, MastForestError> {
57 let block = BasicBlockNode::new(operations, decorators)?;
58 Ok(Self::Block(block))
59 }
60
61 pub fn new_join(
62 left_child: MastNodeId,
63 right_child: MastNodeId,
64 mast_forest: &MastForest,
65 ) -> Result<Self, MastForestError> {
66 let join = JoinNode::new([left_child, right_child], mast_forest)?;
67 Ok(Self::Join(join))
68 }
69
70 pub fn new_split(
71 if_branch: MastNodeId,
72 else_branch: MastNodeId,
73 mast_forest: &MastForest,
74 ) -> Result<Self, MastForestError> {
75 let split = SplitNode::new([if_branch, else_branch], mast_forest)?;
76 Ok(Self::Split(split))
77 }
78
79 pub fn new_loop(body: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
80 let loop_node = LoopNode::new(body, mast_forest)?;
81 Ok(Self::Loop(loop_node))
82 }
83
84 pub fn new_call(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
85 let call = CallNode::new(callee, mast_forest)?;
86 Ok(Self::Call(call))
87 }
88
89 pub fn new_syscall(
90 callee: MastNodeId,
91 mast_forest: &MastForest,
92 ) -> Result<Self, MastForestError> {
93 let syscall = CallNode::new_syscall(callee, mast_forest)?;
94 Ok(Self::Call(syscall))
95 }
96
97 pub fn new_dyn() -> Self {
98 Self::Dyn(DynNode::new_dyn())
99 }
100 pub fn new_dyncall() -> Self {
101 Self::Dyn(DynNode::new_dyncall())
102 }
103
104 pub fn new_external(mast_root: Word) -> Self {
105 Self::External(ExternalNode::new(mast_root))
106 }
107
108 #[cfg(test)]
109 pub fn new_basic_block_with_raw_decorators(
110 operations: Vec<Operation>,
111 decorators: Vec<(usize, crate::Decorator)>,
112 mast_forest: &mut MastForest,
113 ) -> Result<Self, MastForestError> {
114 let block = BasicBlockNode::new_with_raw_decorators(operations, decorators, mast_forest)?;
115 Ok(Self::Block(block))
116 }
117}
118
119impl MastNode {
122 pub fn is_external(&self) -> bool {
124 matches!(self, MastNode::External(_))
125 }
126
127 pub fn is_dyn(&self) -> bool {
129 matches!(self, MastNode::Dyn(_))
130 }
131
132 pub fn is_basic_block(&self) -> bool {
134 matches!(self, Self::Block(_))
135 }
136
137 pub fn get_basic_block(&self) -> Option<&BasicBlockNode> {
140 match self {
141 MastNode::Block(basic_block_node) => Some(basic_block_node),
142 _ => None,
143 }
144 }
145
146 pub fn unwrap_basic_block(&self) -> &BasicBlockNode {
152 match self {
153 Self::Block(basic_block_node) => basic_block_node,
154 other => unwrap_failed(other, "basic block"),
155 }
156 }
157
158 pub fn unwrap_join(&self) -> &JoinNode {
163 match self {
164 Self::Join(join_node) => join_node,
165 other => unwrap_failed(other, "join"),
166 }
167 }
168
169 pub fn unwrap_split(&self) -> &SplitNode {
174 match self {
175 Self::Split(split_node) => split_node,
176 other => unwrap_failed(other, "split"),
177 }
178 }
179
180 pub fn unwrap_loop(&self) -> &LoopNode {
185 match self {
186 Self::Loop(loop_node) => loop_node,
187 other => unwrap_failed(other, "loop"),
188 }
189 }
190
191 pub fn unwrap_call(&self) -> &CallNode {
196 match self {
197 Self::Call(call_node) => call_node,
198 other => unwrap_failed(other, "call"),
199 }
200 }
201
202 pub fn unwrap_dyn(&self) -> &DynNode {
207 match self {
208 Self::Dyn(dyn_node) => dyn_node,
209 other => unwrap_failed(other, "dyn"),
210 }
211 }
212
213 pub fn unwrap_external(&self) -> &ExternalNode {
219 match self {
220 Self::External(external_node) => external_node,
221 other => unwrap_failed(other, "external"),
222 }
223 }
224
225 pub fn remap_children(&self, remapping: &Remapping) -> Self {
227 use MastNode::*;
228 match self {
229 Join(join_node) => Join(join_node.remap_children(remapping)),
230 Split(split_node) => Split(split_node.remap_children(remapping)),
231 Loop(loop_node) => Loop(loop_node.remap_children(remapping)),
232 Call(call_node) => Call(call_node.remap_children(remapping)),
233 Block(_) | Dyn(_) | External(_) => self.clone(),
234 }
235 }
236
237 pub fn has_children(&self) -> bool {
239 match &self {
240 MastNode::Join(_) | MastNode::Split(_) | MastNode::Loop(_) | MastNode::Call(_) => true,
241 MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => false,
242 }
243 }
244
245 pub fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
247 match &self {
248 MastNode::Join(join_node) => {
249 target.push(join_node.first());
250 target.push(join_node.second())
251 },
252 MastNode::Split(split_node) => {
253 target.push(split_node.on_true());
254 target.push(split_node.on_false())
255 },
256 MastNode::Loop(loop_node) => target.push(loop_node.body()),
257 MastNode::Call(call_node) => target.push(call_node.callee()),
258 MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => (),
259 }
260 }
261
262 pub fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> impl PrettyPrint + 'a {
263 match self {
264 MastNode::Block(basic_block_node) => {
265 MastNodePrettyPrint::new(Box::new(basic_block_node.to_pretty_print(mast_forest)))
266 },
267 MastNode::Join(join_node) => {
268 MastNodePrettyPrint::new(Box::new(join_node.to_pretty_print(mast_forest)))
269 },
270 MastNode::Split(split_node) => {
271 MastNodePrettyPrint::new(Box::new(split_node.to_pretty_print(mast_forest)))
272 },
273 MastNode::Loop(loop_node) => {
274 MastNodePrettyPrint::new(Box::new(loop_node.to_pretty_print(mast_forest)))
275 },
276 MastNode::Call(call_node) => {
277 MastNodePrettyPrint::new(Box::new(call_node.to_pretty_print(mast_forest)))
278 },
279 MastNode::Dyn(dyn_node) => {
280 MastNodePrettyPrint::new(Box::new(dyn_node.to_pretty_print(mast_forest)))
281 },
282 MastNode::External(external_node) => {
283 MastNodePrettyPrint::new(Box::new(external_node.to_pretty_print(mast_forest)))
284 },
285 }
286 }
287
288 pub fn domain(&self) -> Felt {
289 match self {
290 MastNode::Block(_) => BasicBlockNode::DOMAIN,
291 MastNode::Join(_) => JoinNode::DOMAIN,
292 MastNode::Split(_) => SplitNode::DOMAIN,
293 MastNode::Loop(_) => LoopNode::DOMAIN,
294 MastNode::Call(call_node) => call_node.domain(),
295 MastNode::Dyn(dyn_node) => dyn_node.domain(),
296 MastNode::External(_) => panic!("Can't fetch domain for an `External` node."),
297 }
298 }
299
300 pub fn digest(&self) -> Word {
301 match self {
302 MastNode::Block(node) => node.digest(),
303 MastNode::Join(node) => node.digest(),
304 MastNode::Split(node) => node.digest(),
305 MastNode::Loop(node) => node.digest(),
306 MastNode::Call(node) => node.digest(),
307 MastNode::Dyn(node) => node.digest(),
308 MastNode::External(node) => node.digest(),
309 }
310 }
311
312 pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
313 match self {
314 MastNode::Block(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
315 MastNode::Join(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
316 MastNode::Split(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
317 MastNode::Loop(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
318 MastNode::Call(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
319 MastNode::Dyn(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
320 MastNode::External(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
321 }
322 }
323
324 pub fn before_enter(&self) -> &[DecoratorId] {
326 use MastNode::*;
327 match self {
328 Block(_) => &[],
329 Join(node) => node.before_enter(),
330 Split(node) => node.before_enter(),
331 Loop(node) => node.before_enter(),
332 Call(node) => node.before_enter(),
333 Dyn(node) => node.before_enter(),
334 External(node) => node.before_enter(),
335 }
336 }
337
338 pub fn after_exit(&self) -> &[DecoratorId] {
340 use MastNode::*;
341 match self {
342 Block(_) => &[],
343 Join(node) => node.after_exit(),
344 Split(node) => node.after_exit(),
345 Loop(node) => node.after_exit(),
346 Call(node) => node.after_exit(),
347 Dyn(node) => node.after_exit(),
348 External(node) => node.after_exit(),
349 }
350 }
351}
352
353impl MastNode {
356 pub fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
358 match self {
359 MastNode::Block(node) => node.prepend_decorators(decorator_ids),
360 MastNode::Join(node) => node.append_before_enter(decorator_ids),
361 MastNode::Split(node) => node.append_before_enter(decorator_ids),
362 MastNode::Loop(node) => node.append_before_enter(decorator_ids),
363 MastNode::Call(node) => node.append_before_enter(decorator_ids),
364 MastNode::Dyn(node) => node.append_before_enter(decorator_ids),
365 MastNode::External(node) => node.append_before_enter(decorator_ids),
366 }
367 }
368
369 pub fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
371 match self {
372 MastNode::Block(node) => node.append_decorators(decorator_ids),
373 MastNode::Join(node) => node.append_after_exit(decorator_ids),
374 MastNode::Split(node) => node.append_after_exit(decorator_ids),
375 MastNode::Loop(node) => node.append_after_exit(decorator_ids),
376 MastNode::Call(node) => node.append_after_exit(decorator_ids),
377 MastNode::Dyn(node) => node.append_after_exit(decorator_ids),
378 MastNode::External(node) => node.append_after_exit(decorator_ids),
379 }
380 }
381
382 pub fn remove_decorators(&mut self) {
383 match self {
384 MastNode::Block(node) => node.remove_decorators(),
385 MastNode::Join(node) => node.remove_decorators(),
386 MastNode::Split(node) => node.remove_decorators(),
387 MastNode::Loop(node) => node.remove_decorators(),
388 MastNode::Call(node) => node.remove_decorators(),
389 MastNode::Dyn(node) => node.remove_decorators(),
390 MastNode::External(node) => node.remove_decorators(),
391 }
392 }
393}
394
395struct MastNodePrettyPrint<'a> {
399 node_pretty_print: Box<dyn PrettyPrint + 'a>,
400}
401
402impl<'a> MastNodePrettyPrint<'a> {
403 pub fn new(node_pretty_print: Box<dyn PrettyPrint + 'a>) -> Self {
404 Self { node_pretty_print }
405 }
406}
407
408impl PrettyPrint for MastNodePrettyPrint<'_> {
409 fn render(&self) -> Document {
410 self.node_pretty_print.render()
411 }
412}
413
414struct MastNodeDisplay<'a> {
415 node_display: Box<dyn fmt::Display + 'a>,
416}
417
418impl<'a> MastNodeDisplay<'a> {
419 pub fn new(node: impl fmt::Display + 'a) -> Self {
420 Self { node_display: Box::new(node) }
421 }
422}
423
424impl fmt::Display for MastNodeDisplay<'_> {
425 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426 self.node_display.fmt(f)
427 }
428}
429
430pub trait MastNodeExt: Send + Sync {
435 fn decorators(&self) -> impl Iterator<Item = (usize, DecoratorId)>;
443
444 fn get_assembly_op<'m>(
454 &self,
455 mast_forest: &'m MastForest,
456 target_op_idx: Option<usize>,
457 ) -> Option<&'m AssemblyOp> {
458 match target_op_idx {
459 Some(target_op_idx) => {
462 for (op_idx, decorator_id) in self.decorators() {
463 if let Some(Decorator::AsmOp(assembly_op)) =
464 mast_forest.get_decorator_by_id(decorator_id)
465 {
466 if target_op_idx >= op_idx
471 && target_op_idx < op_idx + assembly_op.num_cycles() as usize
472 {
473 return Some(assembly_op);
474 }
475 }
476 }
477 },
478 None => {
480 for (_, decorator_id) in self.decorators() {
481 if let Some(Decorator::AsmOp(assembly_op)) =
482 mast_forest.get_decorator_by_id(decorator_id)
483 {
484 return Some(assembly_op);
485 }
486 }
487 },
488 }
489
490 None
491 }
492}
493
494#[cold]
500#[inline(never)]
501#[track_caller]
502fn unwrap_failed(node: &MastNode, expected: &str) -> ! {
503 let actual = match node {
504 MastNode::Block(_) => "basic block",
505 MastNode::Join(_) => "join",
506 MastNode::Split(_) => "split",
507 MastNode::Loop(_) => "loop",
508 MastNode::Call(_) => "call",
509 MastNode::Dyn(_) => "dynamic",
510 MastNode::External(_) => "external",
511 };
512 panic!("tried to unwrap {expected} node, but got {actual}");
513}