miden_core/mast/node/
split_node.rs1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4use miden_crypto::{Felt, Word};
5use miden_formatting::prettier::PrettyPrint;
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9use super::{MastNodeErrorContext, MastNodeExt};
10use crate::{
11 Idx, OPCODE_SPLIT,
12 chiplets::hasher,
13 mast::{DecoratedOpLink, DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
14};
15
16#[derive(Debug, Clone, PartialEq, Eq)]
26#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
27pub struct SplitNode {
28 branches: [MastNodeId; 2],
29 digest: Word,
30 #[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty"))]
31 before_enter: Vec<DecoratorId>,
32 #[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty"))]
33 after_exit: Vec<DecoratorId>,
34}
35
36impl SplitNode {
38 pub const DOMAIN: Felt = Felt::new(OPCODE_SPLIT as u64);
40}
41
42impl SplitNode {
44 pub fn new(
45 branches: [MastNodeId; 2],
46 mast_forest: &MastForest,
47 ) -> Result<Self, MastForestError> {
48 let forest_len = mast_forest.nodes.len();
49 if branches[0].to_usize() >= forest_len {
50 return Err(MastForestError::NodeIdOverflow(branches[0], forest_len));
51 } else if branches[1].to_usize() >= forest_len {
52 return Err(MastForestError::NodeIdOverflow(branches[1], forest_len));
53 }
54 let digest = {
55 let if_branch_hash = mast_forest[branches[0]].digest();
56 let else_branch_hash = mast_forest[branches[1]].digest();
57
58 hasher::merge_in_domain(&[if_branch_hash, else_branch_hash], Self::DOMAIN)
59 };
60
61 Ok(Self {
62 branches,
63 digest,
64 before_enter: Vec::new(),
65 after_exit: Vec::new(),
66 })
67 }
68
69 pub fn new_unsafe(branches: [MastNodeId; 2], digest: Word) -> Self {
72 Self {
73 branches,
74 digest,
75 before_enter: Vec::new(),
76 after_exit: Vec::new(),
77 }
78 }
79}
80
81impl SplitNode {
83 pub fn on_true(&self) -> MastNodeId {
85 self.branches[0]
86 }
87
88 pub fn on_false(&self) -> MastNodeId {
90 self.branches[1]
91 }
92}
93
94impl MastNodeErrorContext for SplitNode {
95 fn decorators(&self) -> impl Iterator<Item = DecoratedOpLink> {
96 self.before_enter.iter().chain(&self.after_exit).copied().enumerate()
97 }
98}
99
100impl SplitNode {
104 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
105 SplitNodePrettyPrint { split_node: self, mast_forest }
106 }
107
108 pub(super) fn to_pretty_print<'a>(
109 &'a self,
110 mast_forest: &'a MastForest,
111 ) -> impl PrettyPrint + 'a {
112 SplitNodePrettyPrint { split_node: self, mast_forest }
113 }
114}
115
116struct SplitNodePrettyPrint<'a> {
117 split_node: &'a SplitNode,
118 mast_forest: &'a MastForest,
119}
120
121impl PrettyPrint for SplitNodePrettyPrint<'_> {
122 #[rustfmt::skip]
123 fn render(&self) -> crate::prettier::Document {
124 use crate::prettier::*;
125
126 let pre_decorators = {
127 let mut pre_decorators = self
128 .split_node
129 .before_enter()
130 .iter()
131 .map(|&decorator_id| self.mast_forest[decorator_id].render())
132 .reduce(|acc, doc| acc + const_text(" ") + doc)
133 .unwrap_or_default();
134 if !pre_decorators.is_empty() {
135 pre_decorators += nl();
136 }
137
138 pre_decorators
139 };
140
141 let post_decorators = {
142 let mut post_decorators = self
143 .split_node
144 .after_exit()
145 .iter()
146 .map(|&decorator_id| self.mast_forest[decorator_id].render())
147 .reduce(|acc, doc| acc + const_text(" ") + doc)
148 .unwrap_or_default();
149 if !post_decorators.is_empty() {
150 post_decorators = nl() + post_decorators;
151 }
152
153 post_decorators
154 };
155
156 let true_branch = self.mast_forest[self.split_node.on_true()].to_pretty_print(self.mast_forest);
157 let false_branch = self.mast_forest[self.split_node.on_false()].to_pretty_print(self.mast_forest);
158
159 let mut doc = pre_decorators;
160 doc += indent(4, const_text("if.true") + nl() + true_branch.render()) + nl();
161 doc += indent(4, const_text("else") + nl() + false_branch.render());
162 doc += nl() + const_text("end");
163 doc + post_decorators
164 }
165}
166
167impl fmt::Display for SplitNodePrettyPrint<'_> {
168 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169 use crate::prettier::PrettyPrint;
170 self.pretty_print(f)
171 }
172}
173
174impl MastNodeExt for SplitNode {
178 fn digest(&self) -> Word {
190 self.digest
191 }
192
193 fn before_enter(&self) -> &[DecoratorId] {
195 &self.before_enter
196 }
197
198 fn after_exit(&self) -> &[DecoratorId] {
200 &self.after_exit
201 }
202 fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
204 self.before_enter.extend_from_slice(decorator_ids);
205 }
206
207 fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
209 self.after_exit.extend_from_slice(decorator_ids);
210 }
211
212 fn remove_decorators(&mut self) {
214 self.before_enter.truncate(0);
215 self.after_exit.truncate(0);
216 }
217
218 fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
219 Box::new(SplitNode::to_display(self, mast_forest))
220 }
221
222 fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
223 Box::new(SplitNode::to_pretty_print(self, mast_forest))
224 }
225
226 fn remap_children(&self, remapping: &Remapping) -> Self {
227 let mut node = self.clone();
228 node.branches[0] = node.branches[0].remap(remapping);
229 node.branches[1] = node.branches[1].remap(remapping);
230 node
231 }
232
233 fn has_children(&self) -> bool {
234 true
235 }
236
237 fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
238 target.push(self.on_true());
239 target.push(self.on_false());
240 }
241
242 fn for_each_child<F>(&self, mut f: F)
243 where
244 F: FnMut(MastNodeId),
245 {
246 f(self.on_true());
247 f(self.on_false());
248 }
249
250 fn domain(&self) -> Felt {
251 Self::DOMAIN
252 }
253}