1use core::ffi::{c_void, CStr};
16
17use anyhow::{anyhow, Result};
18use risc0_circuit_recursion_sys::ffi::{
19 get_trampoline, risc0_circuit_recursion_poly_fp, risc0_circuit_recursion_step_compute_accum,
20 risc0_circuit_recursion_step_exec, risc0_circuit_recursion_step_verify_accum,
21 risc0_circuit_recursion_step_verify_bytes, risc0_circuit_recursion_step_verify_mem,
22 risc0_circuit_recursion_string_free, risc0_circuit_recursion_string_ptr, Callback, RawError,
23};
24use risc0_zkp::{
25 adapter::{CircuitProveDef, CircuitStep, CircuitStepContext, CircuitStepHandler, PolyFp},
26 field::{
27 baby_bear::{BabyBear, BabyBearElem, BabyBearExtElem},
28 Elem,
29 },
30 hal::cpu::SyncSlice,
31};
32
33use crate::CircuitImpl;
34
35impl CircuitStep<BabyBearElem> for CircuitImpl {
36 fn step_compute_accum<S: CircuitStepHandler<BabyBearElem>>(
37 &self,
38 ctx: &CircuitStepContext,
39 handler: &mut S,
40 args: &[SyncSlice<BabyBearElem>],
41 ) -> Result<BabyBearElem> {
42 call_step(
43 ctx,
44 handler,
45 args,
46 |err, ctx, trampoline, size, cycle, args_ptr, args_len| unsafe {
47 risc0_circuit_recursion_step_compute_accum(
48 err, ctx, trampoline, size, cycle, args_ptr, args_len,
49 )
50 },
51 )
52 }
53
54 fn step_verify_accum<S: CircuitStepHandler<BabyBearElem>>(
55 &self,
56 ctx: &CircuitStepContext,
57 handler: &mut S,
58 args: &[SyncSlice<BabyBearElem>],
59 ) -> Result<BabyBearElem> {
60 call_step(
61 ctx,
62 handler,
63 args,
64 |err, ctx, trampoline, size, cycle, args_ptr, args_len| unsafe {
65 risc0_circuit_recursion_step_verify_accum(
66 err, ctx, trampoline, size, cycle, args_ptr, args_len,
67 )
68 },
69 )
70 }
71
72 fn step_exec<S: CircuitStepHandler<BabyBearElem>>(
73 &self,
74 ctx: &CircuitStepContext,
75 handler: &mut S,
76 args: &[SyncSlice<BabyBearElem>],
77 ) -> Result<BabyBearElem> {
78 call_step(
79 ctx,
80 handler,
81 args,
82 |err, ctx, trampoline, size, cycle, args_ptr, args_len| unsafe {
83 risc0_circuit_recursion_step_exec(
84 err, ctx, trampoline, size, cycle, args_ptr, args_len,
85 )
86 },
87 )
88 }
89
90 fn step_verify_bytes<S: CircuitStepHandler<BabyBearElem>>(
91 &self,
92 ctx: &CircuitStepContext,
93 handler: &mut S,
94 args: &[SyncSlice<BabyBearElem>],
95 ) -> Result<BabyBearElem> {
96 call_step(
97 ctx,
98 handler,
99 args,
100 |err, ctx, trampoline, size, cycle, args_ptr, args_len| unsafe {
101 risc0_circuit_recursion_step_verify_bytes(
102 err, ctx, trampoline, size, cycle, args_ptr, args_len,
103 )
104 },
105 )
106 }
107
108 fn step_verify_mem<S: CircuitStepHandler<BabyBearElem>>(
109 &self,
110 ctx: &CircuitStepContext,
111 handler: &mut S,
112 args: &[SyncSlice<BabyBearElem>],
113 ) -> Result<BabyBearElem> {
114 call_step(
115 ctx,
116 handler,
117 args,
118 |err, ctx, trampoline, size, cycle, args_ptr, args_len| unsafe {
119 risc0_circuit_recursion_step_verify_mem(
120 err, ctx, trampoline, size, cycle, args_ptr, args_len,
121 )
122 },
123 )
124 }
125}
126
127impl PolyFp<BabyBear> for CircuitImpl {
128 fn poly_fp(
129 &self,
130 cycle: usize,
131 steps: usize,
132 mix: &[BabyBearExtElem],
133 args: &[&[BabyBearElem]],
134 ) -> BabyBearExtElem {
135 let args: Vec<*const BabyBearElem> = args.iter().map(|x| (*x).as_ptr()).collect();
136 let mut err = RawError::default();
137 let mut result = BabyBearExtElem::ZERO;
138 unsafe {
139 risc0_circuit_recursion_poly_fp(
140 &mut err,
141 cycle,
142 steps,
143 mix.as_ptr(),
144 args.as_ptr(),
145 &mut result,
146 )
147 };
148 if err.msg.is_null() {
149 Ok(result)
150 } else {
151 let what = unsafe {
152 let str = risc0_circuit_recursion_string_ptr(err.msg);
153 let msg = CStr::from_ptr(str).to_str().unwrap().to_string();
154 risc0_circuit_recursion_string_free(err.msg);
155 msg
156 };
157 Err(anyhow!(what))
158 }
159 .unwrap()
160 }
161}
162
163impl CircuitProveDef<BabyBear> for CircuitImpl {}
164
165pub(crate) fn call_step<S, F>(
166 ctx: &CircuitStepContext,
167 handler: &mut S,
168 args: &[SyncSlice<BabyBearElem>],
169 inner: F,
170) -> Result<BabyBearElem>
171where
172 S: CircuitStepHandler<BabyBearElem>,
173 F: FnOnce(
174 *mut RawError,
175 *mut c_void,
176 Callback,
177 usize,
178 usize,
179 *const *mut BabyBearElem,
180 usize,
181 ) -> BabyBearElem,
182{
183 let mut last_err = None;
184 let mut call =
185 |name: &str, extra: &str, args: &[BabyBearElem], outs: &mut [BabyBearElem]| match handler
186 .call(ctx.cycle, name, extra, args, outs)
187 {
188 Ok(()) => true,
189 Err(err) => {
190 last_err = Some(err);
191 false
192 }
193 };
194 let trampoline = get_trampoline(&call);
195 let mut err = RawError::default();
196 let args: Vec<*mut BabyBearElem> = args.iter().map(SyncSlice::get_ptr).collect();
197 let result = inner(
198 &mut err,
199 &mut call as *mut _ as *mut c_void,
200 trampoline,
201 ctx.size,
202 ctx.cycle,
203 args.as_ptr(),
204 args.len(),
205 );
206 if let Some(err) = last_err {
207 return Err(err);
208 }
209 if err.msg.is_null() {
210 Ok(result)
211 } else {
212 let what = unsafe {
213 let str = risc0_circuit_recursion_string_ptr(err.msg);
214 let msg = CStr::from_ptr(str).to_str().unwrap().to_string();
215 risc0_circuit_recursion_string_free(err.msg);
216 msg
217 };
218 Err(anyhow!(what))
219 }
220}