miden_core/mast/node/
call_node.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use miden_crypto::{Felt, Word};
5use miden_formatting::{
6    hex::ToHex,
7    prettier::{Document, PrettyPrint, const_text, nl, text},
8};
9
10use super::MastNodeExt;
11use crate::{
12    OPCODE_CALL, OPCODE_SYSCALL,
13    chiplets::hasher,
14    mast::{DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
15};
16
17// CALL NODE
18// ================================================================================================
19
20/// A Call node describes a function call such that the callee is executed in a different execution
21/// context from the currently executing code.
22///
23/// A call node can be of two types:
24/// - A simple call: the callee is executed in the new user context.
25/// - A syscall: the callee is executed in the root context.
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct CallNode {
28    callee: MastNodeId,
29    is_syscall: bool,
30    digest: Word,
31    before_enter: Vec<DecoratorId>,
32    after_exit: Vec<DecoratorId>,
33}
34
35//-------------------------------------------------------------------------------------------------
36/// Constants
37impl CallNode {
38    /// The domain of the call block (used for control block hashing).
39    pub const CALL_DOMAIN: Felt = Felt::new(OPCODE_CALL as u64);
40    /// The domain of the syscall block (used for control block hashing).
41    pub const SYSCALL_DOMAIN: Felt = Felt::new(OPCODE_SYSCALL as u64);
42}
43
44//-------------------------------------------------------------------------------------------------
45/// Constructors
46impl CallNode {
47    /// Returns a new [`CallNode`] instantiated with the specified callee.
48    pub fn new(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
49        if callee.as_usize() >= mast_forest.nodes.len() {
50            return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len()));
51        }
52        let digest = {
53            let callee_digest = mast_forest[callee].digest();
54
55            hasher::merge_in_domain(&[callee_digest, Word::default()], Self::CALL_DOMAIN)
56        };
57
58        Ok(Self {
59            callee,
60            is_syscall: false,
61            digest,
62            before_enter: Vec::new(),
63            after_exit: Vec::new(),
64        })
65    }
66
67    /// Returns a new [`CallNode`] from values that are assumed to be correct.
68    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
69    pub fn new_unsafe(callee: MastNodeId, digest: Word) -> Self {
70        Self {
71            callee,
72            is_syscall: false,
73            digest,
74            before_enter: Vec::new(),
75            after_exit: Vec::new(),
76        }
77    }
78
79    /// Returns a new [`CallNode`] instantiated with the specified callee and marked as a kernel
80    /// call.
81    pub fn new_syscall(
82        callee: MastNodeId,
83        mast_forest: &MastForest,
84    ) -> Result<Self, MastForestError> {
85        if callee.as_usize() >= mast_forest.nodes.len() {
86            return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len()));
87        }
88        let digest = {
89            let callee_digest = mast_forest[callee].digest();
90
91            hasher::merge_in_domain(&[callee_digest, Word::default()], Self::SYSCALL_DOMAIN)
92        };
93
94        Ok(Self {
95            callee,
96            is_syscall: true,
97            digest,
98            before_enter: Vec::new(),
99            after_exit: Vec::new(),
100        })
101    }
102
103    /// Returns a new syscall [`CallNode`] from values that are assumed to be correct.
104    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
105    pub fn new_syscall_unsafe(callee: MastNodeId, digest: Word) -> Self {
106        Self {
107            callee,
108            is_syscall: true,
109            digest,
110            before_enter: Vec::new(),
111            after_exit: Vec::new(),
112        }
113    }
114}
115
116//-------------------------------------------------------------------------------------------------
117/// Public accessors
118impl CallNode {
119    /// Returns a commitment to this Call node.
120    ///
121    /// The commitment is computed as a hash of the callee and an empty word ([ZERO; 4]) in the
122    /// domain defined by either [Self::CALL_DOMAIN] or [Self::SYSCALL_DOMAIN], depending on
123    /// whether the node represents a simple call or a syscall - i.e.,:
124    /// ```
125    /// # use miden_core::mast::CallNode;
126    /// # use miden_crypto::{Word, hash::rpo::Rpo256 as Hasher};
127    /// # let callee_digest = Word::default();
128    /// Hasher::merge_in_domain(&[callee_digest, Word::default()], CallNode::CALL_DOMAIN);
129    /// ```
130    /// or
131    /// ```
132    /// # use miden_core::mast::CallNode;
133    /// # use miden_crypto::{Word, hash::rpo::Rpo256 as Hasher};
134    /// # let callee_digest = Word::default();
135    /// Hasher::merge_in_domain(&[callee_digest, Word::default()], CallNode::SYSCALL_DOMAIN);
136    /// ```
137    pub fn digest(&self) -> Word {
138        self.digest
139    }
140
141    /// Returns the ID of the node to be invoked by this call node.
142    pub fn callee(&self) -> MastNodeId {
143        self.callee
144    }
145
146    /// Returns true if this call node represents a syscall.
147    pub fn is_syscall(&self) -> bool {
148        self.is_syscall
149    }
150
151    /// Returns the domain of this call node.
152    pub fn domain(&self) -> Felt {
153        if self.is_syscall() {
154            Self::SYSCALL_DOMAIN
155        } else {
156            Self::CALL_DOMAIN
157        }
158    }
159
160    /// Returns the decorators to be executed before this node is executed.
161    pub fn before_enter(&self) -> &[DecoratorId] {
162        &self.before_enter
163    }
164
165    /// Returns the decorators to be executed after this node is executed.
166    pub fn after_exit(&self) -> &[DecoratorId] {
167        &self.after_exit
168    }
169}
170
171//-------------------------------------------------------------------------------------------------
172/// Mutators
173impl CallNode {
174    pub fn remap_children(&self, remapping: &Remapping) -> Self {
175        let mut node = self.clone();
176        node.callee = node.callee.remap(remapping);
177        node
178    }
179
180    /// Sets the list of decorators to be executed before this node.
181    pub fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
182        self.before_enter.extend_from_slice(decorator_ids);
183    }
184
185    /// Sets the list of decorators to be executed after this node.
186    pub fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
187        self.after_exit.extend_from_slice(decorator_ids);
188    }
189
190    /// Removes all decorators from this node.
191    pub fn remove_decorators(&mut self) {
192        self.before_enter.truncate(0);
193        self.after_exit.truncate(0);
194    }
195}
196
197impl MastNodeExt for CallNode {
198    fn decorators(&self) -> impl Iterator<Item = (usize, DecoratorId)> {
199        self.before_enter.iter().chain(&self.after_exit).copied().enumerate()
200    }
201}
202
203// PRETTY PRINTING
204// ================================================================================================
205
206impl CallNode {
207    pub(super) fn to_pretty_print<'a>(
208        &'a self,
209        mast_forest: &'a MastForest,
210    ) -> impl PrettyPrint + 'a {
211        CallNodePrettyPrint { node: self, mast_forest }
212    }
213
214    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
215        CallNodePrettyPrint { node: self, mast_forest }
216    }
217}
218
219struct CallNodePrettyPrint<'a> {
220    node: &'a CallNode,
221    mast_forest: &'a MastForest,
222}
223
224impl CallNodePrettyPrint<'_> {
225    /// Concatenates the provided decorators in a single line. If the list of decorators is not
226    /// empty, prepends `prepend` and appends `append` to the decorator document.
227    fn concatenate_decorators(
228        &self,
229        decorator_ids: &[DecoratorId],
230        prepend: Document,
231        append: Document,
232    ) -> Document {
233        let decorators = decorator_ids
234            .iter()
235            .map(|&decorator_id| self.mast_forest[decorator_id].render())
236            .reduce(|acc, doc| acc + const_text(" ") + doc)
237            .unwrap_or_default();
238
239        if decorators.is_empty() {
240            decorators
241        } else {
242            prepend + decorators + append
243        }
244    }
245
246    fn single_line_pre_decorators(&self) -> Document {
247        self.concatenate_decorators(self.node.before_enter(), Document::Empty, const_text(" "))
248    }
249
250    fn single_line_post_decorators(&self) -> Document {
251        self.concatenate_decorators(self.node.after_exit(), const_text(" "), Document::Empty)
252    }
253
254    fn multi_line_pre_decorators(&self) -> Document {
255        self.concatenate_decorators(self.node.before_enter(), Document::Empty, nl())
256    }
257
258    fn multi_line_post_decorators(&self) -> Document {
259        self.concatenate_decorators(self.node.after_exit(), nl(), Document::Empty)
260    }
261}
262
263impl PrettyPrint for CallNodePrettyPrint<'_> {
264    fn render(&self) -> Document {
265        let call_or_syscall = {
266            let callee_digest = self.mast_forest[self.node.callee].digest();
267            if self.node.is_syscall {
268                const_text("syscall")
269                    + const_text(".")
270                    + text(callee_digest.as_bytes().to_hex_with_prefix())
271            } else {
272                const_text("call")
273                    + const_text(".")
274                    + text(callee_digest.as_bytes().to_hex_with_prefix())
275            }
276        };
277
278        let single_line = self.single_line_pre_decorators()
279            + call_or_syscall.clone()
280            + self.single_line_post_decorators();
281        let multi_line =
282            self.multi_line_pre_decorators() + call_or_syscall + self.multi_line_post_decorators();
283
284        single_line | multi_line
285    }
286}
287
288impl fmt::Display for CallNodePrettyPrint<'_> {
289    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
290        use crate::prettier::PrettyPrint;
291        self.pretty_print(f)
292    }
293}