use alloc::vec::Vec;
use super::{MastForest, MastNode, MastNodeId};
use crate::{
advice::AdviceMap,
serde::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
};
pub(crate) mod asm_op;
pub(crate) mod decorator;
mod info;
use info::MastNodeInfo;
mod basic_blocks;
use basic_blocks::{BasicBlockDataBuilder, BasicBlockDataDecoder};
pub(crate) mod string_table;
pub(crate) use string_table::StringTable;
#[cfg(test)]
mod seed_gen;
#[cfg(test)]
mod tests;
type NodeDataOffset = u32;
type DecoratorDataOffset = u32;
type StringDataOffset = usize;
type StringIndex = usize;
const MAGIC: &[u8; 4] = b"MAST";
const FLAG_STRIPPED: u8 = 0x01;
const FLAGS_RESERVED_MASK: u8 = 0xfe;
const VERSION: [u8; 3] = [0, 0, 2];
impl Serializable for MastForest {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.write_into_with_options(target, false);
}
}
impl MastForest {
fn write_into_with_options<W: ByteWriter>(&self, target: &mut W, stripped: bool) {
let mut basic_block_data_builder = BasicBlockDataBuilder::new();
target.write_bytes(MAGIC);
target.write_u8(if stripped { FLAG_STRIPPED } else { 0x00 });
target.write_bytes(&VERSION);
target.write_usize(self.nodes.len());
target.write_usize(if stripped { 0 } else { self.debug_info.num_decorators() });
let roots: Vec<u32> = self.roots.iter().copied().map(u32::from).collect();
roots.write_into(target);
let mast_node_infos: Vec<MastNodeInfo> = self
.nodes
.iter()
.map(|mast_node| {
let ops_offset = if let MastNode::Block(basic_block) = mast_node {
basic_block_data_builder.encode_basic_block(basic_block)
} else {
0
};
MastNodeInfo::new(mast_node, ops_offset)
})
.collect();
let basic_block_data = basic_block_data_builder.finalize();
basic_block_data.write_into(target);
for mast_node_info in mast_node_infos {
mast_node_info.write_into(target);
}
self.advice_map.write_into(target);
if !stripped {
self.debug_info.write_into(target);
}
}
}
impl Deserializable for MastForest {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let flags = read_and_validate_header(source)?;
let is_stripped = flags & FLAG_STRIPPED != 0;
let node_count = source.read_usize()?;
if node_count > MastForest::MAX_NODES {
return Err(DeserializationError::InvalidValue(format!(
"node count {} exceeds maximum allowed {}",
node_count,
MastForest::MAX_NODES
)));
}
let _decorator_count = source.read_usize()?;
let roots: Vec<u32> = Deserializable::read_from(source)?;
let basic_block_data: Vec<u8> = Deserializable::read_from(source)?;
let mast_node_infos: Vec<MastNodeInfo> = node_infos_iter(source, node_count)
.collect::<Result<Vec<MastNodeInfo>, DeserializationError>>()?;
let advice_map = AdviceMap::read_from(source)?;
let debug_info = if is_stripped {
super::DebugInfo::empty_for_nodes(node_count)
} else {
super::DebugInfo::read_from(source)?
};
let mast_forest = {
let mut mast_forest = MastForest::new();
mast_forest.debug_info = debug_info;
let basic_block_data_decoder = BasicBlockDataDecoder::new(&basic_block_data);
let mast_builders = mast_node_infos
.into_iter()
.map(|node_info| {
node_info.try_into_mast_node_builder(node_count, &basic_block_data_decoder)
})
.collect::<Result<Vec<_>, _>>()?;
for mast_node_builder in mast_builders {
mast_node_builder.add_to_forest_relaxed(&mut mast_forest).map_err(|e| {
DeserializationError::InvalidValue(format!(
"failed to add node to MAST forest while deserializing: {e}",
))
})?;
}
for root in roots {
let root = MastNodeId::from_u32_safe(root, &mast_forest)?;
mast_forest.make_root(root);
}
mast_forest.advice_map = advice_map;
mast_forest
};
Ok(mast_forest)
}
}
fn read_and_validate_header<R: ByteReader>(source: &mut R) -> Result<u8, DeserializationError> {
let magic: [u8; 4] = source.read_array()?;
if magic != *MAGIC {
return Err(DeserializationError::InvalidValue(format!(
"Invalid magic bytes. Expected '{:?}', got '{:?}'",
*MAGIC, magic
)));
}
let flags: u8 = source.read_u8()?;
if flags & FLAGS_RESERVED_MASK != 0 {
return Err(DeserializationError::InvalidValue(format!(
"Unknown flags set in MAST header: {:#04x}. Reserved bits must be zero.",
flags & FLAGS_RESERVED_MASK
)));
}
let version: [u8; 3] = source.read_array()?;
if version != VERSION {
return Err(DeserializationError::InvalidValue(format!(
"Unsupported version. Got '{version:?}', but only '{VERSION:?}' is supported",
)));
}
Ok(flags)
}
fn node_infos_iter<'a, R>(
source: &'a mut R,
node_count: usize,
) -> impl Iterator<Item = Result<MastNodeInfo, DeserializationError>> + 'a
where
R: ByteReader + 'a,
{
let mut remaining = node_count;
core::iter::from_fn(move || {
if remaining == 0 {
return None;
}
remaining -= 1;
Some(MastNodeInfo::read_from(source))
})
}
impl Deserializable for super::UntrustedMastForest {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let forest = MastForest::read_from(source)?;
Ok(super::UntrustedMastForest(forest))
}
fn read_from_bytes(bytes: &[u8]) -> Result<Self, DeserializationError> {
super::UntrustedMastForest::read_from_bytes(bytes)
}
}
pub(super) struct StrippedMastForest<'a>(pub(super) &'a MastForest);
impl Serializable for StrippedMastForest<'_> {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.0.write_into_with_options(target, true);
}
}