Documentation
use std::ptr::NonNull;

use rustc_hash::FxHashMap;
use zkvmc_context::{Compiler, Context, JitFn};

/// A JIT cache that stores compiled functions.
pub struct JitCache<C> {
    inner: C,
    cache: FxHashMap<u32, JitFn>,
}

impl<C> JitCache<C> {
    pub fn new(inner: C) -> Self {
        Self {
            inner,
            cache: FxHashMap::default(),
        }
    }
}

impl<C> Compiler for JitCache<C>
where
    C: Compiler,
{
    fn compile(&mut self, ctx: NonNull<Context>) -> JitFn {
        *self
            .cache
            .entry(unsafe { ctx.as_ref().pc })
            .or_insert_with(|| self.inner.compile(ctx))
    }
}