use std::rc::Rc;
use metal::ComputePipelineDescriptor;
use risc0_zkp::{
core::log2_ceil,
field::{
baby_bear::{BabyBearElem, BabyBearExtElem},
RootsOfUnity,
},
hal::{
metal::{BufferImpl as MetalBuffer, MetalHal},
EvalCheck,
},
INV_RATE,
};
const METAL_LIB: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/kernels.metallib"));
#[derive(Debug)]
pub struct MetalEvalCheck {
hal: Rc<MetalHal>,
kernel: ComputePipelineDescriptor,
}
impl MetalEvalCheck {
pub fn new(hal: Rc<MetalHal>) -> Self {
let library = hal.device.new_library_with_data(METAL_LIB).unwrap();
let function = library.get_function("eval_check", None).unwrap();
let kernel = ComputePipelineDescriptor::new();
kernel.set_compute_function(Some(&function));
Self { hal, kernel }
}
}
impl EvalCheck<MetalHal> for MetalEvalCheck {
#[tracing::instrument(skip_all)]
fn eval_check(
&self,
check: &MetalBuffer<BabyBearElem>,
code: &MetalBuffer<BabyBearElem>,
data: &MetalBuffer<BabyBearElem>,
accum: &MetalBuffer<BabyBearElem>,
mix: &MetalBuffer<BabyBearElem>,
out: &MetalBuffer<BabyBearElem>,
poly_mix: BabyBearExtElem,
po2: usize,
steps: usize,
) {
const EXP_PO2: usize = log2_ceil(INV_RATE);
let domain = steps * INV_RATE;
let poly_mix =
MetalBuffer::copy_from(&self.hal.device, self.hal.cmd_queue.clone(), &[poly_mix]);
let rou = BabyBearElem::ROU_FWD[po2 + EXP_PO2];
let rou = MetalBuffer::copy_from(&self.hal.device, self.hal.cmd_queue.clone(), &[rou]);
let po2 =
MetalBuffer::copy_from(&self.hal.device, self.hal.cmd_queue.clone(), &[po2 as u32]);
let size = MetalBuffer::copy_from(
&self.hal.device,
self.hal.cmd_queue.clone(),
&[domain as u32],
);
let buffers = &[
check.as_arg(),
code.as_arg(),
data.as_arg(),
accum.as_arg(),
mix.as_arg(),
out.as_arg(),
poly_mix.as_arg(),
rou.as_arg(),
po2.as_arg(),
size.as_arg(),
];
self.hal
.dispatch(&self.kernel, buffers, domain as u64, None);
}
}
#[cfg(test)]
mod tests {
use std::rc::Rc;
use risc0_zkp::hal::{cpu::BabyBearCpuHal, metal::MetalHal};
use test_log::test;
use crate::cpu::CpuEvalCheck;
#[test]
#[ignore]
fn eval_check() {
const PO2: usize = 4;
let circuit = crate::CircuitImpl::new();
let cpu_hal = BabyBearCpuHal::new();
let cpu_eval = CpuEvalCheck::new(&circuit);
let gpu_hal = Rc::new(MetalHal::new());
let gpu_eval = super::MetalEvalCheck::new(gpu_hal.clone());
crate::testutil::eval_check(&cpu_hal, cpu_eval, gpu_hal.as_ref(), gpu_eval, PO2);
}
#[test]
#[ignore]
fn memory_usage() {
crate::testutil::EvalCheckParams::new(22);
}
}