miden_core/mast/node/
split_node.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use miden_crypto::{Felt, Word};
5use miden_formatting::prettier::PrettyPrint;
6
7use super::MastNodeExt;
8use crate::{
9    OPCODE_SPLIT,
10    chiplets::hasher,
11    mast::{DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
12};
13
14// SPLIT NODE
15// ================================================================================================
16
17/// A Split node defines conditional execution. When the VM encounters a Split node it executes
18/// either the `on_true` child or `on_false` child.
19///
20/// Which child is executed is determined based on the top of the stack. If the value is `1`, then
21/// the `on_true` child is executed. If the value is `0`, then the `on_false` child is executed. If
22/// the value is neither `0` nor `1`, the execution fails.
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct SplitNode {
25    branches: [MastNodeId; 2],
26    digest: Word,
27    before_enter: Vec<DecoratorId>,
28    after_exit: Vec<DecoratorId>,
29}
30
31/// Constants
32impl SplitNode {
33    /// The domain of the split node (used for control block hashing).
34    pub const DOMAIN: Felt = Felt::new(OPCODE_SPLIT as u64);
35}
36
37/// Constructors
38impl SplitNode {
39    pub fn new(
40        branches: [MastNodeId; 2],
41        mast_forest: &MastForest,
42    ) -> Result<Self, MastForestError> {
43        let forest_len = mast_forest.nodes.len();
44        if branches[0].as_usize() >= forest_len {
45            return Err(MastForestError::NodeIdOverflow(branches[0], forest_len));
46        } else if branches[1].as_usize() >= forest_len {
47            return Err(MastForestError::NodeIdOverflow(branches[1], forest_len));
48        }
49        let digest = {
50            let if_branch_hash = mast_forest[branches[0]].digest();
51            let else_branch_hash = mast_forest[branches[1]].digest();
52
53            hasher::merge_in_domain(&[if_branch_hash, else_branch_hash], Self::DOMAIN)
54        };
55
56        Ok(Self {
57            branches,
58            digest,
59            before_enter: Vec::new(),
60            after_exit: Vec::new(),
61        })
62    }
63
64    /// Returns a new [`SplitNode`] from values that are assumed to be correct.
65    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
66    pub fn new_unsafe(branches: [MastNodeId; 2], digest: Word) -> Self {
67        Self {
68            branches,
69            digest,
70            before_enter: Vec::new(),
71            after_exit: Vec::new(),
72        }
73    }
74}
75
76/// Public accessors
77impl SplitNode {
78    /// Returns a commitment to this Split node.
79    ///
80    /// The commitment is computed as a hash of the `on_true` and `on_false` child nodes in the
81    /// domain defined by [Self::DOMAIN] - i..e,:
82    /// ```
83    /// # use miden_core::mast::SplitNode;
84    /// # use miden_crypto::{Word, hash::rpo::Rpo256 as Hasher};
85    /// # let on_true_digest = Word::default();
86    /// # let on_false_digest = Word::default();
87    /// Hasher::merge_in_domain(&[on_true_digest, on_false_digest], SplitNode::DOMAIN);
88    /// ```
89    pub fn digest(&self) -> Word {
90        self.digest
91    }
92
93    /// Returns the ID of the node which is to be executed if the top of the stack is `1`.
94    pub fn on_true(&self) -> MastNodeId {
95        self.branches[0]
96    }
97
98    /// Returns the ID of the node which is to be executed if the top of the stack is `0`.
99    pub fn on_false(&self) -> MastNodeId {
100        self.branches[1]
101    }
102
103    /// Returns the decorators to be executed before this node is executed.
104    pub fn before_enter(&self) -> &[DecoratorId] {
105        &self.before_enter
106    }
107
108    /// Returns the decorators to be executed after this node is executed.
109    pub fn after_exit(&self) -> &[DecoratorId] {
110        &self.after_exit
111    }
112}
113
114//-------------------------------------------------------------------------------------------------
115/// Mutators
116impl SplitNode {
117    pub fn remap_children(&self, remapping: &Remapping) -> Self {
118        let mut node = self.clone();
119        node.branches[0] = node.branches[0].remap(remapping);
120        node.branches[1] = node.branches[1].remap(remapping);
121        node
122    }
123
124    /// Sets the list of decorators to be executed before this node.
125    pub fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
126        self.before_enter.extend_from_slice(decorator_ids);
127    }
128
129    /// Sets the list of decorators to be executed after this node.
130    pub fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
131        self.after_exit.extend_from_slice(decorator_ids);
132    }
133
134    /// Removes all decorators from this node.
135    pub fn remove_decorators(&mut self) {
136        self.before_enter.truncate(0);
137        self.after_exit.truncate(0);
138    }
139}
140
141impl MastNodeExt for SplitNode {
142    fn decorators(&self) -> impl Iterator<Item = (usize, DecoratorId)> {
143        self.before_enter.iter().chain(&self.after_exit).copied().enumerate()
144    }
145}
146
147// PRETTY PRINTING
148// ================================================================================================
149
150impl SplitNode {
151    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
152        SplitNodePrettyPrint { split_node: self, mast_forest }
153    }
154
155    pub(super) fn to_pretty_print<'a>(
156        &'a self,
157        mast_forest: &'a MastForest,
158    ) -> impl PrettyPrint + 'a {
159        SplitNodePrettyPrint { split_node: self, mast_forest }
160    }
161}
162
163struct SplitNodePrettyPrint<'a> {
164    split_node: &'a SplitNode,
165    mast_forest: &'a MastForest,
166}
167
168impl PrettyPrint for SplitNodePrettyPrint<'_> {
169    #[rustfmt::skip]
170    fn render(&self) -> crate::prettier::Document {
171        use crate::prettier::*;
172
173        let pre_decorators = {
174            let mut pre_decorators = self
175                .split_node
176                .before_enter()
177                .iter()
178                .map(|&decorator_id| self.mast_forest[decorator_id].render())
179                .reduce(|acc, doc| acc + const_text(" ") + doc)
180                .unwrap_or_default();
181            if !pre_decorators.is_empty() {
182                pre_decorators += nl();
183            }
184
185            pre_decorators
186        };
187
188        let post_decorators = {
189            let mut post_decorators = self
190                .split_node
191                .after_exit()
192                .iter()
193                .map(|&decorator_id| self.mast_forest[decorator_id].render())
194                .reduce(|acc, doc| acc + const_text(" ") + doc)
195                .unwrap_or_default();
196            if !post_decorators.is_empty() {
197                post_decorators = nl() + post_decorators;
198            }
199
200            post_decorators
201        };
202
203        let true_branch = self.mast_forest[self.split_node.on_true()].to_pretty_print(self.mast_forest);
204        let false_branch = self.mast_forest[self.split_node.on_false()].to_pretty_print(self.mast_forest);
205
206        let mut doc = pre_decorators;
207        doc += indent(4, const_text("if.true") + nl() + true_branch.render()) + nl();
208        doc += indent(4, const_text("else") + nl() + false_branch.render());
209        doc += nl() + const_text("end");
210        doc + post_decorators
211    }
212}
213
214impl fmt::Display for SplitNodePrettyPrint<'_> {
215    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216        use crate::prettier::PrettyPrint;
217        self.pretty_print(f)
218    }
219}