1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use super::{MastForestContributor, MastNodeExt, fingerprint_with_child_fingerprints};
8use crate::{
9 Felt, Word,
10 chiplets::hasher,
11 mast::{MastForest, MastForestError, MastNodeId},
12 operations::opcodes,
13 prettier::PrettyPrint,
14 utils::{Idx, LookupByIdx},
15};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
29pub struct SplitNode {
30 branches: [MastNodeId; 2],
31 digest: Word,
32}
33
34impl SplitNode {
36 pub const DOMAIN: Felt = Felt::new_unchecked(opcodes::SPLIT as u64);
38}
39
40impl SplitNode {
42 pub fn on_true(&self) -> MastNodeId {
44 self.branches[0]
45 }
46
47 pub fn on_false(&self) -> MastNodeId {
49 self.branches[1]
50 }
51}
52
53impl SplitNode {
57 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
58 SplitNodePrettyPrint { split_node: self, mast_forest }
59 }
60
61 pub(super) fn to_pretty_print<'a>(
62 &'a self,
63 mast_forest: &'a MastForest,
64 ) -> impl PrettyPrint + 'a {
65 SplitNodePrettyPrint { split_node: self, mast_forest }
66 }
67}
68
69struct SplitNodePrettyPrint<'a> {
70 split_node: &'a SplitNode,
71 mast_forest: &'a MastForest,
72}
73
74impl PrettyPrint for SplitNodePrettyPrint<'_> {
75 #[rustfmt::skip]
76 fn render(&self) -> crate::prettier::Document {
77 use crate::prettier::*;
78
79 let true_branch = self.mast_forest[self.split_node.on_true()].to_pretty_print(self.mast_forest);
80 let false_branch = self.mast_forest[self.split_node.on_false()].to_pretty_print(self.mast_forest);
81
82 let mut doc = Document::Empty;
83 doc += indent(4, const_text("if.true") + nl() + true_branch.render()) + nl();
84 doc += indent(4, const_text("else") + nl() + false_branch.render());
85 doc += nl() + const_text("end");
86 doc
87 }
88}
89
90impl fmt::Display for SplitNodePrettyPrint<'_> {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 use crate::prettier::PrettyPrint;
93 self.pretty_print(f)
94 }
95}
96
97impl MastNodeExt for SplitNode {
101 fn digest(&self) -> Word {
113 self.digest
114 }
115
116 fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
117 Box::new(SplitNode::to_display(self, mast_forest))
118 }
119
120 fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
121 Box::new(SplitNode::to_pretty_print(self, mast_forest))
122 }
123
124 fn has_children(&self) -> bool {
125 true
126 }
127
128 fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
129 target.push(self.on_true());
130 target.push(self.on_false());
131 }
132
133 fn for_each_child<F>(&self, mut f: F)
134 where
135 F: FnMut(MastNodeId),
136 {
137 f(self.on_true());
138 f(self.on_false());
139 }
140
141 fn domain(&self) -> Felt {
142 Self::DOMAIN
143 }
144
145 type Builder = SplitNodeBuilder;
146
147 fn to_builder(self, _forest: &MastForest) -> Self::Builder {
148 SplitNodeBuilder::new(self.branches).with_digest(self.digest)
149 }
150}
151
152#[cfg(all(feature = "arbitrary", test))]
156impl proptest::prelude::Arbitrary for SplitNode {
157 type Parameters = ();
158
159 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
160 use proptest::prelude::*;
161
162 use crate::Felt;
163
164 (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
166 .prop_map(|(true_branch, false_branch, digest_array)| {
167 let digest = Word::from(digest_array.map(Felt::new_unchecked));
169 SplitNode {
171 branches: [true_branch, false_branch],
172 digest,
173 }
174 })
175 .no_shrink() .boxed()
177 }
178
179 type Strategy = proptest::prelude::BoxedStrategy<Self>;
180}
181
182#[derive(Debug)]
185pub struct SplitNodeBuilder {
186 branches: [MastNodeId; 2],
187 digest: Option<Word>,
188}
189
190impl SplitNodeBuilder {
191 pub fn new(branches: [MastNodeId; 2]) -> Self {
193 Self { branches, digest: None }
194 }
195
196 pub fn build(self, mast_forest: &MastForest) -> Result<SplitNode, MastForestError> {
198 let forest_len = mast_forest.nodes.len();
199 if self.branches[0].to_usize() >= forest_len {
200 return Err(MastForestError::NodeIdOverflow(self.branches[0], forest_len));
201 } else if self.branches[1].to_usize() >= forest_len {
202 return Err(MastForestError::NodeIdOverflow(self.branches[1], forest_len));
203 }
204
205 let digest = if let Some(forced_digest) = self.digest {
207 forced_digest
208 } else {
209 let true_branch_hash = mast_forest[self.branches[0]].digest();
210 let false_branch_hash = mast_forest[self.branches[1]].digest();
211
212 hasher::merge_in_domain(&[true_branch_hash, false_branch_hash], SplitNode::DOMAIN)
213 };
214
215 Ok(SplitNode { branches: self.branches, digest })
216 }
217
218 pub(in crate::mast) fn build_linked(self) -> Result<SplitNode, MastForestError> {
219 Ok(SplitNode {
220 branches: self.branches,
221 digest: self.digest.ok_or(MastForestError::DigestRequiredForDeserialization)?,
222 })
223 }
224}
225
226impl MastForestContributor for SplitNodeBuilder {
227 fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
228 let forest_len = forest.nodes.len();
230 if self.branches[0].to_usize() >= forest_len {
231 return Err(MastForestError::NodeIdOverflow(self.branches[0], forest_len));
232 } else if self.branches[1].to_usize() >= forest_len {
233 return Err(MastForestError::NodeIdOverflow(self.branches[1], forest_len));
234 }
235
236 let digest = if let Some(forced_digest) = self.digest {
238 forced_digest
239 } else {
240 let true_branch_hash = forest[self.branches[0]].digest();
241 let false_branch_hash = forest[self.branches[1]].digest();
242
243 hasher::merge_in_domain(&[true_branch_hash, false_branch_hash], SplitNode::DOMAIN)
244 };
245
246 let node_id = forest
249 .nodes
250 .push(SplitNode { branches: self.branches, digest }.into())
251 .map_err(|_| MastForestError::TooManyNodes)?;
252
253 Ok(node_id)
254 }
255
256 fn fingerprint_for_node(
257 &self,
258 forest: &MastForest,
259 hash_by_node_id: &impl LookupByIdx<MastNodeId, Word>,
260 ) -> Result<Word, MastForestError> {
261 let node_digest = if let Some(forced_digest) = self.digest {
262 forced_digest
263 } else {
264 let if_branch_hash = forest[self.branches[0]].digest();
265 let else_branch_hash = forest[self.branches[1]].digest();
266
267 hasher::merge_in_domain(&[if_branch_hash, else_branch_hash], SplitNode::DOMAIN)
268 };
269
270 fingerprint_with_child_fingerprints(node_digest, &self.branches, forest, hash_by_node_id)
271 }
272
273 fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
274 SplitNodeBuilder {
275 branches: [
276 *remapping.get(self.branches[0]).unwrap_or(&self.branches[0]),
277 *remapping.get(self.branches[1]).unwrap_or(&self.branches[1]),
278 ],
279 digest: self.digest,
280 }
281 }
282
283 fn with_digest(mut self, digest: Word) -> Self {
284 self.digest = Some(digest);
285 self
286 }
287}
288
289impl SplitNodeBuilder {
290 pub(in crate::mast) fn add_to_forest_relaxed(
300 self,
301 forest: &mut MastForest,
302 ) -> Result<MastNodeId, MastForestError> {
303 let Some(digest) = self.digest else {
306 return Err(MastForestError::DigestRequiredForDeserialization);
307 };
308
309 let node_id = forest
312 .nodes
313 .push(SplitNode { branches: self.branches, digest }.into())
314 .map_err(|_| MastForestError::TooManyNodes)?;
315
316 Ok(node_id)
317 }
318}
319
320#[cfg(any(test, feature = "arbitrary"))]
321impl proptest::prelude::Arbitrary for SplitNodeBuilder {
322 type Parameters = SplitNodeBuilderParams;
323 type Strategy = proptest::strategy::BoxedStrategy<Self>;
324
325 fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
326 use proptest::prelude::*;
327
328 let _ = params;
329 any::<[MastNodeId; 2]>().prop_map(Self::new).boxed()
330 }
331}
332
333#[cfg(any(test, feature = "arbitrary"))]
335#[derive(Clone, Debug, Default)]
336pub struct SplitNodeBuilderParams {}