#[cfg(test)]
use alloc::string::ToString;
use alloc::{boxed::Box, format, vec::Vec};
use core::mem::size_of;
use miden_utils_sync::OnceLockCompat;
use super::{MastForest, MastNode, MastNodeId};
use crate::{
Word,
advice::AdviceMap,
mast::node::MastNodeExt,
serde::{
BudgetedReader, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
SliceReader,
},
};
mod info;
pub use info::{MastNodeEntry, MastNodeInfo};
mod view;
use view::WireAdviceMapView;
pub use view::{AdviceMapView, AdviceValueView, MastForestView};
mod layout;
pub(super) use layout::ForestLayout;
use layout::{OffsetTrackingReader, TrackingReader, WireFlags, read_header_and_scan_layout};
mod resolved;
use resolved::{ResolvedSerializedForest, basic_block_offset_for_node_index};
mod basic_blocks;
use basic_blocks::{BasicBlockDataBuilder, basic_block_data_len};
#[cfg(test)]
mod seed_gen;
#[cfg(test)]
mod tests;
type NodeDataOffset = u32;
const DEFAULT_UNTRUSTED_ALLOCATION_BUDGET_MULTIPLIER: usize = 7;
const TRUSTED_BYTE_READ_BUDGET_MULTIPLIER: usize = 64;
const MAGIC: &[u8; 4] = b"MAST";
pub(super) const FLAG_HASHLESS: u8 = 0x02;
const FLAGS_RESERVED_MASK: u8 = 0xfd;
const VERSION: [u8; 3] = [0, 0, 4];
impl Serializable for MastForest {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.write_into_with_options(target, false);
}
}
impl MastForest {
fn write_into_with_options<W: ByteWriter>(&self, target: &mut W, hashless: bool) {
let mut basic_block_data_builder = BasicBlockDataBuilder::new();
target.write_bytes(MAGIC);
let flags = if hashless { FLAG_HASHLESS } else { 0 };
target.write_u8(flags);
target.write_bytes(&VERSION);
let node_count = self.nodes.len();
let external_node_count = self.nodes.iter().filter(|node| node.is_external()).count();
let internal_node_count = node_count - external_node_count;
target.write_usize(internal_node_count);
target.write_usize(external_node_count);
let roots: Vec<u32> = self.roots.iter().copied().map(u32::from).collect();
roots.write_into(target);
let mut mast_node_entries = Vec::with_capacity(self.nodes.len());
let mut external_digests = Vec::with_capacity(external_node_count);
let mut node_hashes = Vec::new();
for mast_node in self.nodes.iter() {
let ops_offset = if let MastNode::Block(basic_block) = mast_node {
basic_block_data_builder.encode_basic_block(basic_block)
} else {
0
};
mast_node_entries.push(MastNodeEntry::new(mast_node, ops_offset));
if mast_node.is_external() {
external_digests.push(mast_node.digest());
} else if !hashless {
node_hashes.push(mast_node.digest());
}
}
let basic_block_data = basic_block_data_builder.finalize();
basic_block_data.write_into(target);
for mast_node_entry in mast_node_entries {
mast_node_entry.write_into(target);
}
for digest in external_digests {
digest.write_into(target);
}
if !hashless {
for digest in node_hashes {
digest.write_into(target);
}
}
self.advice_map.write_into(target);
}
}
pub(super) fn write_hashless_into<W: ByteWriter>(forest: &MastForest, target: &mut W) {
forest.write_into_with_options(target, true);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MastForestReadMode {
Materialized,
WireBacked,
}
#[derive(Debug)]
pub enum MastForestReadView<'a> {
Materialized(MastForest),
WireBacked(Box<MastForestWireView<'a>>),
}
#[derive(Debug)]
pub struct MastForestWireView<'a> {
bytes: &'a [u8],
layout: ForestLayout,
advice_map: WireAdviceMapView<'a>,
resolved: OnceLockCompat<Result<ResolvedSerializedForest<'a>, DeserializationError>>,
}
impl<'a> MastForestWireView<'a> {
pub fn new(bytes: &'a [u8]) -> Result<Self, DeserializationError> {
let mut reader = SliceReader::new(bytes);
let mut scanner = TrackingReader::new(&mut reader);
let (_flags, layout) = read_header_and_scan_layout(&mut scanner, false)?;
let advice_map = WireAdviceMapView::new(bytes, layout.advice_map_offset())?;
check_no_trailing_payload(bytes, advice_map.end_offset())?;
Ok(Self {
bytes,
layout,
advice_map,
resolved: OnceLockCompat::new(),
})
}
pub fn node_count(&self) -> usize {
self.layout.node_count
}
pub fn procedure_root_count(&self) -> usize {
self.layout.roots_count
}
pub fn procedure_root_at(&self, index: usize) -> Result<MastNodeId, DeserializationError> {
self.layout.read_procedure_root_at(self.bytes, index)
}
pub fn node_info_at(&self, index: usize) -> Result<MastNodeInfo, DeserializationError> {
Ok(MastNodeInfo::from_entry(
self.node_entry_at(index)?,
self.node_digest_at(index)?,
))
}
pub fn node_entry_at(&self, index: usize) -> Result<MastNodeEntry, DeserializationError> {
self.layout.read_node_entry_at(self.bytes, index)
}
pub fn node_digest_at(&self, index: usize) -> Result<Word, DeserializationError> {
self.resolved()?.node_digest_at(index)
}
pub fn advice_map(&self) -> AdviceMapView<'_> {
AdviceMapView::wire(&self.advice_map)
}
fn resolved(&self) -> Result<&ResolvedSerializedForest<'a>, DeserializationError> {
self.resolved
.get_or_init(|| ResolvedSerializedForest::new(self.bytes, self.layout))
.as_ref()
.map_err(Clone::clone)
}
}
fn check_no_trailing_payload(
bytes: &[u8],
debug_info_offset: usize,
) -> Result<(), DeserializationError> {
let payload = bytes.get(debug_info_offset..).ok_or(DeserializationError::UnexpectedEOF)?;
if payload.is_empty() {
return Ok(());
}
Err(extra_bytes_after_mast_forest_payload_error())
}
fn extra_bytes_after_mast_forest_payload_error() -> DeserializationError {
DeserializationError::InvalidValue("extra bytes after MastForest payload".into())
}
impl MastForest {
pub fn read_view_from_bytes(
bytes: &[u8],
mode: MastForestReadMode,
) -> Result<MastForestReadView<'_>, DeserializationError> {
match mode {
MastForestReadMode::Materialized => {
Self::read_from_bytes(bytes).map(MastForestReadView::Materialized)
},
MastForestReadMode::WireBacked => {
MastForestWireView::new(bytes).map(Box::new).map(MastForestReadView::WireBacked)
},
}
}
}
impl MastForestView for MastForestWireView<'_> {
fn node_count(&self) -> usize {
MastForestWireView::node_count(self)
}
fn node_entry_at(&self, index: usize) -> Result<MastNodeEntry, DeserializationError> {
MastForestWireView::node_entry_at(self, index)
}
fn node_digest_at(&self, index: usize) -> Result<Word, DeserializationError> {
MastForestWireView::node_digest_at(self, index)
}
fn procedure_root_count(&self) -> usize {
MastForestWireView::procedure_root_count(self)
}
fn procedure_root_at(&self, index: usize) -> Result<MastNodeId, DeserializationError> {
MastForestWireView::procedure_root_at(self, index)
}
fn advice_map(&self) -> AdviceMapView<'_> {
MastForestWireView::advice_map(self)
}
}
impl MastForestView for MastForestReadView<'_> {
fn node_count(&self) -> usize {
match self {
MastForestReadView::Materialized(forest) => MastForestView::node_count(forest),
MastForestReadView::WireBacked(view) => view.node_count(),
}
}
fn node_entry_at(&self, index: usize) -> Result<MastNodeEntry, DeserializationError> {
match self {
MastForestReadView::Materialized(forest) => {
MastForestView::node_entry_at(forest, index)
},
MastForestReadView::WireBacked(view) => view.node_entry_at(index),
}
}
fn node_digest_at(&self, index: usize) -> Result<Word, DeserializationError> {
match self {
MastForestReadView::Materialized(forest) => {
MastForestView::node_digest_at(forest, index)
},
MastForestReadView::WireBacked(view) => view.node_digest_at(index),
}
}
fn procedure_root_count(&self) -> usize {
match self {
MastForestReadView::Materialized(forest) => {
MastForestView::procedure_root_count(forest)
},
MastForestReadView::WireBacked(view) => view.procedure_root_count(),
}
}
fn procedure_root_at(&self, index: usize) -> Result<MastNodeId, DeserializationError> {
match self {
MastForestReadView::Materialized(forest) => {
MastForestView::procedure_root_at(forest, index)
},
MastForestReadView::WireBacked(view) => view.procedure_root_at(index),
}
}
fn advice_map(&self) -> AdviceMapView<'_> {
match self {
MastForestReadView::Materialized(forest) => MastForestView::advice_map(forest),
MastForestReadView::WireBacked(view) => view.advice_map(),
}
}
}
impl MastForestView for MastForest {
fn node_count(&self) -> usize {
self.nodes.len()
}
fn node_entry_at(&self, index: usize) -> Result<MastNodeEntry, DeserializationError> {
let node = self.nodes.as_slice().get(index).ok_or_else(|| {
DeserializationError::InvalidValue(format!("node index {index} out of bounds"))
})?;
let ops_offset = if matches!(node, MastNode::Block(_)) {
basic_block_offset_for_node_index(self.nodes.as_slice(), index)?
} else {
0
};
Ok(MastNodeEntry::new(node, ops_offset))
}
fn node_digest_at(&self, index: usize) -> Result<Word, DeserializationError> {
self.nodes.as_slice().get(index).map(MastNode::digest).ok_or_else(|| {
DeserializationError::InvalidValue(format!("node index {index} out of bounds"))
})
}
fn procedure_root_count(&self) -> usize {
self.roots.len()
}
fn procedure_root_at(&self, index: usize) -> Result<MastNodeId, DeserializationError> {
self.roots.get(index).copied().ok_or_else(|| {
DeserializationError::InvalidValue(format!(
"root index {} out of bounds for {} roots",
index,
self.roots.len()
))
})
}
fn advice_map(&self) -> AdviceMapView<'_> {
AdviceMapView::materialized(&self.advice_map)
}
}
#[cfg(test)]
impl MastForestWireView<'_> {
fn debug_info_offset(&self) -> usize {
self.advice_map.end_offset()
}
fn node_entry_offset(&self) -> usize {
self.layout.node_entry_offset()
}
fn external_digest_offset(&self) -> usize {
self.layout.external_digest_offset()
}
fn node_hash_offset(&self) -> Option<usize> {
self.layout.node_hash_offset()
}
fn digest_slot_at(&self, index: usize) -> usize {
self.resolved()
.expect("digest slots should be readable for a valid serialized view")
.digest_slot_at(index)
}
}
#[cfg(test)]
fn read_u8_at(bytes: &[u8], offset: &mut usize) -> Result<u8, DeserializationError> {
read_slice_at(bytes, offset, 1).map(|slice| slice[0])
}
#[cfg(test)]
fn read_array_at<const N: usize>(
bytes: &[u8],
offset: &mut usize,
) -> Result<[u8; N], DeserializationError> {
let slice = read_slice_at(bytes, offset, N)?;
let mut result = [0u8; N];
result.copy_from_slice(slice);
Ok(result)
}
#[cfg(test)]
fn read_slice_at<'a>(
bytes: &'a [u8],
offset: &mut usize,
len: usize,
) -> Result<&'a [u8], DeserializationError> {
let end = offset
.checked_add(len)
.ok_or_else(|| DeserializationError::InvalidValue("offset overflow".to_string()))?;
if end > bytes.len() {
return Err(DeserializationError::UnexpectedEOF);
}
let slice = &bytes[*offset..end];
*offset = end;
Ok(slice)
}
#[cfg(test)]
fn read_usize_at(bytes: &[u8], offset: &mut usize) -> Result<usize, DeserializationError> {
if *offset >= bytes.len() {
return Err(DeserializationError::UnexpectedEOF);
}
let first_byte = bytes[*offset];
let length = first_byte.trailing_zeros() as usize + 1;
let result = if length == 9 {
let _marker = read_u8_at(bytes, offset)?;
let value = read_array_at::<8>(bytes, offset)?;
u64::from_le_bytes(value)
} else {
let mut encoded = [0u8; 8];
let value = read_slice_at(bytes, offset, length)?;
encoded[..length].copy_from_slice(value);
u64::from_le_bytes(encoded) >> length
};
if result > usize::MAX as u64 {
return Err(DeserializationError::InvalidValue(format!(
"Encoded value must be less than {}, but {} was provided",
usize::MAX,
result
)));
}
Ok(result as usize)
}
impl Deserializable for MastForest {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let (_flags, forest) = decode_from_reader(source, false)?;
forest.into_materialized()
}
fn read_from_bytes(bytes: &[u8]) -> Result<Self, DeserializationError> {
let budget = bytes.len().saturating_mul(TRUSTED_BYTE_READ_BUDGET_MULTIPLIER);
let mut reader = BudgetedReader::new(SliceReader::new(bytes), budget);
let forest = Self::read_from(&mut reader)?;
if reader.has_more_bytes() {
return Err(extra_bytes_after_mast_forest_payload_error());
}
Ok(forest)
}
}
impl super::UntrustedMastForest {
pub(super) fn into_materialized(self) -> Result<MastForest, DeserializationError> {
let resolved = if let Some(allocation_budget) = self.remaining_allocation_budget {
ResolvedSerializedForest::new_with_allocation_budget(
&self.bytes,
self.layout,
allocation_budget,
)?
} else {
ResolvedSerializedForest::new(&self.bytes, self.layout)?
};
resolved.materialize(self.advice_map)
}
}
pub(super) fn read_untrusted_with_flags<R: ByteReader>(
source: &mut R,
) -> Result<(super::UntrustedMastForest, u8), DeserializationError> {
let (flags, forest) = decode_from_reader(source, true)?;
log_untrusted_overspecification(flags);
Ok((forest, flags.bits()))
}
pub(super) fn read_untrusted_with_flags_and_allocation_budget<R: ByteReader>(
source: &mut R,
allocation_budget: usize,
) -> Result<(super::UntrustedMastForest, u8), DeserializationError> {
let (flags, forest) = decode_from_reader_inner(source, true, Some(allocation_budget))?;
log_untrusted_overspecification(flags);
Ok((forest, flags.bits()))
}
fn log_untrusted_overspecification(flags: WireFlags) {
if !flags.is_hashless() {
log::error!(
"UntrustedMastForest expected HASHLESS input; supplied artifact includes wire node hashes, and validation will recompute them and require them to match"
);
}
}
fn decode_from_reader<R: ByteReader>(
source: &mut R,
allow_hashless: bool,
) -> Result<(WireFlags, super::UntrustedMastForest), DeserializationError> {
decode_from_reader_inner(source, allow_hashless, None)
}
fn decode_from_reader_inner<R: ByteReader>(
source: &mut R,
allow_hashless: bool,
remaining_allocation_budget: Option<usize>,
) -> Result<(WireFlags, super::UntrustedMastForest), DeserializationError> {
let mut recording = TrackingReader::new_recording(source);
let (flags, layout) = read_header_and_scan_layout(&mut recording, allow_hashless)?;
debug_assert_eq!(recording.offset(), layout.advice_map_offset());
let advice_map = AdviceMap::read_from(&mut recording)?;
Ok((
flags,
super::UntrustedMastForest {
bytes: recording.into_recorded(),
layout,
advice_map,
remaining_allocation_budget,
},
))
}
pub(super) fn reserve_allocation<T>(
remaining_budget: &mut usize,
count: usize,
label: &str,
) -> Result<(), DeserializationError> {
let bytes_needed = count
.checked_mul(size_of::<T>())
.ok_or_else(|| DeserializationError::InvalidValue(format!("{label} size overflow")))?;
if bytes_needed > *remaining_budget {
return Err(DeserializationError::InvalidValue(format!(
"{label} requires {bytes_needed} bytes, exceeding the remaining untrusted allocation budget of {} bytes",
*remaining_budget
)));
}
*remaining_budget -= bytes_needed;
Ok(())
}
pub(super) fn default_untrusted_allocation_budget(bytes_len: usize) -> usize {
bytes_len.saturating_mul(DEFAULT_UNTRUSTED_ALLOCATION_BUDGET_MULTIPLIER)
}
impl Deserializable for super::UntrustedMastForest {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
read_untrusted_with_flags(source).map(|(forest, _flags)| forest)
}
fn read_from_bytes(bytes: &[u8]) -> Result<Self, DeserializationError> {
super::UntrustedMastForest::read_from_bytes(bytes)
}
}