use std::sync::Mutex;
use rayon::prelude::*;
use risc0_zkp::{
adapter::{CircuitStep, CircuitStepContext, PolyFp},
core::log2_ceil,
field::{
baby_bear::{BabyBear, BabyBearElem, BabyBearExtElem},
map_pow, Elem, ExtElem, RootsOfUnity,
},
hal::{cpu::CpuBuffer, CircuitHal, Hal},
prove::accum::{Accum, Handler},
INV_RATE, ZK_CYCLES,
};
use crate::{
GLOBAL_MIX, GLOBAL_OUT, REGISTER_GROUP_ACCUM, REGISTER_GROUP_CTRL, REGISTER_GROUP_DATA,
};
pub struct CpuCircuitHal<'a, C: PolyFp<BabyBear>> {
circuit: &'a C,
}
impl<'a, C: PolyFp<BabyBear>> CpuCircuitHal<'a, C> {
pub fn new(circuit: &'a C) -> Self {
Self { circuit }
}
}
impl<'a, C, H> CircuitHal<H> for CpuCircuitHal<'a, C>
where
C: PolyFp<BabyBear> + Sync + CircuitStep<BabyBearElem>,
H: Hal<
Elem = BabyBearElem,
ExtElem = BabyBearExtElem,
Buffer<<H as Hal>::Elem> = CpuBuffer<BabyBearElem>,
>,
{
fn eval_check(
&self,
check: &CpuBuffer<BabyBearElem>,
groups: &[&CpuBuffer<BabyBearElem>],
globals: &[&CpuBuffer<BabyBearElem>],
poly_mix: BabyBearExtElem,
po2: usize,
steps: usize,
) {
const EXP_PO2: usize = log2_ceil(INV_RATE);
let domain = steps * INV_RATE;
let poly_mix_pows = map_pow(poly_mix, crate::info::POLY_MIX_POWERS);
let code = groups[REGISTER_GROUP_CTRL].as_slice();
let code = unsafe { std::slice::from_raw_parts(code.as_ptr(), code.len()) };
let data = groups[REGISTER_GROUP_DATA].as_slice();
let data = unsafe { std::slice::from_raw_parts(data.as_ptr(), data.len()) };
let accum = groups[REGISTER_GROUP_ACCUM].as_slice();
let accum = unsafe { std::slice::from_raw_parts(accum.as_ptr(), accum.len()) };
let mix = globals[GLOBAL_MIX].as_slice();
let mix = unsafe { std::slice::from_raw_parts(mix.as_ptr(), mix.len()) };
let out = globals[GLOBAL_OUT].as_slice();
let out = unsafe { std::slice::from_raw_parts(out.as_ptr(), out.len()) };
let check = check.as_slice();
let check = unsafe { std::slice::from_raw_parts(check.as_ptr(), check.len()) };
let poly_mix_pows = poly_mix_pows.as_slice();
let args: &[&[BabyBearElem]] = &[&code, &out, &data, &mix, &accum];
(0..domain).into_par_iter().for_each(|cycle| {
let tot = self.circuit.poly_fp(cycle, domain, poly_mix_pows, args);
let x = BabyBearElem::ROU_FWD[po2 + EXP_PO2].pow(cycle);
let y = (BabyBearElem::new(3) * x).pow(1 << po2);
let ret = tot * (y - BabyBearElem::new(1)).inv();
let check = unsafe {
std::slice::from_raw_parts_mut(check.as_ptr() as *mut BabyBearElem, check.len())
};
for i in 0..BabyBearExtElem::EXT_SIZE {
check[i * domain + cycle] = ret.elems()[i];
}
});
}
fn accumulate(
&self,
ctrl: &CpuBuffer<BabyBearElem>,
io: &CpuBuffer<BabyBearElem>,
data: &CpuBuffer<BabyBearElem>,
mix: &CpuBuffer<BabyBearElem>,
accum: &CpuBuffer<BabyBearElem>,
steps: usize,
) {
{
let args = &[
ctrl.as_slice_sync(),
io.as_slice_sync(),
data.as_slice_sync(),
mix.as_slice_sync(),
accum.as_slice_sync(),
];
let accumulator: Mutex<Accum<BabyBearExtElem>> = Mutex::new(Accum::new(steps));
tracing::info_span!("step_compute_accum").in_scope(|| {
(0..steps - ZK_CYCLES).into_par_iter().for_each_init(
|| Handler::<BabyBear>::new(&accumulator),
|handler, cycle| {
self.circuit
.step_compute_accum(
&CircuitStepContext { size: steps, cycle },
handler,
args,
)
.unwrap();
},
);
});
tracing::info_span!("calc_prefix_products").in_scope(|| {
accumulator.lock().unwrap().calc_prefix_products();
});
tracing::info_span!("step_verify_accum").in_scope(|| {
(0..steps - ZK_CYCLES).into_par_iter().for_each_init(
|| Handler::<BabyBear>::new(&accumulator),
|handler, cycle| {
self.circuit
.step_verify_accum(
&CircuitStepContext { size: steps, cycle },
handler,
args,
)
.unwrap();
},
);
});
}
{
let mut accum_slice = accum.as_slice_mut();
let mut io = io.as_slice_mut();
for value in accum_slice.iter_mut().chain(io.iter_mut()) {
*value = value.valid_or_zero();
}
}
}
}