miden_core/mast/node/
join_node.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4use miden_crypto::{Felt, Word};
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7
8use super::{MastForestContributor, MastNodeErrorContext, MastNodeExt};
9use crate::{
10    Idx, OPCODE_JOIN,
11    chiplets::hasher,
12    mast::{
13        DecoratedOpLink, DecoratorId, DecoratorStore, MastForest, MastForestError, MastNode,
14        MastNodeId,
15    },
16    prettier::PrettyPrint,
17};
18
19// JOIN NODE
20// ================================================================================================
21
22/// A Join node describe sequential execution. When the VM encounters a Join node, it executes the
23/// first child first and the second child second.
24#[derive(Debug, Clone, PartialEq, Eq)]
25#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
27pub struct JoinNode {
28    children: [MastNodeId; 2],
29    digest: Word,
30    decorator_store: DecoratorStore,
31}
32
33/// Constants
34impl JoinNode {
35    /// The domain of the join block (used for control block hashing).
36    pub const DOMAIN: Felt = Felt::new(OPCODE_JOIN as u64);
37}
38
39/// Public accessors
40impl JoinNode {
41    /// Returns the ID of the node that is to be executed first.
42    pub fn first(&self) -> MastNodeId {
43        self.children[0]
44    }
45
46    /// Returns the ID of the node that is to be executed after the execution of the program
47    /// defined by the first node completes.
48    pub fn second(&self) -> MastNodeId {
49        self.children[1]
50    }
51}
52
53impl MastNodeErrorContext for JoinNode {
54    fn decorators<'a>(
55        &'a self,
56        forest: &'a MastForest,
57    ) -> impl Iterator<Item = DecoratedOpLink> + 'a {
58        // Use the decorator_store for efficient O(1) decorator access
59        let before_enter = self.decorator_store.before_enter(forest);
60        let after_exit = self.decorator_store.after_exit(forest);
61
62        // Convert decorators to DecoratedOpLink tuples
63        before_enter
64            .iter()
65            .map(|&deco_id| (0, deco_id))
66            .chain(after_exit.iter().map(|&deco_id| (1, deco_id)))
67    }
68}
69
70// PRETTY PRINTING
71// ================================================================================================
72
73impl JoinNode {
74    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
75        JoinNodePrettyPrint { join_node: self, mast_forest }
76    }
77
78    pub(super) fn to_pretty_print<'a>(
79        &'a self,
80        mast_forest: &'a MastForest,
81    ) -> impl PrettyPrint + 'a {
82        JoinNodePrettyPrint { join_node: self, mast_forest }
83    }
84}
85
86struct JoinNodePrettyPrint<'a> {
87    join_node: &'a JoinNode,
88    mast_forest: &'a MastForest,
89}
90
91impl PrettyPrint for JoinNodePrettyPrint<'_> {
92    #[rustfmt::skip]
93    fn render(&self) -> crate::prettier::Document {
94        use crate::prettier::*;
95
96        let pre_decorators = {
97            let mut pre_decorators = self
98                .join_node
99                .before_enter(self.mast_forest)
100                .iter()
101                .map(|&decorator_id| self.mast_forest[decorator_id].render())
102                .reduce(|acc, doc| acc + const_text(" ") + doc)
103                .unwrap_or_default();
104            if !pre_decorators.is_empty() {
105                pre_decorators += nl();
106            }
107
108            pre_decorators
109        };
110
111        let post_decorators = {
112            let mut post_decorators = self
113                .join_node
114                .after_exit(self.mast_forest)
115                .iter()
116                .map(|&decorator_id| self.mast_forest[decorator_id].render())
117                .reduce(|acc, doc| acc + const_text(" ") + doc)
118                .unwrap_or_default();
119            if !post_decorators.is_empty() {
120                post_decorators = nl() + post_decorators;
121            }
122
123            post_decorators
124        };
125
126        let first_child =
127            self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
128        let second_child =
129            self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
130
131        pre_decorators
132        + indent(
133            4,
134            const_text("join")
135            + nl()
136            + first_child.render()
137            + nl()
138            + second_child.render(),
139        ) + nl() + const_text("end")
140        + post_decorators
141    }
142}
143
144impl fmt::Display for JoinNodePrettyPrint<'_> {
145    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146        use crate::prettier::PrettyPrint;
147        self.pretty_print(f)
148    }
149}
150
151// SEMANTIC EQUALITY (FOR TESTING)
152// ================================================================================================
153
154#[cfg(test)]
155impl JoinNode {
156    /// Checks if two JoinNodes are semantically equal (i.e., they represent the same join
157    /// operation).
158    ///
159    /// Unlike the derived PartialEq, this method works correctly with both owned and linked
160    /// decorator storage by accessing the actual decorator data from the forest when needed.
161    #[cfg(test)]
162    pub fn semantic_eq(&self, other: &JoinNode, forest: &MastForest) -> bool {
163        // Compare children
164        if self.first() != other.first() || self.second() != other.second() {
165            return false;
166        }
167
168        // Compare digests
169        if self.digest() != other.digest() {
170            return false;
171        }
172
173        // Compare before-enter decorators
174        if self.before_enter(forest) != other.before_enter(forest) {
175            return false;
176        }
177
178        // Compare after-exit decorators
179        if self.after_exit(forest) != other.after_exit(forest) {
180            return false;
181        }
182
183        true
184    }
185}
186
187// MAST NODE TRAIT IMPLEMENTATION
188// ================================================================================================
189
190impl MastNodeExt for JoinNode {
191    /// Returns a commitment to this Join node.
192    ///
193    /// The commitment is computed as a hash of the `first` and `second` child node in the domain
194    /// defined by [Self::DOMAIN] - i.e.,:
195    /// ```
196    /// # use miden_core::mast::JoinNode;
197    /// # use miden_crypto::{Word, hash::rpo::Rpo256 as Hasher};
198    /// # let first_child_digest = Word::default();
199    /// # let second_child_digest = Word::default();
200    /// Hasher::merge_in_domain(&[first_child_digest, second_child_digest], JoinNode::DOMAIN);
201    /// ```
202    fn digest(&self) -> Word {
203        self.digest
204    }
205
206    /// Returns the decorators to be executed before this node is executed.
207    fn before_enter<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
208        #[cfg(debug_assertions)]
209        self.verify_node_in_forest(forest);
210        self.decorator_store.before_enter(forest)
211    }
212
213    /// Returns the decorators to be executed after this node is executed.
214    fn after_exit<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
215        #[cfg(debug_assertions)]
216        self.verify_node_in_forest(forest);
217        self.decorator_store.after_exit(forest)
218    }
219
220    fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
221        Box::new(JoinNode::to_display(self, mast_forest))
222    }
223
224    fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
225        Box::new(JoinNode::to_pretty_print(self, mast_forest))
226    }
227
228    fn has_children(&self) -> bool {
229        true
230    }
231
232    fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
233        target.push(self.first());
234        target.push(self.second());
235    }
236
237    fn for_each_child<F>(&self, mut f: F)
238    where
239        F: FnMut(MastNodeId),
240    {
241        f(self.first());
242        f(self.second());
243    }
244
245    fn domain(&self) -> Felt {
246        Self::DOMAIN
247    }
248
249    type Builder = JoinNodeBuilder;
250
251    fn to_builder(self, forest: &MastForest) -> Self::Builder {
252        // Extract decorators from decorator_store if in Owned state
253        match self.decorator_store {
254            DecoratorStore::Owned { before_enter, after_exit, .. } => {
255                let mut builder = JoinNodeBuilder::new(self.children);
256                builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
257                builder
258            },
259            DecoratorStore::Linked { id } => {
260                // Extract decorators from forest storage when in Linked state
261                let before_enter = forest.before_enter_decorators(id).to_vec();
262                let after_exit = forest.after_exit_decorators(id).to_vec();
263                let mut builder = JoinNodeBuilder::new(self.children);
264                builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
265                builder
266            },
267        }
268    }
269
270    #[cfg(debug_assertions)]
271    fn verify_node_in_forest(&self, forest: &MastForest) {
272        if let Some(id) = self.decorator_store.linked_id() {
273            // Verify that this node is the one stored at the given ID in the forest
274            let self_ptr = self as *const Self;
275            let forest_node = &forest.nodes[id];
276            let forest_node_ptr = match forest_node {
277                MastNode::Join(join_node) => join_node as *const JoinNode as *const (),
278                _ => panic!("Node type mismatch at {:?}", id),
279            };
280            let self_as_void = self_ptr as *const ();
281            debug_assert_eq!(
282                self_as_void, forest_node_ptr,
283                "Node pointer mismatch: expected node at {:?} to be self",
284                id
285            );
286        }
287    }
288}
289
290// ARBITRARY IMPLEMENTATION
291// ================================================================================================
292
293#[cfg(all(feature = "arbitrary", test))]
294impl proptest::prelude::Arbitrary for JoinNode {
295    type Parameters = ();
296
297    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
298        use proptest::prelude::*;
299
300        use crate::Felt;
301
302        // Generate two MastNodeId values and digest for the children
303        (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
304            .prop_map(|(first_child, second_child, digest_array)| {
305                // Generate a random digest
306                let digest = Word::from(digest_array.map(Felt::new));
307                // Construct directly to avoid MastForest validation for arbitrary data
308                JoinNode {
309                    children: [first_child, second_child],
310                    digest,
311                    decorator_store: DecoratorStore::default(),
312                }
313            })
314            .no_shrink()  // Pure random values, no meaningful shrinking pattern
315            .boxed()
316    }
317
318    type Strategy = proptest::prelude::BoxedStrategy<Self>;
319}
320
321// ------------------------------------------------------------------------------------------------
322/// Builder for creating [`JoinNode`] instances with decorators.
323#[derive(Debug)]
324pub struct JoinNodeBuilder {
325    children: [MastNodeId; 2],
326    before_enter: Vec<DecoratorId>,
327    after_exit: Vec<DecoratorId>,
328    digest: Option<Word>,
329}
330
331impl JoinNodeBuilder {
332    /// Creates a new builder for a JoinNode with the specified children.
333    pub fn new(children: [MastNodeId; 2]) -> Self {
334        Self {
335            children,
336            before_enter: Vec::new(),
337            after_exit: Vec::new(),
338            digest: None,
339        }
340    }
341
342    /// Builds the JoinNode with the specified decorators.
343    pub fn build(self, mast_forest: &MastForest) -> Result<JoinNode, MastForestError> {
344        let forest_len = mast_forest.nodes.len();
345        if self.children[0].to_usize() >= forest_len {
346            return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
347        } else if self.children[1].to_usize() >= forest_len {
348            return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
349        }
350
351        // Use the forced digest if provided, otherwise compute the digest
352        let digest = if let Some(forced_digest) = self.digest {
353            forced_digest
354        } else {
355            let left_child_hash = mast_forest[self.children[0]].digest();
356            let right_child_hash = mast_forest[self.children[1]].digest();
357
358            hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
359        };
360
361        Ok(JoinNode {
362            children: self.children,
363            digest,
364            decorator_store: DecoratorStore::new_owned_with_decorators(
365                self.before_enter,
366                self.after_exit,
367            ),
368        })
369    }
370}
371
372impl MastForestContributor for JoinNodeBuilder {
373    fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
374        // Validate child node IDs
375        let forest_len = forest.nodes.len();
376        if self.children[0].to_usize() >= forest_len {
377            return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
378        } else if self.children[1].to_usize() >= forest_len {
379            return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
380        }
381
382        // Use the forced digest if provided, otherwise compute the digest
383        let digest = if let Some(forced_digest) = self.digest {
384            forced_digest
385        } else {
386            let left_child_hash = forest[self.children[0]].digest();
387            let right_child_hash = forest[self.children[1]].digest();
388
389            hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
390        };
391
392        // Determine the node ID that will be assigned
393        let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
394
395        // Store node-level decorators in the centralized NodeToDecoratorIds for efficient access
396        forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
397
398        // Create the node in the forest with Linked variant from the start
399        // Move the data directly without intermediate cloning
400        let node_id = forest
401            .nodes
402            .push(
403                JoinNode {
404                    children: self.children,
405                    digest,
406                    decorator_store: DecoratorStore::Linked { id: future_node_id },
407                }
408                .into(),
409            )
410            .map_err(|_| MastForestError::TooManyNodes)?;
411
412        Ok(node_id)
413    }
414
415    fn fingerprint_for_node(
416        &self,
417        forest: &MastForest,
418        hash_by_node_id: &impl crate::LookupByIdx<MastNodeId, crate::mast::MastNodeFingerprint>,
419    ) -> Result<crate::mast::MastNodeFingerprint, MastForestError> {
420        // Use the fingerprint_from_parts helper function
421        crate::mast::node_fingerprint::fingerprint_from_parts(
422            forest,
423            hash_by_node_id,
424            &self.before_enter,
425            &self.after_exit,
426            &self.children,
427            // Use the forced digest if available, otherwise compute the digest
428            if let Some(forced_digest) = self.digest {
429                forced_digest
430            } else {
431                let left_child_hash = forest[self.children[0]].digest();
432                let right_child_hash = forest[self.children[1]].digest();
433
434                crate::chiplets::hasher::merge_in_domain(
435                    &[left_child_hash, right_child_hash],
436                    JoinNode::DOMAIN,
437                )
438            },
439        )
440    }
441
442    fn remap_children(
443        self,
444        remapping: &impl crate::LookupByIdx<crate::mast::MastNodeId, crate::mast::MastNodeId>,
445    ) -> Self {
446        JoinNodeBuilder {
447            children: [
448                *remapping.get(self.children[0]).unwrap_or(&self.children[0]),
449                *remapping.get(self.children[1]).unwrap_or(&self.children[1]),
450            ],
451            before_enter: self.before_enter,
452            after_exit: self.after_exit,
453            digest: self.digest,
454        }
455    }
456
457    fn with_before_enter(mut self, decorators: impl Into<Vec<crate::mast::DecoratorId>>) -> Self {
458        self.before_enter = decorators.into();
459        self
460    }
461
462    fn with_after_exit(mut self, decorators: impl Into<Vec<crate::mast::DecoratorId>>) -> Self {
463        self.after_exit = decorators.into();
464        self
465    }
466
467    fn append_before_enter(
468        &mut self,
469        decorators: impl IntoIterator<Item = crate::mast::DecoratorId>,
470    ) {
471        self.before_enter.extend(decorators);
472    }
473
474    fn append_after_exit(
475        &mut self,
476        decorators: impl IntoIterator<Item = crate::mast::DecoratorId>,
477    ) {
478        self.after_exit.extend(decorators);
479    }
480
481    fn with_digest(mut self, digest: crate::Word) -> Self {
482        self.digest = Some(digest);
483        self
484    }
485}
486
487impl JoinNodeBuilder {
488    /// Add this node to a forest using relaxed validation.
489    ///
490    /// This method is used during deserialization where nodes may reference child nodes
491    /// that haven't been added to the forest yet. The child node IDs have already been
492    /// validated against the expected final node count during the `try_into_mast_node_builder`
493    /// step, so we can safely skip validation here.
494    ///
495    /// Note: This is not part of the `MastForestContributor` trait because it's only
496    /// intended for internal use during deserialization.
497    pub(in crate::mast) fn add_to_forest_relaxed(
498        self,
499        forest: &mut MastForest,
500    ) -> Result<MastNodeId, MastForestError> {
501        // Use the forced digest if provided, otherwise use a default digest
502        // The actual digest computation will be handled when the forest is complete
503        let Some(digest) = self.digest else {
504            return Err(MastForestError::DigestRequiredForDeserialization);
505        };
506
507        let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
508
509        // Store node-level decorators in the centralized NodeToDecoratorIds for efficient access
510        forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
511
512        // Create the node in the forest with Linked variant from the start
513        // Move the data directly without intermediate cloning
514        let node_id = forest
515            .nodes
516            .push(
517                JoinNode {
518                    children: self.children,
519                    digest,
520                    decorator_store: DecoratorStore::Linked { id: future_node_id },
521                }
522                .into(),
523            )
524            .map_err(|_| MastForestError::TooManyNodes)?;
525
526        Ok(node_id)
527    }
528}
529
530#[cfg(any(test, feature = "arbitrary"))]
531impl proptest::prelude::Arbitrary for JoinNodeBuilder {
532    type Parameters = JoinNodeBuilderParams;
533    type Strategy = proptest::strategy::BoxedStrategy<Self>;
534
535    fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
536        use proptest::prelude::*;
537
538        (
539            any::<[crate::mast::MastNodeId; 2]>(),
540            proptest::collection::vec(
541                super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
542                0..=params.max_decorators,
543            ),
544            proptest::collection::vec(
545                super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
546                0..=params.max_decorators,
547            ),
548        )
549            .prop_map(|(children, before_enter, after_exit)| {
550                Self::new(children).with_before_enter(before_enter).with_after_exit(after_exit)
551            })
552            .boxed()
553    }
554}
555
556/// Parameters for generating JoinNodeBuilder instances
557#[cfg(any(test, feature = "arbitrary"))]
558#[derive(Clone, Debug)]
559pub struct JoinNodeBuilderParams {
560    pub max_decorators: usize,
561    pub max_decorator_id_u32: u32,
562}
563
564#[cfg(any(test, feature = "arbitrary"))]
565impl Default for JoinNodeBuilderParams {
566    fn default() -> Self {
567        Self {
568            max_decorators: 4,
569            max_decorator_id_u32: 10,
570        }
571    }
572}