Skip to main content

miden_core/mast/node/
split_node.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use super::{MastForestContributor, MastNodeExt};
8use crate::{
9    Felt, Word,
10    chiplets::hasher,
11    mast::{
12        DecoratorId, DecoratorStore, MastForest, MastForestError, MastNode, MastNodeFingerprint,
13        MastNodeId,
14    },
15    operations::OPCODE_SPLIT,
16    prettier::PrettyPrint,
17    utils::{Idx, LookupByIdx},
18};
19
20// SPLIT NODE
21// ================================================================================================
22
23/// A Split node defines conditional execution. When the VM encounters a Split node it executes
24/// either the `on_true` child or `on_false` child.
25///
26/// Which child is executed is determined based on the top of the stack. If the value is `1`, then
27/// the `on_true` child is executed. If the value is `0`, then the `on_false` child is executed. If
28/// the value is neither `0` nor `1`, the execution fails.
29#[derive(Debug, Clone, PartialEq, Eq)]
30#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
31#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
32pub struct SplitNode {
33    branches: [MastNodeId; 2],
34    digest: Word,
35    decorator_store: DecoratorStore,
36}
37
38/// Constants
39impl SplitNode {
40    /// The domain of the split node (used for control block hashing).
41    pub const DOMAIN: Felt = Felt::new(OPCODE_SPLIT as u64);
42}
43
44/// Public accessors
45impl SplitNode {
46    /// Returns the ID of the node which is to be executed if the top of the stack is `1`.
47    pub fn on_true(&self) -> MastNodeId {
48        self.branches[0]
49    }
50
51    /// Returns the ID of the node which is to be executed if the top of the stack is `0`.
52    pub fn on_false(&self) -> MastNodeId {
53        self.branches[1]
54    }
55}
56
57// PRETTY PRINTING
58// ================================================================================================
59
60impl SplitNode {
61    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
62        SplitNodePrettyPrint { split_node: self, mast_forest }
63    }
64
65    pub(super) fn to_pretty_print<'a>(
66        &'a self,
67        mast_forest: &'a MastForest,
68    ) -> impl PrettyPrint + 'a {
69        SplitNodePrettyPrint { split_node: self, mast_forest }
70    }
71}
72
73struct SplitNodePrettyPrint<'a> {
74    split_node: &'a SplitNode,
75    mast_forest: &'a MastForest,
76}
77
78impl PrettyPrint for SplitNodePrettyPrint<'_> {
79    #[rustfmt::skip]
80    fn render(&self) -> crate::prettier::Document {
81        use crate::prettier::*;
82
83        let pre_decorators = {
84            let mut pre_decorators = self
85                .split_node
86                .before_enter(self.mast_forest)
87                .iter()
88                .map(|&decorator_id| self.mast_forest[decorator_id].render())
89                .reduce(|acc, doc| acc + const_text(" ") + doc)
90                .unwrap_or_default();
91            if !pre_decorators.is_empty() {
92                pre_decorators += nl();
93            }
94
95            pre_decorators
96        };
97
98        let post_decorators = {
99            let mut post_decorators = self
100                .split_node
101                .after_exit(self.mast_forest)
102                .iter()
103                .map(|&decorator_id| self.mast_forest[decorator_id].render())
104                .reduce(|acc, doc| acc + const_text(" ") + doc)
105                .unwrap_or_default();
106            if !post_decorators.is_empty() {
107                post_decorators = nl() + post_decorators;
108            }
109
110            post_decorators
111        };
112
113        let true_branch = self.mast_forest[self.split_node.on_true()].to_pretty_print(self.mast_forest);
114        let false_branch = self.mast_forest[self.split_node.on_false()].to_pretty_print(self.mast_forest);
115
116        let mut doc = pre_decorators;
117        doc += indent(4, const_text("if.true") + nl() + true_branch.render()) + nl();
118        doc += indent(4, const_text("else") + nl() + false_branch.render());
119        doc += nl() + const_text("end");
120        doc + post_decorators
121    }
122}
123
124impl fmt::Display for SplitNodePrettyPrint<'_> {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        use crate::prettier::PrettyPrint;
127        self.pretty_print(f)
128    }
129}
130
131// MAST NODE TRAIT IMPLEMENTATION
132// ================================================================================================
133
134impl MastNodeExt for SplitNode {
135    /// Returns a commitment to this Split node.
136    ///
137    /// The commitment is computed as a hash of the `on_true` and `on_false` child nodes in the
138    /// domain defined by [Self::DOMAIN] - i..e,:
139    /// ```
140    /// # use miden_core::mast::SplitNode;
141    /// # use miden_crypto::{Word, hash::poseidon2::Poseidon2 as Hasher};
142    /// # let on_true_digest = Word::default();
143    /// # let on_false_digest = Word::default();
144    /// Hasher::merge_in_domain(&[on_true_digest, on_false_digest], SplitNode::DOMAIN);
145    /// ```
146    fn digest(&self) -> Word {
147        self.digest
148    }
149
150    /// Returns the decorators to be executed before this node is executed.
151    fn before_enter<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
152        #[cfg(debug_assertions)]
153        self.verify_node_in_forest(forest);
154        self.decorator_store.before_enter(forest)
155    }
156
157    /// Returns the decorators to be executed after this node is executed.
158    fn after_exit<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
159        #[cfg(debug_assertions)]
160        self.verify_node_in_forest(forest);
161        self.decorator_store.after_exit(forest)
162    }
163
164    fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
165        Box::new(SplitNode::to_display(self, mast_forest))
166    }
167
168    fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
169        Box::new(SplitNode::to_pretty_print(self, mast_forest))
170    }
171
172    fn has_children(&self) -> bool {
173        true
174    }
175
176    fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
177        target.push(self.on_true());
178        target.push(self.on_false());
179    }
180
181    fn for_each_child<F>(&self, mut f: F)
182    where
183        F: FnMut(MastNodeId),
184    {
185        f(self.on_true());
186        f(self.on_false());
187    }
188
189    fn domain(&self) -> Felt {
190        Self::DOMAIN
191    }
192
193    type Builder = SplitNodeBuilder;
194
195    fn to_builder(self, forest: &MastForest) -> Self::Builder {
196        // Extract decorators from decorator_store if in Owned state
197        match self.decorator_store {
198            DecoratorStore::Owned { before_enter, after_exit, .. } => {
199                let mut builder = SplitNodeBuilder::new(self.branches);
200                builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
201                builder
202            },
203            DecoratorStore::Linked { id } => {
204                // Extract decorators from forest storage when in Linked state
205                let before_enter = forest.before_enter_decorators(id).to_vec();
206                let after_exit = forest.after_exit_decorators(id).to_vec();
207                let mut builder = SplitNodeBuilder::new(self.branches);
208                builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
209                builder
210            },
211        }
212    }
213
214    #[cfg(debug_assertions)]
215    fn verify_node_in_forest(&self, forest: &MastForest) {
216        if let Some(id) = self.decorator_store.linked_id() {
217            // Verify that this node is the one stored at the given ID in the forest
218            let self_ptr = self as *const Self;
219            let forest_node = &forest.nodes[id];
220            let forest_node_ptr = match forest_node {
221                MastNode::Split(split_node) => split_node as *const SplitNode as *const (),
222                _ => panic!("Node type mismatch at {:?}", id),
223            };
224            let self_as_void = self_ptr as *const ();
225            debug_assert_eq!(
226                self_as_void, forest_node_ptr,
227                "Node pointer mismatch: expected node at {:?} to be self",
228                id
229            );
230        }
231    }
232}
233
234// ARBITRARY IMPLEMENTATION
235// ================================================================================================
236
237#[cfg(all(feature = "arbitrary", test))]
238impl proptest::prelude::Arbitrary for SplitNode {
239    type Parameters = ();
240
241    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
242        use proptest::prelude::*;
243
244        use crate::Felt;
245
246        // Generate two MastNodeId values and digest for the children
247        (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
248            .prop_map(|(true_branch, false_branch, digest_array)| {
249                // Generate a random digest
250                let digest = Word::from(digest_array.map(Felt::new));
251                // Construct directly to avoid MastForest validation for arbitrary data
252                SplitNode {
253                    branches: [true_branch, false_branch],
254                    digest,
255                    decorator_store: DecoratorStore::default(),
256                }
257            })
258            .no_shrink()  // Pure random values, no meaningful shrinking pattern
259            .boxed()
260    }
261
262    type Strategy = proptest::prelude::BoxedStrategy<Self>;
263}
264
265// ------------------------------------------------------------------------------------------------
266/// Builder for creating [`SplitNode`] instances with decorators.
267#[derive(Debug)]
268pub struct SplitNodeBuilder {
269    branches: [MastNodeId; 2],
270    before_enter: Vec<DecoratorId>,
271    after_exit: Vec<DecoratorId>,
272    digest: Option<Word>,
273}
274
275impl SplitNodeBuilder {
276    /// Creates a new builder for a SplitNode with the specified branches.
277    pub fn new(branches: [MastNodeId; 2]) -> Self {
278        Self {
279            branches,
280            before_enter: Vec::new(),
281            after_exit: Vec::new(),
282            digest: None,
283        }
284    }
285
286    /// Builds the SplitNode with the specified decorators.
287    pub fn build(self, mast_forest: &MastForest) -> Result<SplitNode, MastForestError> {
288        let forest_len = mast_forest.nodes.len();
289        if self.branches[0].to_usize() >= forest_len {
290            return Err(MastForestError::NodeIdOverflow(self.branches[0], forest_len));
291        } else if self.branches[1].to_usize() >= forest_len {
292            return Err(MastForestError::NodeIdOverflow(self.branches[1], forest_len));
293        }
294
295        // Use the forced digest if provided, otherwise compute the digest
296        let digest = if let Some(forced_digest) = self.digest {
297            forced_digest
298        } else {
299            let true_branch_hash = mast_forest[self.branches[0]].digest();
300            let false_branch_hash = mast_forest[self.branches[1]].digest();
301
302            hasher::merge_in_domain(&[true_branch_hash, false_branch_hash], SplitNode::DOMAIN)
303        };
304
305        Ok(SplitNode {
306            branches: self.branches,
307            digest,
308            decorator_store: DecoratorStore::new_owned_with_decorators(
309                self.before_enter,
310                self.after_exit,
311            ),
312        })
313    }
314}
315
316impl MastForestContributor for SplitNodeBuilder {
317    fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
318        // Validate branch node IDs
319        let forest_len = forest.nodes.len();
320        if self.branches[0].to_usize() >= forest_len {
321            return Err(MastForestError::NodeIdOverflow(self.branches[0], forest_len));
322        } else if self.branches[1].to_usize() >= forest_len {
323            return Err(MastForestError::NodeIdOverflow(self.branches[1], forest_len));
324        }
325
326        // Use the forced digest if provided, otherwise compute the digest
327        let digest = if let Some(forced_digest) = self.digest {
328            forced_digest
329        } else {
330            let true_branch_hash = forest[self.branches[0]].digest();
331            let false_branch_hash = forest[self.branches[1]].digest();
332
333            hasher::merge_in_domain(&[true_branch_hash, false_branch_hash], SplitNode::DOMAIN)
334        };
335
336        // Determine the node ID that will be assigned
337        let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
338
339        // Store node-level decorators in the centralized NodeToDecoratorIds for efficient access
340        forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
341
342        // Create the node in the forest with Linked variant from the start
343        // Move the data directly without intermediate cloning
344        let node_id = forest
345            .nodes
346            .push(
347                SplitNode {
348                    branches: self.branches,
349                    digest,
350                    decorator_store: DecoratorStore::Linked { id: future_node_id },
351                }
352                .into(),
353            )
354            .map_err(|_| MastForestError::TooManyNodes)?;
355
356        Ok(node_id)
357    }
358
359    fn fingerprint_for_node(
360        &self,
361        forest: &MastForest,
362        hash_by_node_id: &impl LookupByIdx<MastNodeId, MastNodeFingerprint>,
363    ) -> Result<MastNodeFingerprint, MastForestError> {
364        // Use the fingerprint_from_parts helper function
365        crate::mast::node_fingerprint::fingerprint_from_parts(
366            forest,
367            hash_by_node_id,
368            &self.before_enter,
369            &self.after_exit,
370            &self.branches,
371            // Use the forced digest if available, otherwise compute the digest
372            if let Some(forced_digest) = self.digest {
373                forced_digest
374            } else {
375                let if_branch_hash = forest[self.branches[0]].digest();
376                let else_branch_hash = forest[self.branches[1]].digest();
377
378                crate::chiplets::hasher::merge_in_domain(
379                    &[if_branch_hash, else_branch_hash],
380                    SplitNode::DOMAIN,
381                )
382            },
383        )
384    }
385
386    fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
387        SplitNodeBuilder {
388            branches: [
389                *remapping.get(self.branches[0]).unwrap_or(&self.branches[0]),
390                *remapping.get(self.branches[1]).unwrap_or(&self.branches[1]),
391            ],
392            before_enter: self.before_enter,
393            after_exit: self.after_exit,
394            digest: self.digest,
395        }
396    }
397
398    fn with_before_enter(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
399        self.before_enter = decorators.into();
400        self
401    }
402
403    fn with_after_exit(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
404        self.after_exit = decorators.into();
405        self
406    }
407
408    fn append_before_enter(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
409        self.before_enter.extend(decorators);
410    }
411
412    fn append_after_exit(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
413        self.after_exit.extend(decorators);
414    }
415
416    fn with_digest(mut self, digest: crate::Word) -> Self {
417        self.digest = Some(digest);
418        self
419    }
420}
421
422impl SplitNodeBuilder {
423    /// Add this node to a forest using relaxed validation.
424    ///
425    /// This method is used during deserialization where nodes may reference child nodes
426    /// that haven't been added to the forest yet. The child node IDs have already been
427    /// validated against the expected final node count during the `try_into_mast_node_builder`
428    /// step, so we can safely skip validation here.
429    ///
430    /// Note: This is not part of the `MastForestContributor` trait because it's only
431    /// intended for internal use during deserialization.
432    pub(in crate::mast) fn add_to_forest_relaxed(
433        self,
434        forest: &mut MastForest,
435    ) -> Result<MastNodeId, MastForestError> {
436        // Use the forced digest if provided, otherwise use a default digest
437        // The actual digest computation will be handled when the forest is complete
438        let Some(digest) = self.digest else {
439            return Err(MastForestError::DigestRequiredForDeserialization);
440        };
441
442        let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
443
444        // Create the node in the forest with Linked variant from the start
445        // Move the data directly without intermediate cloning
446        let node_id = forest
447            .nodes
448            .push(
449                SplitNode {
450                    branches: self.branches,
451                    digest,
452                    decorator_store: DecoratorStore::Linked { id: future_node_id },
453                }
454                .into(),
455            )
456            .map_err(|_| MastForestError::TooManyNodes)?;
457
458        Ok(node_id)
459    }
460}
461
462#[cfg(any(test, feature = "arbitrary"))]
463impl proptest::prelude::Arbitrary for SplitNodeBuilder {
464    type Parameters = SplitNodeBuilderParams;
465    type Strategy = proptest::strategy::BoxedStrategy<Self>;
466
467    fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
468        use proptest::prelude::*;
469
470        (
471            any::<[MastNodeId; 2]>(),
472            proptest::collection::vec(
473                super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
474                0..=params.max_decorators,
475            ),
476            proptest::collection::vec(
477                super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
478                0..=params.max_decorators,
479            ),
480        )
481            .prop_map(|(branches, before_enter, after_exit)| {
482                Self::new(branches).with_before_enter(before_enter).with_after_exit(after_exit)
483            })
484            .boxed()
485    }
486}
487
488/// Parameters for generating SplitNodeBuilder instances
489#[cfg(any(test, feature = "arbitrary"))]
490#[derive(Clone, Debug)]
491pub struct SplitNodeBuilderParams {
492    pub max_decorators: usize,
493    pub max_decorator_id_u32: u32,
494}
495
496#[cfg(any(test, feature = "arbitrary"))]
497impl Default for SplitNodeBuilderParams {
498    fn default() -> Self {
499        Self {
500            max_decorators: 4,
501            max_decorator_id_u32: 10,
502        }
503    }
504}