miden_core/mast/node/
call_node.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4use miden_crypto::{Felt, Word};
5use miden_formatting::{
6    hex::ToHex,
7    prettier::{Document, PrettyPrint, const_text, nl, text},
8};
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12use super::{MastNodeErrorContext, MastNodeExt};
13use crate::{
14    OPCODE_CALL, OPCODE_SYSCALL,
15    chiplets::hasher,
16    mast::{DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
17};
18
19// CALL NODE
20// ================================================================================================
21
22/// A Call node describes a function call such that the callee is executed in a different execution
23/// context from the currently executing code.
24///
25/// A call node can be of two types:
26/// - A simple call: the callee is executed in the new user context.
27/// - A syscall: the callee is executed in the root context.
28#[derive(Debug, Clone, PartialEq, Eq)]
29#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
30pub struct CallNode {
31    callee: MastNodeId,
32    is_syscall: bool,
33    digest: Word,
34    #[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty"))]
35    before_enter: Vec<DecoratorId>,
36    #[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty"))]
37    after_exit: Vec<DecoratorId>,
38}
39
40//-------------------------------------------------------------------------------------------------
41/// Constants
42impl CallNode {
43    /// The domain of the call block (used for control block hashing).
44    pub const CALL_DOMAIN: Felt = Felt::new(OPCODE_CALL as u64);
45    /// The domain of the syscall block (used for control block hashing).
46    pub const SYSCALL_DOMAIN: Felt = Felt::new(OPCODE_SYSCALL as u64);
47}
48
49//-------------------------------------------------------------------------------------------------
50/// Constructors
51impl CallNode {
52    /// Returns a new [`CallNode`] instantiated with the specified callee.
53    pub fn new(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
54        if callee.as_usize() >= mast_forest.nodes.len() {
55            return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len()));
56        }
57        let digest = {
58            let callee_digest = mast_forest[callee].digest();
59
60            hasher::merge_in_domain(&[callee_digest, Word::default()], Self::CALL_DOMAIN)
61        };
62
63        Ok(Self {
64            callee,
65            is_syscall: false,
66            digest,
67            before_enter: Vec::new(),
68            after_exit: Vec::new(),
69        })
70    }
71
72    /// Returns a new [`CallNode`] from values that are assumed to be correct.
73    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
74    pub fn new_unsafe(callee: MastNodeId, digest: Word) -> Self {
75        Self {
76            callee,
77            is_syscall: false,
78            digest,
79            before_enter: Vec::new(),
80            after_exit: Vec::new(),
81        }
82    }
83
84    /// Returns a new [`CallNode`] instantiated with the specified callee and marked as a kernel
85    /// call.
86    pub fn new_syscall(
87        callee: MastNodeId,
88        mast_forest: &MastForest,
89    ) -> Result<Self, MastForestError> {
90        if callee.as_usize() >= mast_forest.nodes.len() {
91            return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len()));
92        }
93        let digest = {
94            let callee_digest = mast_forest[callee].digest();
95
96            hasher::merge_in_domain(&[callee_digest, Word::default()], Self::SYSCALL_DOMAIN)
97        };
98
99        Ok(Self {
100            callee,
101            is_syscall: true,
102            digest,
103            before_enter: Vec::new(),
104            after_exit: Vec::new(),
105        })
106    }
107
108    /// Returns a new syscall [`CallNode`] from values that are assumed to be correct.
109    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
110    pub fn new_syscall_unsafe(callee: MastNodeId, digest: Word) -> Self {
111        Self {
112            callee,
113            is_syscall: true,
114            digest,
115            before_enter: Vec::new(),
116            after_exit: Vec::new(),
117        }
118    }
119}
120
121//-------------------------------------------------------------------------------------------------
122/// Public accessors
123impl CallNode {
124    /// Returns the ID of the node to be invoked by this call node.
125    pub fn callee(&self) -> MastNodeId {
126        self.callee
127    }
128
129    /// Returns true if this call node represents a syscall.
130    pub fn is_syscall(&self) -> bool {
131        self.is_syscall
132    }
133
134    /// Returns the domain of this call node.
135    pub fn domain(&self) -> Felt {
136        if self.is_syscall() {
137            Self::SYSCALL_DOMAIN
138        } else {
139            Self::CALL_DOMAIN
140        }
141    }
142}
143
144impl MastNodeErrorContext for CallNode {
145    fn decorators(&self) -> impl Iterator<Item = (usize, DecoratorId)> {
146        self.before_enter.iter().chain(&self.after_exit).copied().enumerate()
147    }
148}
149
150// PRETTY PRINTING
151// ================================================================================================
152
153impl CallNode {
154    pub(super) fn to_pretty_print<'a>(
155        &'a self,
156        mast_forest: &'a MastForest,
157    ) -> impl PrettyPrint + 'a {
158        CallNodePrettyPrint { node: self, mast_forest }
159    }
160
161    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
162        CallNodePrettyPrint { node: self, mast_forest }
163    }
164}
165
166struct CallNodePrettyPrint<'a> {
167    node: &'a CallNode,
168    mast_forest: &'a MastForest,
169}
170
171impl CallNodePrettyPrint<'_> {
172    /// Concatenates the provided decorators in a single line. If the list of decorators is not
173    /// empty, prepends `prepend` and appends `append` to the decorator document.
174    fn concatenate_decorators(
175        &self,
176        decorator_ids: &[DecoratorId],
177        prepend: Document,
178        append: Document,
179    ) -> Document {
180        let decorators = decorator_ids
181            .iter()
182            .map(|&decorator_id| self.mast_forest[decorator_id].render())
183            .reduce(|acc, doc| acc + const_text(" ") + doc)
184            .unwrap_or_default();
185
186        if decorators.is_empty() {
187            decorators
188        } else {
189            prepend + decorators + append
190        }
191    }
192
193    fn single_line_pre_decorators(&self) -> Document {
194        self.concatenate_decorators(self.node.before_enter(), Document::Empty, const_text(" "))
195    }
196
197    fn single_line_post_decorators(&self) -> Document {
198        self.concatenate_decorators(self.node.after_exit(), const_text(" "), Document::Empty)
199    }
200
201    fn multi_line_pre_decorators(&self) -> Document {
202        self.concatenate_decorators(self.node.before_enter(), Document::Empty, nl())
203    }
204
205    fn multi_line_post_decorators(&self) -> Document {
206        self.concatenate_decorators(self.node.after_exit(), nl(), Document::Empty)
207    }
208}
209
210impl PrettyPrint for CallNodePrettyPrint<'_> {
211    fn render(&self) -> Document {
212        let call_or_syscall = {
213            let callee_digest = self.mast_forest[self.node.callee].digest();
214            if self.node.is_syscall {
215                const_text("syscall")
216                    + const_text(".")
217                    + text(callee_digest.as_bytes().to_hex_with_prefix())
218            } else {
219                const_text("call")
220                    + const_text(".")
221                    + text(callee_digest.as_bytes().to_hex_with_prefix())
222            }
223        };
224
225        let single_line = self.single_line_pre_decorators()
226            + call_or_syscall.clone()
227            + self.single_line_post_decorators();
228        let multi_line =
229            self.multi_line_pre_decorators() + call_or_syscall + self.multi_line_post_decorators();
230
231        single_line | multi_line
232    }
233}
234
235impl fmt::Display for CallNodePrettyPrint<'_> {
236    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237        use crate::prettier::PrettyPrint;
238        self.pretty_print(f)
239    }
240}
241
242// MAST NODE TRAIT IMPLEMENTATION
243// ================================================================================================
244
245impl MastNodeExt for CallNode {
246    /// Returns a commitment to this Call node.
247    ///
248    /// The commitment is computed as a hash of the callee and an empty word ([ZERO; 4]) in the
249    /// domain defined by either [Self::CALL_DOMAIN] or [Self::SYSCALL_DOMAIN], depending on
250    /// whether the node represents a simple call or a syscall - i.e.,:
251    /// ```
252    /// # use miden_core::mast::CallNode;
253    /// # use miden_crypto::{Word, hash::rpo::Rpo256 as Hasher};
254    /// # let callee_digest = Word::default();
255    /// Hasher::merge_in_domain(&[callee_digest, Word::default()], CallNode::CALL_DOMAIN);
256    /// ```
257    /// or
258    /// ```
259    /// # use miden_core::mast::CallNode;
260    /// # use miden_crypto::{Word, hash::rpo::Rpo256 as Hasher};
261    /// # let callee_digest = Word::default();
262    /// Hasher::merge_in_domain(&[callee_digest, Word::default()], CallNode::SYSCALL_DOMAIN);
263    /// ```
264    fn digest(&self) -> Word {
265        self.digest
266    }
267
268    /// Returns the decorators to be executed before this node is executed.
269    fn before_enter(&self) -> &[DecoratorId] {
270        &self.before_enter
271    }
272
273    /// Returns the decorators to be executed after this node is executed.
274    fn after_exit(&self) -> &[DecoratorId] {
275        &self.after_exit
276    }
277
278    /// Sets the list of decorators to be executed before this node.
279    fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
280        self.before_enter.extend_from_slice(decorator_ids);
281    }
282
283    /// Sets the list of decorators to be executed after this node.
284    fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
285        self.after_exit.extend_from_slice(decorator_ids);
286    }
287
288    /// Removes all decorators from this node.
289    fn remove_decorators(&mut self) {
290        self.before_enter.truncate(0);
291        self.after_exit.truncate(0);
292    }
293
294    fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
295        Box::new(CallNode::to_display(self, mast_forest))
296    }
297
298    fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
299        Box::new(CallNode::to_pretty_print(self, mast_forest))
300    }
301
302    fn remap_children(&self, remapping: &Remapping) -> Self {
303        let mut node = self.clone();
304        node.callee = node.callee.remap(remapping);
305        node
306    }
307
308    fn has_children(&self) -> bool {
309        true
310    }
311
312    fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
313        target.push(self.callee());
314    }
315
316    fn domain(&self) -> Felt {
317        self.domain()
318    }
319}