Skip to main content

miden_core/mast/node/
call_node.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4use miden_formatting::{
5    hex::ToHex,
6    prettier::{Document, PrettyPrint, const_text, text},
7};
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use super::{MastForestContributor, MastNodeExt, fingerprint_with_child_fingerprints};
12use crate::{
13    Felt, Word,
14    chiplets::hasher,
15    mast::{MastForest, MastForestError, MastNodeId},
16    operations::opcodes,
17    utils::{Idx, LookupByIdx},
18};
19
20// CALL NODE
21// ================================================================================================
22
23/// A Call node describes a function call such that the callee is executed in a different execution
24/// context from the currently executing code.
25///
26/// A call node can be of two types:
27/// - A simple call: the callee is executed in the new user context.
28/// - A syscall: the callee is executed in the root context.
29#[derive(Debug, Clone, PartialEq, Eq)]
30#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
31#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
32pub struct CallNode {
33    callee: MastNodeId,
34    is_syscall: bool,
35    digest: Word,
36}
37
38//-------------------------------------------------------------------------------------------------
39/// Constants
40impl CallNode {
41    /// The domain of the call block (used for control block hashing).
42    pub const CALL_DOMAIN: Felt = Felt::new_unchecked(opcodes::CALL as u64);
43    /// The domain of the syscall block (used for control block hashing).
44    pub const SYSCALL_DOMAIN: Felt = Felt::new_unchecked(opcodes::SYSCALL as u64);
45}
46
47//-------------------------------------------------------------------------------------------------
48/// Public accessors
49impl CallNode {
50    /// Returns the ID of the node to be invoked by this call node.
51    pub fn callee(&self) -> MastNodeId {
52        self.callee
53    }
54
55    /// Returns true if this call node represents a syscall.
56    pub fn is_syscall(&self) -> bool {
57        self.is_syscall
58    }
59
60    /// Returns the domain of this call node.
61    pub fn domain(&self) -> Felt {
62        if self.is_syscall() {
63            Self::SYSCALL_DOMAIN
64        } else {
65            Self::CALL_DOMAIN
66        }
67    }
68}
69
70// PRETTY PRINTING
71// ================================================================================================
72
73impl CallNode {
74    pub(super) fn to_pretty_print<'a>(
75        &'a self,
76        mast_forest: &'a MastForest,
77    ) -> impl PrettyPrint + 'a {
78        CallNodePrettyPrint { node: self, mast_forest }
79    }
80
81    pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
82        CallNodePrettyPrint { node: self, mast_forest }
83    }
84}
85
86struct CallNodePrettyPrint<'a> {
87    node: &'a CallNode,
88    mast_forest: &'a MastForest,
89}
90
91impl PrettyPrint for CallNodePrettyPrint<'_> {
92    fn render(&self) -> Document {
93        let callee_digest = self.mast_forest[self.node.callee].digest();
94        if self.node.is_syscall {
95            const_text("syscall")
96                + const_text(".")
97                + text(callee_digest.as_bytes().to_hex_with_prefix())
98        } else {
99            const_text("call")
100                + const_text(".")
101                + text(callee_digest.as_bytes().to_hex_with_prefix())
102        }
103    }
104}
105
106impl fmt::Display for CallNodePrettyPrint<'_> {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        use crate::prettier::PrettyPrint;
109        self.pretty_print(f)
110    }
111}
112
113// MAST NODE TRAIT IMPLEMENTATION
114// ================================================================================================
115
116impl MastNodeExt for CallNode {
117    /// Returns a commitment to this Call node.
118    ///
119    /// The commitment is computed as a hash of the callee and an empty word ([ZERO; 4]) in the
120    /// domain defined by either [Self::CALL_DOMAIN] or [Self::SYSCALL_DOMAIN], depending on
121    /// whether the node represents a simple call or a syscall - i.e.,:
122    /// ```
123    /// # use miden_core::mast::CallNode;
124    /// # use miden_crypto::{Word, hash::poseidon2::Poseidon2 as Hasher};
125    /// # let callee_digest = Word::default();
126    /// Hasher::merge_in_domain(&[callee_digest, Word::default()], CallNode::CALL_DOMAIN);
127    /// ```
128    /// or
129    /// ```
130    /// # use miden_core::mast::CallNode;
131    /// # use miden_crypto::{Word, hash::poseidon2::Poseidon2 as Hasher};
132    /// # let callee_digest = Word::default();
133    /// Hasher::merge_in_domain(&[callee_digest, Word::default()], CallNode::SYSCALL_DOMAIN);
134    /// ```
135    fn digest(&self) -> Word {
136        self.digest
137    }
138
139    fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
140        Box::new(CallNode::to_display(self, mast_forest))
141    }
142
143    fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
144        Box::new(CallNode::to_pretty_print(self, mast_forest))
145    }
146
147    fn has_children(&self) -> bool {
148        true
149    }
150
151    fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
152        target.push(self.callee());
153    }
154
155    fn for_each_child<F>(&self, mut f: F)
156    where
157        F: FnMut(MastNodeId),
158    {
159        f(self.callee());
160    }
161
162    fn domain(&self) -> Felt {
163        self.domain()
164    }
165
166    type Builder = CallNodeBuilder;
167
168    fn to_builder(self, _forest: &MastForest) -> Self::Builder {
169        let builder = if self.is_syscall {
170            CallNodeBuilder::new_syscall(self.callee)
171        } else {
172            CallNodeBuilder::new(self.callee)
173        };
174        builder.with_digest(self.digest)
175    }
176}
177
178// ARBITRARY IMPLEMENTATION
179// ================================================================================================
180
181#[cfg(all(feature = "arbitrary", test))]
182impl proptest::prelude::Arbitrary for CallNode {
183    type Parameters = ();
184
185    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
186        use proptest::prelude::*;
187
188        use crate::Felt;
189
190        // Generate callee, digest, and whether it's a syscall
191        (any::<MastNodeId>(), any::<[u64; 4]>(), any::<bool>())
192            .prop_map(|(callee, digest_array, is_syscall)| {
193                // Generate a random digest
194                let digest = Word::from(digest_array.map(Felt::new_unchecked));
195                // Construct directly to avoid MastForest validation for arbitrary data
196                CallNode {
197                    callee,
198                    is_syscall,
199                    digest,
200                }
201            })
202            .no_shrink()  // Pure random values, no meaningful shrinking pattern
203            .boxed()
204    }
205
206    type Strategy = proptest::prelude::BoxedStrategy<Self>;
207}
208
209// ------------------------------------------------------------------------------------------------
210/// Builder for creating [`CallNode`] instances.
211#[derive(Debug)]
212pub struct CallNodeBuilder {
213    callee: MastNodeId,
214    is_syscall: bool,
215    digest: Option<Word>,
216}
217
218impl CallNodeBuilder {
219    /// Creates a new builder for a CallNode with the specified callee.
220    pub fn new(callee: MastNodeId) -> Self {
221        Self { callee, is_syscall: false, digest: None }
222    }
223
224    /// Creates a new builder for a syscall CallNode with the specified callee.
225    pub fn new_syscall(callee: MastNodeId) -> Self {
226        Self { callee, is_syscall: true, digest: None }
227    }
228
229    /// Builds the CallNode.
230    pub fn build(self, mast_forest: &MastForest) -> Result<CallNode, MastForestError> {
231        if self.callee.to_usize() >= mast_forest.nodes.len() {
232            return Err(MastForestError::NodeIdOverflow(self.callee, mast_forest.nodes.len()));
233        }
234
235        // Use the forced digest if provided, otherwise compute the digest
236        let digest = if let Some(forced_digest) = self.digest {
237            forced_digest
238        } else {
239            let callee_digest = mast_forest[self.callee].digest();
240            let domain = if self.is_syscall {
241                CallNode::SYSCALL_DOMAIN
242            } else {
243                CallNode::CALL_DOMAIN
244            };
245
246            hasher::merge_in_domain(&[callee_digest, Word::default()], domain)
247        };
248
249        Ok(CallNode {
250            callee: self.callee,
251            is_syscall: self.is_syscall,
252            digest,
253        })
254    }
255
256    pub(in crate::mast) fn build_linked(self) -> Result<CallNode, MastForestError> {
257        Ok(CallNode {
258            callee: self.callee,
259            is_syscall: self.is_syscall,
260            digest: self.digest.ok_or(MastForestError::DigestRequiredForDeserialization)?,
261        })
262    }
263}
264
265impl MastForestContributor for CallNodeBuilder {
266    fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
267        if self.callee.to_usize() >= forest.nodes.len() {
268            return Err(MastForestError::NodeIdOverflow(self.callee, forest.nodes.len()));
269        }
270
271        // Use the forced digest if provided, otherwise compute the digest directly
272        let digest = if let Some(forced_digest) = self.digest {
273            forced_digest
274        } else {
275            let callee_digest = forest[self.callee].digest();
276            let domain = if self.is_syscall {
277                CallNode::SYSCALL_DOMAIN
278            } else {
279                CallNode::CALL_DOMAIN
280            };
281
282            hasher::merge_in_domain(&[callee_digest, Word::default()], domain)
283        };
284
285        // Create the node in the forest with Linked variant from the start
286        // Move the data directly without intermediate Owned node creation
287        let node_id = forest
288            .nodes
289            .push(
290                CallNode {
291                    callee: self.callee,
292                    is_syscall: self.is_syscall,
293                    digest,
294                }
295                .into(),
296            )
297            .map_err(|_| MastForestError::TooManyNodes)?;
298
299        Ok(node_id)
300    }
301
302    fn fingerprint_for_node(
303        &self,
304        forest: &MastForest,
305        hash_by_node_id: &impl LookupByIdx<MastNodeId, Word>,
306    ) -> Result<Word, MastForestError> {
307        let node_digest = if let Some(forced_digest) = self.digest {
308            forced_digest
309        } else {
310            let callee_digest = forest[self.callee].digest();
311            let domain = if self.is_syscall {
312                CallNode::SYSCALL_DOMAIN
313            } else {
314                CallNode::CALL_DOMAIN
315            };
316
317            hasher::merge_in_domain(&[callee_digest, Word::default()], domain)
318        };
319
320        fingerprint_with_child_fingerprints(node_digest, &[self.callee], forest, hash_by_node_id)
321    }
322
323    fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
324        CallNodeBuilder {
325            callee: *remapping.get(self.callee).unwrap_or(&self.callee),
326            is_syscall: self.is_syscall,
327            digest: self.digest,
328        }
329    }
330
331    fn with_digest(mut self, digest: Word) -> Self {
332        self.digest = Some(digest);
333        self
334    }
335}
336
337impl CallNodeBuilder {
338    /// Add this node to a forest using relaxed validation.
339    ///
340    /// This method is used during deserialization where nodes may reference child nodes
341    /// that haven't been added to the forest yet. The child node IDs have already been
342    /// validated against the expected final node count during the `try_into_mast_node_builder`
343    /// step, so we can safely skip validation here.
344    ///
345    /// Note: This is not part of the `MastForestContributor` trait because it's only
346    /// intended for internal use during deserialization.
347    pub(in crate::mast) fn add_to_forest_relaxed(
348        self,
349        forest: &mut MastForest,
350    ) -> Result<MastNodeId, MastForestError> {
351        // Use the forced digest if provided, otherwise use a default digest
352        // The actual digest computation will be handled when the forest is complete
353        let Some(digest) = self.digest else {
354            return Err(MastForestError::DigestRequiredForDeserialization);
355        };
356
357        // Create the node in the forest with Linked variant from the start
358        // Move the data directly without intermediate cloning
359        let node_id = forest
360            .nodes
361            .push(
362                CallNode {
363                    callee: self.callee,
364                    is_syscall: self.is_syscall,
365                    digest,
366                }
367                .into(),
368            )
369            .map_err(|_| MastForestError::TooManyNodes)?;
370
371        Ok(node_id)
372    }
373}
374
375#[cfg(any(test, feature = "arbitrary"))]
376impl proptest::prelude::Arbitrary for CallNodeBuilder {
377    type Parameters = CallNodeBuilderParams;
378    type Strategy = proptest::strategy::BoxedStrategy<Self>;
379
380    fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
381        use proptest::prelude::*;
382
383        let _ = params;
384        (any::<MastNodeId>(), any::<bool>())
385            .prop_map(|(callee, is_syscall)| {
386                if is_syscall {
387                    Self::new_syscall(callee)
388                } else {
389                    Self::new(callee)
390                }
391            })
392            .boxed()
393    }
394}
395
396/// Parameters for generating CallNodeBuilder instances
397#[cfg(any(test, feature = "arbitrary"))]
398#[derive(Clone, Debug, Default)]
399pub struct CallNodeBuilderParams {}