1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use super::{MastForestContributor, MastNodeExt};
8use crate::{
9 Felt, Word,
10 chiplets::hasher,
11 mast::{
12 DecoratorId, DecoratorStore, MastForest, MastForestError, MastNode, MastNodeFingerprint,
13 MastNodeId,
14 },
15 operations::OPCODE_SPLIT,
16 prettier::PrettyPrint,
17 utils::{Idx, LookupByIdx},
18};
19
20#[derive(Debug, Clone, PartialEq, Eq)]
30#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
31#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
32pub struct SplitNode {
33 branches: [MastNodeId; 2],
34 digest: Word,
35 decorator_store: DecoratorStore,
36}
37
38impl SplitNode {
40 pub const DOMAIN: Felt = Felt::new(OPCODE_SPLIT as u64);
42}
43
44impl SplitNode {
46 pub fn on_true(&self) -> MastNodeId {
48 self.branches[0]
49 }
50
51 pub fn on_false(&self) -> MastNodeId {
53 self.branches[1]
54 }
55}
56
57impl SplitNode {
61 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
62 SplitNodePrettyPrint { split_node: self, mast_forest }
63 }
64
65 pub(super) fn to_pretty_print<'a>(
66 &'a self,
67 mast_forest: &'a MastForest,
68 ) -> impl PrettyPrint + 'a {
69 SplitNodePrettyPrint { split_node: self, mast_forest }
70 }
71}
72
73struct SplitNodePrettyPrint<'a> {
74 split_node: &'a SplitNode,
75 mast_forest: &'a MastForest,
76}
77
78impl PrettyPrint for SplitNodePrettyPrint<'_> {
79 #[rustfmt::skip]
80 fn render(&self) -> crate::prettier::Document {
81 use crate::prettier::*;
82
83 let pre_decorators = {
84 let mut pre_decorators = self
85 .split_node
86 .before_enter(self.mast_forest)
87 .iter()
88 .map(|&decorator_id| self.mast_forest[decorator_id].render())
89 .reduce(|acc, doc| acc + const_text(" ") + doc)
90 .unwrap_or_default();
91 if !pre_decorators.is_empty() {
92 pre_decorators += nl();
93 }
94
95 pre_decorators
96 };
97
98 let post_decorators = {
99 let mut post_decorators = self
100 .split_node
101 .after_exit(self.mast_forest)
102 .iter()
103 .map(|&decorator_id| self.mast_forest[decorator_id].render())
104 .reduce(|acc, doc| acc + const_text(" ") + doc)
105 .unwrap_or_default();
106 if !post_decorators.is_empty() {
107 post_decorators = nl() + post_decorators;
108 }
109
110 post_decorators
111 };
112
113 let true_branch = self.mast_forest[self.split_node.on_true()].to_pretty_print(self.mast_forest);
114 let false_branch = self.mast_forest[self.split_node.on_false()].to_pretty_print(self.mast_forest);
115
116 let mut doc = pre_decorators;
117 doc += indent(4, const_text("if.true") + nl() + true_branch.render()) + nl();
118 doc += indent(4, const_text("else") + nl() + false_branch.render());
119 doc += nl() + const_text("end");
120 doc + post_decorators
121 }
122}
123
124impl fmt::Display for SplitNodePrettyPrint<'_> {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 use crate::prettier::PrettyPrint;
127 self.pretty_print(f)
128 }
129}
130
131impl MastNodeExt for SplitNode {
135 fn digest(&self) -> Word {
147 self.digest
148 }
149
150 fn before_enter<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
152 #[cfg(debug_assertions)]
153 self.verify_node_in_forest(forest);
154 self.decorator_store.before_enter(forest)
155 }
156
157 fn after_exit<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
159 #[cfg(debug_assertions)]
160 self.verify_node_in_forest(forest);
161 self.decorator_store.after_exit(forest)
162 }
163
164 fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
165 Box::new(SplitNode::to_display(self, mast_forest))
166 }
167
168 fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
169 Box::new(SplitNode::to_pretty_print(self, mast_forest))
170 }
171
172 fn has_children(&self) -> bool {
173 true
174 }
175
176 fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
177 target.push(self.on_true());
178 target.push(self.on_false());
179 }
180
181 fn for_each_child<F>(&self, mut f: F)
182 where
183 F: FnMut(MastNodeId),
184 {
185 f(self.on_true());
186 f(self.on_false());
187 }
188
189 fn domain(&self) -> Felt {
190 Self::DOMAIN
191 }
192
193 type Builder = SplitNodeBuilder;
194
195 fn to_builder(self, forest: &MastForest) -> Self::Builder {
196 match self.decorator_store {
198 DecoratorStore::Owned { before_enter, after_exit, .. } => {
199 let mut builder = SplitNodeBuilder::new(self.branches);
200 builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
201 builder
202 },
203 DecoratorStore::Linked { id } => {
204 let before_enter = forest.before_enter_decorators(id).to_vec();
206 let after_exit = forest.after_exit_decorators(id).to_vec();
207 let mut builder = SplitNodeBuilder::new(self.branches);
208 builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
209 builder
210 },
211 }
212 }
213
214 #[cfg(debug_assertions)]
215 fn verify_node_in_forest(&self, forest: &MastForest) {
216 if let Some(id) = self.decorator_store.linked_id() {
217 let self_ptr = self as *const Self;
219 let forest_node = &forest.nodes[id];
220 let forest_node_ptr = match forest_node {
221 MastNode::Split(split_node) => split_node as *const SplitNode as *const (),
222 _ => panic!("Node type mismatch at {:?}", id),
223 };
224 let self_as_void = self_ptr as *const ();
225 debug_assert_eq!(
226 self_as_void, forest_node_ptr,
227 "Node pointer mismatch: expected node at {:?} to be self",
228 id
229 );
230 }
231 }
232}
233
234#[cfg(all(feature = "arbitrary", test))]
238impl proptest::prelude::Arbitrary for SplitNode {
239 type Parameters = ();
240
241 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
242 use proptest::prelude::*;
243
244 use crate::Felt;
245
246 (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
248 .prop_map(|(true_branch, false_branch, digest_array)| {
249 let digest = Word::from(digest_array.map(Felt::new));
251 SplitNode {
253 branches: [true_branch, false_branch],
254 digest,
255 decorator_store: DecoratorStore::default(),
256 }
257 })
258 .no_shrink() .boxed()
260 }
261
262 type Strategy = proptest::prelude::BoxedStrategy<Self>;
263}
264
265#[derive(Debug)]
268pub struct SplitNodeBuilder {
269 branches: [MastNodeId; 2],
270 before_enter: Vec<DecoratorId>,
271 after_exit: Vec<DecoratorId>,
272 digest: Option<Word>,
273}
274
275impl SplitNodeBuilder {
276 pub fn new(branches: [MastNodeId; 2]) -> Self {
278 Self {
279 branches,
280 before_enter: Vec::new(),
281 after_exit: Vec::new(),
282 digest: None,
283 }
284 }
285
286 pub fn build(self, mast_forest: &MastForest) -> Result<SplitNode, MastForestError> {
288 let forest_len = mast_forest.nodes.len();
289 if self.branches[0].to_usize() >= forest_len {
290 return Err(MastForestError::NodeIdOverflow(self.branches[0], forest_len));
291 } else if self.branches[1].to_usize() >= forest_len {
292 return Err(MastForestError::NodeIdOverflow(self.branches[1], forest_len));
293 }
294
295 let digest = if let Some(forced_digest) = self.digest {
297 forced_digest
298 } else {
299 let true_branch_hash = mast_forest[self.branches[0]].digest();
300 let false_branch_hash = mast_forest[self.branches[1]].digest();
301
302 hasher::merge_in_domain(&[true_branch_hash, false_branch_hash], SplitNode::DOMAIN)
303 };
304
305 Ok(SplitNode {
306 branches: self.branches,
307 digest,
308 decorator_store: DecoratorStore::new_owned_with_decorators(
309 self.before_enter,
310 self.after_exit,
311 ),
312 })
313 }
314}
315
316impl MastForestContributor for SplitNodeBuilder {
317 fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
318 let forest_len = forest.nodes.len();
320 if self.branches[0].to_usize() >= forest_len {
321 return Err(MastForestError::NodeIdOverflow(self.branches[0], forest_len));
322 } else if self.branches[1].to_usize() >= forest_len {
323 return Err(MastForestError::NodeIdOverflow(self.branches[1], forest_len));
324 }
325
326 let digest = if let Some(forced_digest) = self.digest {
328 forced_digest
329 } else {
330 let true_branch_hash = forest[self.branches[0]].digest();
331 let false_branch_hash = forest[self.branches[1]].digest();
332
333 hasher::merge_in_domain(&[true_branch_hash, false_branch_hash], SplitNode::DOMAIN)
334 };
335
336 let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
338
339 forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
341
342 let node_id = forest
345 .nodes
346 .push(
347 SplitNode {
348 branches: self.branches,
349 digest,
350 decorator_store: DecoratorStore::Linked { id: future_node_id },
351 }
352 .into(),
353 )
354 .map_err(|_| MastForestError::TooManyNodes)?;
355
356 Ok(node_id)
357 }
358
359 fn fingerprint_for_node(
360 &self,
361 forest: &MastForest,
362 hash_by_node_id: &impl LookupByIdx<MastNodeId, MastNodeFingerprint>,
363 ) -> Result<MastNodeFingerprint, MastForestError> {
364 crate::mast::node_fingerprint::fingerprint_from_parts(
366 forest,
367 hash_by_node_id,
368 &self.before_enter,
369 &self.after_exit,
370 &self.branches,
371 if let Some(forced_digest) = self.digest {
373 forced_digest
374 } else {
375 let if_branch_hash = forest[self.branches[0]].digest();
376 let else_branch_hash = forest[self.branches[1]].digest();
377
378 crate::chiplets::hasher::merge_in_domain(
379 &[if_branch_hash, else_branch_hash],
380 SplitNode::DOMAIN,
381 )
382 },
383 )
384 }
385
386 fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
387 SplitNodeBuilder {
388 branches: [
389 *remapping.get(self.branches[0]).unwrap_or(&self.branches[0]),
390 *remapping.get(self.branches[1]).unwrap_or(&self.branches[1]),
391 ],
392 before_enter: self.before_enter,
393 after_exit: self.after_exit,
394 digest: self.digest,
395 }
396 }
397
398 fn with_before_enter(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
399 self.before_enter = decorators.into();
400 self
401 }
402
403 fn with_after_exit(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
404 self.after_exit = decorators.into();
405 self
406 }
407
408 fn append_before_enter(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
409 self.before_enter.extend(decorators);
410 }
411
412 fn append_after_exit(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
413 self.after_exit.extend(decorators);
414 }
415
416 fn with_digest(mut self, digest: crate::Word) -> Self {
417 self.digest = Some(digest);
418 self
419 }
420}
421
422impl SplitNodeBuilder {
423 pub(in crate::mast) fn add_to_forest_relaxed(
433 self,
434 forest: &mut MastForest,
435 ) -> Result<MastNodeId, MastForestError> {
436 let Some(digest) = self.digest else {
439 return Err(MastForestError::DigestRequiredForDeserialization);
440 };
441
442 let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
443
444 let node_id = forest
447 .nodes
448 .push(
449 SplitNode {
450 branches: self.branches,
451 digest,
452 decorator_store: DecoratorStore::Linked { id: future_node_id },
453 }
454 .into(),
455 )
456 .map_err(|_| MastForestError::TooManyNodes)?;
457
458 Ok(node_id)
459 }
460}
461
462#[cfg(any(test, feature = "arbitrary"))]
463impl proptest::prelude::Arbitrary for SplitNodeBuilder {
464 type Parameters = SplitNodeBuilderParams;
465 type Strategy = proptest::strategy::BoxedStrategy<Self>;
466
467 fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
468 use proptest::prelude::*;
469
470 (
471 any::<[MastNodeId; 2]>(),
472 proptest::collection::vec(
473 super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
474 0..=params.max_decorators,
475 ),
476 proptest::collection::vec(
477 super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
478 0..=params.max_decorators,
479 ),
480 )
481 .prop_map(|(branches, before_enter, after_exit)| {
482 Self::new(branches).with_before_enter(before_enter).with_after_exit(after_exit)
483 })
484 .boxed()
485 }
486}
487
488#[cfg(any(test, feature = "arbitrary"))]
490#[derive(Clone, Debug)]
491pub struct SplitNodeBuilderParams {
492 pub max_decorators: usize,
493 pub max_decorator_id_u32: u32,
494}
495
496#[cfg(any(test, feature = "arbitrary"))]
497impl Default for SplitNodeBuilderParams {
498 fn default() -> Self {
499 Self {
500 max_decorators: 4,
501 max_decorator_id_u32: 10,
502 }
503 }
504}