1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4use miden_crypto::{Felt, Word};
5use miden_formatting::prettier::PrettyPrint;
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9use super::{MastForestContributor, MastNodeErrorContext, MastNodeExt};
10use crate::{
11 Idx, OPCODE_SPLIT,
12 chiplets::hasher,
13 mast::{
14 DecoratedOpLink, DecoratorId, DecoratorStore, MastForest, MastForestError, MastNode,
15 MastNodeId,
16 },
17};
18
19#[derive(Debug, Clone, PartialEq, Eq)]
29#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
30#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
31pub struct SplitNode {
32 branches: [MastNodeId; 2],
33 digest: Word,
34 decorator_store: DecoratorStore,
35}
36
37impl SplitNode {
39 pub const DOMAIN: Felt = Felt::new(OPCODE_SPLIT as u64);
41}
42
43impl SplitNode {
45 pub fn on_true(&self) -> MastNodeId {
47 self.branches[0]
48 }
49
50 pub fn on_false(&self) -> MastNodeId {
52 self.branches[1]
53 }
54}
55
56impl MastNodeErrorContext for SplitNode {
57 fn decorators<'a>(
58 &'a self,
59 forest: &'a MastForest,
60 ) -> impl Iterator<Item = DecoratedOpLink> + 'a {
61 let before_enter = self.decorator_store.before_enter(forest);
63 let after_exit = self.decorator_store.after_exit(forest);
64
65 before_enter
67 .iter()
68 .map(|&deco_id| (0, deco_id))
69 .chain(after_exit.iter().map(|&deco_id| (1, deco_id)))
70 }
71}
72
73impl SplitNode {
77 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
78 SplitNodePrettyPrint { split_node: self, mast_forest }
79 }
80
81 pub(super) fn to_pretty_print<'a>(
82 &'a self,
83 mast_forest: &'a MastForest,
84 ) -> impl PrettyPrint + 'a {
85 SplitNodePrettyPrint { split_node: self, mast_forest }
86 }
87}
88
89struct SplitNodePrettyPrint<'a> {
90 split_node: &'a SplitNode,
91 mast_forest: &'a MastForest,
92}
93
94impl PrettyPrint for SplitNodePrettyPrint<'_> {
95 #[rustfmt::skip]
96 fn render(&self) -> crate::prettier::Document {
97 use crate::prettier::*;
98
99 let pre_decorators = {
100 let mut pre_decorators = self
101 .split_node
102 .before_enter(self.mast_forest)
103 .iter()
104 .map(|&decorator_id| self.mast_forest[decorator_id].render())
105 .reduce(|acc, doc| acc + const_text(" ") + doc)
106 .unwrap_or_default();
107 if !pre_decorators.is_empty() {
108 pre_decorators += nl();
109 }
110
111 pre_decorators
112 };
113
114 let post_decorators = {
115 let mut post_decorators = self
116 .split_node
117 .after_exit(self.mast_forest)
118 .iter()
119 .map(|&decorator_id| self.mast_forest[decorator_id].render())
120 .reduce(|acc, doc| acc + const_text(" ") + doc)
121 .unwrap_or_default();
122 if !post_decorators.is_empty() {
123 post_decorators = nl() + post_decorators;
124 }
125
126 post_decorators
127 };
128
129 let true_branch = self.mast_forest[self.split_node.on_true()].to_pretty_print(self.mast_forest);
130 let false_branch = self.mast_forest[self.split_node.on_false()].to_pretty_print(self.mast_forest);
131
132 let mut doc = pre_decorators;
133 doc += indent(4, const_text("if.true") + nl() + true_branch.render()) + nl();
134 doc += indent(4, const_text("else") + nl() + false_branch.render());
135 doc += nl() + const_text("end");
136 doc + post_decorators
137 }
138}
139
140impl fmt::Display for SplitNodePrettyPrint<'_> {
141 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142 use crate::prettier::PrettyPrint;
143 self.pretty_print(f)
144 }
145}
146
147impl MastNodeExt for SplitNode {
151 fn digest(&self) -> Word {
163 self.digest
164 }
165
166 fn before_enter<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
168 #[cfg(debug_assertions)]
169 self.verify_node_in_forest(forest);
170 self.decorator_store.before_enter(forest)
171 }
172
173 fn after_exit<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
175 #[cfg(debug_assertions)]
176 self.verify_node_in_forest(forest);
177 self.decorator_store.after_exit(forest)
178 }
179
180 fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
181 Box::new(SplitNode::to_display(self, mast_forest))
182 }
183
184 fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
185 Box::new(SplitNode::to_pretty_print(self, mast_forest))
186 }
187
188 fn has_children(&self) -> bool {
189 true
190 }
191
192 fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
193 target.push(self.on_true());
194 target.push(self.on_false());
195 }
196
197 fn for_each_child<F>(&self, mut f: F)
198 where
199 F: FnMut(MastNodeId),
200 {
201 f(self.on_true());
202 f(self.on_false());
203 }
204
205 fn domain(&self) -> Felt {
206 Self::DOMAIN
207 }
208
209 type Builder = SplitNodeBuilder;
210
211 fn to_builder(self, forest: &MastForest) -> Self::Builder {
212 match self.decorator_store {
214 DecoratorStore::Owned { before_enter, after_exit, .. } => {
215 let mut builder = SplitNodeBuilder::new(self.branches);
216 builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
217 builder
218 },
219 DecoratorStore::Linked { id } => {
220 let before_enter = forest.before_enter_decorators(id).to_vec();
222 let after_exit = forest.after_exit_decorators(id).to_vec();
223 let mut builder = SplitNodeBuilder::new(self.branches);
224 builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
225 builder
226 },
227 }
228 }
229
230 #[cfg(debug_assertions)]
231 fn verify_node_in_forest(&self, forest: &MastForest) {
232 if let Some(id) = self.decorator_store.linked_id() {
233 let self_ptr = self as *const Self;
235 let forest_node = &forest.nodes[id];
236 let forest_node_ptr = match forest_node {
237 MastNode::Split(split_node) => split_node as *const SplitNode as *const (),
238 _ => panic!("Node type mismatch at {:?}", id),
239 };
240 let self_as_void = self_ptr as *const ();
241 debug_assert_eq!(
242 self_as_void, forest_node_ptr,
243 "Node pointer mismatch: expected node at {:?} to be self",
244 id
245 );
246 }
247 }
248}
249
250#[cfg(all(feature = "arbitrary", test))]
254impl proptest::prelude::Arbitrary for SplitNode {
255 type Parameters = ();
256
257 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
258 use proptest::prelude::*;
259
260 use crate::Felt;
261
262 (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
264 .prop_map(|(true_branch, false_branch, digest_array)| {
265 let digest = Word::from(digest_array.map(Felt::new));
267 SplitNode {
269 branches: [true_branch, false_branch],
270 digest,
271 decorator_store: DecoratorStore::default(),
272 }
273 })
274 .no_shrink() .boxed()
276 }
277
278 type Strategy = proptest::prelude::BoxedStrategy<Self>;
279}
280
281#[derive(Debug)]
284pub struct SplitNodeBuilder {
285 branches: [MastNodeId; 2],
286 before_enter: Vec<DecoratorId>,
287 after_exit: Vec<DecoratorId>,
288 digest: Option<Word>,
289}
290
291impl SplitNodeBuilder {
292 pub fn new(branches: [MastNodeId; 2]) -> Self {
294 Self {
295 branches,
296 before_enter: Vec::new(),
297 after_exit: Vec::new(),
298 digest: None,
299 }
300 }
301
302 pub fn build(self, mast_forest: &MastForest) -> Result<SplitNode, MastForestError> {
304 let forest_len = mast_forest.nodes.len();
305 if self.branches[0].to_usize() >= forest_len {
306 return Err(MastForestError::NodeIdOverflow(self.branches[0], forest_len));
307 } else if self.branches[1].to_usize() >= forest_len {
308 return Err(MastForestError::NodeIdOverflow(self.branches[1], forest_len));
309 }
310
311 let digest = if let Some(forced_digest) = self.digest {
313 forced_digest
314 } else {
315 let true_branch_hash = mast_forest[self.branches[0]].digest();
316 let false_branch_hash = mast_forest[self.branches[1]].digest();
317
318 hasher::merge_in_domain(&[true_branch_hash, false_branch_hash], SplitNode::DOMAIN)
319 };
320
321 Ok(SplitNode {
322 branches: self.branches,
323 digest,
324 decorator_store: DecoratorStore::new_owned_with_decorators(
325 self.before_enter,
326 self.after_exit,
327 ),
328 })
329 }
330}
331
332impl MastForestContributor for SplitNodeBuilder {
333 fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
334 let forest_len = forest.nodes.len();
336 if self.branches[0].to_usize() >= forest_len {
337 return Err(MastForestError::NodeIdOverflow(self.branches[0], forest_len));
338 } else if self.branches[1].to_usize() >= forest_len {
339 return Err(MastForestError::NodeIdOverflow(self.branches[1], forest_len));
340 }
341
342 let digest = if let Some(forced_digest) = self.digest {
344 forced_digest
345 } else {
346 let true_branch_hash = forest[self.branches[0]].digest();
347 let false_branch_hash = forest[self.branches[1]].digest();
348
349 hasher::merge_in_domain(&[true_branch_hash, false_branch_hash], SplitNode::DOMAIN)
350 };
351
352 let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
354
355 forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
357
358 let node_id = forest
361 .nodes
362 .push(
363 SplitNode {
364 branches: self.branches,
365 digest,
366 decorator_store: DecoratorStore::Linked { id: future_node_id },
367 }
368 .into(),
369 )
370 .map_err(|_| MastForestError::TooManyNodes)?;
371
372 Ok(node_id)
373 }
374
375 fn fingerprint_for_node(
376 &self,
377 forest: &MastForest,
378 hash_by_node_id: &impl crate::LookupByIdx<MastNodeId, crate::mast::MastNodeFingerprint>,
379 ) -> Result<crate::mast::MastNodeFingerprint, MastForestError> {
380 crate::mast::node_fingerprint::fingerprint_from_parts(
382 forest,
383 hash_by_node_id,
384 &self.before_enter,
385 &self.after_exit,
386 &self.branches,
387 if let Some(forced_digest) = self.digest {
389 forced_digest
390 } else {
391 let if_branch_hash = forest[self.branches[0]].digest();
392 let else_branch_hash = forest[self.branches[1]].digest();
393
394 crate::chiplets::hasher::merge_in_domain(
395 &[if_branch_hash, else_branch_hash],
396 SplitNode::DOMAIN,
397 )
398 },
399 )
400 }
401
402 fn remap_children(
403 self,
404 remapping: &impl crate::LookupByIdx<crate::mast::MastNodeId, crate::mast::MastNodeId>,
405 ) -> Self {
406 SplitNodeBuilder {
407 branches: [
408 *remapping.get(self.branches[0]).unwrap_or(&self.branches[0]),
409 *remapping.get(self.branches[1]).unwrap_or(&self.branches[1]),
410 ],
411 before_enter: self.before_enter,
412 after_exit: self.after_exit,
413 digest: self.digest,
414 }
415 }
416
417 fn with_before_enter(mut self, decorators: impl Into<Vec<crate::mast::DecoratorId>>) -> Self {
418 self.before_enter = decorators.into();
419 self
420 }
421
422 fn with_after_exit(mut self, decorators: impl Into<Vec<crate::mast::DecoratorId>>) -> Self {
423 self.after_exit = decorators.into();
424 self
425 }
426
427 fn append_before_enter(
428 &mut self,
429 decorators: impl IntoIterator<Item = crate::mast::DecoratorId>,
430 ) {
431 self.before_enter.extend(decorators);
432 }
433
434 fn append_after_exit(
435 &mut self,
436 decorators: impl IntoIterator<Item = crate::mast::DecoratorId>,
437 ) {
438 self.after_exit.extend(decorators);
439 }
440
441 fn with_digest(mut self, digest: crate::Word) -> Self {
442 self.digest = Some(digest);
443 self
444 }
445}
446
447impl SplitNodeBuilder {
448 pub(in crate::mast) fn add_to_forest_relaxed(
458 self,
459 forest: &mut MastForest,
460 ) -> Result<MastNodeId, MastForestError> {
461 let Some(digest) = self.digest else {
464 return Err(MastForestError::DigestRequiredForDeserialization);
465 };
466
467 let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
468
469 forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
471
472 let node_id = forest
475 .nodes
476 .push(
477 SplitNode {
478 branches: self.branches,
479 digest,
480 decorator_store: DecoratorStore::Linked { id: future_node_id },
481 }
482 .into(),
483 )
484 .map_err(|_| MastForestError::TooManyNodes)?;
485
486 Ok(node_id)
487 }
488}
489
490#[cfg(any(test, feature = "arbitrary"))]
491impl proptest::prelude::Arbitrary for SplitNodeBuilder {
492 type Parameters = SplitNodeBuilderParams;
493 type Strategy = proptest::strategy::BoxedStrategy<Self>;
494
495 fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
496 use proptest::prelude::*;
497
498 (
499 any::<[crate::mast::MastNodeId; 2]>(),
500 proptest::collection::vec(
501 super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
502 0..=params.max_decorators,
503 ),
504 proptest::collection::vec(
505 super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
506 0..=params.max_decorators,
507 ),
508 )
509 .prop_map(|(branches, before_enter, after_exit)| {
510 Self::new(branches).with_before_enter(before_enter).with_after_exit(after_exit)
511 })
512 .boxed()
513 }
514}
515
516#[cfg(any(test, feature = "arbitrary"))]
518#[derive(Clone, Debug)]
519pub struct SplitNodeBuilderParams {
520 pub max_decorators: usize,
521 pub max_decorator_id_u32: u32,
522}
523
524#[cfg(any(test, feature = "arbitrary"))]
525impl Default for SplitNodeBuilderParams {
526 fn default() -> Self {
527 Self {
528 max_decorators: 4,
529 max_decorator_id_u32: 10,
530 }
531 }
532}