risc0_circuit_recursion/
cpu.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Mutex;
16
17use rayon::prelude::*;
18use risc0_core::scope;
19use risc0_zkp::{
20    adapter::{CircuitStep, CircuitStepContext, PolyFp},
21    core::log2_ceil,
22    field::{
23        baby_bear::{BabyBear, BabyBearElem, BabyBearExtElem},
24        map_pow, Elem, ExtElem, RootsOfUnity,
25    },
26    hal::{cpu::CpuBuffer, AccumPreflight, CircuitHal, Hal},
27    prove::accum::{Accum, Handler},
28    INV_RATE, ZK_CYCLES,
29};
30
31use crate::{
32    GLOBAL_MIX, GLOBAL_OUT, REGISTER_GROUP_ACCUM, REGISTER_GROUP_CTRL, REGISTER_GROUP_DATA,
33};
34
35pub struct CpuCircuitHal<'a, C: PolyFp<BabyBear>> {
36    circuit: &'a C,
37}
38
39impl<'a, C: PolyFp<BabyBear>> CpuCircuitHal<'a, C> {
40    pub fn new(circuit: &'a C) -> Self {
41        Self { circuit }
42    }
43}
44
45impl<C, H> CircuitHal<H> for CpuCircuitHal<'_, C>
46where
47    C: PolyFp<BabyBear> + Sync + CircuitStep<BabyBearElem>,
48    H: Hal<
49        Elem = BabyBearElem,
50        ExtElem = BabyBearExtElem,
51        Buffer<<H as Hal>::Elem> = CpuBuffer<BabyBearElem>,
52    >,
53{
54    fn eval_check(
55        &self,
56        check: &CpuBuffer<BabyBearElem>,
57        groups: &[&CpuBuffer<BabyBearElem>],
58        globals: &[&CpuBuffer<BabyBearElem>],
59        poly_mix: BabyBearExtElem,
60        po2: usize,
61        steps: usize,
62    ) {
63        const EXP_PO2: usize = log2_ceil(INV_RATE);
64        let domain = steps * INV_RATE;
65
66        let poly_mix_pows = map_pow(poly_mix, crate::info::POLY_MIX_POWERS);
67
68        // SAFETY: Convert a borrow of a cell into a raw const slice so that we can pass
69        // it over the thread boundary. This should be safe because the scope of the
70        // usage is within this function and each thread access will not overlap with
71        // each other.
72
73        let code = groups[REGISTER_GROUP_CTRL].as_slice();
74        let code = unsafe { std::slice::from_raw_parts(code.as_ptr(), code.len()) };
75        let data = groups[REGISTER_GROUP_DATA].as_slice();
76        let data = unsafe { std::slice::from_raw_parts(data.as_ptr(), data.len()) };
77        let accum = groups[REGISTER_GROUP_ACCUM].as_slice();
78        let accum = unsafe { std::slice::from_raw_parts(accum.as_ptr(), accum.len()) };
79        let mix = globals[GLOBAL_MIX].as_slice();
80        let mix = unsafe { std::slice::from_raw_parts(mix.as_ptr(), mix.len()) };
81        let out = globals[GLOBAL_OUT].as_slice();
82        let out = unsafe { std::slice::from_raw_parts(out.as_ptr(), out.len()) };
83        let check = check.as_slice();
84        let check = unsafe { std::slice::from_raw_parts(check.as_ptr(), check.len()) };
85        let poly_mix_pows = poly_mix_pows.as_slice();
86
87        let args: &[&[BabyBearElem]] = &[code, out, data, mix, accum];
88
89        (0..domain).into_par_iter().for_each(|cycle| {
90            let tot = self.circuit.poly_fp(cycle, domain, poly_mix_pows, args);
91            let x = BabyBearElem::ROU_FWD[po2 + EXP_PO2].pow(cycle);
92            // TODO: what is this magic number 3?
93            let y = (BabyBearElem::new(3) * x).pow(1 << po2);
94            let ret = tot * (y - BabyBearElem::new(1)).inv();
95
96            // SAFETY: This conversion is to make the check slice mutable, which should be
97            // safe because each thread access will not overlap with each other.
98            let check = unsafe {
99                std::slice::from_raw_parts_mut(check.as_ptr() as *mut BabyBearElem, check.len())
100            };
101            for i in 0..BabyBearExtElem::EXT_SIZE {
102                check[i * domain + cycle] = ret.elems()[i];
103            }
104        });
105    }
106
107    fn accumulate(
108        &self,
109        _preflight: &AccumPreflight,
110        ctrl: &CpuBuffer<BabyBearElem>,
111        io: &CpuBuffer<BabyBearElem>,
112        data: &CpuBuffer<BabyBearElem>,
113        mix: &CpuBuffer<BabyBearElem>,
114        accum: &CpuBuffer<BabyBearElem>,
115        steps: usize,
116    ) {
117        {
118            let args = &[
119                ctrl.as_slice_sync(),
120                io.as_slice_sync(),
121                data.as_slice_sync(),
122                mix.as_slice_sync(),
123                accum.as_slice_sync(),
124            ];
125
126            let accumulator: Mutex<Accum<BabyBearExtElem>> = Mutex::new(Accum::new(steps));
127            scope!("step_compute_accum", {
128                (0..steps - ZK_CYCLES).into_par_iter().for_each_init(
129                    || Handler::<BabyBear>::new(&accumulator),
130                    |handler, cycle| {
131                        self.circuit
132                            .step_compute_accum(
133                                &CircuitStepContext { size: steps, cycle },
134                                handler,
135                                args,
136                            )
137                            .unwrap();
138                    },
139                );
140            });
141            scope!("calc_prefix_products", {
142                accumulator.lock().unwrap().calc_prefix_products();
143            });
144            scope!("step_verify_accum", {
145                (0..steps - ZK_CYCLES).into_par_iter().for_each_init(
146                    || Handler::<BabyBear>::new(&accumulator),
147                    |handler, cycle| {
148                        self.circuit
149                            .step_verify_accum(
150                                &CircuitStepContext { size: steps, cycle },
151                                handler,
152                                args,
153                            )
154                            .unwrap();
155                    },
156                );
157            });
158        }
159
160        {
161            // Zero out 'invalid' entries in accum and io
162            let mut accum_slice = accum.as_slice_mut();
163            let mut io = io.as_slice_mut();
164            for value in accum_slice.iter_mut().chain(io.iter_mut()) {
165                *value = value.valid_or_zero();
166            }
167        }
168    }
169}