use anyhow::{Context as _, Result};
use risc0_circuit_recursion_sys::{RawPreflightTrace, StepMode};
use risc0_core::scope;
use risc0_zkp::{
adapter::{CircuitInfo as _, TapsProvider as _},
field::{
baby_bear::{BabyBear, BabyBearElem, BabyBearExtElem},
Elem as _,
},
hal::Hal,
};
use crate::{CircuitImpl, CIRCUIT};
use super::{preflight::Preflight, CircuitAccumulator, CircuitWitnessGenerator, Program};
pub(crate) struct WitnessGenerator<H: Hal> {
work_cycles: u32,
total_cycles: u32,
pub global: H::Buffer<H::Elem>,
pub ctrl: H::Buffer<H::Elem>,
pub data: H::Buffer<H::Elem>,
pub accum: H::Buffer<H::Elem>,
}
impl<H> WitnessGenerator<H>
where
H: Hal<Field = BabyBear, Elem = BabyBearElem, ExtElem = BabyBearExtElem>,
{
pub fn new<C: CircuitWitnessGenerator<H>>(
hal: &H,
circuit_hal: &C,
zkr: &Program,
preflight: &Preflight,
) -> Result<Self> {
scope!("witgen");
let total_cycles = 1 << zkr.po2;
let global = vec![BabyBearElem::INVALID; CircuitImpl::OUTPUT_SIZE];
let global = hal.copy_from_elem("global", &global);
let mut ctrl = vec![BabyBearElem::ZERO; total_cycles * CIRCUIT.ctrl_size()];
let ctrl_size = CIRCUIT.ctrl_size();
assert_eq!(ctrl_size, zkr.code_size);
for i in 0..zkr.code_rows() {
for j in 0..ctrl_size {
ctrl[j * total_cycles + i] = zkr.code[i * ctrl_size + j];
}
}
let ctrl = hal.copy_from_elem("ctrl", &ctrl);
let data = hal.alloc_elem_init(
"data",
total_cycles * CIRCUIT.data_size(),
BabyBearElem::INVALID,
);
let accum = hal.alloc_elem_init(
"accum",
total_cycles * CIRCUIT.accum_size(),
BabyBearElem::INVALID,
);
let work_cycles = zkr.code_rows() as u32;
let raw_trace = RawPreflightTrace {
wom: preflight.trace.wom.as_ptr(),
cycles: preflight.trace.cycles.as_ptr(),
iops: preflight.trace.iops.as_ptr(),
num_woms: preflight.trace.wom.len() as u32,
num_cycles: work_cycles,
num_iops: preflight.trace.iops.len() as u32,
};
circuit_hal
.generate_witness(
StepMode::Parallel,
total_cycles as u32,
&raw_trace,
&ctrl,
&data,
&global,
)
.context("witness generation failure")?;
scope!("noise", {
use risc0_zkp::ZK_CYCLES;
let mut rng = rand::rng();
let noise = vec![BabyBearElem::random(&mut rng); ZK_CYCLES * CIRCUIT.data_size()];
hal.eltwise_copy_elem_slice(
&data,
&noise,
CIRCUIT.data_size(), ZK_CYCLES, 0, ZK_CYCLES, total_cycles - ZK_CYCLES, total_cycles, );
});
scope!("zeroize", {
hal.eltwise_zeroize_elem(&data);
hal.eltwise_zeroize_elem(&global);
});
Ok(Self {
work_cycles,
total_cycles: total_cycles as u32,
global,
ctrl,
data,
accum,
})
}
pub fn accum<C: CircuitAccumulator<H>>(
&self,
hal: &H,
circuit_hal: &C,
mix: &[BabyBearElem],
) -> Result<H::Buffer<H::Elem>> {
scope!("accum");
let mix = hal.copy_from_elem("mix", mix);
scope!("noise", {
use risc0_zkp::ZK_CYCLES;
let mut rng = rand::rng();
let total_cycles = self.total_cycles as usize;
let noise = vec![BabyBearElem::random(&mut rng); ZK_CYCLES * CIRCUIT.accum_size()];
hal.eltwise_copy_elem_slice(
&self.accum,
&noise,
CIRCUIT.accum_size(), ZK_CYCLES, 0, ZK_CYCLES, total_cycles - ZK_CYCLES, total_cycles, );
});
circuit_hal.accumulate(
self.work_cycles,
self.total_cycles,
&self.ctrl,
&self.global,
&self.data,
&mix,
&self.accum,
)?;
scope!("zeroize", {
hal.eltwise_zeroize_elem(&self.accum);
hal.eltwise_zeroize_elem(&self.global);
});
Ok(mix)
}
}