miden_core/mast/node/
join_node.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use miden_crypto::{hash::rpo::RpoDigest, Felt};
5
6use crate::{
7    chiplets::hasher,
8    mast::{DecoratorId, MastForest, MastForestError, MastNodeId},
9    prettier::PrettyPrint,
10    OPCODE_JOIN,
11};
12
13// JOIN NODE
14// ================================================================================================
15
16/// A Join node describe sequential execution. When the VM encounters a Join node, it executes the
17/// first child first and the second child second.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct JoinNode {
20    children: [MastNodeId; 2],
21    digest: RpoDigest,
22    before_enter: Vec<DecoratorId>,
23    after_exit: Vec<DecoratorId>,
24}
25
26/// Constants
27impl JoinNode {
28    /// The domain of the join block (used for control block hashing).
29    pub const DOMAIN: Felt = Felt::new(OPCODE_JOIN as u64);
30}
31
32/// Constructors
33impl JoinNode {
34    /// Returns a new [`JoinNode`] instantiated with the specified children nodes.
35    pub fn new(
36        children: [MastNodeId; 2],
37        mast_forest: &MastForest,
38    ) -> Result<Self, MastForestError> {
39        let forest_len = mast_forest.nodes.len();
40        if children[0].as_usize() >= forest_len {
41            return Err(MastForestError::NodeIdOverflow(children[0], forest_len));
42        } else if children[1].as_usize() >= forest_len {
43            return Err(MastForestError::NodeIdOverflow(children[1], forest_len));
44        }
45        let digest = {
46            let left_child_hash = mast_forest[children[0]].digest();
47            let right_child_hash = mast_forest[children[1]].digest();
48
49            hasher::merge_in_domain(&[left_child_hash, right_child_hash], Self::DOMAIN)
50        };
51
52        Ok(Self {
53            children,
54            digest,
55            before_enter: Vec::new(),
56            after_exit: Vec::new(),
57        })
58    }
59
60    /// Returns a new [`JoinNode`] from values that are assumed to be correct.
61    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
62    pub fn new_unsafe(children: [MastNodeId; 2], digest: RpoDigest) -> Self {
63        Self {
64            children,
65            digest,
66            before_enter: Vec::new(),
67            after_exit: Vec::new(),
68        }
69    }
70}
71
72/// Public accessors
73impl JoinNode {
74    /// Returns a commitment to this Join node.
75    ///
76    /// The commitment is computed as a hash of the `first` and `second` child node in the domain
77    /// defined by [Self::DOMAIN] - i.e.,:
78    /// ```
79    /// # use miden_core::mast::JoinNode;
80    /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}};
81    /// # let first_child_digest = Digest::default();
82    /// # let second_child_digest = Digest::default();
83    /// Hasher::merge_in_domain(&[first_child_digest, second_child_digest], JoinNode::DOMAIN);
84    /// ```
85    pub fn digest(&self) -> RpoDigest {
86        self.digest
87    }
88
89    /// Returns the ID of the node that is to be executed first.
90    pub fn first(&self) -> MastNodeId {
91        self.children[0]
92    }
93
94    /// Returns the ID of the node that is to be executed after the execution of the program
95    /// defined by the first node completes.
96    pub fn second(&self) -> MastNodeId {
97        self.children[1]
98    }
99
100    /// Returns the decorators to be executed before this node is executed.
101    pub fn before_enter(&self) -> &[DecoratorId] {
102        &self.before_enter
103    }
104
105    /// Returns the decorators to be executed after this node is executed.
106    pub fn after_exit(&self) -> &[DecoratorId] {
107        &self.after_exit
108    }
109}
110
111/// Mutators
112impl JoinNode {
113    /// Sets the list of decorators to be executed before this node.
114    pub fn set_before_enter(&mut self, decorator_ids: Vec<DecoratorId>) {
115        self.before_enter = decorator_ids;
116    }
117
118    /// Sets the list of decorators to be executed after this node.
119    pub fn set_after_exit(&mut self, decorator_ids: Vec<DecoratorId>) {
120        self.after_exit = decorator_ids;
121    }
122}
123
124// PRETTY PRINTING
125// ================================================================================================
126
127impl JoinNode {
128    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
129        JoinNodePrettyPrint { join_node: self, mast_forest }
130    }
131
132    pub(super) fn to_pretty_print<'a>(
133        &'a self,
134        mast_forest: &'a MastForest,
135    ) -> impl PrettyPrint + 'a {
136        JoinNodePrettyPrint { join_node: self, mast_forest }
137    }
138}
139
140struct JoinNodePrettyPrint<'a> {
141    join_node: &'a JoinNode,
142    mast_forest: &'a MastForest,
143}
144
145impl PrettyPrint for JoinNodePrettyPrint<'_> {
146    #[rustfmt::skip]
147    fn render(&self) -> crate::prettier::Document {
148        use crate::prettier::*;
149
150        let pre_decorators = {
151            let mut pre_decorators = self
152                .join_node
153                .before_enter()
154                .iter()
155                .map(|&decorator_id| self.mast_forest[decorator_id].render())
156                .reduce(|acc, doc| acc + const_text(" ") + doc)
157                .unwrap_or_default();
158            if !pre_decorators.is_empty() {
159                pre_decorators += nl();
160            }
161
162            pre_decorators
163        };
164
165        let post_decorators = {
166            let mut post_decorators = self
167                .join_node
168                .after_exit()
169                .iter()
170                .map(|&decorator_id| self.mast_forest[decorator_id].render())
171                .reduce(|acc, doc| acc + const_text(" ") + doc)
172                .unwrap_or_default();
173            if !post_decorators.is_empty() {
174                post_decorators = nl() + post_decorators;
175            }
176
177            post_decorators
178        };
179
180        let first_child =
181            self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
182        let second_child =
183            self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
184
185        pre_decorators
186        + indent(
187            4,
188            const_text("join")
189            + nl()
190            + first_child.render()
191            + nl()
192            + second_child.render(),
193        ) + nl() + const_text("end")
194        + post_decorators
195    }
196}
197
198impl fmt::Display for JoinNodePrettyPrint<'_> {
199    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200        use crate::prettier::PrettyPrint;
201        self.pretty_print(f)
202    }
203}