miden_core/mast/node/
call_node.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use miden_crypto::{hash::rpo::RpoDigest, Felt};
5use miden_formatting::{
6    hex::ToHex,
7    prettier::{const_text, nl, text, Document, PrettyPrint},
8};
9
10use crate::{
11    chiplets::hasher,
12    mast::{DecoratorId, MastForest, MastForestError, MastNodeId},
13    OPCODE_CALL, OPCODE_SYSCALL,
14};
15
16// CALL NODE
17// ================================================================================================
18
19/// A Call node describes a function call such that the callee is executed in a different execution
20/// context from the currently executing code.
21///
22/// A call node can be of two types:
23/// - A simple call: the callee is executed in the new user context.
24/// - A syscall: the callee is executed in the root context.
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct CallNode {
27    callee: MastNodeId,
28    is_syscall: bool,
29    digest: RpoDigest,
30    before_enter: Vec<DecoratorId>,
31    after_exit: Vec<DecoratorId>,
32}
33
34//-------------------------------------------------------------------------------------------------
35/// Constants
36impl CallNode {
37    /// The domain of the call block (used for control block hashing).
38    pub const CALL_DOMAIN: Felt = Felt::new(OPCODE_CALL as u64);
39    /// The domain of the syscall block (used for control block hashing).
40    pub const SYSCALL_DOMAIN: Felt = Felt::new(OPCODE_SYSCALL as u64);
41}
42
43//-------------------------------------------------------------------------------------------------
44/// Constructors
45impl CallNode {
46    /// Returns a new [`CallNode`] instantiated with the specified callee.
47    pub fn new(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
48        if callee.as_usize() >= mast_forest.nodes.len() {
49            return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len()));
50        }
51        let digest = {
52            let callee_digest = mast_forest[callee].digest();
53
54            hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::CALL_DOMAIN)
55        };
56
57        Ok(Self {
58            callee,
59            is_syscall: false,
60            digest,
61            before_enter: Vec::new(),
62            after_exit: Vec::new(),
63        })
64    }
65
66    /// Returns a new [`CallNode`] from values that are assumed to be correct.
67    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
68    pub fn new_unsafe(callee: MastNodeId, digest: RpoDigest) -> Self {
69        Self {
70            callee,
71            is_syscall: false,
72            digest,
73            before_enter: Vec::new(),
74            after_exit: Vec::new(),
75        }
76    }
77
78    /// Returns a new [`CallNode`] instantiated with the specified callee and marked as a kernel
79    /// call.
80    pub fn new_syscall(
81        callee: MastNodeId,
82        mast_forest: &MastForest,
83    ) -> Result<Self, MastForestError> {
84        if callee.as_usize() >= mast_forest.nodes.len() {
85            return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len()));
86        }
87        let digest = {
88            let callee_digest = mast_forest[callee].digest();
89
90            hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::SYSCALL_DOMAIN)
91        };
92
93        Ok(Self {
94            callee,
95            is_syscall: true,
96            digest,
97            before_enter: Vec::new(),
98            after_exit: Vec::new(),
99        })
100    }
101
102    /// Returns a new syscall [`CallNode`] from values that are assumed to be correct.
103    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
104    pub fn new_syscall_unsafe(callee: MastNodeId, digest: RpoDigest) -> Self {
105        Self {
106            callee,
107            is_syscall: true,
108            digest,
109            before_enter: Vec::new(),
110            after_exit: Vec::new(),
111        }
112    }
113}
114
115//-------------------------------------------------------------------------------------------------
116/// Public accessors
117impl CallNode {
118    /// Returns a commitment to this Call node.
119    ///
120    /// The commitment is computed as a hash of the callee and an empty word ([ZERO; 4]) in the
121    /// domain defined by either [Self::CALL_DOMAIN] or [Self::SYSCALL_DOMAIN], depending on
122    /// whether the node represents a simple call or a syscall - i.e.,:
123    /// ```
124    /// # use miden_core::mast::CallNode;
125    /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}};
126    /// # let callee_digest = Digest::default();
127    /// Hasher::merge_in_domain(&[callee_digest, Digest::default()], CallNode::CALL_DOMAIN);
128    /// ```
129    /// or
130    /// ```
131    /// # use miden_core::mast::CallNode;
132    /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}};
133    /// # let callee_digest = Digest::default();
134    /// Hasher::merge_in_domain(&[callee_digest, Digest::default()], CallNode::SYSCALL_DOMAIN);
135    /// ```
136    pub fn digest(&self) -> RpoDigest {
137        self.digest
138    }
139
140    /// Returns the ID of the node to be invoked by this call node.
141    pub fn callee(&self) -> MastNodeId {
142        self.callee
143    }
144
145    /// Returns true if this call node represents a syscall.
146    pub fn is_syscall(&self) -> bool {
147        self.is_syscall
148    }
149
150    /// Returns the domain of this call node.
151    pub fn domain(&self) -> Felt {
152        if self.is_syscall() {
153            Self::SYSCALL_DOMAIN
154        } else {
155            Self::CALL_DOMAIN
156        }
157    }
158
159    /// Returns the decorators to be executed before this node is executed.
160    pub fn before_enter(&self) -> &[DecoratorId] {
161        &self.before_enter
162    }
163
164    /// Returns the decorators to be executed after this node is executed.
165    pub fn after_exit(&self) -> &[DecoratorId] {
166        &self.after_exit
167    }
168}
169
170/// Mutators
171impl CallNode {
172    /// Sets the list of decorators to be executed before this node.
173    pub fn set_before_enter(&mut self, decorator_ids: Vec<DecoratorId>) {
174        self.before_enter = decorator_ids;
175    }
176
177    /// Sets the list of decorators to be executed after this node.
178    pub fn set_after_exit(&mut self, decorator_ids: Vec<DecoratorId>) {
179        self.after_exit = decorator_ids;
180    }
181}
182
183// PRETTY PRINTING
184// ================================================================================================
185
186impl CallNode {
187    pub(super) fn to_pretty_print<'a>(
188        &'a self,
189        mast_forest: &'a MastForest,
190    ) -> impl PrettyPrint + 'a {
191        CallNodePrettyPrint { node: self, mast_forest }
192    }
193
194    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
195        CallNodePrettyPrint { node: self, mast_forest }
196    }
197}
198
199struct CallNodePrettyPrint<'a> {
200    node: &'a CallNode,
201    mast_forest: &'a MastForest,
202}
203
204impl CallNodePrettyPrint<'_> {
205    /// Concatenates the provided decorators in a single line. If the list of decorators is not
206    /// empty, prepends `prepend` and appends `append` to the decorator document.
207    fn concatenate_decorators(
208        &self,
209        decorator_ids: &[DecoratorId],
210        prepend: Document,
211        append: Document,
212    ) -> Document {
213        let decorators = decorator_ids
214            .iter()
215            .map(|&decorator_id| self.mast_forest[decorator_id].render())
216            .reduce(|acc, doc| acc + const_text(" ") + doc)
217            .unwrap_or_default();
218
219        if decorators.is_empty() {
220            decorators
221        } else {
222            prepend + decorators + append
223        }
224    }
225
226    fn single_line_pre_decorators(&self) -> Document {
227        self.concatenate_decorators(self.node.before_enter(), Document::Empty, const_text(" "))
228    }
229
230    fn single_line_post_decorators(&self) -> Document {
231        self.concatenate_decorators(self.node.after_exit(), const_text(" "), Document::Empty)
232    }
233
234    fn multi_line_pre_decorators(&self) -> Document {
235        self.concatenate_decorators(self.node.before_enter(), Document::Empty, nl())
236    }
237
238    fn multi_line_post_decorators(&self) -> Document {
239        self.concatenate_decorators(self.node.after_exit(), nl(), Document::Empty)
240    }
241}
242
243impl PrettyPrint for CallNodePrettyPrint<'_> {
244    fn render(&self) -> Document {
245        let call_or_syscall = {
246            let callee_digest = self.mast_forest[self.node.callee].digest();
247            if self.node.is_syscall {
248                const_text("syscall")
249                    + const_text(".")
250                    + text(callee_digest.as_bytes().to_hex_with_prefix())
251            } else {
252                const_text("call")
253                    + const_text(".")
254                    + text(callee_digest.as_bytes().to_hex_with_prefix())
255            }
256        };
257
258        let single_line = self.single_line_pre_decorators()
259            + call_or_syscall.clone()
260            + self.single_line_post_decorators();
261        let multi_line =
262            self.multi_line_pre_decorators() + call_or_syscall + self.multi_line_post_decorators();
263
264        single_line | multi_line
265    }
266}
267
268impl fmt::Display for CallNodePrettyPrint<'_> {
269    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270        use crate::prettier::PrettyPrint;
271        self.pretty_print(f)
272    }
273}