Skip to main content

miden_core/mast/node/
join_node.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use super::{MastForestContributor, MastNodeExt};
8#[cfg(debug_assertions)]
9use crate::mast::MastNode;
10use crate::{
11    Felt, Word,
12    chiplets::hasher,
13    mast::{
14        DecoratorId, DecoratorStore, MastForest, MastForestError, MastNodeFingerprint, MastNodeId,
15    },
16    operations::opcodes,
17    prettier::PrettyPrint,
18    utils::{Idx, LookupByIdx},
19};
20
21// JOIN NODE
22// ================================================================================================
23
24/// A Join node describe sequential execution. When the VM encounters a Join node, it executes the
25/// first child first and the second child second.
26#[derive(Debug, Clone, PartialEq, Eq)]
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
29pub struct JoinNode {
30    children: [MastNodeId; 2],
31    digest: Word,
32    decorator_store: DecoratorStore,
33}
34
35/// Constants
36impl JoinNode {
37    /// The domain of the join block (used for control block hashing).
38    pub const DOMAIN: Felt = Felt::new(opcodes::JOIN as u64);
39}
40
41/// Public accessors
42impl JoinNode {
43    /// Returns the ID of the node that is to be executed first.
44    pub fn first(&self) -> MastNodeId {
45        self.children[0]
46    }
47
48    /// Returns the ID of the node that is to be executed after the execution of the program
49    /// defined by the first node completes.
50    pub fn second(&self) -> MastNodeId {
51        self.children[1]
52    }
53}
54
55// PRETTY PRINTING
56// ================================================================================================
57
58impl JoinNode {
59    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
60        JoinNodePrettyPrint { join_node: self, mast_forest }
61    }
62
63    pub(super) fn to_pretty_print<'a>(
64        &'a self,
65        mast_forest: &'a MastForest,
66    ) -> impl PrettyPrint + 'a {
67        JoinNodePrettyPrint { join_node: self, mast_forest }
68    }
69}
70
71struct JoinNodePrettyPrint<'a> {
72    join_node: &'a JoinNode,
73    mast_forest: &'a MastForest,
74}
75
76impl PrettyPrint for JoinNodePrettyPrint<'_> {
77    #[rustfmt::skip]
78    fn render(&self) -> crate::prettier::Document {
79        use crate::prettier::*;
80
81        let pre_decorators = {
82            let mut pre_decorators = self
83                .join_node
84                .before_enter(self.mast_forest)
85                .iter()
86                .map(|&decorator_id| self.mast_forest[decorator_id].render())
87                .reduce(|acc, doc| acc + const_text(" ") + doc)
88                .unwrap_or_default();
89            if !pre_decorators.is_empty() {
90                pre_decorators += nl();
91            }
92
93            pre_decorators
94        };
95
96        let post_decorators = {
97            let mut post_decorators = self
98                .join_node
99                .after_exit(self.mast_forest)
100                .iter()
101                .map(|&decorator_id| self.mast_forest[decorator_id].render())
102                .reduce(|acc, doc| acc + const_text(" ") + doc)
103                .unwrap_or_default();
104            if !post_decorators.is_empty() {
105                post_decorators = nl() + post_decorators;
106            }
107
108            post_decorators
109        };
110
111        let first_child =
112            self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
113        let second_child =
114            self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
115
116        pre_decorators
117        + indent(
118            4,
119            const_text("join")
120            + nl()
121            + first_child.render()
122            + nl()
123            + second_child.render(),
124        ) + nl() + const_text("end")
125        + post_decorators
126    }
127}
128
129impl fmt::Display for JoinNodePrettyPrint<'_> {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        use crate::prettier::PrettyPrint;
132        self.pretty_print(f)
133    }
134}
135
136// SEMANTIC EQUALITY (FOR TESTING)
137// ================================================================================================
138
139#[cfg(test)]
140impl JoinNode {
141    /// Checks if two JoinNodes are semantically equal (i.e., they represent the same join
142    /// operation).
143    ///
144    /// Unlike the derived PartialEq, this method works correctly with both owned and linked
145    /// decorator storage by accessing the actual decorator data from the forest when needed.
146    #[cfg(test)]
147    pub fn semantic_eq(&self, other: &JoinNode, forest: &MastForest) -> bool {
148        // Compare children
149        if self.first() != other.first() || self.second() != other.second() {
150            return false;
151        }
152
153        // Compare digests
154        if self.digest() != other.digest() {
155            return false;
156        }
157
158        // Compare before-enter decorators
159        if self.before_enter(forest) != other.before_enter(forest) {
160            return false;
161        }
162
163        // Compare after-exit decorators
164        if self.after_exit(forest) != other.after_exit(forest) {
165            return false;
166        }
167
168        true
169    }
170}
171
172// MAST NODE TRAIT IMPLEMENTATION
173// ================================================================================================
174
175impl MastNodeExt for JoinNode {
176    /// Returns a commitment to this Join node.
177    ///
178    /// The commitment is computed as a hash of the `first` and `second` child node in the domain
179    /// defined by [Self::DOMAIN] - i.e.,:
180    /// ```
181    /// # use miden_core::mast::JoinNode;
182    /// # use miden_crypto::{Word, hash::poseidon2::Poseidon2 as Hasher};
183    /// # let first_child_digest = Word::default();
184    /// # let second_child_digest = Word::default();
185    /// Hasher::merge_in_domain(&[first_child_digest, second_child_digest], JoinNode::DOMAIN);
186    /// ```
187    fn digest(&self) -> Word {
188        self.digest
189    }
190
191    /// Returns the decorators to be executed before this node is executed.
192    fn before_enter<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
193        #[cfg(debug_assertions)]
194        self.verify_node_in_forest(forest);
195        self.decorator_store.before_enter(forest)
196    }
197
198    /// Returns the decorators to be executed after this node is executed.
199    fn after_exit<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
200        #[cfg(debug_assertions)]
201        self.verify_node_in_forest(forest);
202        self.decorator_store.after_exit(forest)
203    }
204
205    fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
206        Box::new(JoinNode::to_display(self, mast_forest))
207    }
208
209    fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
210        Box::new(JoinNode::to_pretty_print(self, mast_forest))
211    }
212
213    fn has_children(&self) -> bool {
214        true
215    }
216
217    fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
218        target.push(self.first());
219        target.push(self.second());
220    }
221
222    fn for_each_child<F>(&self, mut f: F)
223    where
224        F: FnMut(MastNodeId),
225    {
226        f(self.first());
227        f(self.second());
228    }
229
230    fn domain(&self) -> Felt {
231        Self::DOMAIN
232    }
233
234    type Builder = JoinNodeBuilder;
235
236    fn to_builder(self, forest: &MastForest) -> Self::Builder {
237        // Extract decorators from decorator_store if in Owned state
238        match self.decorator_store {
239            DecoratorStore::Owned { before_enter, after_exit, .. } => {
240                let mut builder = JoinNodeBuilder::new(self.children);
241                builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
242                builder
243            },
244            DecoratorStore::Linked { id } => {
245                // Extract decorators from forest storage when in Linked state
246                let before_enter = forest.before_enter_decorators(id).to_vec();
247                let after_exit = forest.after_exit_decorators(id).to_vec();
248                let mut builder = JoinNodeBuilder::new(self.children);
249                builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
250                builder
251            },
252        }
253    }
254
255    #[cfg(debug_assertions)]
256    fn verify_node_in_forest(&self, forest: &MastForest) {
257        if let Some(id) = self.decorator_store.linked_id() {
258            // Verify that this node is the one stored at the given ID in the forest
259            let self_ptr = self as *const Self;
260            let forest_node = &forest.nodes[id];
261            let forest_node_ptr = match forest_node {
262                MastNode::Join(join_node) => join_node as *const JoinNode as *const (),
263                _ => panic!("Node type mismatch at {:?}", id),
264            };
265            let self_as_void = self_ptr as *const ();
266            debug_assert_eq!(
267                self_as_void, forest_node_ptr,
268                "Node pointer mismatch: expected node at {:?} to be self",
269                id
270            );
271        }
272    }
273}
274
275// ARBITRARY IMPLEMENTATION
276// ================================================================================================
277
278#[cfg(all(feature = "arbitrary", test))]
279impl proptest::prelude::Arbitrary for JoinNode {
280    type Parameters = ();
281
282    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
283        use proptest::prelude::*;
284
285        use crate::Felt;
286
287        // Generate two MastNodeId values and digest for the children
288        (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
289            .prop_map(|(first_child, second_child, digest_array)| {
290                // Generate a random digest
291                let digest = Word::from(digest_array.map(Felt::new));
292                // Construct directly to avoid MastForest validation for arbitrary data
293                JoinNode {
294                    children: [first_child, second_child],
295                    digest,
296                    decorator_store: DecoratorStore::default(),
297                }
298            })
299            .no_shrink()  // Pure random values, no meaningful shrinking pattern
300            .boxed()
301    }
302
303    type Strategy = proptest::prelude::BoxedStrategy<Self>;
304}
305
306// ------------------------------------------------------------------------------------------------
307/// Builder for creating [`JoinNode`] instances with decorators.
308#[derive(Debug)]
309pub struct JoinNodeBuilder {
310    children: [MastNodeId; 2],
311    before_enter: Vec<DecoratorId>,
312    after_exit: Vec<DecoratorId>,
313    digest: Option<Word>,
314}
315
316impl JoinNodeBuilder {
317    /// Creates a new builder for a JoinNode with the specified children.
318    pub fn new(children: [MastNodeId; 2]) -> Self {
319        Self {
320            children,
321            before_enter: Vec::new(),
322            after_exit: Vec::new(),
323            digest: None,
324        }
325    }
326
327    /// Builds the JoinNode with the specified decorators.
328    pub fn build(self, mast_forest: &MastForest) -> Result<JoinNode, MastForestError> {
329        let forest_len = mast_forest.nodes.len();
330        if self.children[0].to_usize() >= forest_len {
331            return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
332        } else if self.children[1].to_usize() >= forest_len {
333            return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
334        }
335
336        // Use the forced digest if provided, otherwise compute the digest
337        let digest = if let Some(forced_digest) = self.digest {
338            forced_digest
339        } else {
340            let left_child_hash = mast_forest[self.children[0]].digest();
341            let right_child_hash = mast_forest[self.children[1]].digest();
342
343            hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
344        };
345
346        Ok(JoinNode {
347            children: self.children,
348            digest,
349            decorator_store: DecoratorStore::new_owned_with_decorators(
350                self.before_enter,
351                self.after_exit,
352            ),
353        })
354    }
355}
356
357impl MastForestContributor for JoinNodeBuilder {
358    fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
359        // Validate child node IDs
360        let forest_len = forest.nodes.len();
361        if self.children[0].to_usize() >= forest_len {
362            return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
363        } else if self.children[1].to_usize() >= forest_len {
364            return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
365        }
366
367        // Use the forced digest if provided, otherwise compute the digest
368        let digest = if let Some(forced_digest) = self.digest {
369            forced_digest
370        } else {
371            let left_child_hash = forest[self.children[0]].digest();
372            let right_child_hash = forest[self.children[1]].digest();
373
374            hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
375        };
376
377        // Determine the node ID that will be assigned
378        let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
379
380        // Store node-level decorators in the centralized NodeToDecoratorIds for efficient access
381        forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
382
383        // Create the node in the forest with Linked variant from the start
384        // Move the data directly without intermediate cloning
385        let node_id = forest
386            .nodes
387            .push(
388                JoinNode {
389                    children: self.children,
390                    digest,
391                    decorator_store: DecoratorStore::Linked { id: future_node_id },
392                }
393                .into(),
394            )
395            .map_err(|_| MastForestError::TooManyNodes)?;
396
397        Ok(node_id)
398    }
399
400    fn fingerprint_for_node(
401        &self,
402        forest: &MastForest,
403        hash_by_node_id: &impl LookupByIdx<MastNodeId, MastNodeFingerprint>,
404    ) -> Result<MastNodeFingerprint, MastForestError> {
405        // Use the fingerprint_from_parts helper function
406        crate::mast::node_fingerprint::fingerprint_from_parts(
407            forest,
408            hash_by_node_id,
409            &self.before_enter,
410            &self.after_exit,
411            &self.children,
412            // Use the forced digest if available, otherwise compute the digest
413            if let Some(forced_digest) = self.digest {
414                forced_digest
415            } else {
416                let left_child_hash = forest[self.children[0]].digest();
417                let right_child_hash = forest[self.children[1]].digest();
418
419                crate::chiplets::hasher::merge_in_domain(
420                    &[left_child_hash, right_child_hash],
421                    JoinNode::DOMAIN,
422                )
423            },
424        )
425    }
426
427    fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
428        JoinNodeBuilder {
429            children: [
430                *remapping.get(self.children[0]).unwrap_or(&self.children[0]),
431                *remapping.get(self.children[1]).unwrap_or(&self.children[1]),
432            ],
433            before_enter: self.before_enter,
434            after_exit: self.after_exit,
435            digest: self.digest,
436        }
437    }
438
439    fn with_before_enter(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
440        self.before_enter = decorators.into();
441        self
442    }
443
444    fn with_after_exit(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
445        self.after_exit = decorators.into();
446        self
447    }
448
449    fn append_before_enter(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
450        self.before_enter.extend(decorators);
451    }
452
453    fn append_after_exit(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
454        self.after_exit.extend(decorators);
455    }
456
457    fn with_digest(mut self, digest: crate::Word) -> Self {
458        self.digest = Some(digest);
459        self
460    }
461}
462
463impl JoinNodeBuilder {
464    /// Add this node to a forest using relaxed validation.
465    ///
466    /// This method is used during deserialization where nodes may reference child nodes
467    /// that haven't been added to the forest yet. The child node IDs have already been
468    /// validated against the expected final node count during the `try_into_mast_node_builder`
469    /// step, so we can safely skip validation here.
470    ///
471    /// Note: This is not part of the `MastForestContributor` trait because it's only
472    /// intended for internal use during deserialization.
473    pub(in crate::mast) fn add_to_forest_relaxed(
474        self,
475        forest: &mut MastForest,
476    ) -> Result<MastNodeId, MastForestError> {
477        // Use the forced digest if provided, otherwise use a default digest
478        // The actual digest computation will be handled when the forest is complete
479        let Some(digest) = self.digest else {
480            return Err(MastForestError::DigestRequiredForDeserialization);
481        };
482
483        let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
484
485        // Create the node in the forest with Linked variant from the start
486        // Move the data directly without intermediate cloning
487        let node_id = forest
488            .nodes
489            .push(
490                JoinNode {
491                    children: self.children,
492                    digest,
493                    decorator_store: DecoratorStore::Linked { id: future_node_id },
494                }
495                .into(),
496            )
497            .map_err(|_| MastForestError::TooManyNodes)?;
498
499        Ok(node_id)
500    }
501}
502
503#[cfg(any(test, feature = "arbitrary"))]
504impl proptest::prelude::Arbitrary for JoinNodeBuilder {
505    type Parameters = JoinNodeBuilderParams;
506    type Strategy = proptest::strategy::BoxedStrategy<Self>;
507
508    fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
509        use proptest::prelude::*;
510
511        (
512            any::<[MastNodeId; 2]>(),
513            proptest::collection::vec(
514                super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
515                0..=params.max_decorators,
516            ),
517            proptest::collection::vec(
518                super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
519                0..=params.max_decorators,
520            ),
521        )
522            .prop_map(|(children, before_enter, after_exit)| {
523                Self::new(children).with_before_enter(before_enter).with_after_exit(after_exit)
524            })
525            .boxed()
526    }
527}
528
529/// Parameters for generating JoinNodeBuilder instances
530#[cfg(any(test, feature = "arbitrary"))]
531#[derive(Clone, Debug)]
532pub struct JoinNodeBuilderParams {
533    pub max_decorators: usize,
534    pub max_decorator_id_u32: u32,
535}
536
537#[cfg(any(test, feature = "arbitrary"))]
538impl Default for JoinNodeBuilderParams {
539    fn default() -> Self {
540        Self {
541            max_decorators: 4,
542            max_decorator_id_u32: 10,
543        }
544    }
545}