risc0_circuit_recursion/
cpp.rs

1// Copyright 2024 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::ffi::{c_void, CStr};
16
17use anyhow::{anyhow, Result};
18use risc0_circuit_recursion_sys::ffi::{
19    get_trampoline, risc0_circuit_recursion_poly_fp, risc0_circuit_recursion_step_compute_accum,
20    risc0_circuit_recursion_step_exec, risc0_circuit_recursion_step_verify_accum,
21    risc0_circuit_recursion_step_verify_bytes, risc0_circuit_recursion_step_verify_mem,
22    risc0_circuit_recursion_string_free, risc0_circuit_recursion_string_ptr, Callback, RawError,
23};
24use risc0_zkp::{
25    adapter::{CircuitProveDef, CircuitStep, CircuitStepContext, CircuitStepHandler, PolyFp},
26    field::{
27        baby_bear::{BabyBear, BabyBearElem, BabyBearExtElem},
28        Elem,
29    },
30    hal::cpu::SyncSlice,
31};
32
33use crate::CircuitImpl;
34
35impl CircuitStep<BabyBearElem> for CircuitImpl {
36    fn step_compute_accum<S: CircuitStepHandler<BabyBearElem>>(
37        &self,
38        ctx: &CircuitStepContext,
39        handler: &mut S,
40        args: &[SyncSlice<BabyBearElem>],
41    ) -> Result<BabyBearElem> {
42        call_step(
43            ctx,
44            handler,
45            args,
46            |err, ctx, trampoline, size, cycle, args_ptr, args_len| unsafe {
47                risc0_circuit_recursion_step_compute_accum(
48                    err, ctx, trampoline, size, cycle, args_ptr, args_len,
49                )
50            },
51        )
52    }
53
54    fn step_verify_accum<S: CircuitStepHandler<BabyBearElem>>(
55        &self,
56        ctx: &CircuitStepContext,
57        handler: &mut S,
58        args: &[SyncSlice<BabyBearElem>],
59    ) -> Result<BabyBearElem> {
60        call_step(
61            ctx,
62            handler,
63            args,
64            |err, ctx, trampoline, size, cycle, args_ptr, args_len| unsafe {
65                risc0_circuit_recursion_step_verify_accum(
66                    err, ctx, trampoline, size, cycle, args_ptr, args_len,
67                )
68            },
69        )
70    }
71
72    fn step_exec<S: CircuitStepHandler<BabyBearElem>>(
73        &self,
74        ctx: &CircuitStepContext,
75        handler: &mut S,
76        args: &[SyncSlice<BabyBearElem>],
77    ) -> Result<BabyBearElem> {
78        call_step(
79            ctx,
80            handler,
81            args,
82            |err, ctx, trampoline, size, cycle, args_ptr, args_len| unsafe {
83                risc0_circuit_recursion_step_exec(
84                    err, ctx, trampoline, size, cycle, args_ptr, args_len,
85                )
86            },
87        )
88    }
89
90    fn step_verify_bytes<S: CircuitStepHandler<BabyBearElem>>(
91        &self,
92        ctx: &CircuitStepContext,
93        handler: &mut S,
94        args: &[SyncSlice<BabyBearElem>],
95    ) -> Result<BabyBearElem> {
96        call_step(
97            ctx,
98            handler,
99            args,
100            |err, ctx, trampoline, size, cycle, args_ptr, args_len| unsafe {
101                risc0_circuit_recursion_step_verify_bytes(
102                    err, ctx, trampoline, size, cycle, args_ptr, args_len,
103                )
104            },
105        )
106    }
107
108    fn step_verify_mem<S: CircuitStepHandler<BabyBearElem>>(
109        &self,
110        ctx: &CircuitStepContext,
111        handler: &mut S,
112        args: &[SyncSlice<BabyBearElem>],
113    ) -> Result<BabyBearElem> {
114        call_step(
115            ctx,
116            handler,
117            args,
118            |err, ctx, trampoline, size, cycle, args_ptr, args_len| unsafe {
119                risc0_circuit_recursion_step_verify_mem(
120                    err, ctx, trampoline, size, cycle, args_ptr, args_len,
121                )
122            },
123        )
124    }
125}
126
127impl PolyFp<BabyBear> for CircuitImpl {
128    fn poly_fp(
129        &self,
130        cycle: usize,
131        steps: usize,
132        mix: &[BabyBearExtElem],
133        args: &[&[BabyBearElem]],
134    ) -> BabyBearExtElem {
135        let args: Vec<*const BabyBearElem> = args.iter().map(|x| (*x).as_ptr()).collect();
136        let mut err = RawError::default();
137        let mut result = BabyBearExtElem::ZERO;
138        unsafe {
139            risc0_circuit_recursion_poly_fp(
140                &mut err,
141                cycle,
142                steps,
143                mix.as_ptr(),
144                args.as_ptr(),
145                &mut result,
146            )
147        };
148        if err.msg.is_null() {
149            Ok(result)
150        } else {
151            let what = unsafe {
152                let str = risc0_circuit_recursion_string_ptr(err.msg);
153                let msg = CStr::from_ptr(str).to_str().unwrap().to_string();
154                risc0_circuit_recursion_string_free(err.msg);
155                msg
156            };
157            Err(anyhow!(what))
158        }
159        .unwrap()
160    }
161}
162
163impl CircuitProveDef<BabyBear> for CircuitImpl {}
164
165pub(crate) fn call_step<S, F>(
166    ctx: &CircuitStepContext,
167    handler: &mut S,
168    args: &[SyncSlice<BabyBearElem>],
169    inner: F,
170) -> Result<BabyBearElem>
171where
172    S: CircuitStepHandler<BabyBearElem>,
173    F: FnOnce(
174        *mut RawError,
175        *mut c_void,
176        Callback,
177        usize,
178        usize,
179        *const *mut BabyBearElem,
180        usize,
181    ) -> BabyBearElem,
182{
183    let mut last_err = None;
184    let mut call =
185        |name: &str, extra: &str, args: &[BabyBearElem], outs: &mut [BabyBearElem]| match handler
186            .call(ctx.cycle, name, extra, args, outs)
187        {
188            Ok(()) => true,
189            Err(err) => {
190                last_err = Some(err);
191                false
192            }
193        };
194    let trampoline = get_trampoline(&call);
195    let mut err = RawError::default();
196    let args: Vec<*mut BabyBearElem> = args.iter().map(SyncSlice::get_ptr).collect();
197    let result = inner(
198        &mut err,
199        &mut call as *mut _ as *mut c_void,
200        trampoline,
201        ctx.size,
202        ctx.cycle,
203        args.as_ptr(),
204        args.len(),
205    );
206    if let Some(err) = last_err {
207        return Err(err);
208    }
209    if err.msg.is_null() {
210        Ok(result)
211    } else {
212        let what = unsafe {
213            let str = risc0_circuit_recursion_string_ptr(err.msg);
214            let msg = CStr::from_ptr(str).to_str().unwrap().to_string();
215            risc0_circuit_recursion_string_free(err.msg);
216            msg
217        };
218        Err(anyhow!(what))
219    }
220}