1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4use miden_formatting::{
5 hex::ToHex,
6 prettier::{Document, PrettyPrint, const_text, text},
7};
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use super::{MastForestContributor, MastNodeExt, fingerprint_with_child_fingerprints};
12use crate::{
13 Felt, Word,
14 chiplets::hasher,
15 mast::{MastForest, MastForestError, MastNodeId},
16 operations::opcodes,
17 utils::{Idx, LookupByIdx},
18};
19
20#[derive(Debug, Clone, PartialEq, Eq)]
30#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
31#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
32pub struct CallNode {
33 callee: MastNodeId,
34 is_syscall: bool,
35 digest: Word,
36}
37
38impl CallNode {
41 pub const CALL_DOMAIN: Felt = Felt::new_unchecked(opcodes::CALL as u64);
43 pub const SYSCALL_DOMAIN: Felt = Felt::new_unchecked(opcodes::SYSCALL as u64);
45}
46
47impl CallNode {
50 pub fn callee(&self) -> MastNodeId {
52 self.callee
53 }
54
55 pub fn is_syscall(&self) -> bool {
57 self.is_syscall
58 }
59
60 pub fn domain(&self) -> Felt {
62 if self.is_syscall() {
63 Self::SYSCALL_DOMAIN
64 } else {
65 Self::CALL_DOMAIN
66 }
67 }
68}
69
70impl CallNode {
74 pub(super) fn to_pretty_print<'a>(
75 &'a self,
76 mast_forest: &'a MastForest,
77 ) -> impl PrettyPrint + 'a {
78 CallNodePrettyPrint { node: self, mast_forest }
79 }
80
81 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
82 CallNodePrettyPrint { node: self, mast_forest }
83 }
84}
85
86struct CallNodePrettyPrint<'a> {
87 node: &'a CallNode,
88 mast_forest: &'a MastForest,
89}
90
91impl PrettyPrint for CallNodePrettyPrint<'_> {
92 fn render(&self) -> Document {
93 let callee_digest = self.mast_forest[self.node.callee].digest();
94 if self.node.is_syscall {
95 const_text("syscall")
96 + const_text(".")
97 + text(callee_digest.as_bytes().to_hex_with_prefix())
98 } else {
99 const_text("call")
100 + const_text(".")
101 + text(callee_digest.as_bytes().to_hex_with_prefix())
102 }
103 }
104}
105
106impl fmt::Display for CallNodePrettyPrint<'_> {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 use crate::prettier::PrettyPrint;
109 self.pretty_print(f)
110 }
111}
112
113impl MastNodeExt for CallNode {
117 fn digest(&self) -> Word {
136 self.digest
137 }
138
139 fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
140 Box::new(CallNode::to_display(self, mast_forest))
141 }
142
143 fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
144 Box::new(CallNode::to_pretty_print(self, mast_forest))
145 }
146
147 fn has_children(&self) -> bool {
148 true
149 }
150
151 fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
152 target.push(self.callee());
153 }
154
155 fn for_each_child<F>(&self, mut f: F)
156 where
157 F: FnMut(MastNodeId),
158 {
159 f(self.callee());
160 }
161
162 fn domain(&self) -> Felt {
163 self.domain()
164 }
165
166 type Builder = CallNodeBuilder;
167
168 fn to_builder(self, _forest: &MastForest) -> Self::Builder {
169 let builder = if self.is_syscall {
170 CallNodeBuilder::new_syscall(self.callee)
171 } else {
172 CallNodeBuilder::new(self.callee)
173 };
174 builder.with_digest(self.digest)
175 }
176}
177
178#[cfg(all(feature = "arbitrary", test))]
182impl proptest::prelude::Arbitrary for CallNode {
183 type Parameters = ();
184
185 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
186 use proptest::prelude::*;
187
188 use crate::Felt;
189
190 (any::<MastNodeId>(), any::<[u64; 4]>(), any::<bool>())
192 .prop_map(|(callee, digest_array, is_syscall)| {
193 let digest = Word::from(digest_array.map(Felt::new_unchecked));
195 CallNode {
197 callee,
198 is_syscall,
199 digest,
200 }
201 })
202 .no_shrink() .boxed()
204 }
205
206 type Strategy = proptest::prelude::BoxedStrategy<Self>;
207}
208
209#[derive(Debug)]
212pub struct CallNodeBuilder {
213 callee: MastNodeId,
214 is_syscall: bool,
215 digest: Option<Word>,
216}
217
218impl CallNodeBuilder {
219 pub fn new(callee: MastNodeId) -> Self {
221 Self { callee, is_syscall: false, digest: None }
222 }
223
224 pub fn new_syscall(callee: MastNodeId) -> Self {
226 Self { callee, is_syscall: true, digest: None }
227 }
228
229 pub fn build(self, mast_forest: &MastForest) -> Result<CallNode, MastForestError> {
231 if self.callee.to_usize() >= mast_forest.nodes.len() {
232 return Err(MastForestError::NodeIdOverflow(self.callee, mast_forest.nodes.len()));
233 }
234
235 let digest = if let Some(forced_digest) = self.digest {
237 forced_digest
238 } else {
239 let callee_digest = mast_forest[self.callee].digest();
240 let domain = if self.is_syscall {
241 CallNode::SYSCALL_DOMAIN
242 } else {
243 CallNode::CALL_DOMAIN
244 };
245
246 hasher::merge_in_domain(&[callee_digest, Word::default()], domain)
247 };
248
249 Ok(CallNode {
250 callee: self.callee,
251 is_syscall: self.is_syscall,
252 digest,
253 })
254 }
255
256 pub(in crate::mast) fn build_linked(self) -> Result<CallNode, MastForestError> {
257 Ok(CallNode {
258 callee: self.callee,
259 is_syscall: self.is_syscall,
260 digest: self.digest.ok_or(MastForestError::DigestRequiredForDeserialization)?,
261 })
262 }
263}
264
265impl MastForestContributor for CallNodeBuilder {
266 fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
267 if self.callee.to_usize() >= forest.nodes.len() {
268 return Err(MastForestError::NodeIdOverflow(self.callee, forest.nodes.len()));
269 }
270
271 let digest = if let Some(forced_digest) = self.digest {
273 forced_digest
274 } else {
275 let callee_digest = forest[self.callee].digest();
276 let domain = if self.is_syscall {
277 CallNode::SYSCALL_DOMAIN
278 } else {
279 CallNode::CALL_DOMAIN
280 };
281
282 hasher::merge_in_domain(&[callee_digest, Word::default()], domain)
283 };
284
285 let node_id = forest
288 .nodes
289 .push(
290 CallNode {
291 callee: self.callee,
292 is_syscall: self.is_syscall,
293 digest,
294 }
295 .into(),
296 )
297 .map_err(|_| MastForestError::TooManyNodes)?;
298
299 Ok(node_id)
300 }
301
302 fn fingerprint_for_node(
303 &self,
304 forest: &MastForest,
305 hash_by_node_id: &impl LookupByIdx<MastNodeId, Word>,
306 ) -> Result<Word, MastForestError> {
307 let node_digest = if let Some(forced_digest) = self.digest {
308 forced_digest
309 } else {
310 let callee_digest = forest[self.callee].digest();
311 let domain = if self.is_syscall {
312 CallNode::SYSCALL_DOMAIN
313 } else {
314 CallNode::CALL_DOMAIN
315 };
316
317 hasher::merge_in_domain(&[callee_digest, Word::default()], domain)
318 };
319
320 fingerprint_with_child_fingerprints(node_digest, &[self.callee], forest, hash_by_node_id)
321 }
322
323 fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
324 CallNodeBuilder {
325 callee: *remapping.get(self.callee).unwrap_or(&self.callee),
326 is_syscall: self.is_syscall,
327 digest: self.digest,
328 }
329 }
330
331 fn with_digest(mut self, digest: Word) -> Self {
332 self.digest = Some(digest);
333 self
334 }
335}
336
337impl CallNodeBuilder {
338 pub(in crate::mast) fn add_to_forest_relaxed(
348 self,
349 forest: &mut MastForest,
350 ) -> Result<MastNodeId, MastForestError> {
351 let Some(digest) = self.digest else {
354 return Err(MastForestError::DigestRequiredForDeserialization);
355 };
356
357 let node_id = forest
360 .nodes
361 .push(
362 CallNode {
363 callee: self.callee,
364 is_syscall: self.is_syscall,
365 digest,
366 }
367 .into(),
368 )
369 .map_err(|_| MastForestError::TooManyNodes)?;
370
371 Ok(node_id)
372 }
373}
374
375#[cfg(any(test, feature = "arbitrary"))]
376impl proptest::prelude::Arbitrary for CallNodeBuilder {
377 type Parameters = CallNodeBuilderParams;
378 type Strategy = proptest::strategy::BoxedStrategy<Self>;
379
380 fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
381 use proptest::prelude::*;
382
383 let _ = params;
384 (any::<MastNodeId>(), any::<bool>())
385 .prop_map(|(callee, is_syscall)| {
386 if is_syscall {
387 Self::new_syscall(callee)
388 } else {
389 Self::new(callee)
390 }
391 })
392 .boxed()
393 }
394}
395
396#[cfg(any(test, feature = "arbitrary"))]
398#[derive(Clone, Debug, Default)]
399pub struct CallNodeBuilderParams {}