risc0_circuit_recursion/prove/
mod.rs1mod hal;
20mod preflight;
21mod program;
22mod witgen;
23pub mod zkr;
24
25use std::{collections::VecDeque, fmt::Debug, rc::Rc};
26
27use anyhow::Result;
28use cfg_if::cfg_if;
29use risc0_core::scope;
30use risc0_zkp::{
31 adapter::{CircuitInfo, PROOF_SYSTEM_INFO},
32 core::digest::Digest,
33 field::{
34 baby_bear::{BabyBear, BabyBearElem, BabyBearExtElem},
35 Elem as _,
36 },
37 hal::{Buffer, CircuitHal, Hal},
38};
39use serde::{Deserialize, Serialize};
40
41use self::{
42 hal::{CircuitAccumulator, CircuitWitnessGenerator},
43 preflight::Preflight,
44 witgen::WitnessGenerator,
45};
46use crate::{
47 taps::TAPSET, CircuitImpl, REGISTER_GROUP_ACCUM, REGISTER_GROUP_CTRL, REGISTER_GROUP_DATA,
48};
49
50pub use self::program::Program;
51
52const RECURSION_CODE_SIZE: usize = 23;
56
57#[derive(Clone, Debug, Serialize, Deserialize)]
58#[non_exhaustive]
59pub struct RecursionReceipt {
60 pub seal: Vec<u32>,
61 pub output: Vec<u32>,
62}
63
64impl RecursionReceipt {
65 pub fn seal_size(&self) -> usize {
67 core::mem::size_of_val(self.seal.as_slice())
68 }
69
70 pub fn out_stream(&self) -> VecDeque<u32> {
72 let mut vec: VecDeque<u32> = VecDeque::new();
73 vec.extend(self.output.iter());
74 vec
75 }
76}
77
78pub trait RecursionProver {
79 fn prove(&self, program: Program, input: VecDeque<u32>) -> Result<RecursionReceipt>;
80}
81
82pub fn recursion_prover(hashfn: &str) -> Result<Box<dyn RecursionProver>> {
83 cfg_if! {
84 if #[cfg(feature = "cuda")] {
85 self::hal::cuda::recursion_prover(hashfn)
86 } else {
89 self::hal::cpu::recursion_prover(hashfn)
90 }
91 }
92}
93
94pub struct Prover {
96 program: Program,
97 hashfn: String,
98 input: VecDeque<u32>,
99}
100
101#[non_exhaustive]
105pub enum DigestKind {
106 Poseidon2,
107 Sha256,
108}
109
110impl Prover {
111 pub fn new(program: Program, hashfn: &str) -> Self {
113 Self {
114 program,
115 hashfn: hashfn.to_string(),
116 input: VecDeque::new(),
117 }
118 }
119
120 pub fn add_input(&mut self, input: &[u32]) {
122 self.input.extend(input);
123 }
124
125 pub fn add_input_digest(&mut self, digest: &Digest, kind: DigestKind) {
127 match kind {
128 DigestKind::Poseidon2 => self.add_input(digest.as_words()),
130 DigestKind::Sha256 => self.add_input(bytemuck::cast_slice(
132 &digest
133 .as_words()
134 .iter()
135 .copied()
136 .flat_map(|x| [x & 0xffff, x >> 16])
137 .map(BabyBearElem::new)
138 .collect::<Vec<_>>(),
139 )),
140 }
141 }
142
143 pub fn run(&mut self) -> Result<RecursionReceipt> {
146 let prover = recursion_prover(&self.hashfn)?;
147 prover.prove(self.program.clone(), self.input.clone())
148 }
149}
150
151pub(crate) struct RecursionProverImpl<H, C>
152where
153 H: Hal<Field = BabyBear, Elem = BabyBearElem, ExtElem = BabyBearExtElem>,
154 C: CircuitHal<H> + CircuitWitnessGenerator<H>,
155{
156 hal: Rc<H>,
157 circuit_hal: Rc<C>,
158}
159
160impl<H, C> RecursionProver for RecursionProverImpl<H, C>
161where
162 H: Hal<Field = BabyBear, Elem = BabyBearElem, ExtElem = BabyBearExtElem>,
163 C: CircuitHal<H> + CircuitWitnessGenerator<H> + CircuitAccumulator<H>,
164{
165 fn prove(&self, program: Program, input: VecDeque<u32>) -> Result<RecursionReceipt> {
166 scope!("prove");
167
168 let preflight = self.preflight(&program, input)?;
169
170 let witgen = WitnessGenerator::new(
171 self.hal.as_ref(),
172 self.circuit_hal.as_ref(),
173 &program,
174 &preflight,
175 )?;
176
177 let global = &witgen.global;
178
179 let seal = scope!("prove", {
180 let mut prover = risc0_zkp::prove::Prover::new(self.hal.as_ref(), TAPSET);
181 let hashfn = &self.hal.get_hash_suite().hashfn;
182
183 let mix = scope!("main", {
184 prover
187 .iop()
188 .commit(&hashfn.hash_elem_slice(&PROOF_SYSTEM_INFO.encode()));
189 prover
190 .iop()
191 .commit(&hashfn.hash_elem_slice(&CircuitImpl::CIRCUIT_INFO.encode()));
192
193 let global_len = global.size();
195 let mut header = vec![BabyBearElem::ZERO; global_len + 1];
196 global.view_mut(|view| {
197 for (i, elem) in view.iter_mut().enumerate() {
198 *elem = elem.valid_or_zero();
199 header[i] = *elem;
200 }
201 header[global_len] = BabyBearElem::new_raw(program.po2 as u32);
202 });
203
204 let header_digest = hashfn.hash_elem_slice(&header);
205 prover.iop().commit(&header_digest);
206 prover.iop().write_field_elem_slice(header.as_slice());
207 prover.set_po2(program.po2);
208
209 prover.commit_group(REGISTER_GROUP_CTRL, &witgen.ctrl);
210 prover.commit_group(REGISTER_GROUP_DATA, &witgen.data);
211
212 let mix: [BabyBearElem; CircuitImpl::MIX_SIZE] =
214 std::array::from_fn(|_| prover.iop().random_elem());
215
216 let mix = witgen.accum(&self.hal, self.circuit_hal.as_ref(), &mix)?;
217
218 prover.commit_group(REGISTER_GROUP_ACCUM, &witgen.accum);
219
220 mix
221 });
222
223 prover.finalize(&[&mix, global], self.circuit_hal.as_ref())
224 });
225
226 Ok(RecursionReceipt {
227 seal,
228 output: preflight.output,
229 })
230 }
231}
232
233impl<H, C> RecursionProverImpl<H, C>
234where
235 H: Hal<Field = BabyBear, Elem = BabyBearElem, ExtElem = BabyBearExtElem>,
236 C: CircuitHal<H> + CircuitWitnessGenerator<H>,
237{
238 pub fn new(hal: Rc<H>, circuit_hal: Rc<C>) -> Self {
239 Self { hal, circuit_hal }
240 }
241
242 fn preflight(&self, program: &Program, input: VecDeque<u32>) -> Result<Preflight> {
243 scope!("preflight");
244
245 let mut preflight = Preflight::new(input);
246 for (cycle, row) in program.code_by_row().enumerate() {
247 preflight.step(cycle, row)?
248 }
249
250 Ok(preflight)
251 }
252}