Skip to main content

miden_core/mast/node/
join_node.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use super::{MastForestContributor, MastNodeExt, fingerprint_with_child_fingerprints};
8use crate::{
9    Felt, Word,
10    chiplets::hasher,
11    mast::{MastForest, MastForestError, MastNodeId},
12    operations::opcodes,
13    prettier::PrettyPrint,
14    utils::{Idx, LookupByIdx},
15};
16
17// JOIN NODE
18// ================================================================================================
19
20/// A Join node describe sequential execution. When the VM encounters a Join node, it executes the
21/// first child first and the second child second.
22#[derive(Debug, Clone, PartialEq, Eq)]
23#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
25pub struct JoinNode {
26    children: [MastNodeId; 2],
27    digest: Word,
28}
29
30/// Constants
31impl JoinNode {
32    /// The domain of the join block (used for control block hashing).
33    pub const DOMAIN: Felt = Felt::new_unchecked(opcodes::JOIN as u64);
34}
35
36/// Public accessors
37impl JoinNode {
38    /// Returns the ID of the node that is to be executed first.
39    pub fn first(&self) -> MastNodeId {
40        self.children[0]
41    }
42
43    /// Returns the ID of the node that is to be executed after the execution of the program
44    /// defined by the first node completes.
45    pub fn second(&self) -> MastNodeId {
46        self.children[1]
47    }
48}
49
50// PRETTY PRINTING
51// ================================================================================================
52
53impl JoinNode {
54    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
55        JoinNodePrettyPrint { join_node: self, mast_forest }
56    }
57
58    pub(super) fn to_pretty_print<'a>(
59        &'a self,
60        mast_forest: &'a MastForest,
61    ) -> impl PrettyPrint + 'a {
62        JoinNodePrettyPrint { join_node: self, mast_forest }
63    }
64}
65
66struct JoinNodePrettyPrint<'a> {
67    join_node: &'a JoinNode,
68    mast_forest: &'a MastForest,
69}
70
71impl PrettyPrint for JoinNodePrettyPrint<'_> {
72    #[rustfmt::skip]
73    fn render(&self) -> crate::prettier::Document {
74        use crate::prettier::*;
75
76        let first_child =
77            self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
78        let second_child =
79            self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
80
81        indent(
82            4,
83            const_text("join")
84            + nl()
85            + first_child.render()
86            + nl()
87            + second_child.render(),
88        ) + nl() + const_text("end")
89    }
90}
91
92impl fmt::Display for JoinNodePrettyPrint<'_> {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        use crate::prettier::PrettyPrint;
95        self.pretty_print(f)
96    }
97}
98
99// SEMANTIC EQUALITY (FOR TESTING)
100// ================================================================================================
101
102#[cfg(test)]
103impl JoinNode {
104    /// Checks if two JoinNodes are semantically equal (i.e., they represent the same join
105    /// operation).
106    ///
107    /// Unlike the derived PartialEq, this method ignores storage identity and compares the
108    /// executable shape directly.
109    #[cfg(test)]
110    pub fn semantic_eq(&self, other: &JoinNode, _forest: &MastForest) -> bool {
111        // Compare children
112        if self.first() != other.first() || self.second() != other.second() {
113            return false;
114        }
115
116        // Compare digests
117        if self.digest() != other.digest() {
118            return false;
119        }
120
121        true
122    }
123}
124
125// MAST NODE TRAIT IMPLEMENTATION
126// ================================================================================================
127
128impl MastNodeExt for JoinNode {
129    /// Returns a commitment to this Join node.
130    ///
131    /// The commitment is computed as a hash of the `first` and `second` child node in the domain
132    /// defined by [Self::DOMAIN] - i.e.,:
133    /// ```
134    /// # use miden_core::mast::JoinNode;
135    /// # use miden_crypto::{Word, hash::poseidon2::Poseidon2 as Hasher};
136    /// # let first_child_digest = Word::default();
137    /// # let second_child_digest = Word::default();
138    /// Hasher::merge_in_domain(&[first_child_digest, second_child_digest], JoinNode::DOMAIN);
139    /// ```
140    fn digest(&self) -> Word {
141        self.digest
142    }
143
144    fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
145        Box::new(JoinNode::to_display(self, mast_forest))
146    }
147
148    fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
149        Box::new(JoinNode::to_pretty_print(self, mast_forest))
150    }
151
152    fn has_children(&self) -> bool {
153        true
154    }
155
156    fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
157        target.push(self.first());
158        target.push(self.second());
159    }
160
161    fn for_each_child<F>(&self, mut f: F)
162    where
163        F: FnMut(MastNodeId),
164    {
165        f(self.first());
166        f(self.second());
167    }
168
169    fn domain(&self) -> Felt {
170        Self::DOMAIN
171    }
172
173    type Builder = JoinNodeBuilder;
174
175    fn to_builder(self, _forest: &MastForest) -> Self::Builder {
176        JoinNodeBuilder::new(self.children).with_digest(self.digest)
177    }
178}
179
180// ARBITRARY IMPLEMENTATION
181// ================================================================================================
182
183#[cfg(all(feature = "arbitrary", test))]
184impl proptest::prelude::Arbitrary for JoinNode {
185    type Parameters = ();
186
187    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
188        use proptest::prelude::*;
189
190        use crate::Felt;
191
192        // Generate two MastNodeId values and digest for the children
193        (any::<MastNodeId>(), any::<MastNodeId>(), any::<[u64; 4]>())
194            .prop_map(|(first_child, second_child, digest_array)| {
195                // Generate a random digest
196                let digest = Word::from(digest_array.map(Felt::new_unchecked));
197                // Construct directly to avoid MastForest validation for arbitrary data
198                JoinNode {
199                    children: [first_child, second_child],
200                    digest,
201                }
202            })
203            .no_shrink()  // Pure random values, no meaningful shrinking pattern
204            .boxed()
205    }
206
207    type Strategy = proptest::prelude::BoxedStrategy<Self>;
208}
209
210// ------------------------------------------------------------------------------------------------
211/// Builder for creating [`JoinNode`] instances.
212#[derive(Debug)]
213pub struct JoinNodeBuilder {
214    children: [MastNodeId; 2],
215    digest: Option<Word>,
216}
217
218impl JoinNodeBuilder {
219    /// Creates a new builder for a JoinNode with the specified children.
220    pub fn new(children: [MastNodeId; 2]) -> Self {
221        Self { children, digest: None }
222    }
223
224    /// Builds the JoinNode.
225    pub fn build(self, mast_forest: &MastForest) -> Result<JoinNode, MastForestError> {
226        let forest_len = mast_forest.nodes.len();
227        if self.children[0].to_usize() >= forest_len {
228            return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
229        } else if self.children[1].to_usize() >= forest_len {
230            return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
231        }
232
233        // Use the forced digest if provided, otherwise compute the digest
234        let digest = if let Some(forced_digest) = self.digest {
235            forced_digest
236        } else {
237            let left_child_hash = mast_forest[self.children[0]].digest();
238            let right_child_hash = mast_forest[self.children[1]].digest();
239
240            hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
241        };
242
243        Ok(JoinNode { children: self.children, digest })
244    }
245
246    pub(in crate::mast) fn build_linked(self) -> Result<JoinNode, MastForestError> {
247        Ok(JoinNode {
248            children: self.children,
249            digest: self.digest.ok_or(MastForestError::DigestRequiredForDeserialization)?,
250        })
251    }
252}
253
254impl MastForestContributor for JoinNodeBuilder {
255    fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
256        // Validate child node IDs
257        let forest_len = forest.nodes.len();
258        if self.children[0].to_usize() >= forest_len {
259            return Err(MastForestError::NodeIdOverflow(self.children[0], forest_len));
260        } else if self.children[1].to_usize() >= forest_len {
261            return Err(MastForestError::NodeIdOverflow(self.children[1], forest_len));
262        }
263
264        // Use the forced digest if provided, otherwise compute the digest
265        let digest = if let Some(forced_digest) = self.digest {
266            forced_digest
267        } else {
268            let left_child_hash = forest[self.children[0]].digest();
269            let right_child_hash = forest[self.children[1]].digest();
270
271            hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
272        };
273
274        // Create the node in the forest with Linked variant from the start
275        // Move the data directly without intermediate cloning
276        let node_id = forest
277            .nodes
278            .push(JoinNode { children: self.children, digest }.into())
279            .map_err(|_| MastForestError::TooManyNodes)?;
280
281        Ok(node_id)
282    }
283
284    fn fingerprint_for_node(
285        &self,
286        forest: &MastForest,
287        hash_by_node_id: &impl LookupByIdx<MastNodeId, Word>,
288    ) -> Result<Word, MastForestError> {
289        let node_digest = if let Some(forced_digest) = self.digest {
290            forced_digest
291        } else {
292            let left_child_hash = forest[self.children[0]].digest();
293            let right_child_hash = forest[self.children[1]].digest();
294
295            hasher::merge_in_domain(&[left_child_hash, right_child_hash], JoinNode::DOMAIN)
296        };
297
298        fingerprint_with_child_fingerprints(node_digest, &self.children, forest, hash_by_node_id)
299    }
300
301    fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
302        JoinNodeBuilder {
303            children: [
304                *remapping.get(self.children[0]).unwrap_or(&self.children[0]),
305                *remapping.get(self.children[1]).unwrap_or(&self.children[1]),
306            ],
307            digest: self.digest,
308        }
309    }
310
311    fn with_digest(mut self, digest: Word) -> Self {
312        self.digest = Some(digest);
313        self
314    }
315}
316
317impl JoinNodeBuilder {
318    /// Add this node to a forest using relaxed validation.
319    ///
320    /// This method is used during deserialization where nodes may reference child nodes
321    /// that haven't been added to the forest yet. The child node IDs have already been
322    /// validated against the expected final node count during the `try_into_mast_node_builder`
323    /// step, so we can safely skip validation here.
324    ///
325    /// Note: This is not part of the `MastForestContributor` trait because it's only
326    /// intended for internal use during deserialization.
327    pub(in crate::mast) fn add_to_forest_relaxed(
328        self,
329        forest: &mut MastForest,
330    ) -> Result<MastNodeId, MastForestError> {
331        // Use the forced digest if provided, otherwise use a default digest
332        // The actual digest computation will be handled when the forest is complete
333        let Some(digest) = self.digest else {
334            return Err(MastForestError::DigestRequiredForDeserialization);
335        };
336
337        // Create the node in the forest with Linked variant from the start
338        // Move the data directly without intermediate cloning
339        let node_id = forest
340            .nodes
341            .push(JoinNode { children: self.children, digest }.into())
342            .map_err(|_| MastForestError::TooManyNodes)?;
343
344        Ok(node_id)
345    }
346}
347
348#[cfg(any(test, feature = "arbitrary"))]
349impl proptest::prelude::Arbitrary for JoinNodeBuilder {
350    type Parameters = JoinNodeBuilderParams;
351    type Strategy = proptest::strategy::BoxedStrategy<Self>;
352
353    fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
354        use proptest::prelude::*;
355
356        let _ = params;
357        any::<[MastNodeId; 2]>().prop_map(Self::new).boxed()
358    }
359}
360
361/// Parameters for generating JoinNodeBuilder instances
362#[cfg(any(test, feature = "arbitrary"))]
363#[derive(Clone, Debug, Default)]
364pub struct JoinNodeBuilderParams {}