miden-processor 0.14.0

Miden VM processor
Documentation
use alloc::collections::BTreeMap;

use pretty_assertions::assert_eq;
use vm_core::crypto::hash::RpoDigest;

use super::*;
use crate::{MastForestStore, MemAdviceProvider, MemMastForestStore, MemoryAddress, ProcessState};

#[test]
fn test_advice_provider() {
    let kernel_source = "
        export.foo
            push.2323 mem_store.100 trace.11
        end 
    ";

    let program_source = "
    proc.truncate_stack.4
        loc_storew.0 dropw movupw.3
        sdepth neq.16
        while.true
            dropw movupw.3
            sdepth neq.16
        end
        loc_loadw.0
    end

    # mainly used to break basic blocks
    proc.noop
        swap swap
    end

    # Tests different cases of batch sizes
    proc.basic_block
        # batch with 1 group
        swap drop swap trace.1 
        
        call.noop

        # batch with 2 groups
        push.1 drop trace.2

        call.noop

        # batch with 3 groups (rounded up to 4)
        push.1 push.2 drop drop trace.3

        call.noop

        # batch with 5 groups (rounded up to 8)
        push.1 push.2 push.3 push.4 drop drop drop drop trace.4

        call.noop

        # batch with 8 pushes (which forces a noop to be inserted in the last position of the batch)
        push.0 push.1 push.2 push.3 push.4 push.5 push.6 push.7 trace.5

        call.noop

        # basic block with >1 batches (where clk needs to be incremented in-between batches due to the inserted RESPAN)
        push.0 push.1 push.2 push.3 push.4 push.5 push.6    trace.6 
        drop drop drop drop drop drop drop drop drop        trace.7
    end

    proc.exec_me
        push.22 mem_store.0
        trace.9
    end

    proc.dyncall_me
        push.23 mem_store.0
        trace.100
    end

    proc.dynexec_me
        push.24 mem_store.0
        trace.101
    end

    proc.will_syscall
        syscall.foo
    end

    proc.control_flow
        # if true
        push.1 trace.16 if.true
            swap swap trace.17
        else
            swap swap
        end

        # if false
        push.0 trace.18 if.true
            swap swap
        else
            swap swap trace.19
        end

        # loop
        push.3 push.1
        while.true
            trace.20
            sub.1 dup neq.0
        end

        trace.21
    end

    begin
        # Check that initial state is consistent
        trace.0 push.10 add drop trace.1

        # Check that basic blocks are handled correctly
        exec.basic_block

        # Check that memory state is restored properly after call
        push.42 mem_store.0 trace.8
        exec.exec_me
        trace.10

        # Check that syscalls are handled correctly
        call.will_syscall
        trace.12

        # Check that dyncalls are handled correctly
        procref.dyncall_me mem_storew.4 dropw push.4 dyncall trace.13
        procref.will_syscall mem_storew.8 dropw push.8 dyncall trace.14

        # Check that dynexecs are handled correctly
        procref.dynexec_me mem_storew.4 dropw push.4 dynexec trace.15

        # Check that control flow operations are handled correctly
        exec.control_flow

        exec.truncate_stack
        trace.22 
    end
    ";

    let stack_inputs = Vec::new();

    let (program, kernel_lib) = {
        let source_manager = Arc::new(DefaultSourceManager::default());

        let kernel_lib =
            Assembler::new(source_manager.clone()).assemble_kernel(kernel_source).unwrap();
        let program = Assembler::with_kernel(source_manager, kernel_lib.clone())
            .assemble_program(program_source)
            .unwrap();

        (program, kernel_lib)
    };

    // fast processor
    let mut fast_host = ConsistencyHost::new(kernel_lib.mast_forest().clone());
    let processor = FastProcessor::new_debug(&stack_inputs);
    let fast_stack_outputs = processor.execute(&program, &mut fast_host).unwrap();

    // slow processor
    let mut slow_host = ConsistencyHost::new(kernel_lib.mast_forest().clone());
    let mut slow_processor = Process::new(
        kernel_lib.kernel().clone(),
        StackInputs::new(stack_inputs).unwrap(),
        ExecutionOptions::default().with_tracing(),
    );
    let slow_stack_outputs = slow_processor.execute(&program, &mut slow_host).unwrap();

    // check outputs
    assert_eq!(fast_stack_outputs, slow_stack_outputs);

    // check hosts. Check one trace event at a time to help debugging.
    for (trace_id, fast_snapshots) in fast_host.snapshots.iter() {
        let slow_snapshots = slow_host.snapshots.get(trace_id).unwrap_or_else(|| {
            panic!("fast host has snapshot(s) for trace id {trace_id}, but slow host doesn't")
        });
        assert_eq!(fast_snapshots, slow_snapshots, "trace id: {trace_id}");
    }
    for (trace_id, slow_snapshots) in slow_host.snapshots.iter() {
        let fast_snapshots = fast_host.snapshots.get(trace_id).unwrap_or_else(|| {
            panic!("slow host has snapshot(s) for trace id {trace_id}, but fast host doesn't")
        });
        assert_eq!(fast_snapshots, slow_snapshots, "trace_id: {trace_id}");
    }

    // Still check the snapshots explicitly just in case we have a bug in the logic above.
    assert_eq!(fast_host.snapshots, slow_host.snapshots);
}

// Host Implementation
// ==============================================================================================

/// A snapshot of the state of a process at a given clock cycle.
#[derive(Debug, PartialEq, Eq)]
struct ProcessStateSnapshot {
    clk: RowIndex,
    ctx: ContextId,
    fmp: u64,
    stack_state: Vec<Felt>,
    stack_words: [Word; 4],
    mem_state: Vec<(MemoryAddress, Felt)>,
}

impl From<ProcessState<'_>> for ProcessStateSnapshot {
    fn from(state: ProcessState) -> Self {
        ProcessStateSnapshot {
            clk: state.clk(),
            ctx: state.ctx(),
            fmp: state.fmp(),
            stack_state: state.get_stack_state(),
            stack_words: [
                state.get_stack_word(0),
                state.get_stack_word(1),
                state.get_stack_word(2),
                state.get_stack_word(3),
            ],
            mem_state: state.get_mem_state(state.ctx()),
        }
    }
}

#[derive(Debug)]
struct ConsistencyHost {
    /// A map of trace ID to a list of snapshots. A single trace ID can be associated with multiple
    /// snapshots for example if it's used in a loop.
    snapshots: BTreeMap<u32, Vec<ProcessStateSnapshot>>,
    advice_provider: MemAdviceProvider,
    store: MemMastForestStore,
}

impl ConsistencyHost {
    fn new(kernel_forest: Arc<MastForest>) -> Self {
        let mut store = MemMastForestStore::default();
        store.insert(kernel_forest);

        Self {
            snapshots: BTreeMap::new(),
            advice_provider: MemAdviceProvider::default(),
            store,
        }
    }
}

impl Host for ConsistencyHost {
    type AdviceProvider = MemAdviceProvider;

    fn advice_provider(&self) -> &Self::AdviceProvider {
        &self.advice_provider
    }

    fn advice_provider_mut(&mut self) -> &mut Self::AdviceProvider {
        &mut self.advice_provider
    }

    fn get_mast_forest(&self, node_digest: &RpoDigest) -> Option<Arc<MastForest>> {
        self.store.get(node_digest)
    }

    fn on_trace(&mut self, process: ProcessState, trace_id: u32) -> Result<(), ExecutionError> {
        let snapshot = ProcessStateSnapshot::from(process);
        self.snapshots.entry(trace_id).or_default().push(snapshot);

        Ok(())
    }
}