miden_core/mast/node/
split_node.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4use miden_crypto::{Felt, Word};
5use miden_formatting::prettier::PrettyPrint;
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9use super::{MastNodeErrorContext, MastNodeExt};
10use crate::{
11    Idx, OPCODE_SPLIT,
12    chiplets::hasher,
13    mast::{DecoratedOpLink, DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
14};
15
16// SPLIT NODE
17// ================================================================================================
18
19/// A Split node defines conditional execution. When the VM encounters a Split node it executes
20/// either the `on_true` child or `on_false` child.
21///
22/// Which child is executed is determined based on the top of the stack. If the value is `1`, then
23/// the `on_true` child is executed. If the value is `0`, then the `on_false` child is executed. If
24/// the value is neither `0` nor `1`, the execution fails.
25#[derive(Debug, Clone, PartialEq, Eq)]
26#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
27pub struct SplitNode {
28    branches: [MastNodeId; 2],
29    digest: Word,
30    #[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty"))]
31    before_enter: Vec<DecoratorId>,
32    #[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty"))]
33    after_exit: Vec<DecoratorId>,
34}
35
36/// Constants
37impl SplitNode {
38    /// The domain of the split node (used for control block hashing).
39    pub const DOMAIN: Felt = Felt::new(OPCODE_SPLIT as u64);
40}
41
42/// Constructors
43impl SplitNode {
44    pub fn new(
45        branches: [MastNodeId; 2],
46        mast_forest: &MastForest,
47    ) -> Result<Self, MastForestError> {
48        let forest_len = mast_forest.nodes.len();
49        if branches[0].to_usize() >= forest_len {
50            return Err(MastForestError::NodeIdOverflow(branches[0], forest_len));
51        } else if branches[1].to_usize() >= forest_len {
52            return Err(MastForestError::NodeIdOverflow(branches[1], forest_len));
53        }
54        let digest = {
55            let if_branch_hash = mast_forest[branches[0]].digest();
56            let else_branch_hash = mast_forest[branches[1]].digest();
57
58            hasher::merge_in_domain(&[if_branch_hash, else_branch_hash], Self::DOMAIN)
59        };
60
61        Ok(Self {
62            branches,
63            digest,
64            before_enter: Vec::new(),
65            after_exit: Vec::new(),
66        })
67    }
68
69    /// Returns a new [`SplitNode`] from values that are assumed to be correct.
70    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
71    pub fn new_unsafe(branches: [MastNodeId; 2], digest: Word) -> Self {
72        Self {
73            branches,
74            digest,
75            before_enter: Vec::new(),
76            after_exit: Vec::new(),
77        }
78    }
79}
80
81/// Public accessors
82impl SplitNode {
83    /// Returns the ID of the node which is to be executed if the top of the stack is `1`.
84    pub fn on_true(&self) -> MastNodeId {
85        self.branches[0]
86    }
87
88    /// Returns the ID of the node which is to be executed if the top of the stack is `0`.
89    pub fn on_false(&self) -> MastNodeId {
90        self.branches[1]
91    }
92}
93
94impl MastNodeErrorContext for SplitNode {
95    fn decorators(&self) -> impl Iterator<Item = DecoratedOpLink> {
96        self.before_enter.iter().chain(&self.after_exit).copied().enumerate()
97    }
98}
99
100// PRETTY PRINTING
101// ================================================================================================
102
103impl SplitNode {
104    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
105        SplitNodePrettyPrint { split_node: self, mast_forest }
106    }
107
108    pub(super) fn to_pretty_print<'a>(
109        &'a self,
110        mast_forest: &'a MastForest,
111    ) -> impl PrettyPrint + 'a {
112        SplitNodePrettyPrint { split_node: self, mast_forest }
113    }
114}
115
116struct SplitNodePrettyPrint<'a> {
117    split_node: &'a SplitNode,
118    mast_forest: &'a MastForest,
119}
120
121impl PrettyPrint for SplitNodePrettyPrint<'_> {
122    #[rustfmt::skip]
123    fn render(&self) -> crate::prettier::Document {
124        use crate::prettier::*;
125
126        let pre_decorators = {
127            let mut pre_decorators = self
128                .split_node
129                .before_enter()
130                .iter()
131                .map(|&decorator_id| self.mast_forest[decorator_id].render())
132                .reduce(|acc, doc| acc + const_text(" ") + doc)
133                .unwrap_or_default();
134            if !pre_decorators.is_empty() {
135                pre_decorators += nl();
136            }
137
138            pre_decorators
139        };
140
141        let post_decorators = {
142            let mut post_decorators = self
143                .split_node
144                .after_exit()
145                .iter()
146                .map(|&decorator_id| self.mast_forest[decorator_id].render())
147                .reduce(|acc, doc| acc + const_text(" ") + doc)
148                .unwrap_or_default();
149            if !post_decorators.is_empty() {
150                post_decorators = nl() + post_decorators;
151            }
152
153            post_decorators
154        };
155
156        let true_branch = self.mast_forest[self.split_node.on_true()].to_pretty_print(self.mast_forest);
157        let false_branch = self.mast_forest[self.split_node.on_false()].to_pretty_print(self.mast_forest);
158
159        let mut doc = pre_decorators;
160        doc += indent(4, const_text("if.true") + nl() + true_branch.render()) + nl();
161        doc += indent(4, const_text("else") + nl() + false_branch.render());
162        doc += nl() + const_text("end");
163        doc + post_decorators
164    }
165}
166
167impl fmt::Display for SplitNodePrettyPrint<'_> {
168    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169        use crate::prettier::PrettyPrint;
170        self.pretty_print(f)
171    }
172}
173
174// MAST NODE TRAIT IMPLEMENTATION
175// ================================================================================================
176
177impl MastNodeExt for SplitNode {
178    /// Returns a commitment to this Split node.
179    ///
180    /// The commitment is computed as a hash of the `on_true` and `on_false` child nodes in the
181    /// domain defined by [Self::DOMAIN] - i..e,:
182    /// ```
183    /// # use miden_core::mast::SplitNode;
184    /// # use miden_crypto::{Word, hash::rpo::Rpo256 as Hasher};
185    /// # let on_true_digest = Word::default();
186    /// # let on_false_digest = Word::default();
187    /// Hasher::merge_in_domain(&[on_true_digest, on_false_digest], SplitNode::DOMAIN);
188    /// ```
189    fn digest(&self) -> Word {
190        self.digest
191    }
192
193    /// Returns the decorators to be executed before this node is executed.
194    fn before_enter(&self) -> &[DecoratorId] {
195        &self.before_enter
196    }
197
198    /// Returns the decorators to be executed after this node is executed.
199    fn after_exit(&self) -> &[DecoratorId] {
200        &self.after_exit
201    }
202    /// Sets the list of decorators to be executed before this node.
203    fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
204        self.before_enter.extend_from_slice(decorator_ids);
205    }
206
207    /// Sets the list of decorators to be executed after this node.
208    fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
209        self.after_exit.extend_from_slice(decorator_ids);
210    }
211
212    /// Removes all decorators from this node.
213    fn remove_decorators(&mut self) {
214        self.before_enter.truncate(0);
215        self.after_exit.truncate(0);
216    }
217
218    fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
219        Box::new(SplitNode::to_display(self, mast_forest))
220    }
221
222    fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
223        Box::new(SplitNode::to_pretty_print(self, mast_forest))
224    }
225
226    fn remap_children(&self, remapping: &Remapping) -> Self {
227        let mut node = self.clone();
228        node.branches[0] = node.branches[0].remap(remapping);
229        node.branches[1] = node.branches[1].remap(remapping);
230        node
231    }
232
233    fn has_children(&self) -> bool {
234        true
235    }
236
237    fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
238        target.push(self.on_true());
239        target.push(self.on_false());
240    }
241
242    fn for_each_child<F>(&self, mut f: F)
243    where
244        F: FnMut(MastNodeId),
245    {
246        f(self.on_true());
247        f(self.on_false());
248    }
249
250    fn domain(&self) -> Felt {
251        Self::DOMAIN
252    }
253}