Skip to main content

miden_core/mast/node/
join_node.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use super::{MastForestContributor, MastNodeExt};
8#[cfg(debug_assertions)]
9use crate::mast::MastNode;
10use crate::{
11    Felt, Word,
12    chiplets::hasher,
13    mast::{
14        DecoratorId, DecoratorStore, MastForest, MastForestError, MastNodeFingerprint, MastNodeId,
15    },
16    operations::opcodes,
17    prettier::PrettyPrint,
18    utils::{Idx, LookupByIdx},
19};
20
21// JOIN NODE
22// ================================================================================================
23
24/// A Join node describe sequential execution. When the VM encounters a Join node, it executes the
25/// first child first and the second child second.
26#[derive(Debug, Clone, PartialEq, Eq)]
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
29pub struct JoinNode {
30    children: [MastNodeId; 2],
31    digest: Word,
32    decorator_store: DecoratorStore,
33}
34
35/// Constants
36impl JoinNode {
37    /// The domain of the join block (used for control block hashing).
38    pub const DOMAIN: Felt = Felt::new_unchecked(opcodes::JOIN as u64);
39}
40
41/// Public accessors
42impl JoinNode {
43    /// Returns the ID of the node that is to be executed first.
44    pub fn first(&self) -> MastNodeId {
45        self.children[0]
46    }
47
48    /// Returns the ID of the node that is to be executed after the execution of the program
49    /// defined by the first node completes.
50    pub fn second(&self) -> MastNodeId {
51        self.children[1]
52    }
53}
54
55// PRETTY PRINTING
56// ================================================================================================
57
58impl JoinNode {
59    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
60        JoinNodePrettyPrint { join_node: self, mast_forest }
61    }
62
63    pub(super) fn to_pretty_print<'a>(
64        &'a self,
65        mast_forest: &'a MastForest,
66    ) -> impl PrettyPrint + 'a {
67        JoinNodePrettyPrint { join_node: self, mast_forest }
68    }
69}
70
71struct JoinNodePrettyPrint<'a> {
72    join_node: &'a JoinNode,
73    mast_forest: &'a MastForest,
74}
75
76impl PrettyPrint for JoinNodePrettyPrint<'_> {
77    #[rustfmt::skip]
78    fn render(&self) -> crate::prettier::Document {
79        use crate::prettier::*;
80
81        let pre_decorators = {
82            let mut pre_decorators = self
83                .join_node
84                .before_enter(self.mast_forest)
85                .iter()
86                .map(|&decorator_id| self.mast_forest[decorator_id].render())
87                .reduce(|acc, doc| acc + const_text(" ") + doc)
88                .unwrap_or_default();
89            if !pre_decorators.is_empty() {
90                pre_decorators += nl();
91            }
92
93            pre_decorators
94        };
95
96        let post_decorators = {
97            let mut post_decorators = self
98                .join_node
99                .after_exit(self.mast_forest)
100                .iter()
101                .map(|&decorator_id| self.mast_forest[decorator_id].render())
102                .reduce(|acc, doc| acc + const_text(" ") + doc)
103                .unwrap_or_default();
104            if !post_decorators.is_empty() {
105                post_decorators = nl() + post_decorators;
106            }
107
108            post_decorators
109        };
110
111        let first_child =
112            self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
113        let second_child =
114            self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
115
116        pre_decorators
117        + indent(
118            4,
119            const_text("join")
120            + nl()
121            + first_child.render()
122            + nl()
123            + second_child.render(),
124        ) + nl() + const_text("end")
125        + post_decorators
126    }
127}
128
129impl fmt::Display for JoinNodePrettyPrint<'_> {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        use crate::prettier::PrettyPrint;
132        self.pretty_print(f)
133    }
134}
135
136// SEMANTIC EQUALITY (FOR TESTING)
137// ================================================================================================
138
139#[cfg(test)]
140impl JoinNode {
141    /// Checks if two JoinNodes are semantically equal (i.e., they represent the same join
142    /// operation).
143    ///
144    /// Unlike the derived PartialEq, this method works correctly with both owned and linked
145    /// decorator storage by accessing the actual decorator data from the forest when needed.
146    #[cfg(test)]
147    pub fn semantic_eq(&self, other: &JoinNode, forest: &MastForest) -> bool {
148        // Compare children
149        if self.first() != other.first() || self.second() != other.second() {
150            return false;
151        }
152
153        // Compare digests
154        if self.digest() != other.digest() {
155            return false;
156        }
157
158        // Compare before-enter decorators
159        if self.before_enter(forest) != other.before_enter(forest) {
160            return false;
161        }
162
163        // Compare after-exit decorators
164        if self.after_exit(forest) != other.after_exit(forest) {
165            return false;
166        }
167
168        true
169    }
170}
171
172// MAST NODE TRAIT IMPLEMENTATION
173// ================================================================================================
174
175impl MastNodeExt for JoinNode {
176    /// Returns a commitment to this Join node.
177    ///
178    /// The commitment is computed as a hash of the `first` and `second` child node in the domain
179    /// defined by [Self::DOMAIN] - i.e.,:
180    /// ```
181    /// # use miden_core::mast::JoinNode;
182    /// # use miden_crypto::{Word, hash::poseidon2::Poseidon2 as Hasher};
183    /// # let first_child_digest = Word::default();
184    /// # let second_child_digest = Word::default();
185    /// Hasher::merge_in_domain(&[first_child_digest, second_child_digest], JoinNode::DOMAIN);
186    /// ```
187    fn digest(&self) -> Word {
188        self.digest
189    }
190
191    /// Returns the decorators to be executed before this node is executed.
192    fn before_enter<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
193        #[cfg(debug_assertions)]
194        self.verify_node_in_forest(forest);
195        self.decorator_store.before_enter(forest)
196    }
197
198    /// Returns the decorators to be executed after this node is executed.
199    fn after_exit<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
200        #[cfg(debug_assertions)]
201        self.verify_node_in_forest(forest);
202        self.decorator_store.after_exit(forest)
203    }
204
205    fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
206        Box::new(JoinNode::to_display(self, mast_forest))
207    }
208
209    fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
210        Box::new(JoinNode::to_pretty_print(self, mast_forest))
211    }
212
213    fn has_children(&self) -> bool {
214        true
215    }
216
217    fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
218        target.push(self.first());
219        target.push(self.second());
220    }
221
222    fn for_each_child<F>(&self, mut f: F)
223    where
224        F: FnMut(MastNodeId),
225    {
226        f(self.first());
227        f(self.second());
228    }
229
230    fn domain(&self) -> Felt {
231        Self::DOMAIN
232    }
233
234    type Builder = JoinNodeBuilder;
235
236    fn to_builder(self, forest: &MastForest) -> Self::Builder {
237        // Extract decorators from decorator_store if in Owned state
238        match self.decorator_store {
239            DecoratorStore::Owned { before_enter, after_exit, .. } => {
240                let mut builder = JoinNodeBuilder::new(self.children);
241                builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
242                builder
243            },
244            DecoratorStore::Linked { id } => {
245                // Extract decorators from forest storage when in Linked state
246                let before_enter = forest.before_enter_decorators(id).to_vec();
247                let after_exit = forest.after_exit_decorators(id).to_vec();
248                let mut builder = JoinNodeBuilder::new(self.children);
249                builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
250                builder
251            },
252        }
253    }
254
255    #[cfg(debug_assertions)]
256    fn verify_node_in_forest(&self, forest: &MastForest) {
257        if let Some(id) = self.decorator_store.linked_id() {
258            // Verify that this node is the one stored at the given ID in the forest
259            let self_ptr = self as *const Self;
260            let forest_node = &forest.nodes[id];
261            let forest_node_ptr = match forest_node {
262                MastNode::Join(join_node) => join_node as *const JoinNode as *const (),
263                _ => panic!("Node type mismatch at {id:?}"),
264            };
265            let self_as_void = self_ptr as *const ();
266            debug_assert_eq!(
267                self_as_void, forest_node_ptr,
268                "Node pointer mismatch: expected node at {id:?} to be self"
269            );
270        }
271    }
272}
273
274// ARBITRARY IMPLEMENTATION
275// ================================================================================================
276
277#[cfg(all(feature = "arbitrary", test))]
278impl proptest::prelude::Arbitrary for JoinNode {
279    type Parameters = ();
280
281    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
282        use proptest::prelude::*;
283
284        use crate::Felt;
285
286        // Generate two MastNodeId values and digest for the children
287        (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
288            .prop_map(|(first_child, second_child, digest_array)| {
289                // Generate a random digest
290                let digest = Word::from(digest_array.map(Felt::new_unchecked));
291                // Construct directly to avoid MastForest validation for arbitrary data
292                JoinNode {
293                    children: [first_child, second_child],
294                    digest,
295                    decorator_store: DecoratorStore::default(),
296                }
297            })
298            .no_shrink()  // Pure random values, no meaningful shrinking pattern
299            .boxed()
300    }
301
302    type Strategy = proptest::prelude::BoxedStrategy<Self>;
303}
304
305// ------------------------------------------------------------------------------------------------
306/// Builder for creating [`JoinNode`] instances with decorators.
307#[derive(Debug)]
308pub struct JoinNodeBuilder {
309    children: [MastNodeId; 2],
310    before_enter: Vec<DecoratorId>,
311    after_exit: Vec<DecoratorId>,
312    digest: Option<Word>,
313}
314
315impl JoinNodeBuilder {
316    /// Creates a new builder for a JoinNode with the specified children.
317    pub fn new(children: [MastNodeId; 2]) -> Self {
318        Self {
319            children,
320            before_enter: Vec::new(),
321            after_exit: Vec::new(),
322            digest: None,
323        }
324    }
325
326    /// Builds the JoinNode with the specified decorators.
327    pub fn build(self, mast_forest: &MastForest) -> Result<JoinNode, MastForestError> {
328        let forest_len = mast_forest.nodes.len();
329        if self.children[0].to_usize() >= forest_len {
330            return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
331        } else if self.children[1].to_usize() >= forest_len {
332            return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
333        }
334
335        // Use the forced digest if provided, otherwise compute the digest
336        let digest = if let Some(forced_digest) = self.digest {
337            forced_digest
338        } else {
339            let left_child_hash = mast_forest[self.children[0]].digest();
340            let right_child_hash = mast_forest[self.children[1]].digest();
341
342            hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
343        };
344
345        Ok(JoinNode {
346            children: self.children,
347            digest,
348            decorator_store: DecoratorStore::new_owned_with_decorators(
349                self.before_enter,
350                self.after_exit,
351            ),
352        })
353    }
354}
355
356impl MastForestContributor for JoinNodeBuilder {
357    fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
358        // Validate child node IDs
359        let forest_len = forest.nodes.len();
360        if self.children[0].to_usize() >= forest_len {
361            return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
362        } else if self.children[1].to_usize() >= forest_len {
363            return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
364        }
365
366        // Use the forced digest if provided, otherwise compute the digest
367        let digest = if let Some(forced_digest) = self.digest {
368            forced_digest
369        } else {
370            let left_child_hash = forest[self.children[0]].digest();
371            let right_child_hash = forest[self.children[1]].digest();
372
373            hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
374        };
375
376        // Determine the node ID that will be assigned
377        let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
378
379        // Store node-level decorators in the centralized NodeToDecoratorIds for efficient access
380        forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
381
382        // Create the node in the forest with Linked variant from the start
383        // Move the data directly without intermediate cloning
384        let node_id = forest
385            .nodes
386            .push(
387                JoinNode {
388                    children: self.children,
389                    digest,
390                    decorator_store: DecoratorStore::Linked { id: future_node_id },
391                }
392                .into(),
393            )
394            .map_err(|_| MastForestError::TooManyNodes)?;
395
396        Ok(node_id)
397    }
398
399    fn fingerprint_for_node(
400        &self,
401        forest: &MastForest,
402        hash_by_node_id: &impl LookupByIdx<MastNodeId, MastNodeFingerprint>,
403    ) -> Result<MastNodeFingerprint, MastForestError> {
404        // Use the fingerprint_from_parts helper function
405        crate::mast::node_fingerprint::fingerprint_from_parts(
406            forest,
407            hash_by_node_id,
408            &self.before_enter,
409            &self.after_exit,
410            &self.children,
411            // Use the forced digest if available, otherwise compute the digest
412            if let Some(forced_digest) = self.digest {
413                forced_digest
414            } else {
415                let left_child_hash = forest[self.children[0]].digest();
416                let right_child_hash = forest[self.children[1]].digest();
417
418                hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
419            },
420        )
421    }
422
423    fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
424        JoinNodeBuilder {
425            children: [
426                *remapping.get(self.children[0]).unwrap_or(&self.children[0]),
427                *remapping.get(self.children[1]).unwrap_or(&self.children[1]),
428            ],
429            before_enter: self.before_enter,
430            after_exit: self.after_exit,
431            digest: self.digest,
432        }
433    }
434
435    fn with_before_enter(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
436        self.before_enter = decorators.into();
437        self
438    }
439
440    fn with_after_exit(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
441        self.after_exit = decorators.into();
442        self
443    }
444
445    fn append_before_enter(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
446        self.before_enter.extend(decorators);
447    }
448
449    fn append_after_exit(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
450        self.after_exit.extend(decorators);
451    }
452
453    fn with_digest(mut self, digest: Word) -> Self {
454        self.digest = Some(digest);
455        self
456    }
457}
458
459impl JoinNodeBuilder {
460    /// Add this node to a forest using relaxed validation.
461    ///
462    /// This method is used during deserialization where nodes may reference child nodes
463    /// that haven't been added to the forest yet. The child node IDs have already been
464    /// validated against the expected final node count during the `try_into_mast_node_builder`
465    /// step, so we can safely skip validation here.
466    ///
467    /// Note: This is not part of the `MastForestContributor` trait because it's only
468    /// intended for internal use during deserialization.
469    pub(in crate::mast) fn add_to_forest_relaxed(
470        self,
471        forest: &mut MastForest,
472    ) -> Result<MastNodeId, MastForestError> {
473        // Use the forced digest if provided, otherwise use a default digest
474        // The actual digest computation will be handled when the forest is complete
475        let Some(digest) = self.digest else {
476            return Err(MastForestError::DigestRequiredForDeserialization);
477        };
478
479        let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
480
481        // Create the node in the forest with Linked variant from the start
482        // Move the data directly without intermediate cloning
483        let node_id = forest
484            .nodes
485            .push(
486                JoinNode {
487                    children: self.children,
488                    digest,
489                    decorator_store: DecoratorStore::Linked { id: future_node_id },
490                }
491                .into(),
492            )
493            .map_err(|_| MastForestError::TooManyNodes)?;
494
495        Ok(node_id)
496    }
497}
498
499#[cfg(any(test, feature = "arbitrary"))]
500impl proptest::prelude::Arbitrary for JoinNodeBuilder {
501    type Parameters = JoinNodeBuilderParams;
502    type Strategy = proptest::strategy::BoxedStrategy<Self>;
503
504    fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
505        use proptest::prelude::*;
506
507        (
508            any::<[MastNodeId; 2]>(),
509            proptest::collection::vec(
510                super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
511                0..=params.max_decorators,
512            ),
513            proptest::collection::vec(
514                super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
515                0..=params.max_decorators,
516            ),
517        )
518            .prop_map(|(children, before_enter, after_exit)| {
519                Self::new(children).with_before_enter(before_enter).with_after_exit(after_exit)
520            })
521            .boxed()
522    }
523}
524
525/// Parameters for generating JoinNodeBuilder instances
526#[cfg(any(test, feature = "arbitrary"))]
527#[derive(Clone, Debug)]
528pub struct JoinNodeBuilderParams {
529    pub max_decorators: usize,
530    pub max_decorator_id_u32: u32,
531}
532
533#[cfg(any(test, feature = "arbitrary"))]
534impl Default for JoinNodeBuilderParams {
535    fn default() -> Self {
536        Self {
537            max_decorators: 4,
538            max_decorator_id_u32: 10,
539        }
540    }
541}