risc0_circuit_recursion/prove/
mod.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Prover implementation for the recursion VM.
16//!
17//! This module contains the recursion [Prover].
18
19mod hal;
20mod preflight;
21mod program;
22mod witgen;
23pub mod zkr;
24
25use std::{collections::VecDeque, fmt::Debug, rc::Rc};
26
27use anyhow::Result;
28use cfg_if::cfg_if;
29use risc0_core::scope;
30use risc0_zkp::{
31    adapter::{CircuitInfo, PROOF_SYSTEM_INFO},
32    core::digest::Digest,
33    field::{
34        baby_bear::{BabyBear, BabyBearElem, BabyBearExtElem},
35        Elem as _,
36    },
37    hal::{Buffer, CircuitHal, Hal},
38};
39use serde::{Deserialize, Serialize};
40
41use self::{
42    hal::{CircuitAccumulator, CircuitWitnessGenerator},
43    preflight::Preflight,
44    witgen::WitnessGenerator,
45};
46use crate::{
47    taps::TAPSET, CircuitImpl, REGISTER_GROUP_ACCUM, REGISTER_GROUP_CTRL, REGISTER_GROUP_DATA,
48};
49
50pub use self::program::Program;
51
52// TODO: Automatically generate this constant from the circuit somehow without
53// messing up bootstrap dependencies.
54/// Size of the code group in the taps of the recursion circuit.
55const RECURSION_CODE_SIZE: usize = 23;
56
57#[derive(Clone, Debug, Serialize, Deserialize)]
58#[non_exhaustive]
59pub struct RecursionReceipt {
60    pub seal: Vec<u32>,
61    pub output: Vec<u32>,
62}
63
64impl RecursionReceipt {
65    /// Total number of bytes used by the seal of this receipt.
66    pub fn seal_size(&self) -> usize {
67        core::mem::size_of_val(self.seal.as_slice())
68    }
69
70    /// Allocates a [VecDeque] and copies the output stream into it for decoding.
71    pub fn out_stream(&self) -> VecDeque<u32> {
72        let mut vec: VecDeque<u32> = VecDeque::new();
73        vec.extend(self.output.iter());
74        vec
75    }
76}
77
78pub trait RecursionProver {
79    fn prove(&self, program: Program, input: VecDeque<u32>) -> Result<RecursionReceipt>;
80}
81
82pub fn recursion_prover(hashfn: &str) -> Result<Box<dyn RecursionProver>> {
83    cfg_if! {
84        if #[cfg(feature = "cuda")] {
85            self::hal::cuda::recursion_prover(hashfn)
86        // } else if #[cfg(any(all(target_os = "macos", target_arch = "aarch64"), target_os = "ios"))] {
87        // self::hal::metal::recursion_prover(hashfn)
88        } else {
89            self::hal::cpu::recursion_prover(hashfn)
90        }
91    }
92}
93
94/// Prover for the recursion circuit.
95pub struct Prover {
96    program: Program,
97    hashfn: String,
98    input: VecDeque<u32>,
99}
100
101/// Kinds of digests recognized by the recursion program language.
102// NOTE: Default is additionally a recognized type in the recursion program language. It's not
103// yet supported here because some of the code in this module assumes Poseidon2 is Default.
104#[non_exhaustive]
105pub enum DigestKind {
106    Poseidon2,
107    Sha256,
108}
109
110impl Prover {
111    /// Creates a new prover with the given recursion program.
112    pub fn new(program: Program, hashfn: &str) -> Self {
113        Self {
114            program,
115            hashfn: hashfn.to_string(),
116            input: VecDeque::new(),
117        }
118    }
119
120    /// Add a set of u32s to the input for the recursion program.
121    pub fn add_input(&mut self, input: &[u32]) {
122        self.input.extend(input);
123    }
124
125    /// Add a digest to the input for the recursion program.
126    pub fn add_input_digest(&mut self, digest: &Digest, kind: DigestKind) {
127        match kind {
128            // Poseidon2 digests consist of  BabyBear field elems and do not need to be split.
129            DigestKind::Poseidon2 => self.add_input(digest.as_words()),
130            // SHA-256 digests need to be split into 16-bit half words to avoid overflowing.
131            DigestKind::Sha256 => self.add_input(bytemuck::cast_slice(
132                &digest
133                    .as_words()
134                    .iter()
135                    .copied()
136                    .flat_map(|x| [x & 0xffff, x >> 16])
137                    .map(BabyBearElem::new)
138                    .collect::<Vec<_>>(),
139            )),
140        }
141    }
142
143    /// Run the prover, producing a receipt of execution for the recursion circuit over the loaded
144    /// program and input.
145    pub fn run(&mut self) -> Result<RecursionReceipt> {
146        let prover = recursion_prover(&self.hashfn)?;
147        prover.prove(self.program.clone(), self.input.clone())
148    }
149}
150
151pub(crate) struct RecursionProverImpl<H, C>
152where
153    H: Hal<Field = BabyBear, Elem = BabyBearElem, ExtElem = BabyBearExtElem>,
154    C: CircuitHal<H> + CircuitWitnessGenerator<H>,
155{
156    hal: Rc<H>,
157    circuit_hal: Rc<C>,
158}
159
160impl<H, C> RecursionProver for RecursionProverImpl<H, C>
161where
162    H: Hal<Field = BabyBear, Elem = BabyBearElem, ExtElem = BabyBearExtElem>,
163    C: CircuitHal<H> + CircuitWitnessGenerator<H> + CircuitAccumulator<H>,
164{
165    fn prove(&self, program: Program, input: VecDeque<u32>) -> Result<RecursionReceipt> {
166        scope!("prove");
167
168        let preflight = self.preflight(&program, input)?;
169
170        let witgen = WitnessGenerator::new(
171            self.hal.as_ref(),
172            self.circuit_hal.as_ref(),
173            &program,
174            &preflight,
175        )?;
176
177        let global = &witgen.global;
178
179        let seal = scope!("prove", {
180            let mut prover = risc0_zkp::prove::Prover::new(self.hal.as_ref(), TAPSET);
181            let hashfn = &self.hal.get_hash_suite().hashfn;
182
183            let mix = scope!("main", {
184                // At the start of the protocol, seed the Fiat-Shamir transcript with context information
185                // about the proof system and circuit.
186                prover
187                    .iop()
188                    .commit(&hashfn.hash_elem_slice(&PROOF_SYSTEM_INFO.encode()));
189                prover
190                    .iop()
191                    .commit(&hashfn.hash_elem_slice(&CircuitImpl::CIRCUIT_INFO.encode()));
192
193                // Concat globals and po2 into a vector.
194                let global_len = global.size();
195                let mut header = vec![BabyBearElem::ZERO; global_len + 1];
196                global.view_mut(|view| {
197                    for (i, elem) in view.iter_mut().enumerate() {
198                        *elem = elem.valid_or_zero();
199                        header[i] = *elem;
200                    }
201                    header[global_len] = BabyBearElem::new_raw(program.po2 as u32);
202                });
203
204                let header_digest = hashfn.hash_elem_slice(&header);
205                prover.iop().commit(&header_digest);
206                prover.iop().write_field_elem_slice(header.as_slice());
207                prover.set_po2(program.po2);
208
209                prover.commit_group(REGISTER_GROUP_CTRL, &witgen.ctrl);
210                prover.commit_group(REGISTER_GROUP_DATA, &witgen.data);
211
212                // Make the mixing values
213                let mix: [BabyBearElem; CircuitImpl::MIX_SIZE] =
214                    std::array::from_fn(|_| prover.iop().random_elem());
215
216                let mix = witgen.accum(&self.hal, self.circuit_hal.as_ref(), &mix)?;
217
218                prover.commit_group(REGISTER_GROUP_ACCUM, &witgen.accum);
219
220                mix
221            });
222
223            prover.finalize(&[&mix, global], self.circuit_hal.as_ref())
224        });
225
226        Ok(RecursionReceipt {
227            seal,
228            output: preflight.output,
229        })
230    }
231}
232
233impl<H, C> RecursionProverImpl<H, C>
234where
235    H: Hal<Field = BabyBear, Elem = BabyBearElem, ExtElem = BabyBearExtElem>,
236    C: CircuitHal<H> + CircuitWitnessGenerator<H>,
237{
238    pub fn new(hal: Rc<H>, circuit_hal: Rc<C>) -> Self {
239        Self { hal, circuit_hal }
240    }
241
242    fn preflight(&self, program: &Program, input: VecDeque<u32>) -> Result<Preflight> {
243        scope!("preflight");
244
245        let mut preflight = Preflight::new(input);
246        for (cycle, row) in program.code_by_row().enumerate() {
247            preflight.step(cycle, row)?
248        }
249
250        Ok(preflight)
251    }
252}