use alloc::collections::BTreeMap;
use pretty_assertions::assert_eq;
use vm_core::crypto::hash::RpoDigest;
use super::*;
use crate::{MastForestStore, MemAdviceProvider, MemMastForestStore, MemoryAddress, ProcessState};
#[test]
fn test_advice_provider() {
let kernel_source = "
export.foo
push.2323 mem_store.100 trace.11
end
";
let program_source = "
proc.truncate_stack.4
loc_storew.0 dropw movupw.3
sdepth neq.16
while.true
dropw movupw.3
sdepth neq.16
end
loc_loadw.0
end
# mainly used to break basic blocks
proc.noop
swap swap
end
# Tests different cases of batch sizes
proc.basic_block
# batch with 1 group
swap drop swap trace.1
call.noop
# batch with 2 groups
push.1 drop trace.2
call.noop
# batch with 3 groups (rounded up to 4)
push.1 push.2 drop drop trace.3
call.noop
# batch with 5 groups (rounded up to 8)
push.1 push.2 push.3 push.4 drop drop drop drop trace.4
call.noop
# batch with 8 pushes (which forces a noop to be inserted in the last position of the batch)
push.0 push.1 push.2 push.3 push.4 push.5 push.6 push.7 trace.5
call.noop
# basic block with >1 batches (where clk needs to be incremented in-between batches due to the inserted RESPAN)
push.0 push.1 push.2 push.3 push.4 push.5 push.6 trace.6
drop drop drop drop drop drop drop drop drop trace.7
end
proc.exec_me
push.22 mem_store.0
trace.9
end
proc.dyncall_me
push.23 mem_store.0
trace.100
end
proc.dynexec_me
push.24 mem_store.0
trace.101
end
proc.will_syscall
syscall.foo
end
proc.control_flow
# if true
push.1 trace.16 if.true
swap swap trace.17
else
swap swap
end
# if false
push.0 trace.18 if.true
swap swap
else
swap swap trace.19
end
# loop
push.3 push.1
while.true
trace.20
sub.1 dup neq.0
end
trace.21
end
begin
# Check that initial state is consistent
trace.0 push.10 add drop trace.1
# Check that basic blocks are handled correctly
exec.basic_block
# Check that memory state is restored properly after call
push.42 mem_store.0 trace.8
exec.exec_me
trace.10
# Check that syscalls are handled correctly
call.will_syscall
trace.12
# Check that dyncalls are handled correctly
procref.dyncall_me mem_storew.4 dropw push.4 dyncall trace.13
procref.will_syscall mem_storew.8 dropw push.8 dyncall trace.14
# Check that dynexecs are handled correctly
procref.dynexec_me mem_storew.4 dropw push.4 dynexec trace.15
# Check that control flow operations are handled correctly
exec.control_flow
exec.truncate_stack
trace.22
end
";
let stack_inputs = Vec::new();
let (program, kernel_lib) = {
let source_manager = Arc::new(DefaultSourceManager::default());
let kernel_lib =
Assembler::new(source_manager.clone()).assemble_kernel(kernel_source).unwrap();
let program = Assembler::with_kernel(source_manager, kernel_lib.clone())
.assemble_program(program_source)
.unwrap();
(program, kernel_lib)
};
let mut fast_host = ConsistencyHost::new(kernel_lib.mast_forest().clone());
let processor = FastProcessor::new_debug(&stack_inputs);
let fast_stack_outputs = processor.execute(&program, &mut fast_host).unwrap();
let mut slow_host = ConsistencyHost::new(kernel_lib.mast_forest().clone());
let mut slow_processor = Process::new(
kernel_lib.kernel().clone(),
StackInputs::new(stack_inputs).unwrap(),
ExecutionOptions::default().with_tracing(),
);
let slow_stack_outputs = slow_processor.execute(&program, &mut slow_host).unwrap();
assert_eq!(fast_stack_outputs, slow_stack_outputs);
for (trace_id, fast_snapshots) in fast_host.snapshots.iter() {
let slow_snapshots = slow_host.snapshots.get(trace_id).unwrap_or_else(|| {
panic!("fast host has snapshot(s) for trace id {trace_id}, but slow host doesn't")
});
assert_eq!(fast_snapshots, slow_snapshots, "trace id: {trace_id}");
}
for (trace_id, slow_snapshots) in slow_host.snapshots.iter() {
let fast_snapshots = fast_host.snapshots.get(trace_id).unwrap_or_else(|| {
panic!("slow host has snapshot(s) for trace id {trace_id}, but fast host doesn't")
});
assert_eq!(fast_snapshots, slow_snapshots, "trace_id: {trace_id}");
}
assert_eq!(fast_host.snapshots, slow_host.snapshots);
}
#[derive(Debug, PartialEq, Eq)]
struct ProcessStateSnapshot {
clk: RowIndex,
ctx: ContextId,
fmp: u64,
stack_state: Vec<Felt>,
stack_words: [Word; 4],
mem_state: Vec<(MemoryAddress, Felt)>,
}
impl From<ProcessState<'_>> for ProcessStateSnapshot {
fn from(state: ProcessState) -> Self {
ProcessStateSnapshot {
clk: state.clk(),
ctx: state.ctx(),
fmp: state.fmp(),
stack_state: state.get_stack_state(),
stack_words: [
state.get_stack_word(0),
state.get_stack_word(1),
state.get_stack_word(2),
state.get_stack_word(3),
],
mem_state: state.get_mem_state(state.ctx()),
}
}
}
#[derive(Debug)]
struct ConsistencyHost {
snapshots: BTreeMap<u32, Vec<ProcessStateSnapshot>>,
advice_provider: MemAdviceProvider,
store: MemMastForestStore,
}
impl ConsistencyHost {
fn new(kernel_forest: Arc<MastForest>) -> Self {
let mut store = MemMastForestStore::default();
store.insert(kernel_forest);
Self {
snapshots: BTreeMap::new(),
advice_provider: MemAdviceProvider::default(),
store,
}
}
}
impl Host for ConsistencyHost {
type AdviceProvider = MemAdviceProvider;
fn advice_provider(&self) -> &Self::AdviceProvider {
&self.advice_provider
}
fn advice_provider_mut(&mut self) -> &mut Self::AdviceProvider {
&mut self.advice_provider
}
fn get_mast_forest(&self, node_digest: &RpoDigest) -> Option<Arc<MastForest>> {
self.store.get(node_digest)
}
fn on_trace(&mut self, process: ProcessState, trace_id: u32) -> Result<(), ExecutionError> {
let snapshot = ProcessStateSnapshot::from(process);
self.snapshots.entry(trace_id).or_default().push(snapshot);
Ok(())
}
}