use alloc::{boxed::Box, vec::Vec};
use core::fmt;
use miden_formatting::{
hex::ToHex,
prettier::{Document, PrettyPrint, const_text, nl, text},
};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use super::{MastForestContributor, MastNodeExt};
#[cfg(debug_assertions)]
use crate::mast::MastNode;
use crate::{
Felt, Word,
chiplets::hasher,
mast::{
DecoratorId, DecoratorStore, MastForest, MastForestError, MastNodeFingerprint, MastNodeId,
},
operations::opcodes,
utils::{Idx, LookupByIdx},
};
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
pub struct CallNode {
callee: MastNodeId,
is_syscall: bool,
digest: Word,
decorator_store: DecoratorStore,
}
impl CallNode {
pub const CALL_DOMAIN: Felt = Felt::new(opcodes::CALL as u64);
pub const SYSCALL_DOMAIN: Felt = Felt::new(opcodes::SYSCALL as u64);
}
impl CallNode {
pub fn callee(&self) -> MastNodeId {
self.callee
}
pub fn is_syscall(&self) -> bool {
self.is_syscall
}
pub fn domain(&self) -> Felt {
if self.is_syscall() {
Self::SYSCALL_DOMAIN
} else {
Self::CALL_DOMAIN
}
}
}
impl CallNode {
pub(super) fn to_pretty_print<'a>(
&'a self,
mast_forest: &'a MastForest,
) -> impl PrettyPrint + 'a {
CallNodePrettyPrint { node: self, mast_forest }
}
pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
CallNodePrettyPrint { node: self, mast_forest }
}
}
struct CallNodePrettyPrint<'a> {
node: &'a CallNode,
mast_forest: &'a MastForest,
}
impl CallNodePrettyPrint<'_> {
fn concatenate_decorators(
&self,
decorator_ids: &[DecoratorId],
prepend: Document,
append: Document,
) -> Document {
let decorators = decorator_ids
.iter()
.map(|&decorator_id| self.mast_forest[decorator_id].render())
.reduce(|acc, doc| acc + const_text(" ") + doc)
.unwrap_or_default();
if decorators.is_empty() {
decorators
} else {
prepend + decorators + append
}
}
fn single_line_pre_decorators(&self) -> Document {
self.concatenate_decorators(
self.node.before_enter(self.mast_forest),
Document::Empty,
const_text(" "),
)
}
fn single_line_post_decorators(&self) -> Document {
self.concatenate_decorators(
self.node.after_exit(self.mast_forest),
const_text(" "),
Document::Empty,
)
}
fn multi_line_pre_decorators(&self) -> Document {
self.concatenate_decorators(self.node.before_enter(self.mast_forest), Document::Empty, nl())
}
fn multi_line_post_decorators(&self) -> Document {
self.concatenate_decorators(self.node.after_exit(self.mast_forest), nl(), Document::Empty)
}
}
impl PrettyPrint for CallNodePrettyPrint<'_> {
fn render(&self) -> Document {
let call_or_syscall = {
let callee_digest = self.mast_forest[self.node.callee].digest();
if self.node.is_syscall {
const_text("syscall")
+ const_text(".")
+ text(callee_digest.as_bytes().to_hex_with_prefix())
} else {
const_text("call")
+ const_text(".")
+ text(callee_digest.as_bytes().to_hex_with_prefix())
}
};
let single_line = self.single_line_pre_decorators()
+ call_or_syscall.clone()
+ self.single_line_post_decorators();
let multi_line =
self.multi_line_pre_decorators() + call_or_syscall + self.multi_line_post_decorators();
single_line | multi_line
}
}
impl fmt::Display for CallNodePrettyPrint<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use crate::prettier::PrettyPrint;
self.pretty_print(f)
}
}
impl MastNodeExt for CallNode {
fn digest(&self) -> Word {
self.digest
}
fn before_enter<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
#[cfg(debug_assertions)]
self.verify_node_in_forest(forest);
self.decorator_store.before_enter(forest)
}
fn after_exit<'a>(&'a self, forest: &'a MastForest) -> &'a [DecoratorId] {
#[cfg(debug_assertions)]
self.verify_node_in_forest(forest);
self.decorator_store.after_exit(forest)
}
fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
Box::new(CallNode::to_display(self, mast_forest))
}
fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
Box::new(CallNode::to_pretty_print(self, mast_forest))
}
fn has_children(&self) -> bool {
true
}
fn append_children_to(&self, target: &mut Vec<MastNodeId>) {
target.push(self.callee());
}
fn for_each_child<F>(&self, mut f: F)
where
F: FnMut(MastNodeId),
{
f(self.callee());
}
fn domain(&self) -> Felt {
self.domain()
}
type Builder = CallNodeBuilder;
fn to_builder(self, forest: &MastForest) -> Self::Builder {
match self.decorator_store {
DecoratorStore::Owned { before_enter, after_exit, .. } => {
let mut builder = if self.is_syscall {
CallNodeBuilder::new_syscall(self.callee)
} else {
CallNodeBuilder::new(self.callee)
};
builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
builder
},
DecoratorStore::Linked { id } => {
let before_enter = forest.before_enter_decorators(id).to_vec();
let after_exit = forest.after_exit_decorators(id).to_vec();
let mut builder = if self.is_syscall {
CallNodeBuilder::new_syscall(self.callee)
} else {
CallNodeBuilder::new(self.callee)
};
builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
builder
},
}
}
#[cfg(debug_assertions)]
fn verify_node_in_forest(&self, forest: &MastForest) {
if let Some(id) = self.decorator_store.linked_id() {
let self_ptr = self as *const Self;
let forest_node = &forest.nodes[id];
let forest_node_ptr = match forest_node {
MastNode::Call(call_node) => call_node as *const CallNode as *const (),
_ => panic!("Node type mismatch at {:?}", id),
};
let self_as_void = self_ptr as *const ();
debug_assert_eq!(
self_as_void, forest_node_ptr,
"Node pointer mismatch: expected node at {:?} to be self",
id
);
}
}
}
#[cfg(all(feature = "arbitrary", test))]
impl proptest::prelude::Arbitrary for CallNode {
type Parameters = ();
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
use proptest::prelude::*;
use crate::Felt;
(any::<MastNodeId>(), any::<[u64; 4]>(), any::<bool>())
.prop_map(|(callee, digest_array, is_syscall)| {
let digest = Word::from(digest_array.map(Felt::new));
CallNode {
callee,
is_syscall,
digest,
decorator_store: DecoratorStore::default(),
}
})
.no_shrink() .boxed()
}
type Strategy = proptest::prelude::BoxedStrategy<Self>;
}
#[derive(Debug)]
pub struct CallNodeBuilder {
callee: MastNodeId,
is_syscall: bool,
before_enter: Vec<DecoratorId>,
after_exit: Vec<DecoratorId>,
digest: Option<Word>,
}
impl CallNodeBuilder {
pub fn new(callee: MastNodeId) -> Self {
Self {
callee,
is_syscall: false,
before_enter: Vec::new(),
after_exit: Vec::new(),
digest: None,
}
}
pub fn new_syscall(callee: MastNodeId) -> Self {
Self {
callee,
is_syscall: true,
before_enter: Vec::new(),
after_exit: Vec::new(),
digest: None,
}
}
pub fn build(self, mast_forest: &MastForest) -> Result<CallNode, MastForestError> {
if self.callee.to_usize() >= mast_forest.nodes.len() {
return Err(MastForestError::NodeIdOverflow(self.callee, mast_forest.nodes.len()));
}
let digest = if let Some(forced_digest) = self.digest {
forced_digest
} else {
let callee_digest = mast_forest[self.callee].digest();
let domain = if self.is_syscall {
CallNode::SYSCALL_DOMAIN
} else {
CallNode::CALL_DOMAIN
};
hasher::merge_in_domain(&[callee_digest, Word::default()], domain)
};
Ok(CallNode {
callee: self.callee,
is_syscall: self.is_syscall,
digest,
decorator_store: DecoratorStore::new_owned_with_decorators(
self.before_enter,
self.after_exit,
),
})
}
}
impl MastForestContributor for CallNodeBuilder {
fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
if self.callee.to_usize() >= forest.nodes.len() {
return Err(MastForestError::NodeIdOverflow(self.callee, forest.nodes.len()));
}
let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
let digest = if let Some(forced_digest) = self.digest {
forced_digest
} else {
let callee_digest = forest[self.callee].digest();
let domain = if self.is_syscall {
CallNode::SYSCALL_DOMAIN
} else {
CallNode::CALL_DOMAIN
};
hasher::merge_in_domain(&[callee_digest, Word::default()], domain)
};
forest.register_node_decorators(future_node_id, &self.before_enter, &self.after_exit);
let node_id = forest
.nodes
.push(
CallNode {
callee: self.callee,
is_syscall: self.is_syscall,
digest,
decorator_store: DecoratorStore::Linked { id: future_node_id },
}
.into(),
)
.map_err(|_| MastForestError::TooManyNodes)?;
Ok(node_id)
}
fn fingerprint_for_node(
&self,
forest: &MastForest,
hash_by_node_id: &impl LookupByIdx<MastNodeId, MastNodeFingerprint>,
) -> Result<MastNodeFingerprint, MastForestError> {
crate::mast::node_fingerprint::fingerprint_from_parts(
forest,
hash_by_node_id,
&self.before_enter,
&self.after_exit,
&[self.callee],
if let Some(forced_digest) = self.digest {
forced_digest
} else {
let callee_digest = forest[self.callee].digest();
let domain = if self.is_syscall {
CallNode::SYSCALL_DOMAIN
} else {
CallNode::CALL_DOMAIN
};
crate::chiplets::hasher::merge_in_domain(
&[callee_digest, miden_crypto::Word::default()],
domain,
)
},
)
}
fn remap_children(self, remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
CallNodeBuilder {
callee: *remapping.get(self.callee).unwrap_or(&self.callee),
is_syscall: self.is_syscall,
before_enter: self.before_enter,
after_exit: self.after_exit,
digest: self.digest,
}
}
fn with_before_enter(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
self.before_enter = decorators.into();
self
}
fn with_after_exit(mut self, decorators: impl Into<Vec<DecoratorId>>) -> Self {
self.after_exit = decorators.into();
self
}
fn append_before_enter(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
self.before_enter.extend(decorators);
}
fn append_after_exit(&mut self, decorators: impl IntoIterator<Item = DecoratorId>) {
self.after_exit.extend(decorators);
}
fn with_digest(mut self, digest: crate::Word) -> Self {
self.digest = Some(digest);
self
}
}
impl CallNodeBuilder {
pub(in crate::mast) fn add_to_forest_relaxed(
self,
forest: &mut MastForest,
) -> Result<MastNodeId, MastForestError> {
let Some(digest) = self.digest else {
return Err(MastForestError::DigestRequiredForDeserialization);
};
let future_node_id = MastNodeId::new_unchecked(forest.nodes.len() as u32);
let node_id = forest
.nodes
.push(
CallNode {
callee: self.callee,
is_syscall: self.is_syscall,
digest,
decorator_store: DecoratorStore::Linked { id: future_node_id },
}
.into(),
)
.map_err(|_| MastForestError::TooManyNodes)?;
Ok(node_id)
}
}
#[cfg(any(test, feature = "arbitrary"))]
impl proptest::prelude::Arbitrary for CallNodeBuilder {
type Parameters = CallNodeBuilderParams;
type Strategy = proptest::strategy::BoxedStrategy<Self>;
fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
use proptest::prelude::*;
(
any::<MastNodeId>(),
any::<bool>(),
proptest::collection::vec(
super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
0..=params.max_decorators,
),
proptest::collection::vec(
super::arbitrary::decorator_id_strategy(params.max_decorator_id_u32),
0..=params.max_decorators,
),
)
.prop_map(|(callee, is_syscall, before_enter, after_exit)| {
let mut builder = if is_syscall {
Self::new_syscall(callee)
} else {
Self::new(callee)
};
builder = builder.with_before_enter(before_enter).with_after_exit(after_exit);
builder
})
.boxed()
}
}
#[cfg(any(test, feature = "arbitrary"))]
#[derive(Clone, Debug)]
pub struct CallNodeBuilderParams {
pub max_decorators: usize,
pub max_decorator_id_u32: u32,
}
#[cfg(any(test, feature = "arbitrary"))]
impl Default for CallNodeBuilderParams {
fn default() -> Self {
Self {
max_decorators: 4,
max_decorator_id_u32: 10,
}
}
}