miden_core/mast/node/
split_node.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use miden_crypto::{Felt, hash::rpo::RpoDigest};
5use miden_formatting::prettier::PrettyPrint;
6
7use super::MastNodeExt;
8use crate::{
9    OPCODE_SPLIT,
10    chiplets::hasher,
11    mast::{DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
12};
13
14// SPLIT NODE
15// ================================================================================================
16
17/// A Split node defines conditional execution. When the VM encounters a Split node it executes
18/// either the `on_true` child or `on_false` child.
19///
20/// Which child is executed is determined based on the top of the stack. If the value is `1`, then
21/// the `on_true` child is executed. If the value is `0`, then the `on_false` child is executed. If
22/// the value is neither `0` nor `1`, the execution fails.
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct SplitNode {
25    branches: [MastNodeId; 2],
26    digest: RpoDigest,
27    before_enter: Vec<DecoratorId>,
28    after_exit: Vec<DecoratorId>,
29}
30
31/// Constants
32impl SplitNode {
33    /// The domain of the split node (used for control block hashing).
34    pub const DOMAIN: Felt = Felt::new(OPCODE_SPLIT as u64);
35}
36
37/// Constructors
38impl SplitNode {
39    pub fn new(
40        branches: [MastNodeId; 2],
41        mast_forest: &MastForest,
42    ) -> Result<Self, MastForestError> {
43        let forest_len = mast_forest.nodes.len();
44        if branches[0].as_usize() >= forest_len {
45            return Err(MastForestError::NodeIdOverflow(branches[0], forest_len));
46        } else if branches[1].as_usize() >= forest_len {
47            return Err(MastForestError::NodeIdOverflow(branches[1], forest_len));
48        }
49        let digest = {
50            let if_branch_hash = mast_forest[branches[0]].digest();
51            let else_branch_hash = mast_forest[branches[1]].digest();
52
53            hasher::merge_in_domain(&[if_branch_hash, else_branch_hash], Self::DOMAIN)
54        };
55
56        Ok(Self {
57            branches,
58            digest,
59            before_enter: Vec::new(),
60            after_exit: Vec::new(),
61        })
62    }
63
64    /// Returns a new [`SplitNode`] from values that are assumed to be correct.
65    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
66    pub fn new_unsafe(branches: [MastNodeId; 2], digest: RpoDigest) -> Self {
67        Self {
68            branches,
69            digest,
70            before_enter: Vec::new(),
71            after_exit: Vec::new(),
72        }
73    }
74}
75
76/// Public accessors
77impl SplitNode {
78    /// Returns a commitment to this Split node.
79    ///
80    /// The commitment is computed as a hash of the `on_true` and `on_false` child nodes in the
81    /// domain defined by [Self::DOMAIN] - i..e,:
82    /// ```
83    /// # use miden_core::mast::SplitNode;
84    /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}};
85    /// # let on_true_digest = Digest::default();
86    /// # let on_false_digest = Digest::default();
87    /// Hasher::merge_in_domain(&[on_true_digest, on_false_digest], SplitNode::DOMAIN);
88    /// ```
89    pub fn digest(&self) -> RpoDigest {
90        self.digest
91    }
92
93    /// Returns the ID of the node which is to be executed if the top of the stack is `1`.
94    pub fn on_true(&self) -> MastNodeId {
95        self.branches[0]
96    }
97
98    /// Returns the ID of the node which is to be executed if the top of the stack is `0`.
99    pub fn on_false(&self) -> MastNodeId {
100        self.branches[1]
101    }
102
103    /// Returns the decorators to be executed before this node is executed.
104    pub fn before_enter(&self) -> &[DecoratorId] {
105        &self.before_enter
106    }
107
108    /// Returns the decorators to be executed after this node is executed.
109    pub fn after_exit(&self) -> &[DecoratorId] {
110        &self.after_exit
111    }
112}
113
114/// Mutators
115impl SplitNode {
116    pub fn remap_children(&self, remapping: &Remapping) -> Self {
117        let mut node = self.clone();
118        node.branches[0] = node.branches[0].remap(remapping);
119        node.branches[1] = node.branches[1].remap(remapping);
120        node
121    }
122
123    /// Sets the list of decorators to be executed before this node.
124    pub fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
125        self.before_enter.extend_from_slice(decorator_ids);
126    }
127
128    /// Sets the list of decorators to be executed after this node.
129    pub fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
130        self.after_exit.extend_from_slice(decorator_ids);
131    }
132}
133
134impl MastNodeExt for SplitNode {
135    fn decorators(&self) -> impl Iterator<Item = (usize, DecoratorId)> {
136        self.before_enter.iter().chain(&self.after_exit).copied().enumerate()
137    }
138}
139
140// PRETTY PRINTING
141// ================================================================================================
142
143impl SplitNode {
144    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
145        SplitNodePrettyPrint { split_node: self, mast_forest }
146    }
147
148    pub(super) fn to_pretty_print<'a>(
149        &'a self,
150        mast_forest: &'a MastForest,
151    ) -> impl PrettyPrint + 'a {
152        SplitNodePrettyPrint { split_node: self, mast_forest }
153    }
154}
155
156struct SplitNodePrettyPrint<'a> {
157    split_node: &'a SplitNode,
158    mast_forest: &'a MastForest,
159}
160
161impl PrettyPrint for SplitNodePrettyPrint<'_> {
162    #[rustfmt::skip]
163    fn render(&self) -> crate::prettier::Document {
164        use crate::prettier::*;
165
166        let pre_decorators = {
167            let mut pre_decorators = self
168                .split_node
169                .before_enter()
170                .iter()
171                .map(|&decorator_id| self.mast_forest[decorator_id].render())
172                .reduce(|acc, doc| acc + const_text(" ") + doc)
173                .unwrap_or_default();
174            if !pre_decorators.is_empty() {
175                pre_decorators += nl();
176            }
177
178            pre_decorators
179        };
180
181        let post_decorators = {
182            let mut post_decorators = self
183                .split_node
184                .after_exit()
185                .iter()
186                .map(|&decorator_id| self.mast_forest[decorator_id].render())
187                .reduce(|acc, doc| acc + const_text(" ") + doc)
188                .unwrap_or_default();
189            if !post_decorators.is_empty() {
190                post_decorators = nl() + post_decorators;
191            }
192
193            post_decorators
194        };
195
196        let true_branch = self.mast_forest[self.split_node.on_true()].to_pretty_print(self.mast_forest);
197        let false_branch = self.mast_forest[self.split_node.on_false()].to_pretty_print(self.mast_forest);
198
199        let mut doc = pre_decorators;
200        doc += indent(4, const_text("if.true") + nl() + true_branch.render()) + nl();
201        doc += indent(4, const_text("else") + nl() + false_branch.render());
202        doc += nl() + const_text("end");
203        doc + post_decorators
204    }
205}
206
207impl fmt::Display for SplitNodePrettyPrint<'_> {
208    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209        use crate::prettier::PrettyPrint;
210        self.pretty_print(f)
211    }
212}