1mod exec;
20mod plonk;
21mod preflight;
22mod program;
23pub mod zkr;
24
25use std::{collections::VecDeque, fmt::Debug, mem::take, rc::Rc};
26
27use crate::{
28 cpu::CpuCircuitHal, CircuitImpl, CIRCUIT, REGISTER_GROUP_ACCUM, REGISTER_GROUP_CTRL,
29 REGISTER_GROUP_DATA,
30};
31use anyhow::{bail, Result};
32use rand::thread_rng;
33use risc0_core::scope;
34use risc0_zkp::{
35 adapter::{CircuitInfo, CircuitStepContext, TapsProvider, PROOF_SYSTEM_INFO},
36 core::{digest::Digest, hash::poseidon2::Poseidon2HashSuite},
37 field::{
38 baby_bear::{BabyBear, BabyBearElem, BabyBearExtElem},
39 Elem,
40 },
41 hal::{cpu::CpuHal, AccumPreflight, CircuitHal, Hal},
42 prove::adapter::ProveAdapter,
43 ZK_CYCLES,
44};
45use serde::{Deserialize, Serialize};
46
47use self::exec::RecursionExecutor;
48pub use self::program::Program;
49
50#[derive(Clone)]
52pub struct HalPair<H, C>
53where
54 H: Hal<Field = BabyBear, Elem = BabyBearElem, ExtElem = BabyBearExtElem>,
55 C: CircuitHal<H>,
56{
57 pub hal: Rc<H>,
59
60 pub circuit_hal: Rc<C>,
62}
63
64const RECURSION_CODE_SIZE: usize = 23;
68
69#[derive(Clone, Debug, Serialize, Deserialize)]
70pub struct RecursionReceipt {
71 pub seal: Vec<u32>,
72 pub output: Vec<u32>,
73}
74
75impl RecursionReceipt {
76 pub fn seal_size(&self) -> usize {
78 core::mem::size_of_val(self.seal.as_slice())
79 }
80}
81
82pub struct Prover {
84 program: Program,
85 hashfn: String,
86 input: VecDeque<u32>,
87 split_points: Vec<usize>,
88 output: Vec<u32>,
89}
90
91#[cfg(feature = "cuda")]
92mod cuda {
93 pub use crate::cuda::{CudaCircuitHalPoseidon2, CudaCircuitHalSha256};
94 pub use risc0_zkp::hal::cuda::{CudaHalPoseidon2, CudaHalSha256};
95
96 use super::{HalPair, Rc};
97
98 pub fn sha256_hal_pair() -> HalPair<CudaHalSha256, CudaCircuitHalSha256> {
99 let hal = Rc::new(CudaHalSha256::new());
100 let circuit_hal = Rc::new(CudaCircuitHalSha256::new(hal.clone()));
101 HalPair { hal, circuit_hal }
102 }
103
104 pub fn poseidon2_hal_pair() -> HalPair<CudaHalPoseidon2, CudaCircuitHalPoseidon2> {
105 let hal = Rc::new(CudaHalPoseidon2::new());
106 let circuit_hal = Rc::new(CudaCircuitHalPoseidon2::new(hal.clone()));
107 HalPair { hal, circuit_hal }
108 }
109}
110
111#[cfg(any(all(target_os = "macos", target_arch = "aarch64"), target_os = "ios"))]
112mod metal {
113 pub use crate::metal::MetalCircuitHal;
114 pub use risc0_zkp::hal::metal::{
115 MetalHalPoseidon2, MetalHalSha256, MetalHashPoseidon2, MetalHashSha256,
116 };
117
118 use super::{HalPair, Rc};
119
120 pub fn sha256_hal_pair() -> HalPair<MetalHalSha256, MetalCircuitHal<MetalHashSha256>> {
121 let hal = Rc::new(MetalHalSha256::new());
122 let circuit_hal = Rc::new(MetalCircuitHal::<MetalHashSha256>::new(hal.clone()));
123 HalPair { hal, circuit_hal }
124 }
125
126 pub fn poseidon2_hal_pair() -> HalPair<MetalHalPoseidon2, MetalCircuitHal<MetalHashPoseidon2>> {
127 let hal = Rc::new(MetalHalPoseidon2::new());
128 let circuit_hal = Rc::new(MetalCircuitHal::<MetalHashPoseidon2>::new(hal.clone()));
129 HalPair { hal, circuit_hal }
130 }
131}
132
133mod cpu {
134 use risc0_zkp::core::hash::{poseidon_254::Poseidon254HashSuite, sha::Sha256HashSuite};
135
136 use super::{
137 BabyBear, CircuitImpl, CpuCircuitHal, CpuHal, HalPair, Poseidon2HashSuite, Rc, CIRCUIT,
138 };
139
140 #[allow(dead_code)]
141 pub fn sha256_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>> {
142 let hal = Rc::new(CpuHal::new(Sha256HashSuite::new_suite()));
143 let circuit_hal = Rc::new(CpuCircuitHal::new(&CIRCUIT));
144 HalPair { hal, circuit_hal }
145 }
146
147 #[allow(dead_code)]
148 pub fn poseidon2_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>> {
149 let hal = Rc::new(CpuHal::new(Poseidon2HashSuite::new_suite()));
150 let circuit_hal = Rc::new(CpuCircuitHal::new(&CIRCUIT));
151 HalPair { hal, circuit_hal }
152 }
153
154 #[allow(dead_code)]
155 pub fn poseidon254_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>>
156 {
157 let hal = Rc::new(CpuHal::new(Poseidon254HashSuite::new_suite()));
158 let circuit_hal = Rc::new(CpuCircuitHal::new(&CIRCUIT));
159 HalPair { hal, circuit_hal }
160 }
161}
162
163cfg_if::cfg_if! {
164 if #[cfg(feature = "cuda")] {
165 #[allow(dead_code)]
167 pub fn sha256_hal_pair() -> HalPair<cuda::CudaHalSha256, cuda::CudaCircuitHalSha256> {
168 cuda::sha256_hal_pair()
169 }
170
171 #[allow(dead_code)]
173 pub fn poseidon2_hal_pair() -> HalPair<cuda::CudaHalPoseidon2, cuda::CudaCircuitHalPoseidon2> {
174 cuda::poseidon2_hal_pair()
175 }
176
177 #[allow(dead_code)]
179 pub fn poseidon254_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>> {
180 cpu::poseidon254_hal_pair()
181 }
182 } else if #[cfg(any(all(target_os = "macos", target_arch = "aarch64"), target_os = "ios"))] {
183 #[allow(dead_code)]
185 pub fn sha256_hal_pair() -> HalPair<metal::MetalHalSha256, metal::MetalCircuitHal<metal::MetalHashSha256>> {
186 metal::sha256_hal_pair()
187 }
188
189 #[allow(dead_code)]
191 pub fn poseidon2_hal_pair() -> HalPair<metal::MetalHalPoseidon2, metal::MetalCircuitHal<metal::MetalHashPoseidon2>> {
192 metal::poseidon2_hal_pair()
193 }
194
195 #[allow(dead_code)]
197 pub fn poseidon254_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>> {
198 cpu::poseidon254_hal_pair()
199 }
200 } else {
201 #[allow(dead_code)]
203 pub fn sha256_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>> {
204 cpu::sha256_hal_pair()
205 }
206
207 #[allow(dead_code)]
209 pub fn poseidon2_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>> {
210 cpu::poseidon2_hal_pair()
211 }
212
213 #[allow(dead_code)]
215 pub fn poseidon254_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>> {
216 cpu::poseidon254_hal_pair()
217 }
218 }
219}
220
221#[non_exhaustive]
225pub enum DigestKind {
226 Poseidon2,
227 Sha256,
228}
229
230impl Prover {
231 pub fn new(program: Program, hashfn: &str) -> Self {
235 Self {
236 program,
237 hashfn: hashfn.to_string(),
238 input: VecDeque::new(),
239 split_points: Vec::new(),
240 output: Vec::new(),
241 }
242 }
243
244 pub fn add_input(&mut self, input: &[u32]) {
246 self.input.extend(input);
247 }
248
249 pub fn add_input_digest(&mut self, digest: &Digest, kind: DigestKind) {
251 match kind {
252 DigestKind::Poseidon2 => self.add_input(digest.as_words()),
254 DigestKind::Sha256 => self.add_input(bytemuck::cast_slice(
256 &digest
257 .as_words()
258 .iter()
259 .copied()
260 .flat_map(|x| [x & 0xffff, x >> 16])
261 .map(BabyBearElem::new)
262 .collect::<Vec<_>>(),
263 )),
264 }
265 }
266
267 pub fn run(&mut self) -> Result<RecursionReceipt> {
270 match self.hashfn.as_ref() {
272 "poseidon2" => {
273 let hal_pair = poseidon2_hal_pair();
274 let (hal, circuit_hal) = (hal_pair.hal.as_ref(), hal_pair.circuit_hal.as_ref());
275 self.run_with_hal(hal, circuit_hal)
276 }
277 "poseidon_254" => {
278 let hal_pair = poseidon254_hal_pair();
279 let (hal, circuit_hal) = (hal_pair.hal.as_ref(), hal_pair.circuit_hal.as_ref());
280 self.run_with_hal(hal, circuit_hal)
281 }
282 "sha-256" => {
283 let hal_pair = sha256_hal_pair();
284 let (hal, circuit_hal) = (hal_pair.hal.as_ref(), hal_pair.circuit_hal.as_ref());
285 self.run_with_hal(hal, circuit_hal)
286 }
287 _ => bail!("no hal found for {}", self.hashfn),
288 }
289 }
290
291 pub fn run_with_hal<H, C>(&mut self, hal: &H, circuit_hal: &C) -> Result<RecursionReceipt>
294 where
295 H: Hal<Field = BabyBear, Elem = BabyBearElem, ExtElem = BabyBearExtElem>,
296 C: CircuitHal<H>,
297 {
298 scope!("run_with_hal");
299
300 let machine_ctx = self.preflight()?;
301
302 let split_points = core::mem::take(&mut self.split_points);
303
304 let mut executor = scope!("witgen", {
305 let mut executor =
306 RecursionExecutor::new(&CIRCUIT, &self.program, machine_ctx, split_points);
307 executor.run()?;
308 Result::<RecursionExecutor, anyhow::Error>::Ok(executor)
309 })?;
310
311 let seal = scope!("prove", {
312 let mut adapter = ProveAdapter::new(&mut executor.executor);
313 let mut prover = risc0_zkp::prove::Prover::new(hal, CIRCUIT.get_taps());
314 let hashfn = Rc::clone(&hal.get_hash_suite().hashfn);
315
316 let (mix, io) = scope!("main", {
317 prover
320 .iop()
321 .commit(&hashfn.hash_elem_slice(&PROOF_SYSTEM_INFO.encode()));
322 prover
323 .iop()
324 .commit(&hashfn.hash_elem_slice(&CircuitImpl::CIRCUIT_INFO.encode()));
325
326 adapter.execute(prover.iop(), hal);
327
328 prover.set_po2(adapter.po2() as usize);
329
330 let ctrl = hal.copy_from_elem("ctrl", &adapter.get_code().as_slice());
331 prover.commit_group(REGISTER_GROUP_CTRL, &ctrl);
332
333 let data = hal.copy_from_elem("data", &adapter.get_data().as_slice());
334 prover.commit_group(REGISTER_GROUP_DATA, &data);
335
336 let mix = scope!("alloc+copy(mix)", {
338 let mix: Vec<_> = (0..CircuitImpl::MIX_SIZE)
339 .map(|_| prover.iop().random_elem())
340 .collect();
341 hal.copy_from_elem("mix", mix.as_slice())
342 });
343
344 let steps = adapter.get_steps();
345
346 let accum = scope!(
347 "alloc(accum)",
348 hal.alloc_elem_init(
349 "accum",
350 steps * CIRCUIT.accum_size(),
351 BabyBearElem::INVALID,
352 )
353 );
354
355 scope!("noise(accum)", {
357 let mut rng = thread_rng();
358 let noise =
359 vec![BabyBearElem::random(&mut rng); ZK_CYCLES * CIRCUIT.accum_size()];
360 hal.eltwise_copy_elem_slice(
361 &accum,
362 &noise,
363 CIRCUIT.accum_size(), ZK_CYCLES, 0, ZK_CYCLES, steps - ZK_CYCLES, steps, );
370 });
371
372 let io = scope!(
373 "copy(io)",
374 hal.copy_from_elem("io", &adapter.get_io().as_slice())
375 );
376
377 let preflight = AccumPreflight::default();
379 circuit_hal.accumulate(&preflight, &ctrl, &io, &data, &mix, &accum, steps);
380
381 prover.commit_group(REGISTER_GROUP_ACCUM, &accum);
382
383 (mix, io)
384 });
385
386 prover.finalize(&[&mix, &io], circuit_hal)
387 });
388
389 Ok(RecursionReceipt {
390 seal,
391 output: self.output.clone(),
392 })
393 }
394
395 fn preflight(&mut self) -> Result<exec::MachineContext> {
396 scope!("preflight");
397
398 let mut machine = exec::MachineContext::new(take(&mut self.input));
399 let mut preflight = preflight::Preflight::new(&mut machine);
400 let size = (1 << self.program.po2) - ZK_CYCLES;
401
402 for (cycle, row) in self.program.code_by_row().enumerate() {
403 let ctx = CircuitStepContext { cycle, size };
404
405 preflight.set_top(&ctx, row)?
406 }
407
408 let zero_row = vec![BabyBearElem::ZERO; self.program.code_size];
410 for cycle in self.program.code_rows()..size {
411 let ctx = CircuitStepContext { cycle, size };
412
413 preflight.set_top(&ctx, &zero_row)?
414 }
415
416 self.split_points = preflight.split_points;
417 self.split_points.push(size);
418 self.output = preflight.output;
419 (machine.iop_reads, machine.byte_reads) = (preflight.iop_reads, preflight.byte_reads);
420 Ok(machine)
421 }
422}