miden_core/mast/node/
call_node.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use miden_crypto::{Felt, hash::rpo::RpoDigest};
5use miden_formatting::{
6    hex::ToHex,
7    prettier::{Document, PrettyPrint, const_text, nl, text},
8};
9
10use super::MastNodeExt;
11use crate::{
12    OPCODE_CALL, OPCODE_SYSCALL,
13    chiplets::hasher,
14    mast::{DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
15};
16
17// CALL NODE
18// ================================================================================================
19
20/// A Call node describes a function call such that the callee is executed in a different execution
21/// context from the currently executing code.
22///
23/// A call node can be of two types:
24/// - A simple call: the callee is executed in the new user context.
25/// - A syscall: the callee is executed in the root context.
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct CallNode {
28    callee: MastNodeId,
29    is_syscall: bool,
30    digest: RpoDigest,
31    before_enter: Vec<DecoratorId>,
32    after_exit: Vec<DecoratorId>,
33}
34
35//-------------------------------------------------------------------------------------------------
36/// Constants
37impl CallNode {
38    /// The domain of the call block (used for control block hashing).
39    pub const CALL_DOMAIN: Felt = Felt::new(OPCODE_CALL as u64);
40    /// The domain of the syscall block (used for control block hashing).
41    pub const SYSCALL_DOMAIN: Felt = Felt::new(OPCODE_SYSCALL as u64);
42}
43
44//-------------------------------------------------------------------------------------------------
45/// Constructors
46impl CallNode {
47    /// Returns a new [`CallNode`] instantiated with the specified callee.
48    pub fn new(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
49        if callee.as_usize() >= mast_forest.nodes.len() {
50            return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len()));
51        }
52        let digest = {
53            let callee_digest = mast_forest[callee].digest();
54
55            hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::CALL_DOMAIN)
56        };
57
58        Ok(Self {
59            callee,
60            is_syscall: false,
61            digest,
62            before_enter: Vec::new(),
63            after_exit: Vec::new(),
64        })
65    }
66
67    /// Returns a new [`CallNode`] from values that are assumed to be correct.
68    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
69    pub fn new_unsafe(callee: MastNodeId, digest: RpoDigest) -> Self {
70        Self {
71            callee,
72            is_syscall: false,
73            digest,
74            before_enter: Vec::new(),
75            after_exit: Vec::new(),
76        }
77    }
78
79    /// Returns a new [`CallNode`] instantiated with the specified callee and marked as a kernel
80    /// call.
81    pub fn new_syscall(
82        callee: MastNodeId,
83        mast_forest: &MastForest,
84    ) -> Result<Self, MastForestError> {
85        if callee.as_usize() >= mast_forest.nodes.len() {
86            return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len()));
87        }
88        let digest = {
89            let callee_digest = mast_forest[callee].digest();
90
91            hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::SYSCALL_DOMAIN)
92        };
93
94        Ok(Self {
95            callee,
96            is_syscall: true,
97            digest,
98            before_enter: Vec::new(),
99            after_exit: Vec::new(),
100        })
101    }
102
103    /// Returns a new syscall [`CallNode`] from values that are assumed to be correct.
104    /// Should only be used when the source of the inputs is trusted (e.g. deserialization).
105    pub fn new_syscall_unsafe(callee: MastNodeId, digest: RpoDigest) -> Self {
106        Self {
107            callee,
108            is_syscall: true,
109            digest,
110            before_enter: Vec::new(),
111            after_exit: Vec::new(),
112        }
113    }
114}
115
116//-------------------------------------------------------------------------------------------------
117/// Public accessors
118impl CallNode {
119    /// Returns a commitment to this Call node.
120    ///
121    /// The commitment is computed as a hash of the callee and an empty word ([ZERO; 4]) in the
122    /// domain defined by either [Self::CALL_DOMAIN] or [Self::SYSCALL_DOMAIN], depending on
123    /// whether the node represents a simple call or a syscall - i.e.,:
124    /// ```
125    /// # use miden_core::mast::CallNode;
126    /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}};
127    /// # let callee_digest = Digest::default();
128    /// Hasher::merge_in_domain(&[callee_digest, Digest::default()], CallNode::CALL_DOMAIN);
129    /// ```
130    /// or
131    /// ```
132    /// # use miden_core::mast::CallNode;
133    /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}};
134    /// # let callee_digest = Digest::default();
135    /// Hasher::merge_in_domain(&[callee_digest, Digest::default()], CallNode::SYSCALL_DOMAIN);
136    /// ```
137    pub fn digest(&self) -> RpoDigest {
138        self.digest
139    }
140
141    /// Returns the ID of the node to be invoked by this call node.
142    pub fn callee(&self) -> MastNodeId {
143        self.callee
144    }
145
146    /// Returns true if this call node represents a syscall.
147    pub fn is_syscall(&self) -> bool {
148        self.is_syscall
149    }
150
151    /// Returns the domain of this call node.
152    pub fn domain(&self) -> Felt {
153        if self.is_syscall() {
154            Self::SYSCALL_DOMAIN
155        } else {
156            Self::CALL_DOMAIN
157        }
158    }
159
160    /// Returns the decorators to be executed before this node is executed.
161    pub fn before_enter(&self) -> &[DecoratorId] {
162        &self.before_enter
163    }
164
165    /// Returns the decorators to be executed after this node is executed.
166    pub fn after_exit(&self) -> &[DecoratorId] {
167        &self.after_exit
168    }
169}
170
171/// Mutators
172impl CallNode {
173    pub fn remap_children(&self, remapping: &Remapping) -> Self {
174        let mut node = self.clone();
175        node.callee = node.callee.remap(remapping);
176        node
177    }
178
179    /// Sets the list of decorators to be executed before this node.
180    pub fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
181        self.before_enter.extend_from_slice(decorator_ids);
182    }
183
184    /// Sets the list of decorators to be executed after this node.
185    pub fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
186        self.after_exit.extend_from_slice(decorator_ids);
187    }
188}
189
190impl MastNodeExt for CallNode {
191    fn decorators(&self) -> impl Iterator<Item = (usize, DecoratorId)> {
192        self.before_enter.iter().chain(&self.after_exit).copied().enumerate()
193    }
194}
195
196// PRETTY PRINTING
197// ================================================================================================
198
199impl CallNode {
200    pub(super) fn to_pretty_print<'a>(
201        &'a self,
202        mast_forest: &'a MastForest,
203    ) -> impl PrettyPrint + 'a {
204        CallNodePrettyPrint { node: self, mast_forest }
205    }
206
207    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
208        CallNodePrettyPrint { node: self, mast_forest }
209    }
210}
211
212struct CallNodePrettyPrint<'a> {
213    node: &'a CallNode,
214    mast_forest: &'a MastForest,
215}
216
217impl CallNodePrettyPrint<'_> {
218    /// Concatenates the provided decorators in a single line. If the list of decorators is not
219    /// empty, prepends `prepend` and appends `append` to the decorator document.
220    fn concatenate_decorators(
221        &self,
222        decorator_ids: &[DecoratorId],
223        prepend: Document,
224        append: Document,
225    ) -> Document {
226        let decorators = decorator_ids
227            .iter()
228            .map(|&decorator_id| self.mast_forest[decorator_id].render())
229            .reduce(|acc, doc| acc + const_text(" ") + doc)
230            .unwrap_or_default();
231
232        if decorators.is_empty() {
233            decorators
234        } else {
235            prepend + decorators + append
236        }
237    }
238
239    fn single_line_pre_decorators(&self) -> Document {
240        self.concatenate_decorators(self.node.before_enter(), Document::Empty, const_text(" "))
241    }
242
243    fn single_line_post_decorators(&self) -> Document {
244        self.concatenate_decorators(self.node.after_exit(), const_text(" "), Document::Empty)
245    }
246
247    fn multi_line_pre_decorators(&self) -> Document {
248        self.concatenate_decorators(self.node.before_enter(), Document::Empty, nl())
249    }
250
251    fn multi_line_post_decorators(&self) -> Document {
252        self.concatenate_decorators(self.node.after_exit(), nl(), Document::Empty)
253    }
254}
255
256impl PrettyPrint for CallNodePrettyPrint<'_> {
257    fn render(&self) -> Document {
258        let call_or_syscall = {
259            let callee_digest = self.mast_forest[self.node.callee].digest();
260            if self.node.is_syscall {
261                const_text("syscall")
262                    + const_text(".")
263                    + text(callee_digest.as_bytes().to_hex_with_prefix())
264            } else {
265                const_text("call")
266                    + const_text(".")
267                    + text(callee_digest.as_bytes().to_hex_with_prefix())
268            }
269        };
270
271        let single_line = self.single_line_pre_decorators()
272            + call_or_syscall.clone()
273            + self.single_line_post_decorators();
274        let multi_line =
275            self.multi_line_pre_decorators() + call_or_syscall + self.multi_line_post_decorators();
276
277        single_line | multi_line
278    }
279}
280
281impl fmt::Display for CallNodePrettyPrint<'_> {
282    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
283        use crate::prettier::PrettyPrint;
284        self.pretty_print(f)
285    }
286}