miden_core/mast/node/
split_node.rs1use alloc::vec::Vec;
2use core::fmt;
3
4use miden_crypto::{hash::rpo::RpoDigest, Felt};
5use miden_formatting::prettier::PrettyPrint;
6
7use crate::{
8 chiplets::hasher,
9 mast::{DecoratorId, MastForest, MastForestError, MastNodeId},
10 OPCODE_SPLIT,
11};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct SplitNode {
24 branches: [MastNodeId; 2],
25 digest: RpoDigest,
26 before_enter: Vec<DecoratorId>,
27 after_exit: Vec<DecoratorId>,
28}
29
30impl SplitNode {
32 pub const DOMAIN: Felt = Felt::new(OPCODE_SPLIT as u64);
34}
35
36impl SplitNode {
38 pub fn new(
39 branches: [MastNodeId; 2],
40 mast_forest: &MastForest,
41 ) -> Result<Self, MastForestError> {
42 let forest_len = mast_forest.nodes.len();
43 if branches[0].as_usize() >= forest_len {
44 return Err(MastForestError::NodeIdOverflow(branches[0], forest_len));
45 } else if branches[1].as_usize() >= forest_len {
46 return Err(MastForestError::NodeIdOverflow(branches[1], forest_len));
47 }
48 let digest = {
49 let if_branch_hash = mast_forest[branches[0]].digest();
50 let else_branch_hash = mast_forest[branches[1]].digest();
51
52 hasher::merge_in_domain(&[if_branch_hash, else_branch_hash], Self::DOMAIN)
53 };
54
55 Ok(Self {
56 branches,
57 digest,
58 before_enter: Vec::new(),
59 after_exit: Vec::new(),
60 })
61 }
62
63 pub fn new_unsafe(branches: [MastNodeId; 2], digest: RpoDigest) -> Self {
66 Self {
67 branches,
68 digest,
69 before_enter: Vec::new(),
70 after_exit: Vec::new(),
71 }
72 }
73}
74
75impl SplitNode {
77 pub fn digest(&self) -> RpoDigest {
89 self.digest
90 }
91
92 pub fn on_true(&self) -> MastNodeId {
94 self.branches[0]
95 }
96
97 pub fn on_false(&self) -> MastNodeId {
99 self.branches[1]
100 }
101
102 pub fn before_enter(&self) -> &[DecoratorId] {
104 &self.before_enter
105 }
106
107 pub fn after_exit(&self) -> &[DecoratorId] {
109 &self.after_exit
110 }
111}
112
113impl SplitNode {
115 pub fn set_before_enter(&mut self, decorator_ids: Vec<DecoratorId>) {
117 self.before_enter = decorator_ids;
118 }
119
120 pub fn set_after_exit(&mut self, decorator_ids: Vec<DecoratorId>) {
122 self.after_exit = decorator_ids;
123 }
124}
125
126impl SplitNode {
130 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
131 SplitNodePrettyPrint { split_node: self, mast_forest }
132 }
133
134 pub(super) fn to_pretty_print<'a>(
135 &'a self,
136 mast_forest: &'a MastForest,
137 ) -> impl PrettyPrint + 'a {
138 SplitNodePrettyPrint { split_node: self, mast_forest }
139 }
140}
141
142struct SplitNodePrettyPrint<'a> {
143 split_node: &'a SplitNode,
144 mast_forest: &'a MastForest,
145}
146
147impl PrettyPrint for SplitNodePrettyPrint<'_> {
148 #[rustfmt::skip]
149 fn render(&self) -> crate::prettier::Document {
150 use crate::prettier::*;
151
152 let pre_decorators = {
153 let mut pre_decorators = self
154 .split_node
155 .before_enter()
156 .iter()
157 .map(|&decorator_id| self.mast_forest[decorator_id].render())
158 .reduce(|acc, doc| acc + const_text(" ") + doc)
159 .unwrap_or_default();
160 if !pre_decorators.is_empty() {
161 pre_decorators += nl();
162 }
163
164 pre_decorators
165 };
166
167 let post_decorators = {
168 let mut post_decorators = self
169 .split_node
170 .after_exit()
171 .iter()
172 .map(|&decorator_id| self.mast_forest[decorator_id].render())
173 .reduce(|acc, doc| acc + const_text(" ") + doc)
174 .unwrap_or_default();
175 if !post_decorators.is_empty() {
176 post_decorators = nl() + post_decorators;
177 }
178
179 post_decorators
180 };
181
182 let true_branch = self.mast_forest[self.split_node.on_true()].to_pretty_print(self.mast_forest);
183 let false_branch = self.mast_forest[self.split_node.on_false()].to_pretty_print(self.mast_forest);
184
185 let mut doc = pre_decorators;
186 doc += indent(4, const_text("if.true") + nl() + true_branch.render()) + nl();
187 doc += indent(4, const_text("else") + nl() + false_branch.render());
188 doc += nl() + const_text("end");
189 doc + post_decorators
190 }
191}
192
193impl fmt::Display for SplitNodePrettyPrint<'_> {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 use crate::prettier::PrettyPrint;
196 self.pretty_print(f)
197 }
198}