miden_core/mast/node/
split_node.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use miden_crypto::{Felt, hash::rpo::RpoDigest};
5use miden_formatting::prettier::PrettyPrint;
6
7use crate::{
8    OPCODE_SPLIT,
9    chiplets::hasher,
10    mast::{DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
11};
12
13// SPLIT NODE
14// ================================================================================================
15
16/// A Split node defines conditional execution. When the VM encounters a Split node it executes
17/// either the `on_true` child or `on_false` child.
18///
19/// Which child is executed is determined based on the top of the stack. If the value is `1`, then
20/// the `on_true` child is executed. If the value is `0`, then the `on_false` child is executed. If
21/// the value is neither `0` nor `1`, the execution fails.
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct SplitNode {
24    branches: [MastNodeId; 2],
25    digest: RpoDigest,
26    before_enter: Vec<DecoratorId>,
27    after_exit: Vec<DecoratorId>,
28}
29
30/// Constants
31impl SplitNode {
32    /// The domain of the split node (used for control block hashing).
33    pub const DOMAIN: Felt = Felt::new(OPCODE_SPLIT as u64);
34}
35
36/// Constructors
37impl SplitNode {
38    pub fn new(
39        branches: [MastNodeId; 2],
40        mast_forest: &MastForest,
41    ) -> Result<Self, MastForestError> {
42        let forest_len = mast_forest.nodes.len();
43        if branches[0].as_usize() >= forest_len {
44            return Err(MastForestError::NodeIdOverflow(branches[0], forest_len));
45        } else if branches[1].as_usize() >= forest_len {
46            return Err(MastForestError::NodeIdOverflow(branches[1], forest_len));
47        }
48        let digest = {
49            let if_branch_hash = mast_forest[branches[0]].digest();
50            let else_branch_hash = mast_forest[branches[1]].digest();
51
52            hasher::merge_in_domain(&[if_branch_hash, else_branch_hash], Self::DOMAIN)
53        };
54
55        Ok(Self {
56            branches,
57            digest,
58            before_enter: Vec::new(),
59            after_exit: Vec::new(),
60        })
61    }
62
63    /// Returns a new [`SplitNode`] from values that are assumed to be correct.
64    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
65    pub fn new_unsafe(branches: [MastNodeId; 2], digest: RpoDigest) -> Self {
66        Self {
67            branches,
68            digest,
69            before_enter: Vec::new(),
70            after_exit: Vec::new(),
71        }
72    }
73}
74
75/// Public accessors
76impl SplitNode {
77    /// Returns a commitment to this Split node.
78    ///
79    /// The commitment is computed as a hash of the `on_true` and `on_false` child nodes in the
80    /// domain defined by [Self::DOMAIN] - i..e,:
81    /// ```
82    /// # use miden_core::mast::SplitNode;
83    /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}};
84    /// # let on_true_digest = Digest::default();
85    /// # let on_false_digest = Digest::default();
86    /// Hasher::merge_in_domain(&[on_true_digest, on_false_digest], SplitNode::DOMAIN);
87    /// ```
88    pub fn digest(&self) -> RpoDigest {
89        self.digest
90    }
91
92    /// Returns the ID of the node which is to be executed if the top of the stack is `1`.
93    pub fn on_true(&self) -> MastNodeId {
94        self.branches[0]
95    }
96
97    /// Returns the ID of the node which is to be executed if the top of the stack is `0`.
98    pub fn on_false(&self) -> MastNodeId {
99        self.branches[1]
100    }
101
102    /// Returns the decorators to be executed before this node is executed.
103    pub fn before_enter(&self) -> &[DecoratorId] {
104        &self.before_enter
105    }
106
107    /// Returns the decorators to be executed after this node is executed.
108    pub fn after_exit(&self) -> &[DecoratorId] {
109        &self.after_exit
110    }
111}
112
113/// Mutators
114impl SplitNode {
115    pub fn remap_children(&self, remapping: &Remapping) -> Self {
116        let mut node = self.clone();
117        node.branches[0] = node.branches[0].remap(remapping);
118        node.branches[1] = node.branches[1].remap(remapping);
119        node
120    }
121
122    /// Sets the list of decorators to be executed before this node.
123    pub fn set_before_enter(&mut self, decorator_ids: Vec<DecoratorId>) {
124        self.before_enter = decorator_ids;
125    }
126
127    /// Sets the list of decorators to be executed after this node.
128    pub fn set_after_exit(&mut self, decorator_ids: Vec<DecoratorId>) {
129        self.after_exit = decorator_ids;
130    }
131}
132
133// PRETTY PRINTING
134// ================================================================================================
135
136impl SplitNode {
137    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
138        SplitNodePrettyPrint { split_node: self, mast_forest }
139    }
140
141    pub(super) fn to_pretty_print<'a>(
142        &'a self,
143        mast_forest: &'a MastForest,
144    ) -> impl PrettyPrint + 'a {
145        SplitNodePrettyPrint { split_node: self, mast_forest }
146    }
147}
148
149struct SplitNodePrettyPrint<'a> {
150    split_node: &'a SplitNode,
151    mast_forest: &'a MastForest,
152}
153
154impl PrettyPrint for SplitNodePrettyPrint<'_> {
155    #[rustfmt::skip]
156    fn render(&self) -> crate::prettier::Document {
157        use crate::prettier::*;
158
159        let pre_decorators = {
160            let mut pre_decorators = self
161                .split_node
162                .before_enter()
163                .iter()
164                .map(|&decorator_id| self.mast_forest[decorator_id].render())
165                .reduce(|acc, doc| acc + const_text(" ") + doc)
166                .unwrap_or_default();
167            if !pre_decorators.is_empty() {
168                pre_decorators += nl();
169            }
170
171            pre_decorators
172        };
173
174        let post_decorators = {
175            let mut post_decorators = self
176                .split_node
177                .after_exit()
178                .iter()
179                .map(|&decorator_id| self.mast_forest[decorator_id].render())
180                .reduce(|acc, doc| acc + const_text(" ") + doc)
181                .unwrap_or_default();
182            if !post_decorators.is_empty() {
183                post_decorators = nl() + post_decorators;
184            }
185
186            post_decorators
187        };
188
189        let true_branch = self.mast_forest[self.split_node.on_true()].to_pretty_print(self.mast_forest);
190        let false_branch = self.mast_forest[self.split_node.on_false()].to_pretty_print(self.mast_forest);
191
192        let mut doc = pre_decorators;
193        doc += indent(4, const_text("if.true") + nl() + true_branch.render()) + nl();
194        doc += indent(4, const_text("else") + nl() + false_branch.render());
195        doc += nl() + const_text("end");
196        doc + post_decorators
197    }
198}
199
200impl fmt::Display for SplitNodePrettyPrint<'_> {
201    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202        use crate::prettier::PrettyPrint;
203        self.pretty_print(f)
204    }
205}