miden_core/mast/node/
join_node.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use miden_crypto::{Felt, hash::rpo::RpoDigest};
5
6use super::MastNodeExt;
7use crate::{
8    OPCODE_JOIN,
9    chiplets::hasher,
10    mast::{DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
11    prettier::PrettyPrint,
12};
13
14// JOIN NODE
15// ================================================================================================
16
17/// A Join node describe sequential execution. When the VM encounters a Join node, it executes the
18/// first child first and the second child second.
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct JoinNode {
21    children: [MastNodeId; 2],
22    digest: RpoDigest,
23    before_enter: Vec<DecoratorId>,
24    after_exit: Vec<DecoratorId>,
25}
26
27/// Constants
28impl JoinNode {
29    /// The domain of the join block (used for control block hashing).
30    pub const DOMAIN: Felt = Felt::new(OPCODE_JOIN as u64);
31}
32
33/// Constructors
34impl JoinNode {
35    /// Returns a new [`JoinNode`] instantiated with the specified children nodes.
36    pub fn new(
37        children: [MastNodeId; 2],
38        mast_forest: &MastForest,
39    ) -> Result<Self, MastForestError> {
40        let forest_len = mast_forest.nodes.len();
41        if children[0].as_usize() >= forest_len {
42            return Err(MastForestError::NodeIdOverflow(children[0], forest_len));
43        } else if children[1].as_usize() >= forest_len {
44            return Err(MastForestError::NodeIdOverflow(children[1], forest_len));
45        }
46        let digest = {
47            let left_child_hash = mast_forest[children[0]].digest();
48            let right_child_hash = mast_forest[children[1]].digest();
49
50            hasher::merge_in_domain(&[left_child_hash, right_child_hash], Self::DOMAIN)
51        };
52
53        Ok(Self {
54            children,
55            digest,
56            before_enter: Vec::new(),
57            after_exit: Vec::new(),
58        })
59    }
60
61    /// Returns a new [`JoinNode`] from values that are assumed to be correct.
62    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
63    pub fn new_unsafe(children: [MastNodeId; 2], digest: RpoDigest) -> Self {
64        Self {
65            children,
66            digest,
67            before_enter: Vec::new(),
68            after_exit: Vec::new(),
69        }
70    }
71}
72
73/// Public accessors
74impl JoinNode {
75    /// Returns a commitment to this Join node.
76    ///
77    /// The commitment is computed as a hash of the `first` and `second` child node in the domain
78    /// defined by [Self::DOMAIN] - i.e.,:
79    /// ```
80    /// # use miden_core::mast::JoinNode;
81    /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}};
82    /// # let first_child_digest = Digest::default();
83    /// # let second_child_digest = Digest::default();
84    /// Hasher::merge_in_domain(&[first_child_digest, second_child_digest], JoinNode::DOMAIN);
85    /// ```
86    pub fn digest(&self) -> RpoDigest {
87        self.digest
88    }
89
90    /// Returns the ID of the node that is to be executed first.
91    pub fn first(&self) -> MastNodeId {
92        self.children[0]
93    }
94
95    /// Returns the ID of the node that is to be executed after the execution of the program
96    /// defined by the first node completes.
97    pub fn second(&self) -> MastNodeId {
98        self.children[1]
99    }
100
101    /// Returns the decorators to be executed before this node is executed.
102    pub fn before_enter(&self) -> &[DecoratorId] {
103        &self.before_enter
104    }
105
106    /// Returns the decorators to be executed after this node is executed.
107    pub fn after_exit(&self) -> &[DecoratorId] {
108        &self.after_exit
109    }
110}
111
112/// Mutators
113impl JoinNode {
114    pub fn remap_children(&self, remapping: &Remapping) -> Self {
115        let mut node = self.clone();
116        node.children[0] = node.children[0].remap(remapping);
117        node.children[1] = node.children[1].remap(remapping);
118        node
119    }
120
121    /// Sets the list of decorators to be executed before this node.
122    pub fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
123        self.before_enter.extend_from_slice(decorator_ids);
124    }
125
126    /// Sets the list of decorators to be executed after this node.
127    pub fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
128        self.after_exit.extend_from_slice(decorator_ids);
129    }
130}
131
132impl MastNodeExt for JoinNode {
133    fn decorators(&self) -> impl Iterator<Item = (usize, DecoratorId)> {
134        self.before_enter.iter().chain(&self.after_exit).copied().enumerate()
135    }
136}
137
138// PRETTY PRINTING
139// ================================================================================================
140
141impl JoinNode {
142    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
143        JoinNodePrettyPrint { join_node: self, mast_forest }
144    }
145
146    pub(super) fn to_pretty_print<'a>(
147        &'a self,
148        mast_forest: &'a MastForest,
149    ) -> impl PrettyPrint + 'a {
150        JoinNodePrettyPrint { join_node: self, mast_forest }
151    }
152}
153
154struct JoinNodePrettyPrint<'a> {
155    join_node: &'a JoinNode,
156    mast_forest: &'a MastForest,
157}
158
159impl PrettyPrint for JoinNodePrettyPrint<'_> {
160    #[rustfmt::skip]
161    fn render(&self) -> crate::prettier::Document {
162        use crate::prettier::*;
163
164        let pre_decorators = {
165            let mut pre_decorators = self
166                .join_node
167                .before_enter()
168                .iter()
169                .map(|&decorator_id| self.mast_forest[decorator_id].render())
170                .reduce(|acc, doc| acc + const_text(" ") + doc)
171                .unwrap_or_default();
172            if !pre_decorators.is_empty() {
173                pre_decorators += nl();
174            }
175
176            pre_decorators
177        };
178
179        let post_decorators = {
180            let mut post_decorators = self
181                .join_node
182                .after_exit()
183                .iter()
184                .map(|&decorator_id| self.mast_forest[decorator_id].render())
185                .reduce(|acc, doc| acc + const_text(" ") + doc)
186                .unwrap_or_default();
187            if !post_decorators.is_empty() {
188                post_decorators = nl() + post_decorators;
189            }
190
191            post_decorators
192        };
193
194        let first_child =
195            self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest);
196        let second_child =
197            self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest);
198
199        pre_decorators
200        + indent(
201            4,
202            const_text("join")
203            + nl()
204            + first_child.render()
205            + nl()
206            + second_child.render(),
207        ) + nl() + const_text("end")
208        + post_decorators
209    }
210}
211
212impl fmt::Display for JoinNodePrettyPrint<'_> {
213    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214        use crate::prettier::PrettyPrint;
215        self.pretty_print(f)
216    }
217}