1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use super::{MastForestContributor, MastNodeExt};
8#[cfg(debug_assertions)]
9use crate::mast::MastNode;
10use crate::{
11 Felt, Word,
12 chiplets::hasher,
13 mast::{
14 DecoratorId, DecoratorStore, MastForest, MastForestError, MastNodeFingerprint, MastNodeId,
15 },
16 operations::opcodes,
17 prettier::PrettyPrint,
18 utils::{Idx, LookupByIdx},
19};
20
21#[derive(Debug, Clone, PartialEq, Eq)]
31#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
32#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
33pub struct SplitNode {
34 branches: [MastNodeId; 2],
35 digest: Word,
36 decorator_store: DecoratorStore,
37}
38
39impl SplitNode {
41 pub const DOMAIN: Felt = Felt::new(opcodes::SPLIT as u64);
43}
44
45impl SplitNode {
47 pub fn on_true(&self) -> MastNodeId {
49 self.branches[0]
50 }
51
52 pub fn on_false(&self) -> MastNodeId {
54 self.branches[1]
55 }
56}
57
58impl SplitNode {
62 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
63 SplitNodePrettyPrint { split_node: self, mast_forest }
64 }
65
66 pub(super) fn to_pretty_print<'a>(
67 &'a self,
68 mast_forest: &'a MastForest,
69 ) -> impl PrettyPrint + 'a {
70 SplitNodePrettyPrint { split_node: self, mast_forest }
71 }
72}
73
74struct SplitNodePrettyPrint<'a> {
75 split_node: &'a SplitNode,
76 mast_forest: &'a MastForest,
77}
78
79impl PrettyPrint for SplitNodePrettyPrint<'_> {
80 #[rustfmt::skip]
81 fn render(&self) -> crate::prettier::Document {
82 use crate::prettier::*;
83
84 let pre_decorators = {
85 let mut pre_decorators = self
86 .split_node
87 .before_enter(self.mast_forest)
88 .iter()
89 .map(|&decorator_id| self.mast_forest[decorator_id].render())
90 .reduce(|acc, doc| acc + const_text(" ") + doc)
91 .unwrap_or_default();
92 if !pre_decorators.is_empty() {
93 pre_decorators += nl();
94 }
95
96 pre_decorators
97 };
98
99 let post_decorators = {
100 let mut post_decorators = self
101 .split_node
102 .after_exit(self.mast_forest)
103 .iter()
104 .map(|&decorator_id| self.mast_forest[decorator_id].render())
105 .reduce(|acc, doc| acc + const_text(" ") + doc)
106 .unwrap_or_default();
107 if !post_decorators.is_empty() {
108 post_decorators = nl() + post_decorators;
109 }
110
111 post_decorators
112 };
113
114 let true_branch = self.mast_forest[self.split_node.on_true()].to_pretty_print(self.mast_forest);
115 let false_branch = self.mast_forest[self.split_node.on_false()].to_pretty_print(self.mast_forest);
116
117 let mut doc = pre_decorators;
118 doc += indent(4, const_text("if.true") + nl() + true_branch.render()) + nl();
119 doc += indent(4, const_text("else") + nl() + false_branch.render());
120 doc += nl() + const_text("end");
121 doc + post_decorators
122 }
123}
124
125impl fmt::Display for SplitNodePrettyPrint<'_> {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 use crate::prettier::PrettyPrint;
128 self.pretty_print(f)
129 }
130}
131
132impl MastNodeExt for SplitNode {
136 fn digest(&self) -> Word {
148 self.digest
149 }
150
151 fn before_enter<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
153 #[cfg(debug_assertions)]
154 self.verify_node_in_forest(forest);
155 self.decorator_store.before_enter(forest)
156 }
157
158 fn after_exit<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
160 #[cfg(debug_assertions)]
161 self.verify_node_in_forest(forest);
162 self.decorator_store.after_exit(forest)
163 }
164
165 fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
166 Box::new(SplitNode::to_display(self, mast_forest))
167 }
168
169 fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
170 Box::new(SplitNode::to_pretty_print(self, mast_forest))
171 }
172
173 fn has_children(&self) -> bool {
174 true
175 }
176
177 fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
178 target.push(self.on_true());
179 target.push(self.on_false());
180 }
181
182 fn for_each_child<F>(&self, mut f: F)
183 where
184 F: FnMut(MastNodeId),
185 {
186 f(self.on_true());
187 f(self.on_false());
188 }
189
190 fn domain(&self) -> Felt {
191 Self::DOMAIN
192 }
193
194 type Builder = SplitNodeBuilder;
195
196 fn to_builder(self, forest: &MastForest) -> Self::Builder {
197 match self.decorator_store {
199 DecoratorStore::Owned { before_enter, after_exit, .. } => {
200 let mut builder = SplitNodeBuilder::new(self.branches);
201 builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
202 builder
203 },
204 DecoratorStore::Linked { id } => {
205 let before_enter = forest.before_enter_decorators(id).to_vec();
207 let after_exit = forest.after_exit_decorators(id).to_vec();
208 let mut builder = SplitNodeBuilder::new(self.branches);
209 builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
210 builder
211 },
212 }
213 }
214
215 #[cfg(debug_assertions)]
216 fn verify_node_in_forest(&self, forest: &MastForest) {
217 if let Some(id) = self.decorator_store.linked_id() {
218 let self_ptr = self as *const Self;
220 let forest_node = &forest.nodes[id];
221 let forest_node_ptr = match forest_node {
222 MastNode::Split(split_node) => split_node as *const SplitNode as *const (),
223 _ => panic!("Node type mismatch at {:?}", id),
224 };
225 let self_as_void = self_ptr as *const ();
226 debug_assert_eq!(
227 self_as_void, forest_node_ptr,
228 "Node pointer mismatch: expected node at {:?} to be self",
229 id
230 );
231 }
232 }
233}
234
235#[cfg(all(feature = "arbitrary", test))]
239impl proptest::prelude::Arbitrary for SplitNode {
240 type Parameters = ();
241
242 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
243 use proptest::prelude::*;
244
245 use crate::Felt;
246
247 (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
249 .prop_map(|(true_branch, false_branch, digest_array)| {
250 let digest = Word::from(digest_array.map(Felt::new));
252 SplitNode {
254 branches: [true_branch, false_branch],
255 digest,
256 decorator_store: DecoratorStore::default(),
257 }
258 })
259 .no_shrink() .boxed()
261 }
262
263 type Strategy = proptest::prelude::BoxedStrategy<Self>;
264}
265
266#[derive(Debug)]
269pub struct SplitNodeBuilder {
270 branches: [MastNodeId; 2],
271 before_enter: Vec<DecoratorId>,
272 after_exit: Vec<DecoratorId>,
273 digest: Option<Word>,
274}
275
276impl SplitNodeBuilder {
277 pub fn new(branches: [MastNodeId; 2]) -> Self {
279 Self {
280 branches,
281 before_enter: Vec::new(),
282 after_exit: Vec::new(),
283 digest: None,
284 }
285 }
286
287 pub fn build(self, mast_forest: &MastForest) -> Result<SplitNode, MastForestError> {
289 let forest_len = mast_forest.nodes.len();
290 if self.branches[0].to_usize() >= forest_len {
291 return Err(MastForestError::NodeIdOverflow(self.branches[0], forest_len));
292 } else if self.branches[1].to_usize() >= forest_len {
293 return Err(MastForestError::NodeIdOverflow(self.branches[1], forest_len));
294 }
295
296 let digest = if let Some(forced_digest) = self.digest {
298 forced_digest
299 } else {
300 let true_branch_hash = mast_forest[self.branches[0]].digest();
301 let false_branch_hash = mast_forest[self.branches[1]].digest();
302
303 hasher::merge_in_domain(&[true_branch_hash, false_branch_hash], SplitNode::DOMAIN)
304 };
305
306 Ok(SplitNode {
307 branches: self.branches,
308 digest,
309 decorator_store: DecoratorStore::new_owned_with_decorators(
310 self.before_enter,
311 self.after_exit,
312 ),
313 })
314 }
315}
316
317impl MastForestContributor for SplitNodeBuilder {
318 fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
319 let forest_len = forest.nodes.len();
321 if self.branches[0].to_usize() >= forest_len {
322 return Err(MastForestError::NodeIdOverflow(self.branches[0], forest_len));
323 } else if self.branches[1].to_usize() >= forest_len {
324 return Err(MastForestError::NodeIdOverflow(self.branches[1], forest_len));
325 }
326
327 let digest = if let Some(forced_digest) = self.digest {
329 forced_digest
330 } else {
331 let true_branch_hash = forest[self.branches[0]].digest();
332 let false_branch_hash = forest[self.branches[1]].digest();
333
334 hasher::merge_in_domain(&[true_branch_hash, false_branch_hash], SplitNode::DOMAIN)
335 };
336
337 let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
339
340 forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
342
343 let node_id = forest
346 .nodes
347 .push(
348 SplitNode {
349 branches: self.branches,
350 digest,
351 decorator_store: DecoratorStore::Linked { id: future_node_id },
352 }
353 .into(),
354 )
355 .map_err(|_| MastForestError::TooManyNodes)?;
356
357 Ok(node_id)
358 }
359
360 fn fingerprint_for_node(
361 &self,
362 forest: &MastForest,
363 hash_by_node_id: &impl LookupByIdx<MastNodeId, MastNodeFingerprint>,
364 ) -> Result<MastNodeFingerprint, MastForestError> {
365 crate::mast::node_fingerprint::fingerprint_from_parts(
367 forest,
368 hash_by_node_id,
369 &self.before_enter,
370 &self.after_exit,
371 &self.branches,
372 if let Some(forced_digest) = self.digest {
374 forced_digest
375 } else {
376 let if_branch_hash = forest[self.branches[0]].digest();
377 let else_branch_hash = forest[self.branches[1]].digest();
378
379 crate::chiplets::hasher::merge_in_domain(
380 &[if_branch_hash, else_branch_hash],
381 SplitNode::DOMAIN,
382 )
383 },
384 )
385 }
386
387 fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
388 SplitNodeBuilder {
389 branches: [
390 *remapping.get(self.branches[0]).unwrap_or(&self.branches[0]),
391 *remapping.get(self.branches[1]).unwrap_or(&self.branches[1]),
392 ],
393 before_enter: self.before_enter,
394 after_exit: self.after_exit,
395 digest: self.digest,
396 }
397 }
398
399 fn with_before_enter(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
400 self.before_enter = decorators.into();
401 self
402 }
403
404 fn with_after_exit(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
405 self.after_exit = decorators.into();
406 self
407 }
408
409 fn append_before_enter(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
410 self.before_enter.extend(decorators);
411 }
412
413 fn append_after_exit(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
414 self.after_exit.extend(decorators);
415 }
416
417 fn with_digest(mut self, digest: crate::Word) -> Self {
418 self.digest = Some(digest);
419 self
420 }
421}
422
423impl SplitNodeBuilder {
424 pub(in crate::mast) fn add_to_forest_relaxed(
434 self,
435 forest: &mut MastForest,
436 ) -> Result<MastNodeId, MastForestError> {
437 let Some(digest) = self.digest else {
440 return Err(MastForestError::DigestRequiredForDeserialization);
441 };
442
443 let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
444
445 let node_id = forest
448 .nodes
449 .push(
450 SplitNode {
451 branches: self.branches,
452 digest,
453 decorator_store: DecoratorStore::Linked { id: future_node_id },
454 }
455 .into(),
456 )
457 .map_err(|_| MastForestError::TooManyNodes)?;
458
459 Ok(node_id)
460 }
461}
462
463#[cfg(any(test, feature = "arbitrary"))]
464impl proptest::prelude::Arbitrary for SplitNodeBuilder {
465 type Parameters = SplitNodeBuilderParams;
466 type Strategy = proptest::strategy::BoxedStrategy<Self>;
467
468 fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
469 use proptest::prelude::*;
470
471 (
472 any::<[MastNodeId; 2]>(),
473 proptest::collection::vec(
474 super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
475 0..=params.max_decorators,
476 ),
477 proptest::collection::vec(
478 super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
479 0..=params.max_decorators,
480 ),
481 )
482 .prop_map(|(branches, before_enter, after_exit)| {
483 Self::new(branches).with_before_enter(before_enter).with_after_exit(after_exit)
484 })
485 .boxed()
486 }
487}
488
489#[cfg(any(test, feature = "arbitrary"))]
491#[derive(Clone, Debug)]
492pub struct SplitNodeBuilderParams {
493 pub max_decorators: usize,
494 pub max_decorator_id_u32: u32,
495}
496
497#[cfg(any(test, feature = "arbitrary"))]
498impl Default for SplitNodeBuilderParams {
499 fn default() -> Self {
500 Self {
501 max_decorators: 4,
502 max_decorator_id_u32: 10,
503 }
504 }
505}