miden_core/mast/node/
call_node.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use miden_crypto::{Felt, hash::rpo::RpoDigest};
5use miden_formatting::{
6    hex::ToHex,
7    prettier::{Document, PrettyPrint, const_text, nl, text},
8};
9
10use crate::{
11    OPCODE_CALL, OPCODE_SYSCALL,
12    chiplets::hasher,
13    mast::{DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
14};
15
16// CALL NODE
17// ================================================================================================
18
19/// A Call node describes a function call such that the callee is executed in a different execution
20/// context from the currently executing code.
21///
22/// A call node can be of two types:
23/// - A simple call: the callee is executed in the new user context.
24/// - A syscall: the callee is executed in the root context.
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct CallNode {
27    callee: MastNodeId,
28    is_syscall: bool,
29    digest: RpoDigest,
30    before_enter: Vec<DecoratorId>,
31    after_exit: Vec<DecoratorId>,
32}
33
34//-------------------------------------------------------------------------------------------------
35/// Constants
36impl CallNode {
37    /// The domain of the call block (used for control block hashing).
38    pub const CALL_DOMAIN: Felt = Felt::new(OPCODE_CALL as u64);
39    /// The domain of the syscall block (used for control block hashing).
40    pub const SYSCALL_DOMAIN: Felt = Felt::new(OPCODE_SYSCALL as u64);
41}
42
43//-------------------------------------------------------------------------------------------------
44/// Constructors
45impl CallNode {
46    /// Returns a new [`CallNode`] instantiated with the specified callee.
47    pub fn new(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
48        if callee.as_usize() >= mast_forest.nodes.len() {
49            return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len()));
50        }
51        let digest = {
52            let callee_digest = mast_forest[callee].digest();
53
54            hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::CALL_DOMAIN)
55        };
56
57        Ok(Self {
58            callee,
59            is_syscall: false,
60            digest,
61            before_enter: Vec::new(),
62            after_exit: Vec::new(),
63        })
64    }
65
66    /// Returns a new [`CallNode`] from values that are assumed to be correct.
67    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
68    pub fn new_unsafe(callee: MastNodeId, digest: RpoDigest) -> Self {
69        Self {
70            callee,
71            is_syscall: false,
72            digest,
73            before_enter: Vec::new(),
74            after_exit: Vec::new(),
75        }
76    }
77
78    /// Returns a new [`CallNode`] instantiated with the specified callee and marked as a kernel
79    /// call.
80    pub fn new_syscall(
81        callee: MastNodeId,
82        mast_forest: &MastForest,
83    ) -> Result<Self, MastForestError> {
84        if callee.as_usize() >= mast_forest.nodes.len() {
85            return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len()));
86        }
87        let digest = {
88            let callee_digest = mast_forest[callee].digest();
89
90            hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::SYSCALL_DOMAIN)
91        };
92
93        Ok(Self {
94            callee,
95            is_syscall: true,
96            digest,
97            before_enter: Vec::new(),
98            after_exit: Vec::new(),
99        })
100    }
101
102    /// Returns a new syscall [`CallNode`] from values that are assumed to be correct.
103    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
104    pub fn new_syscall_unsafe(callee: MastNodeId, digest: RpoDigest) -> Self {
105        Self {
106            callee,
107            is_syscall: true,
108            digest,
109            before_enter: Vec::new(),
110            after_exit: Vec::new(),
111        }
112    }
113}
114
115//-------------------------------------------------------------------------------------------------
116/// Public accessors
117impl CallNode {
118    /// Returns a commitment to this Call node.
119    ///
120    /// The commitment is computed as a hash of the callee and an empty word ([ZERO; 4]) in the
121    /// domain defined by either [Self::CALL_DOMAIN] or [Self::SYSCALL_DOMAIN], depending on
122    /// whether the node represents a simple call or a syscall - i.e.,:
123    /// ```
124    /// # use miden_core::mast::CallNode;
125    /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}};
126    /// # let callee_digest = Digest::default();
127    /// Hasher::merge_in_domain(&[callee_digest, Digest::default()], CallNode::CALL_DOMAIN);
128    /// ```
129    /// or
130    /// ```
131    /// # use miden_core::mast::CallNode;
132    /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}};
133    /// # let callee_digest = Digest::default();
134    /// Hasher::merge_in_domain(&[callee_digest, Digest::default()], CallNode::SYSCALL_DOMAIN);
135    /// ```
136    pub fn digest(&self) -> RpoDigest {
137        self.digest
138    }
139
140    /// Returns the ID of the node to be invoked by this call node.
141    pub fn callee(&self) -> MastNodeId {
142        self.callee
143    }
144
145    /// Returns true if this call node represents a syscall.
146    pub fn is_syscall(&self) -> bool {
147        self.is_syscall
148    }
149
150    /// Returns the domain of this call node.
151    pub fn domain(&self) -> Felt {
152        if self.is_syscall() {
153            Self::SYSCALL_DOMAIN
154        } else {
155            Self::CALL_DOMAIN
156        }
157    }
158
159    /// Returns the decorators to be executed before this node is executed.
160    pub fn before_enter(&self) -> &[DecoratorId] {
161        &self.before_enter
162    }
163
164    /// Returns the decorators to be executed after this node is executed.
165    pub fn after_exit(&self) -> &[DecoratorId] {
166        &self.after_exit
167    }
168}
169
170/// Mutators
171impl CallNode {
172    pub fn remap_children(&self, remapping: &Remapping) -> Self {
173        let mut node = self.clone();
174        node.callee = node.callee.remap(remapping);
175        node
176    }
177
178    /// Sets the list of decorators to be executed before this node.
179    pub fn set_before_enter(&mut self, decorator_ids: Vec<DecoratorId>) {
180        self.before_enter = decorator_ids;
181    }
182
183    /// Sets the list of decorators to be executed after this node.
184    pub fn set_after_exit(&mut self, decorator_ids: Vec<DecoratorId>) {
185        self.after_exit = decorator_ids;
186    }
187}
188
189// PRETTY PRINTING
190// ================================================================================================
191
192impl CallNode {
193    pub(super) fn to_pretty_print<'a>(
194        &'a self,
195        mast_forest: &'a MastForest,
196    ) -> impl PrettyPrint + 'a {
197        CallNodePrettyPrint { node: self, mast_forest }
198    }
199
200    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
201        CallNodePrettyPrint { node: self, mast_forest }
202    }
203}
204
205struct CallNodePrettyPrint<'a> {
206    node: &'a CallNode,
207    mast_forest: &'a MastForest,
208}
209
210impl CallNodePrettyPrint<'_> {
211    /// Concatenates the provided decorators in a single line. If the list of decorators is not
212    /// empty, prepends `prepend` and appends `append` to the decorator document.
213    fn concatenate_decorators(
214        &self,
215        decorator_ids: &[DecoratorId],
216        prepend: Document,
217        append: Document,
218    ) -> Document {
219        let decorators = decorator_ids
220            .iter()
221            .map(|&decorator_id| self.mast_forest[decorator_id].render())
222            .reduce(|acc, doc| acc + const_text(" ") + doc)
223            .unwrap_or_default();
224
225        if decorators.is_empty() {
226            decorators
227        } else {
228            prepend + decorators + append
229        }
230    }
231
232    fn single_line_pre_decorators(&self) -> Document {
233        self.concatenate_decorators(self.node.before_enter(), Document::Empty, const_text(" "))
234    }
235
236    fn single_line_post_decorators(&self) -> Document {
237        self.concatenate_decorators(self.node.after_exit(), const_text(" "), Document::Empty)
238    }
239
240    fn multi_line_pre_decorators(&self) -> Document {
241        self.concatenate_decorators(self.node.before_enter(), Document::Empty, nl())
242    }
243
244    fn multi_line_post_decorators(&self) -> Document {
245        self.concatenate_decorators(self.node.after_exit(), nl(), Document::Empty)
246    }
247}
248
249impl PrettyPrint for CallNodePrettyPrint<'_> {
250    fn render(&self) -> Document {
251        let call_or_syscall = {
252            let callee_digest = self.mast_forest[self.node.callee].digest();
253            if self.node.is_syscall {
254                const_text("syscall")
255                    + const_text(".")
256                    + text(callee_digest.as_bytes().to_hex_with_prefix())
257            } else {
258                const_text("call")
259                    + const_text(".")
260                    + text(callee_digest.as_bytes().to_hex_with_prefix())
261            }
262        };
263
264        let single_line = self.single_line_pre_decorators()
265            + call_or_syscall.clone()
266            + self.single_line_post_decorators();
267        let multi_line =
268            self.multi_line_pre_decorators() + call_or_syscall + self.multi_line_post_decorators();
269
270        single_line | multi_line
271    }
272}
273
274impl fmt::Display for CallNodePrettyPrint<'_> {
275    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276        use crate::prettier::PrettyPrint;
277        self.pretty_print(f)
278    }
279}