miden_core/mast/node/
join_node.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4use miden_crypto::{Felt, Word};
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7
8use super::{MastNodeErrorContext, MastNodeExt};
9use crate::{
10    OPCODE_JOIN,
11    chiplets::hasher,
12    mast::{DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
13    prettier::PrettyPrint,
14};
15
16// JOIN NODE
17// ================================================================================================
18
19/// A Join node describe sequential execution. When the VM encounters a Join node, it executes the
20/// first child first and the second child second.
21#[derive(Debug, Clone, PartialEq, Eq)]
22#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
23pub struct JoinNode {
24    children: [MastNodeId; 2],
25    digest: Word,
26    #[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty"))]
27    before_enter: Vec<DecoratorId>,
28    #[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty"))]
29    after_exit: Vec<DecoratorId>,
30}
31
32/// Constants
33impl JoinNode {
34    /// The domain of the join block (used for control block hashing).
35    pub const DOMAIN: Felt = Felt::new(OPCODE_JOIN as u64);
36}
37
38/// Constructors
39impl JoinNode {
40    /// Returns a new [`JoinNode`] instantiated with the specified children nodes.
41    pub fn new(
42        children: [MastNodeId; 2],
43        mast_forest: &MastForest,
44    ) -> Result<Self, MastForestError> {
45        let forest_len = mast_forest.nodes.len();
46        if children[0].as_usize() >= forest_len {
47            return Err(MastForestError::NodeIdOverflow(children[0], forest_len));
48        } else if children[1].as_usize() >= forest_len {
49            return Err(MastForestError::NodeIdOverflow(children[1], forest_len));
50        }
51        let digest = {
52            let left_child_hash = mast_forest[children[0]].digest();
53            let right_child_hash = mast_forest[children[1]].digest();
54
55            hasher::merge_in_domain(&[left_child_hash, right_child_hash], Self::DOMAIN)
56        };
57
58        Ok(Self {
59            children,
60            digest,
61            before_enter: Vec::new(),
62            after_exit: Vec::new(),
63        })
64    }
65
66    /// Returns a new [`JoinNode`] from values that are assumed to be correct.
67    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
68    pub fn new_unsafe(children: [MastNodeId; 2], digest: Word) -> Self {
69        Self {
70            children,
71            digest,
72            before_enter: Vec::new(),
73            after_exit: Vec::new(),
74        }
75    }
76}
77
78/// Public accessors
79impl JoinNode {
80    /// Returns the ID of the node that is to be executed first.
81    pub fn first(&self) -> MastNodeId {
82        self.children[0]
83    }
84
85    /// Returns the ID of the node that is to be executed after the execution of the program
86    /// defined by the first node completes.
87    pub fn second(&self) -> MastNodeId {
88        self.children[1]
89    }
90}
91
92impl MastNodeErrorContext for JoinNode {
93    fn decorators(&self) -> impl Iterator<Item = (usize, DecoratorId)> {
94        self.before_enter.iter().chain(&self.after_exit).copied().enumerate()
95    }
96}
97
98// PRETTY PRINTING
99// ================================================================================================
100
101impl JoinNode {
102    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
103        JoinNodePrettyPrint { join_node: self, mast_forest }
104    }
105
106    pub(super) fn to_pretty_print<'a>(
107        &'a self,
108        mast_forest: &'a MastForest,
109    ) -> impl PrettyPrint + 'a {
110        JoinNodePrettyPrint { join_node: self, mast_forest }
111    }
112}
113
114struct JoinNodePrettyPrint<'a> {
115    join_node: &'a JoinNode,
116    mast_forest: &'a MastForest,
117}
118
119impl PrettyPrint for JoinNodePrettyPrint<'_> {
120    #[rustfmt::skip]
121    fn render(&self) -> crate::prettier::Document {
122        use crate::prettier::*;
123
124        let pre_decorators = {
125            let mut pre_decorators = self
126                .join_node
127                .before_enter()
128                .iter()
129                .map(|&decorator_id| self.mast_forest[decorator_id].render())
130                .reduce(|acc, doc| acc + const_text(" ") + doc)
131                .unwrap_or_default();
132            if !pre_decorators.is_empty() {
133                pre_decorators += nl();
134            }
135
136            pre_decorators
137        };
138
139        let post_decorators = {
140            let mut post_decorators = self
141                .join_node
142                .after_exit()
143                .iter()
144                .map(|&decorator_id| self.mast_forest[decorator_id].render())
145                .reduce(|acc, doc| acc + const_text(" ") + doc)
146                .unwrap_or_default();
147            if !post_decorators.is_empty() {
148                post_decorators = nl() + post_decorators;
149            }
150
151            post_decorators
152        };
153
154        let first_child =
155            self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
156        let second_child =
157            self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
158
159        pre_decorators
160        + indent(
161            4,
162            const_text("join")
163            + nl()
164            + first_child.render()
165            + nl()
166            + second_child.render(),
167        ) + nl() + const_text("end")
168        + post_decorators
169    }
170}
171
172impl fmt::Display for JoinNodePrettyPrint<'_> {
173    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174        use crate::prettier::PrettyPrint;
175        self.pretty_print(f)
176    }
177}
178
179// MAST NODE TRAIT IMPLEMENTATION
180// ================================================================================================
181
182impl MastNodeExt for JoinNode {
183    /// Returns a commitment to this Join node.
184    ///
185    /// The commitment is computed as a hash of the `first` and `second` child node in the domain
186    /// defined by [Self::DOMAIN] - i.e.,:
187    /// ```
188    /// # use miden_core::mast::JoinNode;
189    /// # use miden_crypto::{Word, hash::rpo::Rpo256 as Hasher};
190    /// # let first_child_digest = Word::default();
191    /// # let second_child_digest = Word::default();
192    /// Hasher::merge_in_domain(&[first_child_digest, second_child_digest], JoinNode::DOMAIN);
193    /// ```
194    fn digest(&self) -> Word {
195        self.digest
196    }
197
198    /// Returns the decorators to be executed before this node is executed.
199    fn before_enter(&self) -> &[DecoratorId] {
200        &self.before_enter
201    }
202
203    /// Returns the decorators to be executed after this node is executed.
204    fn after_exit(&self) -> &[DecoratorId] {
205        &self.after_exit
206    }
207    /// Sets the list of decorators to be executed before this node.
208    fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
209        self.before_enter.extend_from_slice(decorator_ids);
210    }
211
212    /// Sets the list of decorators to be executed after this node.
213    fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
214        self.after_exit.extend_from_slice(decorator_ids);
215    }
216
217    /// Removes all decorators from this node.
218    fn remove_decorators(&mut self) {
219        self.before_enter.truncate(0);
220        self.after_exit.truncate(0);
221    }
222
223    fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
224        Box::new(JoinNode::to_display(self, mast_forest))
225    }
226
227    fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
228        Box::new(JoinNode::to_pretty_print(self, mast_forest))
229    }
230
231    fn remap_children(&self, remapping: &Remapping) -> Self {
232        let mut node = self.clone();
233        node.children[0] = node.children[0].remap(remapping);
234        node.children[1] = node.children[1].remap(remapping);
235        node
236    }
237
238    fn has_children(&self) -> bool {
239        true
240    }
241
242    fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
243        target.push(self.first());
244        target.push(self.second());
245    }
246
247    fn domain(&self) -> Felt {
248        Self::DOMAIN
249    }
250}