Skip to main content

miden_core/mast/node/
split_node.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use super::{MastForestContributor, MastNodeExt};
8#[cfg(debug_assertions)]
9use crate::mast::MastNode;
10use crate::{
11    Felt, Word,
12    chiplets::hasher,
13    mast::{
14        DecoratorId, DecoratorStore, MastForest, MastForestError, MastNodeFingerprint, MastNodeId,
15    },
16    operations::opcodes,
17    prettier::PrettyPrint,
18    utils::{Idx, LookupByIdx},
19};
20
21// SPLIT NODE
22// ================================================================================================
23
24/// A Split node defines conditional execution. When the VM encounters a Split node it executes
25/// either the `on_true` child or `on_false` child.
26///
27/// Which child is executed is determined based on the top of the stack. If the value is `1`, then
28/// the `on_true` child is executed. If the value is `0`, then the `on_false` child is executed. If
29/// the value is neither `0` nor `1`, the execution fails.
30#[derive(Debug, Clone, PartialEq, Eq)]
31#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
32#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
33pub struct SplitNode {
34    branches: [MastNodeId; 2],
35    digest: Word,
36    decorator_store: DecoratorStore,
37}
38
39/// Constants
40impl SplitNode {
41    /// The domain of the split node (used for control block hashing).
42    pub const DOMAIN: Felt = Felt::new(opcodes::SPLIT as u64);
43}
44
45/// Public accessors
46impl SplitNode {
47    /// Returns the ID of the node which is to be executed if the top of the stack is `1`.
48    pub fn on_true(&self) -> MastNodeId {
49        self.branches[0]
50    }
51
52    /// Returns the ID of the node which is to be executed if the top of the stack is `0`.
53    pub fn on_false(&self) -> MastNodeId {
54        self.branches[1]
55    }
56}
57
58// PRETTY PRINTING
59// ================================================================================================
60
61impl SplitNode {
62    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
63        SplitNodePrettyPrint { split_node: self, mast_forest }
64    }
65
66    pub(super) fn to_pretty_print<'a>(
67        &'a self,
68        mast_forest: &'a MastForest,
69    ) -> impl PrettyPrint + 'a {
70        SplitNodePrettyPrint { split_node: self, mast_forest }
71    }
72}
73
74struct SplitNodePrettyPrint<'a> {
75    split_node: &'a SplitNode,
76    mast_forest: &'a MastForest,
77}
78
79impl PrettyPrint for SplitNodePrettyPrint<'_> {
80    #[rustfmt::skip]
81    fn render(&self) -> crate::prettier::Document {
82        use crate::prettier::*;
83
84        let pre_decorators = {
85            let mut pre_decorators = self
86                .split_node
87                .before_enter(self.mast_forest)
88                .iter()
89                .map(|&decorator_id| self.mast_forest[decorator_id].render())
90                .reduce(|acc, doc| acc + const_text(" ") + doc)
91                .unwrap_or_default();
92            if !pre_decorators.is_empty() {
93                pre_decorators += nl();
94            }
95
96            pre_decorators
97        };
98
99        let post_decorators = {
100            let mut post_decorators = self
101                .split_node
102                .after_exit(self.mast_forest)
103                .iter()
104                .map(|&decorator_id| self.mast_forest[decorator_id].render())
105                .reduce(|acc, doc| acc + const_text(" ") + doc)
106                .unwrap_or_default();
107            if !post_decorators.is_empty() {
108                post_decorators = nl() + post_decorators;
109            }
110
111            post_decorators
112        };
113
114        let true_branch = self.mast_forest[self.split_node.on_true()].to_pretty_print(self.mast_forest);
115        let false_branch = self.mast_forest[self.split_node.on_false()].to_pretty_print(self.mast_forest);
116
117        let mut doc = pre_decorators;
118        doc += indent(4, const_text("if.true") + nl() + true_branch.render()) + nl();
119        doc += indent(4, const_text("else") + nl() + false_branch.render());
120        doc += nl() + const_text("end");
121        doc + post_decorators
122    }
123}
124
125impl fmt::Display for SplitNodePrettyPrint<'_> {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        use crate::prettier::PrettyPrint;
128        self.pretty_print(f)
129    }
130}
131
132// MAST NODE TRAIT IMPLEMENTATION
133// ================================================================================================
134
135impl MastNodeExt for SplitNode {
136    /// Returns a commitment to this Split node.
137    ///
138    /// The commitment is computed as a hash of the `on_true` and `on_false` child nodes in the
139    /// domain defined by [Self::DOMAIN] - i..e,:
140    /// ```
141    /// # use miden_core::mast::SplitNode;
142    /// # use miden_crypto::{Word, hash::poseidon2::Poseidon2 as Hasher};
143    /// # let on_true_digest = Word::default();
144    /// # let on_false_digest = Word::default();
145    /// Hasher::merge_in_domain(&[on_true_digest, on_false_digest], SplitNode::DOMAIN);
146    /// ```
147    fn digest(&self) -> Word {
148        self.digest
149    }
150
151    /// Returns the decorators to be executed before this node is executed.
152    fn before_enter<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
153        #[cfg(debug_assertions)]
154        self.verify_node_in_forest(forest);
155        self.decorator_store.before_enter(forest)
156    }
157
158    /// Returns the decorators to be executed after this node is executed.
159    fn after_exit<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
160        #[cfg(debug_assertions)]
161        self.verify_node_in_forest(forest);
162        self.decorator_store.after_exit(forest)
163    }
164
165    fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
166        Box::new(SplitNode::to_display(self, mast_forest))
167    }
168
169    fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
170        Box::new(SplitNode::to_pretty_print(self, mast_forest))
171    }
172
173    fn has_children(&self) -> bool {
174        true
175    }
176
177    fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
178        target.push(self.on_true());
179        target.push(self.on_false());
180    }
181
182    fn for_each_child<F>(&self, mut f: F)
183    where
184        F: FnMut(MastNodeId),
185    {
186        f(self.on_true());
187        f(self.on_false());
188    }
189
190    fn domain(&self) -> Felt {
191        Self::DOMAIN
192    }
193
194    type Builder = SplitNodeBuilder;
195
196    fn to_builder(self, forest: &MastForest) -> Self::Builder {
197        // Extract decorators from decorator_store if in Owned state
198        match self.decorator_store {
199            DecoratorStore::Owned { before_enter, after_exit, .. } => {
200                let mut builder = SplitNodeBuilder::new(self.branches);
201                builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
202                builder
203            },
204            DecoratorStore::Linked { id } => {
205                // Extract decorators from forest storage when in Linked state
206                let before_enter = forest.before_enter_decorators(id).to_vec();
207                let after_exit = forest.after_exit_decorators(id).to_vec();
208                let mut builder = SplitNodeBuilder::new(self.branches);
209                builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
210                builder
211            },
212        }
213    }
214
215    #[cfg(debug_assertions)]
216    fn verify_node_in_forest(&self, forest: &MastForest) {
217        if let Some(id) = self.decorator_store.linked_id() {
218            // Verify that this node is the one stored at the given ID in the forest
219            let self_ptr = self as *const Self;
220            let forest_node = &forest.nodes[id];
221            let forest_node_ptr = match forest_node {
222                MastNode::Split(split_node) => split_node as *const SplitNode as *const (),
223                _ => panic!("Node type mismatch at {:?}", id),
224            };
225            let self_as_void = self_ptr as *const ();
226            debug_assert_eq!(
227                self_as_void, forest_node_ptr,
228                "Node pointer mismatch: expected node at {:?} to be self",
229                id
230            );
231        }
232    }
233}
234
235// ARBITRARY IMPLEMENTATION
236// ================================================================================================
237
238#[cfg(all(feature = "arbitrary", test))]
239impl proptest::prelude::Arbitrary for SplitNode {
240    type Parameters = ();
241
242    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
243        use proptest::prelude::*;
244
245        use crate::Felt;
246
247        // Generate two MastNodeId values and digest for the children
248        (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
249            .prop_map(|(true_branch, false_branch, digest_array)| {
250                // Generate a random digest
251                let digest = Word::from(digest_array.map(Felt::new));
252                // Construct directly to avoid MastForest validation for arbitrary data
253                SplitNode {
254                    branches: [true_branch, false_branch],
255                    digest,
256                    decorator_store: DecoratorStore::default(),
257                }
258            })
259            .no_shrink()  // Pure random values, no meaningful shrinking pattern
260            .boxed()
261    }
262
263    type Strategy = proptest::prelude::BoxedStrategy<Self>;
264}
265
266// ------------------------------------------------------------------------------------------------
267/// Builder for creating [`SplitNode`] instances with decorators.
268#[derive(Debug)]
269pub struct SplitNodeBuilder {
270    branches: [MastNodeId; 2],
271    before_enter: Vec<DecoratorId>,
272    after_exit: Vec<DecoratorId>,
273    digest: Option<Word>,
274}
275
276impl SplitNodeBuilder {
277    /// Creates a new builder for a SplitNode with the specified branches.
278    pub fn new(branches: [MastNodeId; 2]) -> Self {
279        Self {
280            branches,
281            before_enter: Vec::new(),
282            after_exit: Vec::new(),
283            digest: None,
284        }
285    }
286
287    /// Builds the SplitNode with the specified decorators.
288    pub fn build(self, mast_forest: &MastForest) -> Result<SplitNode, MastForestError> {
289        let forest_len = mast_forest.nodes.len();
290        if self.branches[0].to_usize() >= forest_len {
291            return Err(MastForestError::NodeIdOverflow(self.branches[0], forest_len));
292        } else if self.branches[1].to_usize() >= forest_len {
293            return Err(MastForestError::NodeIdOverflow(self.branches[1], forest_len));
294        }
295
296        // Use the forced digest if provided, otherwise compute the digest
297        let digest = if let Some(forced_digest) = self.digest {
298            forced_digest
299        } else {
300            let true_branch_hash = mast_forest[self.branches[0]].digest();
301            let false_branch_hash = mast_forest[self.branches[1]].digest();
302
303            hasher::merge_in_domain(&[true_branch_hash, false_branch_hash], SplitNode::DOMAIN)
304        };
305
306        Ok(SplitNode {
307            branches: self.branches,
308            digest,
309            decorator_store: DecoratorStore::new_owned_with_decorators(
310                self.before_enter,
311                self.after_exit,
312            ),
313        })
314    }
315}
316
317impl MastForestContributor for SplitNodeBuilder {
318    fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
319        // Validate branch node IDs
320        let forest_len = forest.nodes.len();
321        if self.branches[0].to_usize() >= forest_len {
322            return Err(MastForestError::NodeIdOverflow(self.branches[0], forest_len));
323        } else if self.branches[1].to_usize() >= forest_len {
324            return Err(MastForestError::NodeIdOverflow(self.branches[1], forest_len));
325        }
326
327        // Use the forced digest if provided, otherwise compute the digest
328        let digest = if let Some(forced_digest) = self.digest {
329            forced_digest
330        } else {
331            let true_branch_hash = forest[self.branches[0]].digest();
332            let false_branch_hash = forest[self.branches[1]].digest();
333
334            hasher::merge_in_domain(&[true_branch_hash, false_branch_hash], SplitNode::DOMAIN)
335        };
336
337        // Determine the node ID that will be assigned
338        let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
339
340        // Store node-level decorators in the centralized NodeToDecoratorIds for efficient access
341        forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
342
343        // Create the node in the forest with Linked variant from the start
344        // Move the data directly without intermediate cloning
345        let node_id = forest
346            .nodes
347            .push(
348                SplitNode {
349                    branches: self.branches,
350                    digest,
351                    decorator_store: DecoratorStore::Linked { id: future_node_id },
352                }
353                .into(),
354            )
355            .map_err(|_| MastForestError::TooManyNodes)?;
356
357        Ok(node_id)
358    }
359
360    fn fingerprint_for_node(
361        &self,
362        forest: &MastForest,
363        hash_by_node_id: &impl LookupByIdx<MastNodeId, MastNodeFingerprint>,
364    ) -> Result<MastNodeFingerprint, MastForestError> {
365        // Use the fingerprint_from_parts helper function
366        crate::mast::node_fingerprint::fingerprint_from_parts(
367            forest,
368            hash_by_node_id,
369            &self.before_enter,
370            &self.after_exit,
371            &self.branches,
372            // Use the forced digest if available, otherwise compute the digest
373            if let Some(forced_digest) = self.digest {
374                forced_digest
375            } else {
376                let if_branch_hash = forest[self.branches[0]].digest();
377                let else_branch_hash = forest[self.branches[1]].digest();
378
379                crate::chiplets::hasher::merge_in_domain(
380                    &[if_branch_hash, else_branch_hash],
381                    SplitNode::DOMAIN,
382                )
383            },
384        )
385    }
386
387    fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
388        SplitNodeBuilder {
389            branches: [
390                *remapping.get(self.branches[0]).unwrap_or(&self.branches[0]),
391                *remapping.get(self.branches[1]).unwrap_or(&self.branches[1]),
392            ],
393            before_enter: self.before_enter,
394            after_exit: self.after_exit,
395            digest: self.digest,
396        }
397    }
398
399    fn with_before_enter(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
400        self.before_enter = decorators.into();
401        self
402    }
403
404    fn with_after_exit(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
405        self.after_exit = decorators.into();
406        self
407    }
408
409    fn append_before_enter(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
410        self.before_enter.extend(decorators);
411    }
412
413    fn append_after_exit(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
414        self.after_exit.extend(decorators);
415    }
416
417    fn with_digest(mut self, digest: crate::Word) -> Self {
418        self.digest = Some(digest);
419        self
420    }
421}
422
423impl SplitNodeBuilder {
424    /// Add this node to a forest using relaxed validation.
425    ///
426    /// This method is used during deserialization where nodes may reference child nodes
427    /// that haven't been added to the forest yet. The child node IDs have already been
428    /// validated against the expected final node count during the `try_into_mast_node_builder`
429    /// step, so we can safely skip validation here.
430    ///
431    /// Note: This is not part of the `MastForestContributor` trait because it's only
432    /// intended for internal use during deserialization.
433    pub(in crate::mast) fn add_to_forest_relaxed(
434        self,
435        forest: &mut MastForest,
436    ) -> Result<MastNodeId, MastForestError> {
437        // Use the forced digest if provided, otherwise use a default digest
438        // The actual digest computation will be handled when the forest is complete
439        let Some(digest) = self.digest else {
440            return Err(MastForestError::DigestRequiredForDeserialization);
441        };
442
443        let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
444
445        // Create the node in the forest with Linked variant from the start
446        // Move the data directly without intermediate cloning
447        let node_id = forest
448            .nodes
449            .push(
450                SplitNode {
451                    branches: self.branches,
452                    digest,
453                    decorator_store: DecoratorStore::Linked { id: future_node_id },
454                }
455                .into(),
456            )
457            .map_err(|_| MastForestError::TooManyNodes)?;
458
459        Ok(node_id)
460    }
461}
462
463#[cfg(any(test, feature = "arbitrary"))]
464impl proptest::prelude::Arbitrary for SplitNodeBuilder {
465    type Parameters = SplitNodeBuilderParams;
466    type Strategy = proptest::strategy::BoxedStrategy<Self>;
467
468    fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
469        use proptest::prelude::*;
470
471        (
472            any::<[MastNodeId; 2]>(),
473            proptest::collection::vec(
474                super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
475                0..=params.max_decorators,
476            ),
477            proptest::collection::vec(
478                super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
479                0..=params.max_decorators,
480            ),
481        )
482            .prop_map(|(branches, before_enter, after_exit)| {
483                Self::new(branches).with_before_enter(before_enter).with_after_exit(after_exit)
484            })
485            .boxed()
486    }
487}
488
489/// Parameters for generating SplitNodeBuilder instances
490#[cfg(any(test, feature = "arbitrary"))]
491#[derive(Clone, Debug)]
492pub struct SplitNodeBuilderParams {
493    pub max_decorators: usize,
494    pub max_decorator_id_u32: u32,
495}
496
497#[cfg(any(test, feature = "arbitrary"))]
498impl Default for SplitNodeBuilderParams {
499    fn default() -> Self {
500        Self {
501            max_decorators: 4,
502            max_decorator_id_u32: 10,
503        }
504    }
505}