1use std::sync::Mutex;
16
17use rayon::prelude::*;
18use risc0_core::scope;
19use risc0_zkp::{
20 adapter::{CircuitStep, CircuitStepContext, PolyFp},
21 core::log2_ceil,
22 field::{
23 baby_bear::{BabyBear, BabyBearElem, BabyBearExtElem},
24 map_pow, Elem, ExtElem, RootsOfUnity,
25 },
26 hal::{cpu::CpuBuffer, AccumPreflight, CircuitHal, Hal},
27 prove::accum::{Accum, Handler},
28 INV_RATE, ZK_CYCLES,
29};
30
31use crate::{
32 GLOBAL_MIX, GLOBAL_OUT, REGISTER_GROUP_ACCUM, REGISTER_GROUP_CTRL, REGISTER_GROUP_DATA,
33};
34
35pub struct CpuCircuitHal<'a, C: PolyFp<BabyBear>> {
36 circuit: &'a C,
37}
38
39impl<'a, C: PolyFp<BabyBear>> CpuCircuitHal<'a, C> {
40 pub fn new(circuit: &'a C) -> Self {
41 Self { circuit }
42 }
43}
44
45impl<C, H> CircuitHal<H> for CpuCircuitHal<'_, C>
46where
47 C: PolyFp<BabyBear> + Sync + CircuitStep<BabyBearElem>,
48 H: Hal<
49 Elem = BabyBearElem,
50 ExtElem = BabyBearExtElem,
51 Buffer<<H as Hal>::Elem> = CpuBuffer<BabyBearElem>,
52 >,
53{
54 fn eval_check(
55 &self,
56 check: &CpuBuffer<BabyBearElem>,
57 groups: &[&CpuBuffer<BabyBearElem>],
58 globals: &[&CpuBuffer<BabyBearElem>],
59 poly_mix: BabyBearExtElem,
60 po2: usize,
61 steps: usize,
62 ) {
63 const EXP_PO2: usize = log2_ceil(INV_RATE);
64 let domain = steps * INV_RATE;
65
66 let poly_mix_pows = map_pow(poly_mix, crate::info::POLY_MIX_POWERS);
67
68 let code = groups[REGISTER_GROUP_CTRL].as_slice();
74 let code = unsafe { std::slice::from_raw_parts(code.as_ptr(), code.len()) };
75 let data = groups[REGISTER_GROUP_DATA].as_slice();
76 let data = unsafe { std::slice::from_raw_parts(data.as_ptr(), data.len()) };
77 let accum = groups[REGISTER_GROUP_ACCUM].as_slice();
78 let accum = unsafe { std::slice::from_raw_parts(accum.as_ptr(), accum.len()) };
79 let mix = globals[GLOBAL_MIX].as_slice();
80 let mix = unsafe { std::slice::from_raw_parts(mix.as_ptr(), mix.len()) };
81 let out = globals[GLOBAL_OUT].as_slice();
82 let out = unsafe { std::slice::from_raw_parts(out.as_ptr(), out.len()) };
83 let check = check.as_slice();
84 let check = unsafe { std::slice::from_raw_parts(check.as_ptr(), check.len()) };
85 let poly_mix_pows = poly_mix_pows.as_slice();
86
87 let args: &[&[BabyBearElem]] = &[code, out, data, mix, accum];
88
89 (0..domain).into_par_iter().for_each(|cycle| {
90 let tot = self.circuit.poly_fp(cycle, domain, poly_mix_pows, args);
91 let x = BabyBearElem::ROU_FWD[po2 + EXP_PO2].pow(cycle);
92 let y = (BabyBearElem::new(3) * x).pow(1 << po2);
94 let ret = tot * (y - BabyBearElem::new(1)).inv();
95
96 let check = unsafe {
99 std::slice::from_raw_parts_mut(check.as_ptr() as *mut BabyBearElem, check.len())
100 };
101 for i in 0..BabyBearExtElem::EXT_SIZE {
102 check[i * domain + cycle] = ret.elems()[i];
103 }
104 });
105 }
106
107 fn accumulate(
108 &self,
109 _preflight: &AccumPreflight,
110 ctrl: &CpuBuffer<BabyBearElem>,
111 io: &CpuBuffer<BabyBearElem>,
112 data: &CpuBuffer<BabyBearElem>,
113 mix: &CpuBuffer<BabyBearElem>,
114 accum: &CpuBuffer<BabyBearElem>,
115 steps: usize,
116 ) {
117 {
118 let args = &[
119 ctrl.as_slice_sync(),
120 io.as_slice_sync(),
121 data.as_slice_sync(),
122 mix.as_slice_sync(),
123 accum.as_slice_sync(),
124 ];
125
126 let accumulator: Mutex<Accum<BabyBearExtElem>> = Mutex::new(Accum::new(steps));
127 scope!("step_compute_accum", {
128 (0..steps - ZK_CYCLES).into_par_iter().for_each_init(
129 || Handler::<BabyBear>::new(&accumulator),
130 |handler, cycle| {
131 self.circuit
132 .step_compute_accum(
133 &CircuitStepContext { size: steps, cycle },
134 handler,
135 args,
136 )
137 .unwrap();
138 },
139 );
140 });
141 scope!("calc_prefix_products", {
142 accumulator.lock().unwrap().calc_prefix_products();
143 });
144 scope!("step_verify_accum", {
145 (0..steps - ZK_CYCLES).into_par_iter().for_each_init(
146 || Handler::<BabyBear>::new(&accumulator),
147 |handler, cycle| {
148 self.circuit
149 .step_verify_accum(
150 &CircuitStepContext { size: steps, cycle },
151 handler,
152 args,
153 )
154 .unwrap();
155 },
156 );
157 });
158 }
159
160 {
161 let mut accum_slice = accum.as_slice_mut();
163 let mut io = io.as_slice_mut();
164 for value in accum_slice.iter_mut().chain(io.iter_mut()) {
165 *value = value.valid_or_zero();
166 }
167 }
168 }
169}