Skip to main content

miden_core/mast/node/
split_node.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use super::{MastForestContributor, MastNodeExt, fingerprint_with_child_fingerprints};
8use crate::{
9    Felt, Word,
10    chiplets::hasher,
11    mast::{MastForest, MastForestError, MastNodeId},
12    operations::opcodes,
13    prettier::PrettyPrint,
14    utils::{Idx, LookupByIdx},
15};
16
17// SPLIT NODE
18// ================================================================================================
19
20/// A Split node defines conditional execution. When the VM encounters a Split node it executes
21/// either the `on_true` child or `on_false` child.
22///
23/// Which child is executed is determined based on the top of the stack. If the value is `1`, then
24/// the `on_true` child is executed. If the value is `0`, then the `on_false` child is executed. If
25/// the value is neither `0` nor `1`, the execution fails.
26#[derive(Debug, Clone, PartialEq, Eq)]
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
29pub struct SplitNode {
30    branches: [MastNodeId; 2],
31    digest: Word,
32}
33
34/// Constants
35impl SplitNode {
36    /// The domain of the split node (used for control block hashing).
37    pub const DOMAIN: Felt = Felt::new_unchecked(opcodes::SPLIT as u64);
38}
39
40/// Public accessors
41impl SplitNode {
42    /// Returns the ID of the node which is to be executed if the top of the stack is `1`.
43    pub fn on_true(&self) -> MastNodeId {
44        self.branches[0]
45    }
46
47    /// Returns the ID of the node which is to be executed if the top of the stack is `0`.
48    pub fn on_false(&self) -> MastNodeId {
49        self.branches[1]
50    }
51}
52
53// PRETTY PRINTING
54// ================================================================================================
55
56impl SplitNode {
57    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
58        SplitNodePrettyPrint { split_node: self, mast_forest }
59    }
60
61    pub(super) fn to_pretty_print<'a>(
62        &'a self,
63        mast_forest: &'a MastForest,
64    ) -> impl PrettyPrint + 'a {
65        SplitNodePrettyPrint { split_node: self, mast_forest }
66    }
67}
68
69struct SplitNodePrettyPrint<'a> {
70    split_node: &'a SplitNode,
71    mast_forest: &'a MastForest,
72}
73
74impl PrettyPrint for SplitNodePrettyPrint<'_> {
75    #[rustfmt::skip]
76    fn render(&self) -> crate::prettier::Document {
77        use crate::prettier::*;
78
79        let true_branch = self.mast_forest[self.split_node.on_true()].to_pretty_print(self.mast_forest);
80        let false_branch = self.mast_forest[self.split_node.on_false()].to_pretty_print(self.mast_forest);
81
82        let mut doc = Document::Empty;
83        doc += indent(4, const_text("if.true") + nl() + true_branch.render()) + nl();
84        doc += indent(4, const_text("else") + nl() + false_branch.render());
85        doc += nl() + const_text("end");
86        doc
87    }
88}
89
90impl fmt::Display for SplitNodePrettyPrint<'_> {
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        use crate::prettier::PrettyPrint;
93        self.pretty_print(f)
94    }
95}
96
97// MAST NODE TRAIT IMPLEMENTATION
98// ================================================================================================
99
100impl MastNodeExt for SplitNode {
101    /// Returns a commitment to this Split node.
102    ///
103    /// The commitment is computed as a hash of the `on_true` and `on_false` child nodes in the
104    /// domain defined by [Self::DOMAIN] - i..e,:
105    /// ```
106    /// # use miden_core::mast::SplitNode;
107    /// # use miden_crypto::{Word, hash::poseidon2::Poseidon2 as Hasher};
108    /// # let on_true_digest = Word::default();
109    /// # let on_false_digest = Word::default();
110    /// Hasher::merge_in_domain(&[on_true_digest, on_false_digest], SplitNode::DOMAIN);
111    /// ```
112    fn digest(&self) -> Word {
113        self.digest
114    }
115
116    fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
117        Box::new(SplitNode::to_display(self, mast_forest))
118    }
119
120    fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
121        Box::new(SplitNode::to_pretty_print(self, mast_forest))
122    }
123
124    fn has_children(&self) -> bool {
125        true
126    }
127
128    fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
129        target.push(self.on_true());
130        target.push(self.on_false());
131    }
132
133    fn for_each_child<F>(&self, mut f: F)
134    where
135        F: FnMut(MastNodeId),
136    {
137        f(self.on_true());
138        f(self.on_false());
139    }
140
141    fn domain(&self) -> Felt {
142        Self::DOMAIN
143    }
144
145    type Builder = SplitNodeBuilder;
146
147    fn to_builder(self, _forest: &MastForest) -> Self::Builder {
148        SplitNodeBuilder::new(self.branches).with_digest(self.digest)
149    }
150}
151
152// ARBITRARY IMPLEMENTATION
153// ================================================================================================
154
155#[cfg(all(feature = "arbitrary", test))]
156impl proptest::prelude::Arbitrary for SplitNode {
157    type Parameters = ();
158
159    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
160        use proptest::prelude::*;
161
162        use crate::Felt;
163
164        // Generate two MastNodeId values and digest for the children
165        (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
166            .prop_map(|(true_branch, false_branch, digest_array)| {
167                // Generate a random digest
168                let digest = Word::from(digest_array.map(Felt::new_unchecked));
169                // Construct directly to avoid MastForest validation for arbitrary data
170                SplitNode {
171                    branches: [true_branch, false_branch],
172                    digest,
173                }
174            })
175            .no_shrink()  // Pure random values, no meaningful shrinking pattern
176            .boxed()
177    }
178
179    type Strategy = proptest::prelude::BoxedStrategy<Self>;
180}
181
182// ------------------------------------------------------------------------------------------------
183/// Builder for creating [`SplitNode`] instances.
184#[derive(Debug)]
185pub struct SplitNodeBuilder {
186    branches: [MastNodeId; 2],
187    digest: Option<Word>,
188}
189
190impl SplitNodeBuilder {
191    /// Creates a new builder for a SplitNode with the specified branches.
192    pub fn new(branches: [MastNodeId; 2]) -> Self {
193        Self { branches, digest: None }
194    }
195
196    /// Builds the SplitNode.
197    pub fn build(self, mast_forest: &MastForest) -> Result<SplitNode, MastForestError> {
198        let forest_len = mast_forest.nodes.len();
199        if self.branches[0].to_usize() >= forest_len {
200            return Err(MastForestError::NodeIdOverflow(self.branches[0], forest_len));
201        } else if self.branches[1].to_usize() >= forest_len {
202            return Err(MastForestError::NodeIdOverflow(self.branches[1], forest_len));
203        }
204
205        // Use the forced digest if provided, otherwise compute the digest
206        let digest = if let Some(forced_digest) = self.digest {
207            forced_digest
208        } else {
209            let true_branch_hash = mast_forest[self.branches[0]].digest();
210            let false_branch_hash = mast_forest[self.branches[1]].digest();
211
212            hasher::merge_in_domain(&[true_branch_hash, false_branch_hash], SplitNode::DOMAIN)
213        };
214
215        Ok(SplitNode { branches: self.branches, digest })
216    }
217
218    pub(in crate::mast) fn build_linked(self) -> Result<SplitNode, MastForestError> {
219        Ok(SplitNode {
220            branches: self.branches,
221            digest: self.digest.ok_or(MastForestError::DigestRequiredForDeserialization)?,
222        })
223    }
224}
225
226impl MastForestContributor for SplitNodeBuilder {
227    fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
228        // Validate branch node IDs
229        let forest_len = forest.nodes.len();
230        if self.branches[0].to_usize() >= forest_len {
231            return Err(MastForestError::NodeIdOverflow(self.branches[0], forest_len));
232        } else if self.branches[1].to_usize() >= forest_len {
233            return Err(MastForestError::NodeIdOverflow(self.branches[1], forest_len));
234        }
235
236        // Use the forced digest if provided, otherwise compute the digest
237        let digest = if let Some(forced_digest) = self.digest {
238            forced_digest
239        } else {
240            let true_branch_hash = forest[self.branches[0]].digest();
241            let false_branch_hash = forest[self.branches[1]].digest();
242
243            hasher::merge_in_domain(&[true_branch_hash, false_branch_hash], SplitNode::DOMAIN)
244        };
245
246        // Create the node in the forest with Linked variant from the start
247        // Move the data directly without intermediate cloning
248        let node_id = forest
249            .nodes
250            .push(SplitNode { branches: self.branches, digest }.into())
251            .map_err(|_| MastForestError::TooManyNodes)?;
252
253        Ok(node_id)
254    }
255
256    fn fingerprint_for_node(
257        &self,
258        forest: &MastForest,
259        hash_by_node_id: &impl LookupByIdx<MastNodeId, Word>,
260    ) -> Result<Word, MastForestError> {
261        let node_digest = if let Some(forced_digest) = self.digest {
262            forced_digest
263        } else {
264            let if_branch_hash = forest[self.branches[0]].digest();
265            let else_branch_hash = forest[self.branches[1]].digest();
266
267            hasher::merge_in_domain(&[if_branch_hash, else_branch_hash], SplitNode::DOMAIN)
268        };
269
270        fingerprint_with_child_fingerprints(node_digest, &self.branches, forest, hash_by_node_id)
271    }
272
273    fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
274        SplitNodeBuilder {
275            branches: [
276                *remapping.get(self.branches[0]).unwrap_or(&self.branches[0]),
277                *remapping.get(self.branches[1]).unwrap_or(&self.branches[1]),
278            ],
279            digest: self.digest,
280        }
281    }
282
283    fn with_digest(mut self, digest: Word) -> Self {
284        self.digest = Some(digest);
285        self
286    }
287}
288
289impl SplitNodeBuilder {
290    /// Add this node to a forest using relaxed validation.
291    ///
292    /// This method is used during deserialization where nodes may reference child nodes
293    /// that haven't been added to the forest yet. The child node IDs have already been
294    /// validated against the expected final node count during the `try_into_mast_node_builder`
295    /// step, so we can safely skip validation here.
296    ///
297    /// Note: This is not part of the `MastForestContributor` trait because it's only
298    /// intended for internal use during deserialization.
299    pub(in crate::mast) fn add_to_forest_relaxed(
300        self,
301        forest: &mut MastForest,
302    ) -> Result<MastNodeId, MastForestError> {
303        // Use the forced digest if provided, otherwise use a default digest
304        // The actual digest computation will be handled when the forest is complete
305        let Some(digest) = self.digest else {
306            return Err(MastForestError::DigestRequiredForDeserialization);
307        };
308
309        // Create the node in the forest with Linked variant from the start
310        // Move the data directly without intermediate cloning
311        let node_id = forest
312            .nodes
313            .push(SplitNode { branches: self.branches, digest }.into())
314            .map_err(|_| MastForestError::TooManyNodes)?;
315
316        Ok(node_id)
317    }
318}
319
320#[cfg(any(test, feature = "arbitrary"))]
321impl proptest::prelude::Arbitrary for SplitNodeBuilder {
322    type Parameters = SplitNodeBuilderParams;
323    type Strategy = proptest::strategy::BoxedStrategy<Self>;
324
325    fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
326        use proptest::prelude::*;
327
328        let _ = params;
329        any::<[MastNodeId; 2]>().prop_map(Self::new).boxed()
330    }
331}
332
333/// Parameters for generating SplitNodeBuilder instances
334#[cfg(any(test, feature = "arbitrary"))]
335#[derive(Clone, Debug, Default)]
336pub struct SplitNodeBuilderParams {}