miden_core/mast/node/
join_node.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use miden_crypto::{Felt, hash::rpo::RpoDigest};
5
6use crate::{
7    OPCODE_JOIN,
8    chiplets::hasher,
9    mast::{DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
10    prettier::PrettyPrint,
11};
12
13// JOIN NODE
14// ================================================================================================
15
16/// A Join node describe sequential execution. When the VM encounters a Join node, it executes the
17/// first child first and the second child second.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct JoinNode {
20    children: [MastNodeId; 2],
21    digest: RpoDigest,
22    before_enter: Vec<DecoratorId>,
23    after_exit: Vec<DecoratorId>,
24}
25
26/// Constants
27impl JoinNode {
28    /// The domain of the join block (used for control block hashing).
29    pub const DOMAIN: Felt = Felt::new(OPCODE_JOIN as u64);
30}
31
32/// Constructors
33impl JoinNode {
34    /// Returns a new [`JoinNode`] instantiated with the specified children nodes.
35    pub fn new(
36        children: [MastNodeId; 2],
37        mast_forest: &MastForest,
38    ) -> Result<Self, MastForestError> {
39        let forest_len = mast_forest.nodes.len();
40        if children[0].as_usize() >= forest_len {
41            return Err(MastForestError::NodeIdOverflow(children[0], forest_len));
42        } else if children[1].as_usize() >= forest_len {
43            return Err(MastForestError::NodeIdOverflow(children[1], forest_len));
44        }
45        let digest = {
46            let left_child_hash = mast_forest[children[0]].digest();
47            let right_child_hash = mast_forest[children[1]].digest();
48
49            hasher::merge_in_domain(&[left_child_hash, right_child_hash], Self::DOMAIN)
50        };
51
52        Ok(Self {
53            children,
54            digest,
55            before_enter: Vec::new(),
56            after_exit: Vec::new(),
57        })
58    }
59
60    /// Returns a new [`JoinNode`] from values that are assumed to be correct.
61    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
62    pub fn new_unsafe(children: [MastNodeId; 2], digest: RpoDigest) -> Self {
63        Self {
64            children,
65            digest,
66            before_enter: Vec::new(),
67            after_exit: Vec::new(),
68        }
69    }
70}
71
72/// Public accessors
73impl JoinNode {
74    /// Returns a commitment to this Join node.
75    ///
76    /// The commitment is computed as a hash of the `first` and `second` child node in the domain
77    /// defined by [Self::DOMAIN] - i.e.,:
78    /// ```
79    /// # use miden_core::mast::JoinNode;
80    /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}};
81    /// # let first_child_digest = Digest::default();
82    /// # let second_child_digest = Digest::default();
83    /// Hasher::merge_in_domain(&[first_child_digest, second_child_digest], JoinNode::DOMAIN);
84    /// ```
85    pub fn digest(&self) -> RpoDigest {
86        self.digest
87    }
88
89    /// Returns the ID of the node that is to be executed first.
90    pub fn first(&self) -> MastNodeId {
91        self.children[0]
92    }
93
94    /// Returns the ID of the node that is to be executed after the execution of the program
95    /// defined by the first node completes.
96    pub fn second(&self) -> MastNodeId {
97        self.children[1]
98    }
99
100    /// Returns the decorators to be executed before this node is executed.
101    pub fn before_enter(&self) -> &[DecoratorId] {
102        &self.before_enter
103    }
104
105    /// Returns the decorators to be executed after this node is executed.
106    pub fn after_exit(&self) -> &[DecoratorId] {
107        &self.after_exit
108    }
109}
110
111/// Mutators
112impl JoinNode {
113    pub fn remap_children(&self, remapping: &Remapping) -> Self {
114        let mut node = self.clone();
115        node.children[0] = node.children[0].remap(remapping);
116        node.children[1] = node.children[1].remap(remapping);
117        node
118    }
119
120    /// Sets the list of decorators to be executed before this node.
121    pub fn set_before_enter(&mut self, decorator_ids: Vec<DecoratorId>) {
122        self.before_enter = decorator_ids;
123    }
124
125    /// Sets the list of decorators to be executed after this node.
126    pub fn set_after_exit(&mut self, decorator_ids: Vec<DecoratorId>) {
127        self.after_exit = decorator_ids;
128    }
129}
130
131// PRETTY PRINTING
132// ================================================================================================
133
134impl JoinNode {
135    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
136        JoinNodePrettyPrint { join_node: self, mast_forest }
137    }
138
139    pub(super) fn to_pretty_print<'a>(
140        &'a self,
141        mast_forest: &'a MastForest,
142    ) -> impl PrettyPrint + 'a {
143        JoinNodePrettyPrint { join_node: self, mast_forest }
144    }
145}
146
147struct JoinNodePrettyPrint<'a> {
148    join_node: &'a JoinNode,
149    mast_forest: &'a MastForest,
150}
151
152impl PrettyPrint for JoinNodePrettyPrint<'_> {
153    #[rustfmt::skip]
154    fn render(&self) -> crate::prettier::Document {
155        use crate::prettier::*;
156
157        let pre_decorators = {
158            let mut pre_decorators = self
159                .join_node
160                .before_enter()
161                .iter()
162                .map(|&decorator_id| self.mast_forest[decorator_id].render())
163                .reduce(|acc, doc| acc + const_text(" ") + doc)
164                .unwrap_or_default();
165            if !pre_decorators.is_empty() {
166                pre_decorators += nl();
167            }
168
169            pre_decorators
170        };
171
172        let post_decorators = {
173            let mut post_decorators = self
174                .join_node
175                .after_exit()
176                .iter()
177                .map(|&decorator_id| self.mast_forest[decorator_id].render())
178                .reduce(|acc, doc| acc + const_text(" ") + doc)
179                .unwrap_or_default();
180            if !post_decorators.is_empty() {
181                post_decorators = nl() + post_decorators;
182            }
183
184            post_decorators
185        };
186
187        let first_child =
188            self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
189        let second_child =
190            self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
191
192        pre_decorators
193        + indent(
194            4,
195            const_text("join")
196            + nl()
197            + first_child.render()
198            + nl()
199            + second_child.render(),
200        ) + nl() + const_text("end")
201        + post_decorators
202    }
203}
204
205impl fmt::Display for JoinNodePrettyPrint<'_> {
206    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207        use crate::prettier::PrettyPrint;
208        self.pretty_print(f)
209    }
210}