Skip to main content

miden_core/mast/node/
join_node.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use super::{MastForestContributor, MastNodeExt};
8use crate::{
9    Felt, Word,
10    chiplets::hasher,
11    mast::{
12        DecoratorId, DecoratorStore, MastForest, MastForestError, MastNode, MastNodeFingerprint,
13        MastNodeId,
14    },
15    operations::OPCODE_JOIN,
16    prettier::PrettyPrint,
17    utils::{Idx, LookupByIdx},
18};
19
20// JOIN NODE
21// ================================================================================================
22
23/// A Join node describe sequential execution. When the VM encounters a Join node, it executes the
24/// first child first and the second child second.
25#[derive(Debug, Clone, PartialEq, Eq)]
26#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
27#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
28pub struct JoinNode {
29    children: [MastNodeId; 2],
30    digest: Word,
31    decorator_store: DecoratorStore,
32}
33
34/// Constants
35impl JoinNode {
36    /// The domain of the join block (used for control block hashing).
37    pub const DOMAIN: Felt = Felt::new(OPCODE_JOIN as u64);
38}
39
40/// Public accessors
41impl JoinNode {
42    /// Returns the ID of the node that is to be executed first.
43    pub fn first(&self) -> MastNodeId {
44        self.children[0]
45    }
46
47    /// Returns the ID of the node that is to be executed after the execution of the program
48    /// defined by the first node completes.
49    pub fn second(&self) -> MastNodeId {
50        self.children[1]
51    }
52}
53
54// PRETTY PRINTING
55// ================================================================================================
56
57impl JoinNode {
58    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
59        JoinNodePrettyPrint { join_node: self, mast_forest }
60    }
61
62    pub(super) fn to_pretty_print<'a>(
63        &'a self,
64        mast_forest: &'a MastForest,
65    ) -> impl PrettyPrint + 'a {
66        JoinNodePrettyPrint { join_node: self, mast_forest }
67    }
68}
69
70struct JoinNodePrettyPrint<'a> {
71    join_node: &'a JoinNode,
72    mast_forest: &'a MastForest,
73}
74
75impl PrettyPrint for JoinNodePrettyPrint<'_> {
76    #[rustfmt::skip]
77    fn render(&self) -> crate::prettier::Document {
78        use crate::prettier::*;
79
80        let pre_decorators = {
81            let mut pre_decorators = self
82                .join_node
83                .before_enter(self.mast_forest)
84                .iter()
85                .map(|&decorator_id| self.mast_forest[decorator_id].render())
86                .reduce(|acc, doc| acc + const_text(" ") + doc)
87                .unwrap_or_default();
88            if !pre_decorators.is_empty() {
89                pre_decorators += nl();
90            }
91
92            pre_decorators
93        };
94
95        let post_decorators = {
96            let mut post_decorators = self
97                .join_node
98                .after_exit(self.mast_forest)
99                .iter()
100                .map(|&decorator_id| self.mast_forest[decorator_id].render())
101                .reduce(|acc, doc| acc + const_text(" ") + doc)
102                .unwrap_or_default();
103            if !post_decorators.is_empty() {
104                post_decorators = nl() + post_decorators;
105            }
106
107            post_decorators
108        };
109
110        let first_child =
111            self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
112        let second_child =
113            self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
114
115        pre_decorators
116        + indent(
117            4,
118            const_text("join")
119            + nl()
120            + first_child.render()
121            + nl()
122            + second_child.render(),
123        ) + nl() + const_text("end")
124        + post_decorators
125    }
126}
127
128impl fmt::Display for JoinNodePrettyPrint<'_> {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        use crate::prettier::PrettyPrint;
131        self.pretty_print(f)
132    }
133}
134
135// SEMANTIC EQUALITY (FOR TESTING)
136// ================================================================================================
137
138#[cfg(test)]
139impl JoinNode {
140    /// Checks if two JoinNodes are semantically equal (i.e., they represent the same join
141    /// operation).
142    ///
143    /// Unlike the derived PartialEq, this method works correctly with both owned and linked
144    /// decorator storage by accessing the actual decorator data from the forest when needed.
145    #[cfg(test)]
146    pub fn semantic_eq(&self, other: &JoinNode, forest: &MastForest) -> bool {
147        // Compare children
148        if self.first() != other.first() || self.second() != other.second() {
149            return false;
150        }
151
152        // Compare digests
153        if self.digest() != other.digest() {
154            return false;
155        }
156
157        // Compare before-enter decorators
158        if self.before_enter(forest) != other.before_enter(forest) {
159            return false;
160        }
161
162        // Compare after-exit decorators
163        if self.after_exit(forest) != other.after_exit(forest) {
164            return false;
165        }
166
167        true
168    }
169}
170
171// MAST NODE TRAIT IMPLEMENTATION
172// ================================================================================================
173
174impl MastNodeExt for JoinNode {
175    /// Returns a commitment to this Join node.
176    ///
177    /// The commitment is computed as a hash of the `first` and `second` child node in the domain
178    /// defined by [Self::DOMAIN] - i.e.,:
179    /// ```
180    /// # use miden_core::mast::JoinNode;
181    /// # use miden_crypto::{Word, hash::poseidon2::Poseidon2 as Hasher};
182    /// # let first_child_digest = Word::default();
183    /// # let second_child_digest = Word::default();
184    /// Hasher::merge_in_domain(&[first_child_digest, second_child_digest], JoinNode::DOMAIN);
185    /// ```
186    fn digest(&self) -> Word {
187        self.digest
188    }
189
190    /// Returns the decorators to be executed before this node is executed.
191    fn before_enter<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
192        #[cfg(debug_assertions)]
193        self.verify_node_in_forest(forest);
194        self.decorator_store.before_enter(forest)
195    }
196
197    /// Returns the decorators to be executed after this node is executed.
198    fn after_exit<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
199        #[cfg(debug_assertions)]
200        self.verify_node_in_forest(forest);
201        self.decorator_store.after_exit(forest)
202    }
203
204    fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
205        Box::new(JoinNode::to_display(self, mast_forest))
206    }
207
208    fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
209        Box::new(JoinNode::to_pretty_print(self, mast_forest))
210    }
211
212    fn has_children(&self) -> bool {
213        true
214    }
215
216    fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
217        target.push(self.first());
218        target.push(self.second());
219    }
220
221    fn for_each_child<F>(&self, mut f: F)
222    where
223        F: FnMut(MastNodeId),
224    {
225        f(self.first());
226        f(self.second());
227    }
228
229    fn domain(&self) -> Felt {
230        Self::DOMAIN
231    }
232
233    type Builder = JoinNodeBuilder;
234
235    fn to_builder(self, forest: &MastForest) -> Self::Builder {
236        // Extract decorators from decorator_store if in Owned state
237        match self.decorator_store {
238            DecoratorStore::Owned { before_enter, after_exit, .. } => {
239                let mut builder = JoinNodeBuilder::new(self.children);
240                builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
241                builder
242            },
243            DecoratorStore::Linked { id } => {
244                // Extract decorators from forest storage when in Linked state
245                let before_enter = forest.before_enter_decorators(id).to_vec();
246                let after_exit = forest.after_exit_decorators(id).to_vec();
247                let mut builder = JoinNodeBuilder::new(self.children);
248                builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
249                builder
250            },
251        }
252    }
253
254    #[cfg(debug_assertions)]
255    fn verify_node_in_forest(&self, forest: &MastForest) {
256        if let Some(id) = self.decorator_store.linked_id() {
257            // Verify that this node is the one stored at the given ID in the forest
258            let self_ptr = self as *const Self;
259            let forest_node = &forest.nodes[id];
260            let forest_node_ptr = match forest_node {
261                MastNode::Join(join_node) => join_node as *const JoinNode as *const (),
262                _ => panic!("Node type mismatch at {:?}", id),
263            };
264            let self_as_void = self_ptr as *const ();
265            debug_assert_eq!(
266                self_as_void, forest_node_ptr,
267                "Node pointer mismatch: expected node at {:?} to be self",
268                id
269            );
270        }
271    }
272}
273
274// ARBITRARY IMPLEMENTATION
275// ================================================================================================
276
277#[cfg(all(feature = "arbitrary", test))]
278impl proptest::prelude::Arbitrary for JoinNode {
279    type Parameters = ();
280
281    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
282        use proptest::prelude::*;
283
284        use crate::Felt;
285
286        // Generate two MastNodeId values and digest for the children
287        (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
288            .prop_map(|(first_child, second_child, digest_array)| {
289                // Generate a random digest
290                let digest = Word::from(digest_array.map(Felt::new));
291                // Construct directly to avoid MastForest validation for arbitrary data
292                JoinNode {
293                    children: [first_child, second_child],
294                    digest,
295                    decorator_store: DecoratorStore::default(),
296                }
297            })
298            .no_shrink()  // Pure random values, no meaningful shrinking pattern
299            .boxed()
300    }
301
302    type Strategy = proptest::prelude::BoxedStrategy<Self>;
303}
304
305// ------------------------------------------------------------------------------------------------
306/// Builder for creating [`JoinNode`] instances with decorators.
307#[derive(Debug)]
308pub struct JoinNodeBuilder {
309    children: [MastNodeId; 2],
310    before_enter: Vec<DecoratorId>,
311    after_exit: Vec<DecoratorId>,
312    digest: Option<Word>,
313}
314
315impl JoinNodeBuilder {
316    /// Creates a new builder for a JoinNode with the specified children.
317    pub fn new(children: [MastNodeId; 2]) -> Self {
318        Self {
319            children,
320            before_enter: Vec::new(),
321            after_exit: Vec::new(),
322            digest: None,
323        }
324    }
325
326    /// Builds the JoinNode with the specified decorators.
327    pub fn build(self, mast_forest: &MastForest) -> Result<JoinNode, MastForestError> {
328        let forest_len = mast_forest.nodes.len();
329        if self.children[0].to_usize() >= forest_len {
330            return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
331        } else if self.children[1].to_usize() >= forest_len {
332            return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
333        }
334
335        // Use the forced digest if provided, otherwise compute the digest
336        let digest = if let Some(forced_digest) = self.digest {
337            forced_digest
338        } else {
339            let left_child_hash = mast_forest[self.children[0]].digest();
340            let right_child_hash = mast_forest[self.children[1]].digest();
341
342            hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
343        };
344
345        Ok(JoinNode {
346            children: self.children,
347            digest,
348            decorator_store: DecoratorStore::new_owned_with_decorators(
349                self.before_enter,
350                self.after_exit,
351            ),
352        })
353    }
354}
355
356impl MastForestContributor for JoinNodeBuilder {
357    fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
358        // Validate child node IDs
359        let forest_len = forest.nodes.len();
360        if self.children[0].to_usize() >= forest_len {
361            return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
362        } else if self.children[1].to_usize() >= forest_len {
363            return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
364        }
365
366        // Use the forced digest if provided, otherwise compute the digest
367        let digest = if let Some(forced_digest) = self.digest {
368            forced_digest
369        } else {
370            let left_child_hash = forest[self.children[0]].digest();
371            let right_child_hash = forest[self.children[1]].digest();
372
373            hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
374        };
375
376        // Determine the node ID that will be assigned
377        let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
378
379        // Store node-level decorators in the centralized NodeToDecoratorIds for efficient access
380        forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
381
382        // Create the node in the forest with Linked variant from the start
383        // Move the data directly without intermediate cloning
384        let node_id = forest
385            .nodes
386            .push(
387                JoinNode {
388                    children: self.children,
389                    digest,
390                    decorator_store: DecoratorStore::Linked { id: future_node_id },
391                }
392                .into(),
393            )
394            .map_err(|_| MastForestError::TooManyNodes)?;
395
396        Ok(node_id)
397    }
398
399    fn fingerprint_for_node(
400        &self,
401        forest: &MastForest,
402        hash_by_node_id: &impl LookupByIdx<MastNodeId, MastNodeFingerprint>,
403    ) -> Result<MastNodeFingerprint, MastForestError> {
404        // Use the fingerprint_from_parts helper function
405        crate::mast::node_fingerprint::fingerprint_from_parts(
406            forest,
407            hash_by_node_id,
408            &self.before_enter,
409            &self.after_exit,
410            &self.children,
411            // Use the forced digest if available, otherwise compute the digest
412            if let Some(forced_digest) = self.digest {
413                forced_digest
414            } else {
415                let left_child_hash = forest[self.children[0]].digest();
416                let right_child_hash = forest[self.children[1]].digest();
417
418                crate::chiplets::hasher::merge_in_domain(
419                    &[left_child_hash, right_child_hash],
420                    JoinNode::DOMAIN,
421                )
422            },
423        )
424    }
425
426    fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
427        JoinNodeBuilder {
428            children: [
429                *remapping.get(self.children[0]).unwrap_or(&self.children[0]),
430                *remapping.get(self.children[1]).unwrap_or(&self.children[1]),
431            ],
432            before_enter: self.before_enter,
433            after_exit: self.after_exit,
434            digest: self.digest,
435        }
436    }
437
438    fn with_before_enter(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
439        self.before_enter = decorators.into();
440        self
441    }
442
443    fn with_after_exit(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
444        self.after_exit = decorators.into();
445        self
446    }
447
448    fn append_before_enter(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
449        self.before_enter.extend(decorators);
450    }
451
452    fn append_after_exit(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
453        self.after_exit.extend(decorators);
454    }
455
456    fn with_digest(mut self, digest: crate::Word) -> Self {
457        self.digest = Some(digest);
458        self
459    }
460}
461
462impl JoinNodeBuilder {
463    /// Add this node to a forest using relaxed validation.
464    ///
465    /// This method is used during deserialization where nodes may reference child nodes
466    /// that haven't been added to the forest yet. The child node IDs have already been
467    /// validated against the expected final node count during the `try_into_mast_node_builder`
468    /// step, so we can safely skip validation here.
469    ///
470    /// Note: This is not part of the `MastForestContributor` trait because it's only
471    /// intended for internal use during deserialization.
472    pub(in crate::mast) fn add_to_forest_relaxed(
473        self,
474        forest: &mut MastForest,
475    ) -> Result<MastNodeId, MastForestError> {
476        // Use the forced digest if provided, otherwise use a default digest
477        // The actual digest computation will be handled when the forest is complete
478        let Some(digest) = self.digest else {
479            return Err(MastForestError::DigestRequiredForDeserialization);
480        };
481
482        let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
483
484        // Create the node in the forest with Linked variant from the start
485        // Move the data directly without intermediate cloning
486        let node_id = forest
487            .nodes
488            .push(
489                JoinNode {
490                    children: self.children,
491                    digest,
492                    decorator_store: DecoratorStore::Linked { id: future_node_id },
493                }
494                .into(),
495            )
496            .map_err(|_| MastForestError::TooManyNodes)?;
497
498        Ok(node_id)
499    }
500}
501
502#[cfg(any(test, feature = "arbitrary"))]
503impl proptest::prelude::Arbitrary for JoinNodeBuilder {
504    type Parameters = JoinNodeBuilderParams;
505    type Strategy = proptest::strategy::BoxedStrategy<Self>;
506
507    fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
508        use proptest::prelude::*;
509
510        (
511            any::<[MastNodeId; 2]>(),
512            proptest::collection::vec(
513                super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
514                0..=params.max_decorators,
515            ),
516            proptest::collection::vec(
517                super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
518                0..=params.max_decorators,
519            ),
520        )
521            .prop_map(|(children, before_enter, after_exit)| {
522                Self::new(children).with_before_enter(before_enter).with_after_exit(after_exit)
523            })
524            .boxed()
525    }
526}
527
528/// Parameters for generating JoinNodeBuilder instances
529#[cfg(any(test, feature = "arbitrary"))]
530#[derive(Clone, Debug)]
531pub struct JoinNodeBuilderParams {
532    pub max_decorators: usize,
533    pub max_decorator_id_u32: u32,
534}
535
536#[cfg(any(test, feature = "arbitrary"))]
537impl Default for JoinNodeBuilderParams {
538    fn default() -> Self {
539        Self {
540            max_decorators: 4,
541            max_decorator_id_u32: 10,
542        }
543    }
544}