#[cfg(test)]
use alloc::string::ToString;
use alloc::{format, vec::Vec};
use miden_utils_sync::OnceLockCompat;
use super::{MastForest, MastNode, MastNodeId};
use crate::{
advice::AdviceMap,
mast::node::MastNodeExt,
serde::{
BudgetedReader, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
SliceReader,
},
};
pub(crate) mod asm_op;
pub(crate) mod decorator;
mod info;
pub use info::{MastNodeEntry, MastNodeInfo};
mod view;
pub use view::MastForestView;
mod layout;
pub(super) use layout::ForestLayout;
use layout::{OffsetTrackingReader, TrackingReader, WireFlags, read_header_and_scan_layout};
mod resolved;
use resolved::{ResolvedSerializedForest, basic_block_offset_for_node_index};
mod basic_blocks;
use basic_blocks::{BasicBlockDataBuilder, basic_block_data_len};
pub(crate) mod string_table;
pub(crate) use string_table::StringTable;
#[cfg(test)]
mod seed_gen;
#[cfg(test)]
mod tests;
type NodeDataOffset = u32;
type DecoratorDataOffset = u32;
type StringDataOffset = usize;
type StringIndex = usize;
const DEFAULT_UNTRUSTED_ALLOCATION_BUDGET_MULTIPLIER: usize = 7;
const TRUSTED_BYTE_READ_BUDGET_MULTIPLIER: usize = 64;
const MAGIC: &[u8; 4] = b"MAST";
const FLAG_STRIPPED: u8 = 0x01;
pub(super) const FLAG_HASHLESS: u8 = 0x02;
const FLAGS_RESERVED_MASK: u8 = 0xfc;
const VERSION: [u8; 3] = [0, 0, 3];
impl Serializable for MastForest {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.write_into_with_options(target, false, false);
}
}
impl MastForest {
fn write_into_with_options<W: ByteWriter>(
&self,
target: &mut W,
stripped: bool,
hashless: bool,
) {
let mut basic_block_data_builder = BasicBlockDataBuilder::new();
target.write_bytes(MAGIC);
let flags = if stripped || hashless { FLAG_STRIPPED } else { 0 }
| if hashless { FLAG_HASHLESS } else { 0 };
target.write_u8(flags);
target.write_bytes(&VERSION);
let node_count = self.nodes.len();
let external_node_count = self.nodes.iter().filter(|node| node.is_external()).count();
let internal_node_count = node_count - external_node_count;
target.write_usize(internal_node_count);
target.write_usize(external_node_count);
let roots: Vec<u32> = self.roots.iter().copied().map(u32::from).collect();
roots.write_into(target);
let mut mast_node_entries = Vec::with_capacity(self.nodes.len());
let mut external_digests = Vec::new();
let mut node_hashes = Vec::new();
for mast_node in self.nodes.iter() {
let ops_offset = if let MastNode::Block(basic_block) = mast_node {
basic_block_data_builder.encode_basic_block(basic_block)
} else {
0
};
mast_node_entries.push(MastNodeEntry::new(mast_node, ops_offset));
if mast_node.is_external() {
external_digests.push(mast_node.digest());
} else if !hashless {
node_hashes.push(mast_node.digest());
}
}
let basic_block_data = basic_block_data_builder.finalize();
basic_block_data.write_into(target);
for mast_node_entry in mast_node_entries {
mast_node_entry.write_into(target);
}
for digest in external_digests {
digest.write_into(target);
}
if !hashless {
for digest in node_hashes {
digest.write_into(target);
}
}
self.advice_map.write_into(target);
if !stripped {
self.debug_info.write_into(target);
}
}
}
pub(super) fn write_stripped_into<W: ByteWriter>(forest: &MastForest, target: &mut W) {
forest.write_into_with_options(target, true, false);
}
pub(super) fn write_hashless_into<W: ByteWriter>(forest: &MastForest, target: &mut W) {
forest.write_into_with_options(target, true, true);
}
pub(super) fn stripped_size_hint(forest: &MastForest) -> usize {
serialized_size_hint(forest, true, false)
}
fn serialized_size_hint(forest: &MastForest, stripped: bool, hashless: bool) -> usize {
let node_count = forest.nodes.len();
let external_count = forest.nodes.iter().filter(|node| node.is_external()).count();
let non_external_count = node_count - external_count;
let mut size = MAGIC.len() + 1 + VERSION.len();
size += non_external_count.get_size_hint();
size += external_count.get_size_hint();
let roots_len = forest.roots.len();
size += roots_len.get_size_hint();
size += roots_len * size_of::<u32>();
let mut basic_block_len = 0usize;
for node in forest.nodes.iter() {
if let MastNode::Block(block) = node {
basic_block_len += basic_block_data_len(block);
}
}
size += basic_block_len.get_size_hint() + basic_block_len;
size += node_count * MastNodeEntry::SERIALIZED_SIZE;
size += external_count * crate::Word::min_serialized_size();
if !hashless {
size += non_external_count * crate::Word::min_serialized_size();
}
size += forest.advice_map.serialized_size_hint();
if !stripped {
size += forest.debug_info.get_size_hint();
}
size
}
#[derive(Debug)]
pub struct SerializedMastForest<'a> {
bytes: &'a [u8],
flags: WireFlags,
layout: ForestLayout,
resolved: OnceLockCompat<Result<ResolvedSerializedForest<'a>, DeserializationError>>,
}
impl<'a> SerializedMastForest<'a> {
pub fn new(bytes: &'a [u8]) -> Result<Self, DeserializationError> {
let mut reader = SliceReader::new(bytes);
let mut scanner = TrackingReader::new(&mut reader);
let (flags, layout) = read_header_and_scan_layout(&mut scanner, true)?;
Ok(Self {
bytes,
flags,
layout,
resolved: OnceLockCompat::new(),
})
}
pub fn node_count(&self) -> usize {
self.layout.node_count
}
pub fn is_hashless(&self) -> bool {
self.flags.is_hashless()
}
pub fn is_stripped(&self) -> bool {
self.flags.is_stripped()
}
pub fn procedure_root_count(&self) -> usize {
self.layout.roots_count
}
pub fn procedure_root_at(&self, index: usize) -> Result<MastNodeId, DeserializationError> {
self.layout.read_procedure_root_at(self.bytes, index)
}
pub fn node_info_at(&self, index: usize) -> Result<MastNodeInfo, DeserializationError> {
Ok(MastNodeInfo::from_entry(
self.node_entry_at(index)?,
self.node_digest_at(index)?,
))
}
pub fn node_entry_at(&self, index: usize) -> Result<MastNodeEntry, DeserializationError> {
self.layout.read_node_entry_at(self.bytes, index)
}
pub fn node_digest_at(&self, index: usize) -> Result<crate::Word, DeserializationError> {
self.resolved()?.node_digest_at(index)
}
fn resolved(&self) -> Result<&ResolvedSerializedForest<'a>, DeserializationError> {
self.resolved
.get_or_init(|| ResolvedSerializedForest::new(self.bytes, self.layout))
.as_ref()
.map_err(Clone::clone)
}
}
impl MastForestView for SerializedMastForest<'_> {
fn node_count(&self) -> usize {
SerializedMastForest::node_count(self)
}
fn node_entry_at(&self, index: usize) -> Result<MastNodeEntry, DeserializationError> {
SerializedMastForest::node_entry_at(self, index)
}
fn node_digest_at(&self, index: usize) -> Result<crate::Word, DeserializationError> {
SerializedMastForest::node_digest_at(self, index)
}
fn procedure_root_count(&self) -> usize {
SerializedMastForest::procedure_root_count(self)
}
fn procedure_root_at(&self, index: usize) -> Result<MastNodeId, DeserializationError> {
SerializedMastForest::procedure_root_at(self, index)
}
}
impl MastForestView for MastForest {
fn node_count(&self) -> usize {
self.nodes.len()
}
fn node_entry_at(&self, index: usize) -> Result<MastNodeEntry, DeserializationError> {
let node = self.nodes.as_slice().get(index).ok_or_else(|| {
DeserializationError::InvalidValue(format!("node index {index} out of bounds"))
})?;
let ops_offset = if matches!(node, MastNode::Block(_)) {
basic_block_offset_for_node_index(self.nodes.as_slice(), index)?
} else {
0
};
Ok(MastNodeEntry::new(node, ops_offset))
}
fn node_digest_at(&self, index: usize) -> Result<crate::Word, DeserializationError> {
self.nodes.as_slice().get(index).map(MastNode::digest).ok_or_else(|| {
DeserializationError::InvalidValue(format!("node index {index} out of bounds"))
})
}
fn procedure_root_count(&self) -> usize {
self.roots.len()
}
fn procedure_root_at(&self, index: usize) -> Result<MastNodeId, DeserializationError> {
self.roots.get(index).copied().ok_or_else(|| {
DeserializationError::InvalidValue(format!(
"root index {} out of bounds for {} roots",
index,
self.roots.len()
))
})
}
}
#[cfg(test)]
impl SerializedMastForest<'_> {
fn advice_map_offset(&self) -> Result<usize, DeserializationError> {
self.layout.advice_map_offset()
}
fn node_entry_offset(&self) -> usize {
self.layout.node_entry_offset
}
fn node_hash_offset(&self) -> Option<usize> {
self.layout.node_hash_offset
}
fn digest_slot_at(&self, index: usize) -> usize {
self.resolved()
.expect("digest slots should be readable for a valid serialized view")
.digest_slot_at(index)
}
}
#[cfg(test)]
fn read_u8_at(bytes: &[u8], offset: &mut usize) -> Result<u8, DeserializationError> {
read_slice_at(bytes, offset, 1).map(|slice| slice[0])
}
#[cfg(test)]
fn read_array_at<const N: usize>(
bytes: &[u8],
offset: &mut usize,
) -> Result<[u8; N], DeserializationError> {
let slice = read_slice_at(bytes, offset, N)?;
let mut result = [0u8; N];
result.copy_from_slice(slice);
Ok(result)
}
#[cfg(test)]
fn read_slice_at<'a>(
bytes: &'a [u8],
offset: &mut usize,
len: usize,
) -> Result<&'a [u8], DeserializationError> {
let end = offset
.checked_add(len)
.ok_or_else(|| DeserializationError::InvalidValue("offset overflow".to_string()))?;
if end > bytes.len() {
return Err(DeserializationError::UnexpectedEOF);
}
let slice = &bytes[*offset..end];
*offset = end;
Ok(slice)
}
#[cfg(test)]
fn read_usize_at(bytes: &[u8], offset: &mut usize) -> Result<usize, DeserializationError> {
if *offset >= bytes.len() {
return Err(DeserializationError::UnexpectedEOF);
}
let first_byte = bytes[*offset];
let length = first_byte.trailing_zeros() as usize + 1;
let result = if length == 9 {
let _marker = read_u8_at(bytes, offset)?;
let value = read_array_at::<8>(bytes, offset)?;
u64::from_le_bytes(value)
} else {
let mut encoded = [0u8; 8];
let value = read_slice_at(bytes, offset, length)?;
encoded[..length].copy_from_slice(value);
u64::from_le_bytes(encoded) >> length
};
if result > usize::MAX as u64 {
return Err(DeserializationError::InvalidValue(format!(
"Encoded value must be less than {}, but {} was provided",
usize::MAX,
result
)));
}
Ok(result as usize)
}
impl Deserializable for MastForest {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let (_flags, forest) = decode_from_reader(source, false)?;
forest.into_materialized()
}
fn read_from_bytes(bytes: &[u8]) -> Result<Self, DeserializationError> {
let budget = bytes.len().saturating_mul(TRUSTED_BYTE_READ_BUDGET_MULTIPLIER);
let mut reader = BudgetedReader::new(SliceReader::new(bytes), budget);
Self::read_from(&mut reader)
}
}
impl super::UntrustedMastForest {
pub(super) fn into_materialized(self) -> Result<MastForest, DeserializationError> {
let resolved = if let Some(allocation_budget) = self.remaining_allocation_budget {
ResolvedSerializedForest::new_with_allocation_budget(
&self.bytes,
self.layout,
allocation_budget,
)?
} else {
ResolvedSerializedForest::new(&self.bytes, self.layout)?
};
resolved.materialize(self.advice_map, self.debug_info)
}
}
pub(super) fn read_untrusted_with_flags<R: ByteReader>(
source: &mut R,
) -> Result<(super::UntrustedMastForest, u8), DeserializationError> {
let (flags, forest) = decode_from_reader(source, true)?;
log_untrusted_overspecification(flags);
Ok((forest, flags.bits()))
}
pub(super) fn read_untrusted_with_flags_and_allocation_budget<R: ByteReader>(
source: &mut R,
allocation_budget: usize,
) -> Result<(super::UntrustedMastForest, u8), DeserializationError> {
let (flags, forest) = decode_from_reader_inner(source, true, Some(allocation_budget))?;
log_untrusted_overspecification(flags);
Ok((forest, flags.bits()))
}
fn log_untrusted_overspecification(flags: WireFlags) {
if !flags.is_hashless() {
log::error!(
"UntrustedMastForest expected HASHLESS input; supplied artifact includes wire node hashes, and validation will recompute them and require them to match"
);
}
if !flags.is_stripped() {
log::error!(
"UntrustedMastForest expected STRIPPED input; supplied artifact includes DebugInfo and other optional payloads over the wire"
);
}
}
fn decode_from_reader<R: ByteReader>(
source: &mut R,
allow_hashless: bool,
) -> Result<(WireFlags, super::UntrustedMastForest), DeserializationError> {
decode_from_reader_inner(source, allow_hashless, None)
}
fn decode_from_reader_inner<R: ByteReader>(
source: &mut R,
allow_hashless: bool,
mut remaining_allocation_budget: Option<usize>,
) -> Result<(WireFlags, super::UntrustedMastForest), DeserializationError> {
let mut recording = TrackingReader::new_recording(source);
let (flags, layout) = read_header_and_scan_layout(&mut recording, allow_hashless)?;
debug_assert_eq!(recording.offset(), layout.advice_map_offset);
let advice_map = AdviceMap::read_from(&mut recording)?;
let debug_info = if flags.is_stripped() {
if let Some(allocation_budget) = &mut remaining_allocation_budget {
reserve_allocation::<usize>(
allocation_budget,
layout.node_count.checked_add(1).ok_or_else(|| {
DeserializationError::InvalidValue("debug-info node count overflow".into())
})?,
"empty debug-info scaffolding",
)?;
}
super::DebugInfo::empty_for_nodes(layout.node_count)
} else {
super::DebugInfo::read_from(&mut recording)?
};
Ok((
flags,
super::UntrustedMastForest {
bytes: recording.into_recorded(),
layout,
advice_map,
debug_info,
remaining_allocation_budget,
},
))
}
pub(super) fn reserve_allocation<T>(
remaining_budget: &mut usize,
count: usize,
label: &str,
) -> Result<(), DeserializationError> {
let bytes_needed = count
.checked_mul(size_of::<T>())
.ok_or_else(|| DeserializationError::InvalidValue(format!("{label} size overflow")))?;
if bytes_needed > *remaining_budget {
return Err(DeserializationError::InvalidValue(format!(
"{label} requires {bytes_needed} bytes, exceeding the remaining untrusted allocation budget of {} bytes",
*remaining_budget
)));
}
*remaining_budget -= bytes_needed;
Ok(())
}
pub(super) fn default_untrusted_allocation_budget(bytes_len: usize) -> usize {
bytes_len.saturating_mul(DEFAULT_UNTRUSTED_ALLOCATION_BUDGET_MULTIPLIER)
}
impl Deserializable for super::UntrustedMastForest {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
read_untrusted_with_flags(source).map(|(forest, _flags)| forest)
}
fn read_from_bytes(bytes: &[u8]) -> Result<Self, DeserializationError> {
super::UntrustedMastForest::read_from_bytes(bytes)
}
}