use std::string::ToString;
use super::*;
use crate::{
Felt, ONE, Word,
chiplets::hasher,
mast::{
BasicBlockNodeBuilder, CallNodeBuilder, DynNodeBuilder, ExternalNodeBuilder,
JoinNodeBuilder, LoopNodeBuilder, MastForestContributor, MastForestError, MastNodeExt,
MastNodeId, OP_BATCH_SIZE, OpBatch, SplitNodeBuilder, UntrustedMastForest,
},
operations::{DebugOptions, Decorator, Operation},
serde::{ByteReader, DeserializationError, SliceReader},
utils::Idx,
};
#[test]
fn confirm_operation_and_decorator_structure() {
match Operation::Noop {
Operation::Noop => (),
Operation::Assert(_) => (),
Operation::SDepth => (),
Operation::Caller => (),
Operation::Clk => (),
Operation::Add => (),
Operation::Neg => (),
Operation::Mul => (),
Operation::Inv => (),
Operation::Incr => (),
Operation::And => (),
Operation::Or => (),
Operation::Not => (),
Operation::Eq => (),
Operation::Eqz => (),
Operation::Expacc => (),
Operation::Ext2Mul => (),
Operation::U32split => (),
Operation::U32add => (),
Operation::U32assert2(_) => (),
Operation::U32add3 => (),
Operation::U32sub => (),
Operation::U32mul => (),
Operation::U32madd => (),
Operation::U32div => (),
Operation::U32and => (),
Operation::U32xor => (),
Operation::Pad => (),
Operation::Drop => (),
Operation::Dup0 => (),
Operation::Dup1 => (),
Operation::Dup2 => (),
Operation::Dup3 => (),
Operation::Dup4 => (),
Operation::Dup5 => (),
Operation::Dup6 => (),
Operation::Dup7 => (),
Operation::Dup9 => (),
Operation::Dup11 => (),
Operation::Dup13 => (),
Operation::Dup15 => (),
Operation::Swap => (),
Operation::SwapW => (),
Operation::SwapW2 => (),
Operation::SwapW3 => (),
Operation::SwapDW => (),
Operation::MovUp2 => (),
Operation::MovUp3 => (),
Operation::MovUp4 => (),
Operation::MovUp5 => (),
Operation::MovUp6 => (),
Operation::MovUp7 => (),
Operation::MovUp8 => (),
Operation::MovDn2 => (),
Operation::MovDn3 => (),
Operation::MovDn4 => (),
Operation::MovDn5 => (),
Operation::MovDn6 => (),
Operation::MovDn7 => (),
Operation::MovDn8 => (),
Operation::CSwap => (),
Operation::CSwapW => (),
Operation::Push(_) => (),
Operation::AdvPop => (),
Operation::AdvPopW => (),
Operation::MLoadW => (),
Operation::MStoreW => (),
Operation::MLoad => (),
Operation::MStore => (),
Operation::MStream => (),
Operation::Pipe => (),
Operation::CryptoStream => (),
Operation::HPerm => (),
Operation::MpVerify(_) => (),
Operation::MrUpdate => (),
Operation::FriE2F4 => (),
Operation::HornerBase => (),
Operation::HornerExt => (),
Operation::EvalCircuit => (),
Operation::Emit => (),
Operation::LogPrecompile => (),
};
match Decorator::Trace(0) {
Decorator::Debug(debug_options) => match debug_options {
DebugOptions::StackAll => (),
DebugOptions::StackTop(_) => (),
DebugOptions::MemAll => (),
DebugOptions::MemInterval(..) => (),
DebugOptions::LocalInterval(..) => (),
DebugOptions::AdvStackTop(_) => (),
},
Decorator::Trace(_) => (),
};
}
#[test]
fn serialize_deserialize_all_nodes() {
let mut mast_forest = MastForest::new();
let basic_block_id = {
let operations = vec![
Operation::Noop,
Operation::Assert(Felt::from_u32(42)),
Operation::SDepth,
Operation::Caller,
Operation::Clk,
Operation::Add,
Operation::Neg,
Operation::Mul,
Operation::Inv,
Operation::Incr,
Operation::And,
Operation::Or,
Operation::Not,
Operation::Eq,
Operation::Eqz,
Operation::Expacc,
Operation::Ext2Mul,
Operation::U32split,
Operation::U32add,
Operation::U32assert2(Felt::from_u32(222)),
Operation::U32add3,
Operation::U32sub,
Operation::U32mul,
Operation::U32madd,
Operation::U32div,
Operation::U32and,
Operation::U32xor,
Operation::Pad,
Operation::Drop,
Operation::Dup0,
Operation::Dup1,
Operation::Dup2,
Operation::Dup3,
Operation::Dup4,
Operation::Dup5,
Operation::Dup6,
Operation::Dup7,
Operation::Dup9,
Operation::Dup11,
Operation::Dup13,
Operation::Dup15,
Operation::Swap,
Operation::SwapW,
Operation::SwapW2,
Operation::SwapW3,
Operation::SwapDW,
Operation::MovUp2,
Operation::MovUp3,
Operation::MovUp4,
Operation::MovUp5,
Operation::MovUp6,
Operation::MovUp7,
Operation::MovUp8,
Operation::MovDn2,
Operation::MovDn3,
Operation::MovDn4,
Operation::MovDn5,
Operation::MovDn6,
Operation::MovDn7,
Operation::MovDn8,
Operation::CSwap,
Operation::CSwapW,
Operation::Push(Felt::new(45)),
Operation::AdvPop,
Operation::AdvPopW,
Operation::MLoadW,
Operation::MStoreW,
Operation::MLoad,
Operation::MStore,
Operation::MStream,
Operation::Pipe,
Operation::HPerm,
Operation::MpVerify(Felt::from_u32(1022)),
Operation::MrUpdate,
Operation::FriE2F4,
Operation::HornerBase,
Operation::HornerExt,
Operation::Emit,
];
let num_operations = operations.len();
let decorators = vec![
(0, Decorator::Debug(DebugOptions::StackAll)),
(15, Decorator::Debug(DebugOptions::StackTop(255))),
(15, Decorator::Debug(DebugOptions::MemAll)),
(15, Decorator::Debug(DebugOptions::MemInterval(0, 16))),
(17, Decorator::Debug(DebugOptions::LocalInterval(1, 2, 3))),
(19, Decorator::Debug(DebugOptions::AdvStackTop(255))),
(num_operations, Decorator::Trace(55)),
];
{
let decorator_list: Vec<(usize, crate::mast::DecoratorId)> = decorators
.into_iter()
.map(|(idx, decorator)| -> Result<(usize, crate::mast::DecoratorId), MastForestError> {
let decorator_id = mast_forest.add_decorator(decorator)?;
Ok((idx, decorator_id))
})
.collect::<Result<Vec<_>, MastForestError>>()
.unwrap();
BasicBlockNodeBuilder::new(operations, decorator_list)
.add_to_forest(&mut mast_forest)
.unwrap()
}
};
let decorator_id1 = mast_forest.add_decorator(Decorator::Trace(1)).unwrap();
let decorator_id2 = mast_forest.add_decorator(Decorator::Trace(2)).unwrap();
let call_node_id = CallNodeBuilder::new(basic_block_id)
.with_before_enter(vec![decorator_id1])
.with_after_exit(vec![decorator_id2])
.add_to_forest(&mut mast_forest)
.unwrap();
let syscall_node_id = CallNodeBuilder::new_syscall(basic_block_id)
.with_before_enter(vec![decorator_id1])
.with_after_exit(vec![decorator_id2])
.add_to_forest(&mut mast_forest)
.unwrap();
let loop_node_id = LoopNodeBuilder::new(basic_block_id)
.with_before_enter(vec![decorator_id1])
.with_after_exit(vec![decorator_id2])
.add_to_forest(&mut mast_forest)
.unwrap();
let join_node_id = JoinNodeBuilder::new([basic_block_id, call_node_id])
.with_before_enter(vec![decorator_id1])
.with_after_exit(vec![decorator_id2])
.add_to_forest(&mut mast_forest)
.unwrap();
let split_node_id = SplitNodeBuilder::new([basic_block_id, call_node_id])
.with_before_enter(vec![decorator_id1])
.with_after_exit(vec![decorator_id2])
.add_to_forest(&mut mast_forest)
.unwrap();
let dyn_node_id = DynNodeBuilder::new_dyn()
.with_before_enter(vec![decorator_id1])
.with_after_exit(vec![decorator_id2])
.add_to_forest(&mut mast_forest)
.unwrap();
let dyncall_node_id = DynNodeBuilder::new_dyncall()
.with_before_enter(vec![decorator_id1])
.with_after_exit(vec![decorator_id2])
.add_to_forest(&mut mast_forest)
.unwrap();
let external_node_id = ExternalNodeBuilder::new(Word::default())
.with_before_enter(vec![decorator_id1])
.with_after_exit(vec![decorator_id2])
.add_to_forest(&mut mast_forest)
.unwrap();
mast_forest.make_root(join_node_id);
mast_forest.make_root(syscall_node_id);
mast_forest.make_root(loop_node_id);
mast_forest.make_root(split_node_id);
mast_forest.make_root(dyn_node_id);
mast_forest.make_root(dyncall_node_id);
mast_forest.make_root(external_node_id);
let serialized_mast_forest = mast_forest.to_bytes();
let deserialized_mast_forest = MastForest::read_from_bytes(&serialized_mast_forest).unwrap();
assert_eq!(mast_forest, deserialized_mast_forest);
}
#[test]
fn mast_forest_serialize_deserialize_with_child_ids_exceeding_parent_id() {
let mut forest = MastForest::new();
let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap();
let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap();
let zero = BasicBlockNodeBuilder::new(vec![Operation::U32div], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let first = BasicBlockNodeBuilder::new(vec![Operation::U32add], vec![(0, deco0)])
.add_to_forest(&mut forest)
.unwrap();
let second = BasicBlockNodeBuilder::new(vec![Operation::U32and], vec![(1, deco1)])
.add_to_forest(&mut forest)
.unwrap();
JoinNodeBuilder::new([first, second]).add_to_forest(&mut forest).unwrap();
forest.nodes.swap_remove(zero.to_usize());
MastForest::read_from_bytes(&forest.to_bytes()).unwrap();
}
#[test]
fn mast_forest_serialize_deserialize_with_overflowing_ids_fails() {
let mut overflow_forest = MastForest::new();
let id0 = BasicBlockNodeBuilder::new(vec![Operation::Eqz], Vec::new())
.add_to_forest(&mut overflow_forest)
.unwrap();
BasicBlockNodeBuilder::new(vec![Operation::Eqz], Vec::new())
.add_to_forest(&mut overflow_forest)
.unwrap();
let id2 = BasicBlockNodeBuilder::new(vec![Operation::Eqz], Vec::new())
.add_to_forest(&mut overflow_forest)
.unwrap();
let id_join = JoinNodeBuilder::new([id0, id2]).add_to_forest(&mut overflow_forest).unwrap();
let join_node = overflow_forest[id_join].clone();
let mut forest = MastForest::new();
let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap();
let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap();
BasicBlockNodeBuilder::new(vec![Operation::U32add], vec![(0, deco0), (1, deco1)])
.add_to_forest(&mut forest)
.unwrap();
forest.nodes.push(join_node).unwrap();
assert_matches!(
MastForest::read_from_bytes(&forest.to_bytes()),
Err(DeserializationError::InvalidValue(msg)) if msg.contains("number of nodes")
);
}
#[test]
fn mast_forest_invalid_node_id() {
let mut forest = MastForest::new();
let first = BasicBlockNodeBuilder::new(vec![Operation::U32div], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let second = BasicBlockNodeBuilder::new(vec![Operation::U32div], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let mut overflow_forest = MastForest::new();
BasicBlockNodeBuilder::new(vec![Operation::U32div], Vec::new())
.add_to_forest(&mut overflow_forest)
.unwrap();
BasicBlockNodeBuilder::new(vec![Operation::U32div], Vec::new())
.add_to_forest(&mut overflow_forest)
.unwrap();
BasicBlockNodeBuilder::new(vec![Operation::U32div], Vec::new())
.add_to_forest(&mut overflow_forest)
.unwrap();
let overflow = BasicBlockNodeBuilder::new(vec![Operation::U32div], Vec::new())
.add_to_forest(&mut overflow_forest)
.unwrap();
let join = JoinNodeBuilder::new([overflow, second]).add_to_forest(&mut forest);
assert_eq!(join, Err(MastForestError::NodeIdOverflow(overflow, 2)));
let join = JoinNodeBuilder::new([first, overflow]).add_to_forest(&mut forest);
assert_eq!(join, Err(MastForestError::NodeIdOverflow(overflow, 2)));
let split = SplitNodeBuilder::new([overflow, second]).add_to_forest(&mut forest);
assert_eq!(split, Err(MastForestError::NodeIdOverflow(overflow, 2)));
let split = SplitNodeBuilder::new([first, overflow]).add_to_forest(&mut forest);
assert_eq!(split, Err(MastForestError::NodeIdOverflow(overflow, 2)));
assert_eq!(
LoopNodeBuilder::new(overflow).add_to_forest(&mut forest),
Err(MastForestError::NodeIdOverflow(overflow, 2))
);
assert_eq!(
CallNodeBuilder::new(overflow).add_to_forest(&mut forest),
Err(MastForestError::NodeIdOverflow(overflow, 2))
);
assert_eq!(
CallNodeBuilder::new_syscall(overflow).add_to_forest(&mut forest),
Err(MastForestError::NodeIdOverflow(overflow, 2))
);
JoinNodeBuilder::new([first, second]).add_to_forest(&mut forest).unwrap();
}
#[test]
fn mast_forest_serialize_deserialize_advice_map() {
let mut forest = MastForest::new();
let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap();
let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap();
let first = BasicBlockNodeBuilder::new(vec![Operation::U32add], vec![(0, deco0)])
.add_to_forest(&mut forest)
.unwrap();
let second = BasicBlockNodeBuilder::new(vec![Operation::U32and], vec![(1, deco1)])
.add_to_forest(&mut forest)
.unwrap();
JoinNodeBuilder::new([first, second]).add_to_forest(&mut forest).unwrap();
let key = Word::new([ONE, ONE, ONE, ONE]);
let value = vec![ONE, ONE];
forest.advice_map_mut().insert(key, value);
let parsed = MastForest::read_from_bytes(&forest.to_bytes()).unwrap();
assert_eq!(forest.advice_map, parsed.advice_map);
}
#[test]
fn mast_forest_basic_block_serialization_no_decorator_duplication() {
let mut forest = MastForest::new();
let before_enter_deco = forest.add_decorator(Decorator::Trace(1)).unwrap();
let op_deco = forest.add_decorator(Decorator::Trace(2)).unwrap();
let after_exit_deco = forest.add_decorator(Decorator::Trace(3)).unwrap();
let operations = vec![Operation::Add, Operation::Mul];
let block_id = BasicBlockNodeBuilder::new(operations, vec![(0, op_deco)])
.with_before_enter(vec![before_enter_deco])
.with_after_exit(vec![after_exit_deco])
.add_to_forest(&mut forest)
.unwrap();
forest.make_root(block_id);
let serialized = forest.to_bytes();
let deserialized = MastForest::read_from_bytes(&serialized).unwrap();
let deserialized_root_id = deserialized.procedure_roots()[0];
let deserialized_block =
if let crate::mast::MastNode::Block(block) = &deserialized[deserialized_root_id] {
block
} else {
panic!("Expected a block node");
};
assert_eq!(
deserialized_block.before_enter(&deserialized),
&[before_enter_deco],
"before_enter decorator should appear exactly once"
);
assert_eq!(
deserialized_block.after_exit(&deserialized),
&[after_exit_deco],
"after_exit decorator should appear exactly once"
);
let indexed_decorators: Vec<_> =
deserialized_block.indexed_decorator_iter(&deserialized).collect();
assert_eq!(indexed_decorators.len(), 1, "Should have exactly one op-indexed decorator");
assert_eq!(indexed_decorators[0].1, op_deco, "Op-indexed decorator should be preserved");
assert!(
!indexed_decorators.iter().any(|&(_, id)| id == before_enter_deco),
"before_enter decorator should not be duplicated in indexed decorators"
);
assert!(
!indexed_decorators.iter().any(|&(_, id)| id == after_exit_deco),
"after_exit decorator should not be duplicated in indexed decorators"
);
}
#[test]
fn mast_forest_deserialize_invalid_ops_offset_fails() {
let mut forest = MastForest::new();
let block_id = BasicBlockNodeBuilder::new(vec![Operation::Add, Operation::Mul], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
forest.make_root(block_id);
let serialized = forest.to_bytes();
let mut reader = SliceReader::new(&serialized);
let _: [u8; 8] = reader.read_array().unwrap(); let _node_count: usize = reader.read().unwrap();
let _decorator_count: usize = reader.read().unwrap();
let _roots: Vec<u32> = Deserializable::read_from(&mut reader).unwrap();
let basic_block_data: Vec<u8> = Deserializable::read_from(&mut reader).unwrap();
let node_info_offset = 4 + 1 + 3 + 8 + 8 + 8 + 4 + 8 + basic_block_data.len();
let block_discriminant: u64 = 3;
let corrupted_value = (block_discriminant << 60) | u32::MAX as u64;
let mut corrupted = serialized;
corrupted_value.write_into(&mut &mut corrupted[node_info_offset..node_info_offset + 8]);
let result = MastForest::read_from_bytes(&corrupted);
assert_matches!(result, Err(DeserializationError::InvalidValue(_)));
}
#[test]
fn mast_forest_serialize_deserialize_procedure_names() {
let mut forest = MastForest::new();
let block_id = BasicBlockNodeBuilder::new(vec![Operation::Add, Operation::Mul], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
forest.make_root(block_id);
let digest = forest[block_id].digest();
forest.insert_procedure_name(digest, "test_procedure".into());
assert_eq!(forest.procedure_name(&digest), Some("test_procedure"));
assert_eq!(forest.debug_info.num_procedure_names(), 1);
let serialized = forest.to_bytes();
let deserialized = MastForest::read_from_bytes(&serialized).unwrap();
assert_eq!(deserialized.procedure_name(&digest), Some("test_procedure"));
assert_eq!(deserialized.debug_info.num_procedure_names(), 1);
assert_eq!(forest, deserialized);
}
#[test]
fn mast_forest_serialize_deserialize_multiple_procedure_names() {
let mut forest = MastForest::new();
let block1_id = BasicBlockNodeBuilder::new(vec![Operation::Add], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let block2_id = BasicBlockNodeBuilder::new(vec![Operation::Mul], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let block3_id = BasicBlockNodeBuilder::new(vec![Operation::U32sub], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
forest.make_root(block1_id);
forest.make_root(block2_id);
forest.make_root(block3_id);
let digest1 = forest[block1_id].digest();
let digest2 = forest[block2_id].digest();
let digest3 = forest[block3_id].digest();
forest.insert_procedure_name(digest1, "proc_add".into());
forest.insert_procedure_name(digest2, "proc_mul".into());
forest.insert_procedure_name(digest3, "proc_sub".into());
assert_eq!(forest.debug_info.num_procedure_names(), 3);
let serialized = forest.to_bytes();
let deserialized = MastForest::read_from_bytes(&serialized).unwrap();
assert_eq!(deserialized.procedure_name(&digest1), Some("proc_add"));
assert_eq!(deserialized.procedure_name(&digest2), Some("proc_mul"));
assert_eq!(deserialized.procedure_name(&digest3), Some("proc_sub"));
assert_eq!(deserialized.debug_info.num_procedure_names(), 3);
assert_eq!(forest, deserialized);
}
#[test]
fn test_opbatch_roundtrip_preservation() {
let mut forest = MastForest::new();
let operations = vec![
Operation::Add,
Operation::Push(Felt::new(100)),
Operation::Push(Felt::new(200)),
Operation::Mul,
];
let block_id = BasicBlockNodeBuilder::new(operations, Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let original = forest[block_id].unwrap_basic_block();
let deserialized_forest = MastForest::read_from_bytes(&forest.to_bytes()).unwrap();
let deserialized = deserialized_forest[block_id].unwrap_basic_block();
assert_eq!(original.op_batches(), deserialized.op_batches());
}
#[test]
fn test_multi_batch_roundtrip() {
let mut forest = MastForest::new();
let operations: Vec<_> = (0..80).map(|i| Operation::Push(Felt::new(i))).collect();
let block_id = BasicBlockNodeBuilder::new(operations, Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let original = forest[block_id].unwrap_basic_block();
assert!(original.op_batches().len() > 1, "Should have multiple batches");
let deserialized_forest = MastForest::read_from_bytes(&forest.to_bytes()).unwrap();
let deserialized = deserialized_forest[block_id].unwrap_basic_block();
assert_eq!(original.op_batches(), deserialized.op_batches());
}
#[test]
fn test_decorator_indices_preserved_with_padding() {
let mut forest = MastForest::new();
let decorator_id = forest.add_decorator(Decorator::Trace(42)).unwrap();
let operations = vec![
Operation::Add,
Operation::Mul,
Operation::Push(Felt::new(100)), Operation::Drop,
];
let decorators = vec![(2, decorator_id)];
let block_id = BasicBlockNodeBuilder::new(operations.clone(), decorators)
.add_to_forest(&mut forest)
.unwrap();
let serialized = forest.to_bytes();
let deserialized_forest = MastForest::read_from_bytes(&serialized).unwrap();
let original_node = forest[block_id].unwrap_basic_block();
let deserialized_node = deserialized_forest[block_id].unwrap_basic_block();
let original_decorators: Vec<_> = original_node.indexed_decorator_iter(&forest).collect();
let deserialized_decorators: Vec<_> =
deserialized_node.indexed_decorator_iter(&deserialized_forest).collect();
assert_eq!(
original_decorators, deserialized_decorators,
"Decorator indices should be preserved"
);
assert_eq!(deserialized_decorators.len(), 1, "Should have one decorator");
let (padded_idx, _) = deserialized_decorators[0];
let op_at_decorator = deserialized_node.operations().nth(padded_idx).unwrap();
assert!(
matches!(op_at_decorator, Operation::Push(_)),
"Decorator should point to PUSH operation"
);
}
#[test]
fn test_raw_vs_batched_construction_equivalence() {
let mut forest1 = MastForest::new();
let mut forest2 = MastForest::new();
let decorator_id1 = forest1.add_decorator(Decorator::Trace(1)).unwrap();
let _decorator_id2 = forest2.add_decorator(Decorator::Trace(1)).unwrap();
let operations =
vec![Operation::Add, Operation::Mul, Operation::Push(Felt::new(100)), Operation::Drop];
let block_id1 = BasicBlockNodeBuilder::new(operations.clone(), vec![(2, decorator_id1)])
.add_to_forest(&mut forest1)
.unwrap();
let serialized = forest1.to_bytes();
let _deserialized_forest = MastForest::read_from_bytes(&serialized).unwrap();
let original_node = forest1[block_id1].unwrap_basic_block();
let op_batches = original_node.op_batches().to_vec();
let digest = original_node.digest();
let decorators: Vec<_> = original_node.indexed_decorator_iter(&forest1).collect();
let block_id2 = BasicBlockNodeBuilder::from_op_batches(op_batches, decorators, digest)
.add_to_forest(&mut forest2)
.unwrap();
let node1 = forest1[block_id1].unwrap_basic_block();
let node2 = forest2[block_id2].unwrap_basic_block();
let ops1: Vec<_> = node1.operations().collect();
let ops2: Vec<_> = node2.operations().collect();
assert_eq!(ops1, ops2, "Operations should match");
assert_eq!(node1.op_batches(), node2.op_batches(), "OpBatch structures should match");
assert_eq!(node1.digest(), node2.digest(), "Digests should match");
let decorators1: Vec<_> = node1.indexed_decorator_iter(&forest1).collect();
let decorators2: Vec<_> = node2.indexed_decorator_iter(&forest2).collect();
assert_eq!(decorators1, decorators2, "Decorators should match");
}
#[test]
fn test_raw_batched_digest_equivalence() {
let operations = vec![
Operation::Add,
Operation::Mul,
Operation::Push(Felt::new(42)),
Operation::Drop,
Operation::Dup0,
];
let mut forest1 = MastForest::new();
let block_id1 = BasicBlockNodeBuilder::new(operations.clone(), Vec::new())
.add_to_forest(&mut forest1)
.unwrap();
let digest1 = forest1[block_id1].unwrap_basic_block().digest();
let serialized = forest1.to_bytes();
let deserialized = MastForest::read_from_bytes(&serialized).unwrap();
let digest2 = deserialized[block_id1].unwrap_basic_block().digest();
assert_eq!(digest1, digest2, "Digests from Raw and Batched paths should match");
}
#[test]
fn test_batched_construction_preserves_structure() {
let mut forest = MastForest::new();
let operations = vec![
Operation::Add,
Operation::Mul,
Operation::Push(Felt::new(100)),
Operation::Push(Felt::new(200)),
];
let block_id = BasicBlockNodeBuilder::new(operations, Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let original_node = forest[block_id].unwrap_basic_block();
let original_batches = original_node.op_batches().to_vec();
let original_digest = original_node.digest();
let mut forest2 = MastForest::new();
let block_id2 = BasicBlockNodeBuilder::from_op_batches(
original_batches.clone(),
Vec::new(),
original_digest,
)
.add_to_forest(&mut forest2)
.unwrap();
let new_node = forest2[block_id2].unwrap_basic_block();
assert_eq!(
original_batches,
new_node.op_batches(),
"OpBatch structure should be exactly preserved"
);
}
#[test]
fn test_header_backward_compatible() {
let mut forest = MastForest::new();
let block_id = BasicBlockNodeBuilder::new(vec![Operation::Add], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
forest.make_root(block_id);
let bytes = forest.to_bytes();
assert_eq!(&bytes[0..4], b"MAST", "Magic should be MAST");
assert_eq!(bytes[4], 0x00, "Flags should be 0x00 for full serialization");
assert_eq!(&bytes[5..8], &[0, 0, 2], "Version should be [0, 0, 2]");
}
#[test]
fn test_stripped_serialization_smaller_than_full() {
let mut forest = MastForest::new();
let decorator_id = forest.add_decorator(Decorator::Trace(42)).unwrap();
let operations = vec![Operation::Add, Operation::Mul, Operation::Drop];
let block_id = BasicBlockNodeBuilder::new(operations, vec![(0, decorator_id)])
.add_to_forest(&mut forest)
.unwrap();
forest.make_root(block_id);
let digest = forest[block_id].digest();
forest.insert_procedure_name(digest, "test_proc".into());
let full_bytes = forest.to_bytes();
let mut stripped_bytes = Vec::new();
forest.write_stripped(&mut stripped_bytes);
assert!(
stripped_bytes.len() < full_bytes.len(),
"Stripped ({} bytes) should be smaller than full ({} bytes)",
stripped_bytes.len(),
full_bytes.len()
);
}
#[test]
fn test_stripped_serialization_roundtrip() {
let mut forest = MastForest::new();
let decorator_id = forest.add_decorator(Decorator::Trace(42)).unwrap();
let operations = vec![Operation::Add, Operation::Mul, Operation::Drop];
let block_id = BasicBlockNodeBuilder::new(operations, vec![(0, decorator_id)])
.add_to_forest(&mut forest)
.unwrap();
forest.make_root(block_id);
let digest = forest[block_id].digest();
forest.insert_procedure_name(digest, "test_proc".into());
let _ = forest.register_error("test error".into());
let mut stripped_bytes = Vec::new();
forest.write_stripped(&mut stripped_bytes);
let restored = MastForest::read_from_bytes(&stripped_bytes).unwrap();
assert_eq!(forest.num_nodes(), restored.num_nodes());
assert_eq!(forest.procedure_roots().len(), restored.procedure_roots().len());
assert!(
restored.debug_info.is_empty(),
"DebugInfo should be empty after stripped roundtrip"
);
assert_eq!(restored.decorators().len(), 0);
assert_eq!(restored.procedure_name(&digest), None);
}
#[test]
fn test_stripped_header_flags() {
let mut forest = MastForest::new();
let block_id = BasicBlockNodeBuilder::new(vec![Operation::Add], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
forest.make_root(block_id);
let mut stripped_bytes = Vec::new();
forest.write_stripped(&mut stripped_bytes);
assert_eq!(&stripped_bytes[0..4], b"MAST", "Magic should be MAST");
assert_eq!(stripped_bytes[4], 0x01, "Flags should be 0x01 for stripped serialization");
assert_eq!(&stripped_bytes[5..8], &[0, 0, 2], "Version should be [0, 0, 2]");
}
#[test]
fn test_stripped_preserves_digests() {
let mut forest = MastForest::new();
let decorator_id = forest.add_decorator(Decorator::Trace(1)).unwrap();
let block1_id = BasicBlockNodeBuilder::new(vec![Operation::Add], vec![(0, decorator_id)])
.add_to_forest(&mut forest)
.unwrap();
let block2_id = BasicBlockNodeBuilder::new(vec![Operation::Mul], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let join_id = JoinNodeBuilder::new([block1_id, block2_id]).add_to_forest(&mut forest).unwrap();
forest.make_root(join_id);
let original_digests: Vec<_> = forest.nodes().iter().map(|n| n.digest()).collect();
let mut stripped_bytes = Vec::new();
forest.write_stripped(&mut stripped_bytes);
let restored = MastForest::read_from_bytes(&stripped_bytes).unwrap();
let restored_digests: Vec<_> = restored.nodes().iter().map(|n| n.digest()).collect();
assert_eq!(original_digests, restored_digests, "Node digests should be preserved");
}
#[test]
fn test_deserialize_rejects_unknown_flags() {
let mut forest = MastForest::new();
let block_id = BasicBlockNodeBuilder::new(vec![Operation::Add], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
forest.make_root(block_id);
let mut bytes = forest.to_bytes();
bytes[4] = 0x02;
let result = MastForest::read_from_bytes(&bytes);
assert_matches!(
result,
Err(DeserializationError::InvalidValue(msg)) if msg.contains("reserved") || msg.contains("flags")
);
}
mod proptests {
use proptest::{prelude::*, strategy::Just};
use super::*;
use crate::{
mast::{BasicBlockNodeBuilder, MastForest, MastNode, arbitrary::MastForestParams},
operations::Decorator,
};
proptest! {
#[test]
fn proptest_mast_forest_roundtrip(
forest in any_with::<MastForest>(MastForestParams {
decorators: 5,
blocks: 1..=5,
max_joins: 3,
max_splits: 2,
max_loops: 2,
max_calls: 2,
max_syscalls: 0, // Avoid syscalls in roundtrip tests
max_externals: 1,
max_dyns: 1,
})
) {
let serialized = forest.to_bytes();
let deserialized = MastForest::read_from_bytes(&serialized)
.expect("Deserialization should succeed");
prop_assert_eq!(
forest.num_nodes(),
deserialized.num_nodes(),
"Node count should match"
);
for (idx, original) in forest.nodes().iter().enumerate() {
let node_id = crate::mast::MastNodeId::new_unchecked(idx as u32);
let deserialized_node = &deserialized[node_id];
prop_assert_eq!(
original.digest(),
deserialized_node.digest(),
"Node {:?} digest mismatch", node_id
);
if let MastNode::Block(original_block) = original
&& let MastNode::Block(deserialized_block) = deserialized_node
{
prop_assert_eq!(
original_block.op_batches(),
deserialized_block.op_batches(),
"Node {:?}: OpBatch mismatch", node_id
);
let orig_decorators: Vec<_> =
original_block.indexed_decorator_iter(&forest).collect();
let deser_decorators: Vec<_> =
deserialized_block.indexed_decorator_iter(&deserialized).collect();
prop_assert_eq!(
orig_decorators.len(),
deser_decorators.len(),
"Node {:?}: Decorator count mismatch", node_id
);
for ((orig_idx, orig_dec_id), (deser_idx, deser_dec_id)) in
orig_decorators.iter().zip(&deser_decorators)
{
prop_assert_eq!(orig_idx, deser_idx, "Node {:?}: Decorator index mismatch", node_id);
prop_assert_eq!(
forest.decorator_by_id(*orig_dec_id),
deserialized.decorator_by_id(*deser_dec_id),
"Node {:?}: Decorator content mismatch", node_id
);
}
}
}
}
#[test]
fn proptest_multi_batch_roundtrip(
ops in prop::collection::vec(
prop::sample::select(vec![
Operation::Add,
Operation::Mul,
Operation::Push(crate::Felt::new(42)),
Operation::Drop,
Operation::Dup0,
Operation::Swap,
]),
73..=150 )
) {
let mut forest = MastForest::new();
let block_id = BasicBlockNodeBuilder::new(ops, Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let original_block = forest[block_id].unwrap_basic_block();
let original_batches = original_block.op_batches();
prop_assume!(original_batches.len() > 1, "Need multiple batches for this test");
let serialized = forest.to_bytes();
let deserialized_forest = MastForest::read_from_bytes(&serialized)
.expect("Deserialization should succeed");
let deserialized_block = deserialized_forest[block_id].unwrap_basic_block();
let deserialized_batches = deserialized_block.op_batches();
prop_assert_eq!(
original_batches.len(),
deserialized_batches.len(),
"Batch count should match"
);
for (i, (orig_batch, deser_batch)) in
original_batches.iter().zip(deserialized_batches).enumerate()
{
prop_assert_eq!(
orig_batch.ops(),
deser_batch.ops(),
"Batch {}: Operations should match exactly", i
);
prop_assert_eq!(
orig_batch.indptr(),
deser_batch.indptr(),
"Batch {}: Indptr arrays should match exactly", i
);
prop_assert_eq!(
orig_batch.padding(),
deser_batch.padding(),
"Batch {}: Padding metadata should match exactly", i
);
prop_assert_eq!(
orig_batch.groups(),
deser_batch.groups(),
"Batch {}: Groups arrays should match exactly", i
);
prop_assert_eq!(
orig_batch.num_groups(),
deser_batch.num_groups(),
"Batch {}: num_groups should match exactly", i
);
}
}
#[test]
fn proptest_decorator_indices_roundtrip(
(ops, decorator_indices) in (
prop::collection::vec(
prop::sample::select(vec![
Operation::Add,
Operation::Mul,
Operation::Push(Felt::new(99)),
Operation::Drop,
Operation::Dup0,
]),
10..=50
)
).prop_flat_map(|ops| {
let ops_len = ops.len();
(
Just(ops),
prop::collection::vec((0..ops_len, 0..5_u32), 1..=10)
)
})
) {
let mut forest = MastForest::new();
let decorator_id1 = forest.add_decorator(Decorator::Trace(1)).unwrap();
let decorator_id2 = forest.add_decorator(Decorator::Trace(2)).unwrap();
let decorator_id3 = forest.add_decorator(Decorator::Trace(3)).unwrap();
let decorator_id4 = forest.add_decorator(Decorator::Trace(4)).unwrap();
let decorator_id5 = forest.add_decorator(Decorator::Trace(5)).unwrap();
let decorator_ids = [decorator_id1, decorator_id2, decorator_id3, decorator_id4, decorator_id5];
let mut decorators: Vec<(usize, _)> = decorator_indices
.into_iter()
.map(|(idx, dec_id_idx)| (idx, decorator_ids[dec_id_idx as usize]))
.collect();
decorators.sort_by_key(|(idx, _)| *idx);
decorators.dedup_by_key(|(idx, _)| *idx);
let block_id = BasicBlockNodeBuilder::new(ops, decorators)
.add_to_forest(&mut forest)
.unwrap();
let original_block = forest[block_id].unwrap_basic_block();
let serialized = forest.to_bytes();
let deserialized_forest = MastForest::read_from_bytes(&serialized)
.expect("Deserialization should succeed");
let deserialized_block = deserialized_forest[block_id].unwrap_basic_block();
let orig_decorators: Vec<_> =
original_block.indexed_decorator_iter(&forest).collect();
let deser_decorators: Vec<_> =
deserialized_block.indexed_decorator_iter(&deserialized_forest).collect();
prop_assert_eq!(
orig_decorators.len(),
deser_decorators.len(),
"Decorator count should match"
);
for ((orig_idx, orig_dec_id), (deser_idx, deser_dec_id)) in
orig_decorators.iter().zip(&deser_decorators)
{
prop_assert_eq!(
orig_idx,
deser_idx,
"Decorator indices should match (padded form)"
);
prop_assert_eq!(
forest.decorator_by_id(*orig_dec_id),
deserialized_forest.decorator_by_id(*deser_dec_id),
"Decorator content should match"
);
}
}
#[test]
fn proptest_stripped_roundtrip(
forest in any_with::<MastForest>(MastForestParams {
decorators: 10,
blocks: 1..=5,
max_joins: 3,
max_splits: 2,
max_loops: 2,
max_calls: 2,
max_syscalls: 0,
max_externals: 1,
max_dyns: 1,
})
) {
let mut stripped_bytes = Vec::new();
forest.write_stripped(&mut stripped_bytes);
let restored = MastForest::read_from_bytes(&stripped_bytes)
.expect("Stripped deserialization should succeed");
prop_assert_eq!(
forest.num_nodes(),
restored.num_nodes(),
"Node count should match"
);
for (idx, original) in forest.nodes().iter().enumerate() {
let node_id = crate::mast::MastNodeId::new_unchecked(idx as u32);
let restored_node = &restored[node_id];
prop_assert_eq!(
original.digest(),
restored_node.digest(),
"Node {:?} digest mismatch", node_id
);
}
prop_assert!(
restored.debug_info.is_empty(),
"DebugInfo should be empty after stripped roundtrip"
);
}
}
}
#[test]
fn test_debuginfo_serialization_empty() {
let mut forest = MastForest::new();
let ops = vec![Operation::Noop; 4];
let block_id = BasicBlockNodeBuilder::new(ops, Vec::new()).add_to_forest(&mut forest).unwrap();
forest.make_root(block_id);
let bytes = forest.to_bytes();
let deserialized = MastForest::read_from_bytes(&bytes).unwrap();
assert_eq!(forest.num_nodes(), deserialized.num_nodes());
assert_eq!(forest.decorators().len(), 0);
assert_eq!(deserialized.decorators().len(), 0);
}
#[test]
fn test_debuginfo_serialization_sparse() {
let mut forest = MastForest::new();
for i in 0..10 {
let ops = vec![Operation::Noop; 4];
if i % 5 == 0 {
let decorator_id = forest.add_decorator(Decorator::Trace(i)).unwrap();
BasicBlockNodeBuilder::new(ops, vec![(0, decorator_id)])
.add_to_forest(&mut forest)
.unwrap();
} else {
BasicBlockNodeBuilder::new(ops, Vec::new()).add_to_forest(&mut forest).unwrap();
}
}
let bytes = forest.to_bytes();
let deserialized = MastForest::read_from_bytes(&bytes).unwrap();
assert_eq!(forest.decorators().len(), 2);
assert_eq!(deserialized.decorators().len(), 2);
for i in 0..10 {
let node_id = crate::mast::MastNodeId::new_unchecked(i);
let orig_decorators = forest.decorator_indices_for_op(node_id, 0);
let deser_decorators = deserialized.decorator_indices_for_op(node_id, 0);
assert_eq!(orig_decorators, deser_decorators, "Decorators at node {} should match", i);
}
}
#[test]
fn test_debuginfo_serialization_dense() {
let mut forest = MastForest::new();
for i in 0..10 {
let ops = vec![Operation::Noop; 4];
if i < 8 {
let decorator_id = forest.add_decorator(Decorator::Trace(i)).unwrap();
BasicBlockNodeBuilder::new(ops, vec![(0, decorator_id)])
.add_to_forest(&mut forest)
.unwrap();
} else {
BasicBlockNodeBuilder::new(ops, Vec::new()).add_to_forest(&mut forest).unwrap();
}
}
let bytes = forest.to_bytes();
let deserialized = MastForest::read_from_bytes(&bytes).unwrap();
assert_eq!(forest.decorators().len(), 8);
assert_eq!(deserialized.decorators().len(), 8);
for i in 0..10 {
let node_id = crate::mast::MastNodeId::new_unchecked(i);
let orig_decorators = forest.decorator_indices_for_op(node_id, 0);
let deser_decorators = deserialized.decorator_indices_for_op(node_id, 0);
assert_eq!(orig_decorators, deser_decorators, "Decorators at node {} should match", i);
if i < 8 {
assert_eq!(orig_decorators.len(), 1, "Node {} should have 1 decorator", i);
assert_eq!(
deser_decorators.len(),
1,
"Node {} should have 1 decorator after deserialization",
i
);
} else {
assert_eq!(orig_decorators.len(), 0, "Node {} should have no decorators", i);
assert_eq!(
deser_decorators.len(),
0,
"Node {} should have no decorators after deserialization",
i
);
}
}
}
#[test]
fn test_untrusted_forest_valid_roundtrip() {
let mut forest = MastForest::new();
let block1_id = BasicBlockNodeBuilder::new(vec![Operation::Add], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let block2_id = BasicBlockNodeBuilder::new(vec![Operation::Mul], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let join_id = JoinNodeBuilder::new([block1_id, block2_id]).add_to_forest(&mut forest).unwrap();
forest.make_root(join_id);
let bytes = forest.to_bytes();
let untrusted = UntrustedMastForest::read_from_bytes(&bytes).unwrap();
let validated = untrusted.validate().unwrap();
assert_eq!(forest, validated);
}
#[test]
fn test_untrusted_forest_detects_forward_reference() {
let mut forest = MastForest::new();
let zero = BasicBlockNodeBuilder::new(vec![Operation::U32div], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let first = BasicBlockNodeBuilder::new(vec![Operation::U32add], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let second = BasicBlockNodeBuilder::new(vec![Operation::U32and], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
JoinNodeBuilder::new([first, second]).add_to_forest(&mut forest).unwrap();
forest.nodes.swap_remove(zero.to_usize());
let bytes = forest.to_bytes();
let untrusted = UntrustedMastForest::read_from_bytes(&bytes).unwrap();
let result = untrusted.validate();
assert_matches!(result, Err(MastForestError::ForwardReference(_, _)));
}
#[test]
fn test_untrusted_forest_detects_hash_mismatch() {
let mut forest = MastForest::new();
let block_id = BasicBlockNodeBuilder::new(vec![Operation::Add], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
forest.make_root(block_id);
let bytes = forest.to_bytes();
let mut reader = SliceReader::new(&bytes);
let _header: [u8; 8] = reader.read_array().unwrap();
let _node_count: usize = reader.read().unwrap();
let _decorator_count: usize = reader.read().unwrap();
let _roots: Vec<u32> = Deserializable::read_from(&mut reader).unwrap();
let basic_block_data: Vec<u8> = Deserializable::read_from(&mut reader).unwrap();
let node_info_offset = 4 + 1 + 3 + 8 + 8 + 8 + 4 + 8 + basic_block_data.len();
let mut corrupted = bytes.clone();
corrupted[node_info_offset + 8] ^= 0xff;
let untrusted = UntrustedMastForest::read_from_bytes(&corrupted).unwrap();
let result = untrusted.validate();
assert_matches!(result, Err(MastForestError::HashMismatch { .. }));
}
fn build_group(ops: &[Operation]) -> Felt {
let mut group = 0u64;
for (i, op) in ops.iter().enumerate() {
group |= (op.op_code() as u64) << (Operation::OP_BITS * i);
}
Felt::new(group)
}
fn make_batch(num_groups: usize, op: Operation) -> OpBatch {
let ops: Vec<Operation> = (0..num_groups).map(|_| op).collect();
let mut indptr = [0usize; OP_BATCH_SIZE + 1];
for i in 0..num_groups {
indptr[i + 1] = i + 1;
}
for i in (num_groups + 1)..=OP_BATCH_SIZE {
indptr[i] = indptr[i - 1];
}
let mut padding = [false; OP_BATCH_SIZE];
for pad in padding.iter_mut().skip(num_groups) {
*pad = true;
}
let mut groups = [Felt::new(0); OP_BATCH_SIZE];
for group in groups.iter_mut().take(num_groups) {
*group = build_group(&[op]);
}
OpBatch::new_from_parts(ops, indptr, padding, groups, num_groups)
}
fn build_malicious_single_block_forest_bytes(push_imm: Felt) -> Vec<u8> {
let mut forest = MastForest::new();
let block_id = BasicBlockNodeBuilder::new(
vec![Operation::Push(push_imm), Operation::Noop, Operation::Add],
Vec::new(),
)
.add_to_forest(&mut forest)
.unwrap();
forest.make_root(block_id);
let mut bytes = forest.to_bytes();
let malicious_packed_indptr = [0x12, 0x00, 0x00, 0x00];
let (indptr_offset, digest_offset) = locate_single_block_indptr_and_digest_offsets(&bytes);
assert_eq!(
&bytes[indptr_offset..indptr_offset + 4],
&[0x03, 0x00, 0x00, 0x00],
"unexpected original packed indptr (offset computation likely wrong)"
);
bytes[indptr_offset..indptr_offset + 4].copy_from_slice(&malicious_packed_indptr);
if let Some(digest) = compute_single_block_digest_from_decoded_groups(&bytes) {
bytes[digest_offset..digest_offset + 32].copy_from_slice(&digest.to_bytes());
}
bytes
}
struct OffsetReader<'a> {
source: &'a [u8],
pos: usize,
}
impl<'a> OffsetReader<'a> {
fn new(source: &'a [u8]) -> Self {
Self { source, pos: 0 }
}
fn position(&self) -> usize {
self.pos
}
}
impl ByteReader for OffsetReader<'_> {
fn read_u8(&mut self) -> Result<u8, DeserializationError> {
self.check_eor(1)?;
let result = self.source[self.pos];
self.pos += 1;
Ok(result)
}
fn peek_u8(&self) -> Result<u8, DeserializationError> {
self.check_eor(1)?;
Ok(self.source[self.pos])
}
fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
self.check_eor(len)?;
let result = &self.source[self.pos..self.pos + len];
self.pos += len;
Ok(result)
}
fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
self.check_eor(N)?;
let mut result = [0_u8; N];
result.copy_from_slice(&self.source[self.pos..self.pos + N]);
self.pos += N;
Ok(result)
}
fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
if self.pos + num_bytes > self.source.len() {
return Err(DeserializationError::UnexpectedEOF);
}
Ok(())
}
fn has_more_bytes(&self) -> bool {
self.pos < self.source.len()
}
}
fn locate_single_block_indptr_and_digest_offsets(bytes: &[u8]) -> (usize, usize) {
let mut cursor = OffsetReader::new(bytes);
let _header: [u8; 8] = cursor.read_array().unwrap();
let node_count: usize = cursor.read().unwrap();
assert_eq!(node_count, 1);
let _decorator_count: usize = cursor.read().unwrap();
let _roots: Vec<u32> = Deserializable::read_from(&mut cursor).unwrap();
let bb_data_len: usize = cursor.read().unwrap();
let bb_payload_start = cursor.position();
let bb_payload_end = bb_payload_start + bb_data_len;
let node_infos_start = bb_payload_end;
let node_type_u64 = u64::from_le_bytes(
bytes[node_infos_start..node_infos_start + 8]
.try_into()
.expect("node type bytes"),
);
let discriminant = (node_type_u64 >> 60) as u8;
assert_eq!(discriminant, 3, "expected a Block node");
let payload = node_type_u64 & 0x0f_ff_ff_ff_ff_ff_ff_ff;
assert!(payload <= u32::MAX as u64, "Block ops_offset payload must fit in u32");
let ops_offset = payload as usize;
let digest_offset = node_infos_start + 8;
let block_start = bb_payload_start + ops_offset;
assert!(block_start < bb_payload_end);
let mut block_cursor = OffsetReader::new(&bytes[block_start..bb_payload_end]);
let _ops: Vec<Operation> = Deserializable::read_from(&mut block_cursor).unwrap();
let num_batches: u32 = block_cursor.read().unwrap();
assert_eq!(num_batches, 1);
let indptr_offset = block_start + block_cursor.position();
(indptr_offset, digest_offset)
}
fn compute_single_block_digest_from_decoded_groups(bytes: &[u8]) -> Option<Word> {
use crate::chiplets::hasher;
let forest = MastForest::read_from_bytes(bytes).ok()?;
let block = forest[MastNodeId::new_unchecked(0)].unwrap_basic_block().clone();
let op_groups: Vec<Felt> =
block.op_batches().iter().flat_map(|batch| *batch.groups()).collect();
Some(hasher::hash_elements(&op_groups))
}
#[test]
fn test_untrusted_forest_rejects_non_full_prefix_batch() {
let op_batches = vec![make_batch(4, Operation::Add), make_batch(2, Operation::Mul)];
let op_groups: Vec<Felt> =
op_batches.iter().flat_map(|batch| batch.groups()).copied().collect();
let digest = hasher::hash_elements(&op_groups);
let mut forest = MastForest::new();
let block_id = BasicBlockNodeBuilder::from_op_batches(op_batches, Vec::new(), digest)
.add_to_forest(&mut forest)
.unwrap();
forest.make_root(block_id);
let bytes = forest.to_bytes();
let untrusted = UntrustedMastForest::read_from_bytes(&bytes).unwrap();
let result = untrusted.validate();
assert_matches!(result, Err(MastForestError::InvalidBatchPadding(_, _)));
}
#[test]
fn test_untrusted_forest_accepts_full_prefix_batch() {
let op_batches = vec![make_batch(OP_BATCH_SIZE, Operation::Add), make_batch(4, Operation::Mul)];
let op_groups: Vec<Felt> =
op_batches.iter().flat_map(|batch| batch.groups()).copied().collect();
let digest = hasher::hash_elements(&op_groups);
let mut forest = MastForest::new();
let block_id = BasicBlockNodeBuilder::from_op_batches(op_batches, Vec::new(), digest)
.add_to_forest(&mut forest)
.unwrap();
forest.make_root(block_id);
let bytes = forest.to_bytes();
let untrusted = UntrustedMastForest::read_from_bytes(&bytes).unwrap();
let result = untrusted.validate();
assert!(result.is_ok(), "full prefix batches should validate");
}
#[test]
fn test_untrusted_forest_rejects_basic_block_indptr_that_breaks_push_immediate_commitment() {
let imm_a = Felt::new(0xdead_beef_dead_beef);
let imm_b = Felt::new(0xfeed_face_feed_face);
let bytes_a = build_malicious_single_block_forest_bytes(imm_a);
let bytes_b = build_malicious_single_block_forest_bytes(imm_b);
let validated_a = match UntrustedMastForest::read_from_bytes(&bytes_a) {
Ok(untrusted) => untrusted.validate(),
Err(DeserializationError::InvalidValue(msg)) => {
assert!(msg.contains("push immediate"));
return;
},
Err(err) => panic!("unexpected deserialization error: {err:?}"),
};
let validated_b = match UntrustedMastForest::read_from_bytes(&bytes_b) {
Ok(untrusted) => untrusted.validate(),
Err(DeserializationError::InvalidValue(msg)) => {
assert!(msg.contains("push immediate"));
return;
},
Err(err) => panic!("unexpected deserialization error: {err:?}"),
};
let (forest_a, forest_b) = match (validated_a, validated_b) {
(Ok(forest_a), Ok(forest_b)) => (forest_a, forest_b),
(validated_a, validated_b) => {
match validated_a {
Err(MastForestError::InvalidBatchPadding(_, msg)) => {
assert!(msg.contains("push immediate"));
},
Err(err) => panic!("unexpected validation error: {err:?}"),
Ok(_) => {},
}
match validated_b {
Err(MastForestError::InvalidBatchPadding(_, msg)) => {
assert!(msg.contains("push immediate"));
},
Err(err) => panic!("unexpected validation error: {err:?}"),
Ok(_) => {},
}
return;
},
};
let block_a = forest_a[MastNodeId::new_unchecked(0)].unwrap_basic_block().clone();
let block_b = forest_b[MastNodeId::new_unchecked(0)].unwrap_basic_block().clone();
let ops_a: Vec<Operation> = block_a.operations().copied().collect();
let ops_b: Vec<Operation> = block_b.operations().copied().collect();
assert!(
matches!(ops_a.as_slice(), [Operation::Push(v), ..] if *v == imm_a),
"unexpected ops in forest_a: {ops_a:?}"
);
assert!(
matches!(ops_b.as_slice(), [Operation::Push(v), ..] if *v == imm_b),
"unexpected ops in forest_b: {ops_b:?}"
);
assert_ne!(
block_a.digest(),
block_b.digest(),
"BUG: UntrustedMastForest::validate() accepted two basic blocks with different Push immediates \
but identical digests.\n\
digest={:?}\n\
ops_a={ops_a:?}\n\
ops_b={ops_b:?}\n\
groups_a={:?}\n\
groups_b={:?}\n",
block_a.digest(),
block_a.op_batches()[0].groups(),
block_b.op_batches()[0].groups(),
);
}
#[test]
fn test_untrusted_forest_validates_all_node_types() {
let mut forest = MastForest::new();
let block1_id = BasicBlockNodeBuilder::new(vec![Operation::Add], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let block2_id = BasicBlockNodeBuilder::new(vec![Operation::Mul], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
let join_id = JoinNodeBuilder::new([block1_id, block2_id]).add_to_forest(&mut forest).unwrap();
let split_id = SplitNodeBuilder::new([block1_id, block2_id])
.add_to_forest(&mut forest)
.unwrap();
let loop_id = LoopNodeBuilder::new(block1_id).add_to_forest(&mut forest).unwrap();
let call_id = CallNodeBuilder::new(block1_id).add_to_forest(&mut forest).unwrap();
let syscall_id = CallNodeBuilder::new_syscall(block1_id).add_to_forest(&mut forest).unwrap();
let dyn_id = DynNodeBuilder::new_dyn().add_to_forest(&mut forest).unwrap();
let dyncall_id = DynNodeBuilder::new_dyncall().add_to_forest(&mut forest).unwrap();
let external_id = ExternalNodeBuilder::new(Word::default()).add_to_forest(&mut forest).unwrap();
forest.make_root(join_id);
forest.make_root(split_id);
forest.make_root(loop_id);
forest.make_root(call_id);
forest.make_root(syscall_id);
forest.make_root(dyn_id);
forest.make_root(dyncall_id);
forest.make_root(external_id);
let bytes = forest.to_bytes();
let untrusted = UntrustedMastForest::read_from_bytes(&bytes).unwrap();
let validated = untrusted.validate().unwrap();
assert_eq!(forest, validated);
}
#[test]
fn test_untrusted_forest_validates_stripped() {
let mut forest = MastForest::new();
let decorator_id = forest.add_decorator(Decorator::Trace(42)).unwrap();
let block_id = BasicBlockNodeBuilder::new(vec![Operation::Add], vec![(0, decorator_id)])
.add_to_forest(&mut forest)
.unwrap();
forest.make_root(block_id);
let mut stripped_bytes = Vec::new();
forest.write_stripped(&mut stripped_bytes);
let untrusted = UntrustedMastForest::read_from_bytes(&stripped_bytes).unwrap();
let validated = untrusted.validate().unwrap();
assert_eq!(forest.num_nodes(), validated.num_nodes());
assert!(validated.debug_info.is_empty());
}
#[test]
fn test_deserialization_rejects_excessive_node_count() {
let mut bytes = Vec::new();
super::MAGIC.write_into(&mut bytes);
bytes.write_u8(0); super::VERSION.write_into(&mut bytes);
let excessive_count: usize = MastForest::MAX_NODES + 1;
excessive_count.write_into(&mut bytes);
0usize.write_into(&mut bytes);
let result = MastForest::read_from_bytes(&bytes);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("exceeds maximum"),
"Expected error about exceeding maximum, got: {err}"
);
}