1use alloc::{boxed::Box, vec::Vec};
2use core::fmt;
3
4use miden_crypto::{Felt, Word};
5use miden_formatting::{
6 hex::ToHex,
7 prettier::{Document, PrettyPrint, const_text, nl, text},
8};
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12use super::{MastForestContributor, MastNodeErrorContext, MastNodeExt};
13use crate::{
14 Idx, OPCODE_CALL, OPCODE_SYSCALL,
15 chiplets::hasher,
16 mast::{DecoratedOpLink, DecoratorId, DecoratorStore, MastForest, MastForestError, MastNodeId},
17};
18
19#[derive(Debug, Clone, PartialEq, Eq)]
29#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
30#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
31pub struct CallNode {
32 callee: MastNodeId,
33 is_syscall: bool,
34 digest: Word,
35 decorator_store: DecoratorStore,
36}
37
38impl CallNode {
41 pub const CALL_DOMAIN: Felt = Felt::new(OPCODE_CALL as u64);
43 pub const SYSCALL_DOMAIN: Felt = Felt::new(OPCODE_SYSCALL as u64);
45}
46
47impl CallNode {
50 pub fn callee(&self) -> MastNodeId {
52 self.callee
53 }
54
55 pub fn is_syscall(&self) -> bool {
57 self.is_syscall
58 }
59
60 pub fn domain(&self) -> Felt {
62 if self.is_syscall() {
63 Self::SYSCALL_DOMAIN
64 } else {
65 Self::CALL_DOMAIN
66 }
67 }
68}
69
70impl MastNodeErrorContext for CallNode {
71 fn decorators<'a>(
72 &'a self,
73 forest: &'a MastForest,
74 ) -> impl Iterator<Item = DecoratedOpLink> + 'a {
75 let before_enter = self.decorator_store.before_enter(forest);
77 let after_exit = self.decorator_store.after_exit(forest);
78
79 before_enter
81 .iter()
82 .map(|&deco_id| (0, deco_id))
83 .chain(after_exit.iter().map(|&deco_id| (1, deco_id)))
84 }
85}
86
87impl CallNode {
91 pub(super) fn to_pretty_print<'a>(
92 &'a self,
93 mast_forest: &'a MastForest,
94 ) -> impl PrettyPrint + 'a {
95 CallNodePrettyPrint { node: self, mast_forest }
96 }
97
98 pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
99 CallNodePrettyPrint { node: self, mast_forest }
100 }
101}
102
103struct CallNodePrettyPrint<'a> {
104 node: &'a CallNode,
105 mast_forest: &'a MastForest,
106}
107
108impl CallNodePrettyPrint<'_> {
109 fn concatenate_decorators(
112 &self,
113 decorator_ids: &[DecoratorId],
114 prepend: Document,
115 append: Document,
116 ) -> Document {
117 let decorators = decorator_ids
118 .iter()
119 .map(|&decorator_id| self.mast_forest[decorator_id].render())
120 .reduce(|acc, doc| acc + const_text(" ") + doc)
121 .unwrap_or_default();
122
123 if decorators.is_empty() {
124 decorators
125 } else {
126 prepend + decorators + append
127 }
128 }
129
130 fn single_line_pre_decorators(&self) -> Document {
131 self.concatenate_decorators(
132 self.node.before_enter(self.mast_forest),
133 Document::Empty,
134 const_text(" "),
135 )
136 }
137
138 fn single_line_post_decorators(&self) -> Document {
139 self.concatenate_decorators(
140 self.node.after_exit(self.mast_forest),
141 const_text(" "),
142 Document::Empty,
143 )
144 }
145
146 fn multi_line_pre_decorators(&self) -> Document {
147 self.concatenate_decorators(self.node.before_enter(self.mast_forest), Document::Empty, nl())
148 }
149
150 fn multi_line_post_decorators(&self) -> Document {
151 self.concatenate_decorators(self.node.after_exit(self.mast_forest), nl(), Document::Empty)
152 }
153}
154
155impl PrettyPrint for CallNodePrettyPrint<'_> {
156 fn render(&self) -> Document {
157 let call_or_syscall = {
158 let callee_digest = self.mast_forest[self.node.callee].digest();
159 if self.node.is_syscall {
160 const_text("syscall")
161 + const_text(".")
162 + text(callee_digest.as_bytes().to_hex_with_prefix())
163 } else {
164 const_text("call")
165 + const_text(".")
166 + text(callee_digest.as_bytes().to_hex_with_prefix())
167 }
168 };
169
170 let single_line = self.single_line_pre_decorators()
171 + call_or_syscall.clone()
172 + self.single_line_post_decorators();
173 let multi_line =
174 self.multi_line_pre_decorators() + call_or_syscall + self.multi_line_post_decorators();
175
176 single_line | multi_line
177 }
178}
179
180impl fmt::Display for CallNodePrettyPrint<'_> {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 use crate::prettier::PrettyPrint;
183 self.pretty_print(f)
184 }
185}
186
187impl MastNodeExt for CallNode {
191 fn digest(&self) -> Word {
210 self.digest
211 }
212
213 fn before_enter<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
215 #[cfg(debug_assertions)]
216 self.verify_node_in_forest(forest);
217 self.decorator_store.before_enter(forest)
218 }
219
220 fn after_exit<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
222 #[cfg(debug_assertions)]
223 self.verify_node_in_forest(forest);
224 self.decorator_store.after_exit(forest)
225 }
226
227 fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
228 Box::new(CallNode::to_display(self, mast_forest))
229 }
230
231 fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
232 Box::new(CallNode::to_pretty_print(self, mast_forest))
233 }
234
235 fn has_children(&self) -> bool {
236 true
237 }
238
239 fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
240 target.push(self.callee());
241 }
242
243 fn for_each_child<F>(&self, mut f: F)
244 where
245 F: FnMut(MastNodeId),
246 {
247 f(self.callee());
248 }
249
250 fn domain(&self) -> Felt {
251 self.domain()
252 }
253
254 type Builder = CallNodeBuilder;
255
256 fn to_builder(self, forest: &MastForest) -> Self::Builder {
257 match self.decorator_store {
259 DecoratorStore::Owned { before_enter, after_exit, .. } => {
260 let mut builder = if self.is_syscall {
261 CallNodeBuilder::new_syscall(self.callee)
262 } else {
263 CallNodeBuilder::new(self.callee)
264 };
265 builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
266 builder
267 },
268 DecoratorStore::Linked { id } => {
269 let before_enter = forest.before_enter_decorators(id).to_vec();
271 let after_exit = forest.after_exit_decorators(id).to_vec();
272 let mut builder = if self.is_syscall {
273 CallNodeBuilder::new_syscall(self.callee)
274 } else {
275 CallNodeBuilder::new(self.callee)
276 };
277 builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
278 builder
279 },
280 }
281 }
282
283 #[cfg(debug_assertions)]
284 fn verify_node_in_forest(&self, forest: &MastForest) {
285 if let Some(id) = self.decorator_store.linked_id() {
286 let self_ptr = self as *const Self;
288 let forest_node = &forest.nodes[id];
289 let forest_node_ptr = match forest_node {
290 crate::mast::MastNode::Call(call_node) => call_node as *const CallNode as *const (),
291 _ => panic!("Node type mismatch at {:?}", id),
292 };
293 let self_as_void = self_ptr as *const ();
294 debug_assert_eq!(
295 self_as_void, forest_node_ptr,
296 "Node pointer mismatch: expected node at {:?} to be self",
297 id
298 );
299 }
300 }
301}
302
303#[cfg(all(feature = "arbitrary", test))]
307impl proptest::prelude::Arbitrary for CallNode {
308 type Parameters = ();
309
310 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
311 use proptest::prelude::*;
312
313 use crate::Felt;
314
315 (any::<MastNodeId>(), any::<[u64; 4]>(), any::<bool>())
317 .prop_map(|(callee, digest_array, is_syscall)| {
318 let digest = Word::from(digest_array.map(Felt::new));
320 CallNode {
322 callee,
323 is_syscall,
324 digest,
325 decorator_store: DecoratorStore::default(),
326 }
327 })
328 .no_shrink() .boxed()
330 }
331
332 type Strategy = proptest::prelude::BoxedStrategy<Self>;
333}
334
335#[derive(Debug)]
338pub struct CallNodeBuilder {
339 callee: MastNodeId,
340 is_syscall: bool,
341 before_enter: Vec<DecoratorId>,
342 after_exit: Vec<DecoratorId>,
343 digest: Option<Word>,
344}
345
346impl CallNodeBuilder {
347 pub fn new(callee: MastNodeId) -> Self {
349 Self {
350 callee,
351 is_syscall: false,
352 before_enter: Vec::new(),
353 after_exit: Vec::new(),
354 digest: None,
355 }
356 }
357
358 pub fn new_syscall(callee: MastNodeId) -> Self {
360 Self {
361 callee,
362 is_syscall: true,
363 before_enter: Vec::new(),
364 after_exit: Vec::new(),
365 digest: None,
366 }
367 }
368
369 pub fn build(self, mast_forest: &MastForest) -> Result<CallNode, MastForestError> {
371 if self.callee.to_usize() >= mast_forest.nodes.len() {
372 return Err(MastForestError::NodeIdOverflow(self.callee, mast_forest.nodes.len()));
373 }
374
375 let digest = if let Some(forced_digest) = self.digest {
377 forced_digest
378 } else {
379 let callee_digest = mast_forest[self.callee].digest();
380 let domain = if self.is_syscall {
381 CallNode::SYSCALL_DOMAIN
382 } else {
383 CallNode::CALL_DOMAIN
384 };
385
386 hasher::merge_in_domain(&[callee_digest, Word::default()], domain)
387 };
388
389 Ok(CallNode {
390 callee: self.callee,
391 is_syscall: self.is_syscall,
392 digest,
393 decorator_store: DecoratorStore::new_owned_with_decorators(
394 self.before_enter,
395 self.after_exit,
396 ),
397 })
398 }
399}
400
401impl MastForestContributor for CallNodeBuilder {
402 fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
403 if self.callee.to_usize() >= forest.nodes.len() {
404 return Err(MastForestError::NodeIdOverflow(self.callee, forest.nodes.len()));
405 }
406
407 let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
409
410 let digest = if let Some(forced_digest) = self.digest {
412 forced_digest
413 } else {
414 let callee_digest = forest[self.callee].digest();
415 let domain = if self.is_syscall {
416 CallNode::SYSCALL_DOMAIN
417 } else {
418 CallNode::CALL_DOMAIN
419 };
420
421 hasher::merge_in_domain(&[callee_digest, Word::default()], domain)
422 };
423
424 forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
426
427 let node_id = forest
430 .nodes
431 .push(
432 CallNode {
433 callee: self.callee,
434 is_syscall: self.is_syscall,
435 digest,
436 decorator_store: DecoratorStore::Linked { id: future_node_id },
437 }
438 .into(),
439 )
440 .map_err(|_| MastForestError::TooManyNodes)?;
441
442 Ok(node_id)
443 }
444
445 fn fingerprint_for_node(
446 &self,
447 forest: &MastForest,
448 hash_by_node_id: &impl crate::LookupByIdx<MastNodeId, crate::mast::MastNodeFingerprint>,
449 ) -> Result<crate::mast::MastNodeFingerprint, MastForestError> {
450 crate::mast::node_fingerprint::fingerprint_from_parts(
452 forest,
453 hash_by_node_id,
454 &self.before_enter,
455 &self.after_exit,
456 &[self.callee],
457 if let Some(forced_digest) = self.digest {
459 forced_digest
460 } else {
461 let callee_digest = forest[self.callee].digest();
462 let domain = if self.is_syscall {
463 CallNode::SYSCALL_DOMAIN
464 } else {
465 CallNode::CALL_DOMAIN
466 };
467
468 crate::chiplets::hasher::merge_in_domain(
469 &[callee_digest, miden_crypto::Word::default()],
470 domain,
471 )
472 },
473 )
474 }
475
476 fn remap_children(
477 self,
478 remapping: &impl crate::LookupByIdx<crate::mast::MastNodeId, crate::mast::MastNodeId>,
479 ) -> Self {
480 CallNodeBuilder {
481 callee: *remapping.get(self.callee).unwrap_or(&self.callee),
482 is_syscall: self.is_syscall,
483 before_enter: self.before_enter,
484 after_exit: self.after_exit,
485 digest: self.digest,
486 }
487 }
488
489 fn with_before_enter(mut self, decorators: impl Into<Vec<crate::mast::DecoratorId>>) -> Self {
490 self.before_enter = decorators.into();
491 self
492 }
493
494 fn with_after_exit(mut self, decorators: impl Into<Vec<crate::mast::DecoratorId>>) -> Self {
495 self.after_exit = decorators.into();
496 self
497 }
498
499 fn append_before_enter(
500 &mut self,
501 decorators: impl IntoIterator<Item = crate::mast::DecoratorId>,
502 ) {
503 self.before_enter.extend(decorators);
504 }
505
506 fn append_after_exit(
507 &mut self,
508 decorators: impl IntoIterator<Item = crate::mast::DecoratorId>,
509 ) {
510 self.after_exit.extend(decorators);
511 }
512
513 fn with_digest(mut self, digest: crate::Word) -> Self {
514 self.digest = Some(digest);
515 self
516 }
517}
518
519impl CallNodeBuilder {
520 pub(in crate::mast) fn add_to_forest_relaxed(
530 self,
531 forest: &mut MastForest,
532 ) -> Result<MastNodeId, MastForestError> {
533 let Some(digest) = self.digest else {
536 return Err(MastForestError::DigestRequiredForDeserialization);
537 };
538
539 let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
540
541 forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
543
544 let node_id = forest
547 .nodes
548 .push(
549 CallNode {
550 callee: self.callee,
551 is_syscall: self.is_syscall,
552 digest,
553 decorator_store: DecoratorStore::Linked { id: future_node_id },
554 }
555 .into(),
556 )
557 .map_err(|_| MastForestError::TooManyNodes)?;
558
559 Ok(node_id)
560 }
561}
562
563#[cfg(any(test, feature = "arbitrary"))]
564impl proptest::prelude::Arbitrary for CallNodeBuilder {
565 type Parameters = CallNodeBuilderParams;
566 type Strategy = proptest::strategy::BoxedStrategy<Self>;
567
568 fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
569 use proptest::prelude::*;
570
571 (
572 any::<crate::mast::MastNodeId>(),
573 any::<bool>(),
574 proptest::collection::vec(
575 super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
576 0..=params.max_decorators,
577 ),
578 proptest::collection::vec(
579 super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
580 0..=params.max_decorators,
581 ),
582 )
583 .prop_map(|(callee, is_syscall, before_enter, after_exit)| {
584 let mut builder = if is_syscall {
585 Self::new_syscall(callee)
586 } else {
587 Self::new(callee)
588 };
589 builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
590 builder
591 })
592 .boxed()
593 }
594}
595
596#[cfg(any(test, feature = "arbitrary"))]
598#[derive(Clone, Debug)]
599pub struct CallNodeBuilderParams {
600 pub max_decorators: usize,
601 pub max_decorator_id_u32: u32,
602}
603
604#[cfg(any(test, feature = "arbitrary"))]
605impl Default for CallNodeBuilderParams {
606 fn default() -> Self {
607 Self {
608 max_decorators: 4,
609 max_decorator_id_u32: 10,
610 }
611 }
612}