miden_core/mast/node/
loop_node.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use miden_crypto::{hash::rpo::RpoDigest, Felt};
5use miden_formatting::prettier::PrettyPrint;
6
7use crate::{
8    chiplets::hasher,
9    mast::{DecoratorId, MastForest, MastForestError, MastNodeId},
10    OPCODE_LOOP,
11};
12
13// LOOP NODE
14// ================================================================================================
15
16/// A Loop node defines condition-controlled iterative execution. When the VM encounters a Loop
17/// node, it will keep executing the body of the loop as long as the top of the stack is `1``.
18///
19/// The loop is exited when at the end of executing the loop body the top of the stack is `0``.
20/// If the top of the stack is neither `0` nor `1` when the condition is checked, the execution
21/// fails.
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct LoopNode {
24    body: MastNodeId,
25    digest: RpoDigest,
26    before_enter: Vec<DecoratorId>,
27    after_exit: Vec<DecoratorId>,
28}
29
30/// Constants
31impl LoopNode {
32    /// The domain of the loop node (used for control block hashing).
33    pub const DOMAIN: Felt = Felt::new(OPCODE_LOOP as u64);
34}
35
36/// Constructors
37impl LoopNode {
38    /// Returns a new [`LoopNode`] instantiated with the specified body node.
39    pub fn new(body: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
40        if body.as_usize() >= mast_forest.nodes.len() {
41            return Err(MastForestError::NodeIdOverflow(body, mast_forest.nodes.len()));
42        }
43        let digest = {
44            let body_hash = mast_forest[body].digest();
45
46            hasher::merge_in_domain(&[body_hash, RpoDigest::default()], Self::DOMAIN)
47        };
48
49        Ok(Self {
50            body,
51            digest,
52            before_enter: Vec::new(),
53            after_exit: Vec::new(),
54        })
55    }
56
57    /// Returns a new [`LoopNode`] from values that are assumed to be correct.
58    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
59    pub fn new_unsafe(body: MastNodeId, digest: RpoDigest) -> Self {
60        Self {
61            body,
62            digest,
63            before_enter: Vec::new(),
64            after_exit: Vec::new(),
65        }
66    }
67}
68
69impl LoopNode {
70    /// Returns a commitment to this Loop node.
71    ///
72    /// The commitment is computed as a hash of the loop body and an empty word ([ZERO; 4]) in
73    /// the domain defined by [Self::DOMAIN] - i..e,:
74    /// ```
75    /// # use miden_core::mast::LoopNode;
76    /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}};
77    /// # let body_digest = Digest::default();
78    /// Hasher::merge_in_domain(&[body_digest, Digest::default()], LoopNode::DOMAIN);
79    /// ```
80    pub fn digest(&self) -> RpoDigest {
81        self.digest
82    }
83
84    /// Returns the ID of the node presenting the body of the loop.
85    pub fn body(&self) -> MastNodeId {
86        self.body
87    }
88
89    /// Returns the decorators to be executed before this node is executed.
90    pub fn before_enter(&self) -> &[DecoratorId] {
91        &self.before_enter
92    }
93
94    /// Returns the decorators to be executed after this node is executed.
95    pub fn after_exit(&self) -> &[DecoratorId] {
96        &self.after_exit
97    }
98}
99
100/// Mutators
101impl LoopNode {
102    /// Sets the list of decorators to be executed before this node.
103    pub fn set_before_enter(&mut self, decorator_ids: Vec<DecoratorId>) {
104        self.before_enter = decorator_ids;
105    }
106
107    /// Sets the list of decorators to be executed after this node.
108    pub fn set_after_exit(&mut self, decorator_ids: Vec<DecoratorId>) {
109        self.after_exit = decorator_ids;
110    }
111}
112
113// PRETTY PRINTING
114// ================================================================================================
115
116impl LoopNode {
117    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
118        LoopNodePrettyPrint { loop_node: self, mast_forest }
119    }
120
121    pub(super) fn to_pretty_print<'a>(
122        &'a self,
123        mast_forest: &'a MastForest,
124    ) -> impl PrettyPrint + 'a {
125        LoopNodePrettyPrint { loop_node: self, mast_forest }
126    }
127}
128
129struct LoopNodePrettyPrint<'a> {
130    loop_node: &'a LoopNode,
131    mast_forest: &'a MastForest,
132}
133
134impl crate::prettier::PrettyPrint for LoopNodePrettyPrint<'_> {
135    fn render(&self) -> crate::prettier::Document {
136        use crate::prettier::*;
137
138        let pre_decorators = {
139            let mut pre_decorators = self
140                .loop_node
141                .before_enter()
142                .iter()
143                .map(|&decorator_id| self.mast_forest[decorator_id].render())
144                .reduce(|acc, doc| acc + const_text(" ") + doc)
145                .unwrap_or_default();
146            if !pre_decorators.is_empty() {
147                pre_decorators += nl();
148            }
149
150            pre_decorators
151        };
152
153        let post_decorators = {
154            let mut post_decorators = self
155                .loop_node
156                .after_exit()
157                .iter()
158                .map(|&decorator_id| self.mast_forest[decorator_id].render())
159                .reduce(|acc, doc| acc + const_text(" ") + doc)
160                .unwrap_or_default();
161            if !post_decorators.is_empty() {
162                post_decorators = nl() + post_decorators;
163            }
164
165            post_decorators
166        };
167
168        let loop_body = self.mast_forest[self.loop_node.body].to_pretty_print(self.mast_forest);
169
170        pre_decorators
171            + indent(4, const_text("while.true") + nl() + loop_body.render())
172            + nl()
173            + const_text("end")
174            + post_decorators
175    }
176}
177
178impl fmt::Display for LoopNodePrettyPrint<'_> {
179    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180        use crate::prettier::PrettyPrint;
181        self.pretty_print(f)
182    }
183}