1mod basic_block_node;
2use alloc::{boxed::Box, vec::Vec};
3use core::fmt;
4
5pub use basic_block_node::{
6 BasicBlockNode, OpBatch, OperationOrDecorator, BATCH_SIZE as OP_BATCH_SIZE,
7 GROUP_SIZE as OP_GROUP_SIZE,
8};
9
10mod call_node;
11pub use call_node::CallNode;
12
13mod dyn_node;
14pub use dyn_node::DynNode;
15
16mod external;
17pub use external::ExternalNode;
18
19mod join_node;
20pub use join_node::JoinNode;
21
22mod split_node;
23use miden_crypto::{hash::rpo::RpoDigest, Felt};
24use miden_formatting::prettier::{Document, PrettyPrint};
25pub use split_node::SplitNode;
26
27mod loop_node;
28pub use loop_node::LoopNode;
29
30use super::{DecoratorId, MastForestError};
31use crate::{
32 mast::{MastForest, MastNodeId},
33 DecoratorList, Operation,
34};
35
36#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum MastNode {
41 Block(BasicBlockNode),
42 Join(JoinNode),
43 Split(SplitNode),
44 Loop(LoopNode),
45 Call(CallNode),
46 Dyn(DynNode),
47 External(ExternalNode),
48}
49
50impl MastNode {
53 pub fn new_basic_block(
54 operations: Vec<Operation>,
55 decorators: Option<DecoratorList>,
56 ) -> Result<Self, MastForestError> {
57 let block = BasicBlockNode::new(operations, decorators)?;
58 Ok(Self::Block(block))
59 }
60
61 pub fn new_join(
62 left_child: MastNodeId,
63 right_child: MastNodeId,
64 mast_forest: &MastForest,
65 ) -> Result<Self, MastForestError> {
66 let join = JoinNode::new([left_child, right_child], mast_forest)?;
67 Ok(Self::Join(join))
68 }
69
70 pub fn new_split(
71 if_branch: MastNodeId,
72 else_branch: MastNodeId,
73 mast_forest: &MastForest,
74 ) -> Result<Self, MastForestError> {
75 let split = SplitNode::new([if_branch, else_branch], mast_forest)?;
76 Ok(Self::Split(split))
77 }
78
79 pub fn new_loop(body: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
80 let loop_node = LoopNode::new(body, mast_forest)?;
81 Ok(Self::Loop(loop_node))
82 }
83
84 pub fn new_call(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
85 let call = CallNode::new(callee, mast_forest)?;
86 Ok(Self::Call(call))
87 }
88
89 pub fn new_syscall(
90 callee: MastNodeId,
91 mast_forest: &MastForest,
92 ) -> Result<Self, MastForestError> {
93 let syscall = CallNode::new_syscall(callee, mast_forest)?;
94 Ok(Self::Call(syscall))
95 }
96
97 pub fn new_dyn() -> Self {
98 Self::Dyn(DynNode::new_dyn())
99 }
100 pub fn new_dyncall() -> Self {
101 Self::Dyn(DynNode::new_dyncall())
102 }
103
104 pub fn new_external(mast_root: RpoDigest) -> Self {
105 Self::External(ExternalNode::new(mast_root))
106 }
107
108 #[cfg(test)]
109 pub fn new_basic_block_with_raw_decorators(
110 operations: Vec<Operation>,
111 decorators: Vec<(usize, crate::Decorator)>,
112 mast_forest: &mut MastForest,
113 ) -> Result<Self, MastForestError> {
114 let block = BasicBlockNode::new_with_raw_decorators(operations, decorators, mast_forest)?;
115 Ok(Self::Block(block))
116 }
117}
118
119impl MastNode {
122 pub fn is_external(&self) -> bool {
124 matches!(self, MastNode::External(_))
125 }
126
127 pub fn is_dyn(&self) -> bool {
129 matches!(self, MastNode::Dyn(_))
130 }
131
132 pub fn is_basic_block(&self) -> bool {
134 matches!(self, Self::Block(_))
135 }
136
137 pub fn get_basic_block(&self) -> Option<&BasicBlockNode> {
140 match self {
141 MastNode::Block(basic_block_node) => Some(basic_block_node),
142 _ => None,
143 }
144 }
145
146 pub fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> impl PrettyPrint + 'a {
147 match self {
148 MastNode::Block(basic_block_node) => {
149 MastNodePrettyPrint::new(Box::new(basic_block_node.to_pretty_print(mast_forest)))
150 },
151 MastNode::Join(join_node) => {
152 MastNodePrettyPrint::new(Box::new(join_node.to_pretty_print(mast_forest)))
153 },
154 MastNode::Split(split_node) => {
155 MastNodePrettyPrint::new(Box::new(split_node.to_pretty_print(mast_forest)))
156 },
157 MastNode::Loop(loop_node) => {
158 MastNodePrettyPrint::new(Box::new(loop_node.to_pretty_print(mast_forest)))
159 },
160 MastNode::Call(call_node) => {
161 MastNodePrettyPrint::new(Box::new(call_node.to_pretty_print(mast_forest)))
162 },
163 MastNode::Dyn(dyn_node) => {
164 MastNodePrettyPrint::new(Box::new(dyn_node.to_pretty_print(mast_forest)))
165 },
166 MastNode::External(external_node) => {
167 MastNodePrettyPrint::new(Box::new(external_node.to_pretty_print(mast_forest)))
168 },
169 }
170 }
171
172 pub fn domain(&self) -> Felt {
173 match self {
174 MastNode::Block(_) => BasicBlockNode::DOMAIN,
175 MastNode::Join(_) => JoinNode::DOMAIN,
176 MastNode::Split(_) => SplitNode::DOMAIN,
177 MastNode::Loop(_) => LoopNode::DOMAIN,
178 MastNode::Call(call_node) => call_node.domain(),
179 MastNode::Dyn(dyn_node) => dyn_node.domain(),
180 MastNode::External(_) => panic!("Can't fetch domain for an `External` node."),
181 }
182 }
183
184 pub fn digest(&self) -> RpoDigest {
185 match self {
186 MastNode::Block(node) => node.digest(),
187 MastNode::Join(node) => node.digest(),
188 MastNode::Split(node) => node.digest(),
189 MastNode::Loop(node) => node.digest(),
190 MastNode::Call(node) => node.digest(),
191 MastNode::Dyn(node) => node.digest(),
192 MastNode::External(node) => node.digest(),
193 }
194 }
195
196 pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
197 match self {
198 MastNode::Block(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
199 MastNode::Join(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
200 MastNode::Split(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
201 MastNode::Loop(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
202 MastNode::Call(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
203 MastNode::Dyn(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
204 MastNode::External(node) => MastNodeDisplay::new(node.to_display(mast_forest)),
205 }
206 }
207
208 pub fn before_enter(&self) -> &[DecoratorId] {
210 use MastNode::*;
211 match self {
212 Block(_) => &[],
213 Join(node) => node.before_enter(),
214 Split(node) => node.before_enter(),
215 Loop(node) => node.before_enter(),
216 Call(node) => node.before_enter(),
217 Dyn(node) => node.before_enter(),
218 External(node) => node.before_enter(),
219 }
220 }
221
222 pub fn after_exit(&self) -> &[DecoratorId] {
224 use MastNode::*;
225 match self {
226 Block(_) => &[],
227 Join(node) => node.after_exit(),
228 Split(node) => node.after_exit(),
229 Loop(node) => node.after_exit(),
230 Call(node) => node.after_exit(),
231 Dyn(node) => node.after_exit(),
232 External(node) => node.after_exit(),
233 }
234 }
235}
236
237impl MastNode {
239 pub fn set_before_enter(&mut self, decorator_ids: Vec<DecoratorId>) {
241 match self {
242 MastNode::Block(node) => node.prepend_decorators(decorator_ids),
243 MastNode::Join(node) => node.set_before_enter(decorator_ids),
244 MastNode::Split(node) => node.set_before_enter(decorator_ids),
245 MastNode::Loop(node) => node.set_before_enter(decorator_ids),
246 MastNode::Call(node) => node.set_before_enter(decorator_ids),
247 MastNode::Dyn(node) => node.set_before_enter(decorator_ids),
248 MastNode::External(node) => node.set_before_enter(decorator_ids),
249 }
250 }
251
252 pub fn set_after_exit(&mut self, decorator_ids: Vec<DecoratorId>) {
254 match self {
255 MastNode::Block(node) => node.append_decorators(decorator_ids),
256 MastNode::Join(node) => node.set_after_exit(decorator_ids),
257 MastNode::Split(node) => node.set_after_exit(decorator_ids),
258 MastNode::Loop(node) => node.set_after_exit(decorator_ids),
259 MastNode::Call(node) => node.set_after_exit(decorator_ids),
260 MastNode::Dyn(node) => node.set_after_exit(decorator_ids),
261 MastNode::External(node) => node.set_after_exit(decorator_ids),
262 }
263 }
264}
265
266struct MastNodePrettyPrint<'a> {
270 node_pretty_print: Box<dyn PrettyPrint + 'a>,
271}
272
273impl<'a> MastNodePrettyPrint<'a> {
274 pub fn new(node_pretty_print: Box<dyn PrettyPrint + 'a>) -> Self {
275 Self { node_pretty_print }
276 }
277}
278
279impl PrettyPrint for MastNodePrettyPrint<'_> {
280 fn render(&self) -> Document {
281 self.node_pretty_print.render()
282 }
283}
284
285struct MastNodeDisplay<'a> {
286 node_display: Box<dyn fmt::Display + 'a>,
287}
288
289impl<'a> MastNodeDisplay<'a> {
290 pub fn new(node: impl fmt::Display + 'a) -> Self {
291 Self { node_display: Box::new(node) }
292 }
293}
294
295impl fmt::Display for MastNodeDisplay<'_> {
296 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
297 self.node_display.fmt(f)
298 }
299}