zkvmc-context 0.0.1

zkVMc runtime context
Documentation
//! This module provides the context for zkVMc runtime.
mod compiler;
mod syscall;

pub use compiler::*;
use rustc_hash::FxHashMap;
use std::{marker::PhantomPinned, ptr::NonNull};
pub use syscall::*;
use zkvmc_core::{
    addr2frame::LookupPc,
    image::Image,
    memory::{
        read_bytes, read_word, read_words, write_bytes, write_word, write_words, ForkableMemory,
        MmapOffset,
    },
    traits::Reset,
};

/// big enough to hold the whole virtual memory
pub const MEM_SIZE: usize = 0x100000003;
//TODO: just a temporary value, should be calculated
// the maximum size of the code block, 100MB is enough?
const TRACE_BUFF_SIZE: usize = 100 * 1024 * 1024;

/// The Context for the zkVMc runtime.
#[repr(C)]
pub struct Context {
    /// Registers.
    pub regs: [u32; 32],
    /// Program Counter.
    pub pc: u32,
    /// Next Program Counter.
    pub next_pc: u32,
    /// Linear memory between VM and JIT code.
    pub mem: ForkableMemory,
    /// Trace buffer for tracing the instructions.
    pub trace_buf: MmapOffset,
    /// Handler function for the `ecall` instruction.
    pub ecall_handler: extern "C" fn(NonNull<Context>),
    /// Handler function for the `ebreak` instruction.
    pub ebreak_handler: extern "C" fn(NonNull<Context>),
    /// Handler function for an undefined instruction.
    pub undefined_handler: extern "C" fn(NonNull<Context>, u32),
    /// Handler function for the execution traces.
    pub trace_handler: extern "C" fn(NonNull<Context>),
    /// Function called by compiled code or VM to lookup or translate the next compiled code block.
    pub trampoline: extern "C" fn(NonNull<Context>) -> JitFn,

    /// --------------------------
    /// Additional context.
    pub addition: NonNull<dyn Addition>,
    /// Compile the code block at `pc` and return the compiled function.
    pub compiler: Box<dyn Compiler>,
    pub compile_cfg: CompileConfig,
    /// Exit code of the program.
    pub exit_code: Option<u32>,

    image: Image,
    _marker: PhantomPinned,
}

pub trait Addition: Reset + LookupPc {}

impl Context {
    // FIXME: use builder pattern
    pub fn new(
        compiler: Option<Box<dyn Compiler>>,
        entry: u32,
        image: FxHashMap<u32, &[u8]>,
        addition: NonNull<dyn Addition>,
        cfg: Option<CompileConfig>,
        ecall_handler: Option<extern "C" fn(NonNull<Context>)>,
        ebreak_handler: Option<extern "C" fn(NonNull<Context>)>,
        undefined_handler: Option<extern "C" fn(NonNull<Context>, u32)>,
        interrupt_handler: Option<extern "C" fn(NonNull<Context>)>,
    ) -> Context {
        let mem = ForkableMemory::new(MEM_SIZE).expect("Failed to create memory");
        let mut ctx = Context {
            regs: [0; 32],
            pc: 0,
            next_pc: 0,
            mem,
            ecall_handler: ecall_handler.unwrap_or(syscall::default_ecall_handler),
            ebreak_handler: ebreak_handler.unwrap_or(syscall::default_ebreak_handler),
            undefined_handler: undefined_handler.unwrap_or(syscall::default_undefined_handler),
            trace_handler: interrupt_handler.unwrap_or(syscall::default_trace_handler),
            addition,
            exit_code: None,
            trampoline,
            compiler: compiler.unwrap_or_else(|| Box::new(NoopCompiler)),
            compile_cfg: cfg.unwrap_or_default(),
            trace_buf: MmapOffset::new(TRACE_BUFF_SIZE).expect("Failed to create trace buffer"),
            image: Image::new(entry, image),
            _marker: PhantomPinned,
        };
        ctx.setup();
        ctx
    }

    fn setup(&mut self) {
        // load image into memory
        self.image.load(&mut self.pc, self.mem.as_send_sync_ptr());
        self.next_pc = self.pc;
    }

    #[inline]
    pub fn update_pc(&mut self) {
        self.pc = self.next_pc;
    }

    #[inline]
    pub fn write_bytes(&mut self, addr: u32, bytes: &[u8]) {
        unsafe { write_bytes(self.mem.as_ptr(), addr, bytes) };
    }

    #[inline]
    pub fn write_word(&mut self, addr: u32, word: u32) {
        unsafe { write_word(self.mem.as_ptr(), addr, word) };
    }

    #[inline]
    pub fn write_words(&mut self, addr: u32, words: &[u32]) {
        unsafe { write_words(self.mem.as_ptr(), addr, words) };
    }

    #[inline]
    pub fn read_bytes(&self, addr: u32, len: usize) -> &[u8] {
        unsafe { read_bytes(self.mem.as_ptr(), addr, len) }
    }

    #[inline]
    pub fn read_word(&self, addr: u32) -> u32 {
        unsafe { read_word(self.mem.as_ptr(), addr) }
    }

    #[inline]
    pub fn read_words(&self, addr: u32, len: usize) -> &[u32] {
        unsafe { read_words(self.mem.as_ptr(), addr, len) }
    }

    #[inline]
    pub fn as_addition<T>(&self) -> &T {
        unsafe { self.addition.cast().as_ref() }
    }

    #[inline]
    pub fn as_addition_mut<T>(&mut self) -> &mut T {
        unsafe { self.addition.cast().as_mut() }
    }
}

impl Drop for Context {
    fn drop(&mut self) {
        unsafe {
            let _ = Box::from_raw(self.addition.as_ptr());
        }
    }
}

impl Reset for Context {
    fn reset(&mut self) {
        self.regs = [0; 32];
        self.exit_code = None;
        self.mem.reset();
        self.trace_buf.reset();
        self.setup();
        unsafe { self.addition.as_mut().reset() }
    }
}