risc0_circuit_recursion/prove/
mod.rs

1// Copyright 2024 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Prover implementation for the recursion VM.
16//!
17//! This module contains the recursion [Prover].
18
19mod exec;
20mod plonk;
21mod preflight;
22mod program;
23pub mod zkr;
24
25use std::{collections::VecDeque, fmt::Debug, mem::take, rc::Rc};
26
27use crate::{
28    cpu::CpuCircuitHal, CircuitImpl, CIRCUIT, REGISTER_GROUP_ACCUM, REGISTER_GROUP_CTRL,
29    REGISTER_GROUP_DATA,
30};
31use anyhow::{bail, Result};
32use rand::thread_rng;
33use risc0_core::scope;
34use risc0_zkp::{
35    adapter::{CircuitInfo, CircuitStepContext, TapsProvider, PROOF_SYSTEM_INFO},
36    core::{digest::Digest, hash::poseidon2::Poseidon2HashSuite},
37    field::{
38        baby_bear::{BabyBear, BabyBearElem, BabyBearExtElem},
39        Elem,
40    },
41    hal::{cpu::CpuHal, AccumPreflight, CircuitHal, Hal},
42    prove::adapter::ProveAdapter,
43    ZK_CYCLES,
44};
45use serde::{Deserialize, Serialize};
46
47use self::exec::RecursionExecutor;
48pub use self::program::Program;
49
50/// A pair of [Hal] and [CircuitHal].
51#[derive(Clone)]
52pub struct HalPair<H, C>
53where
54    H: Hal<Field = BabyBear, Elem = BabyBearElem, ExtElem = BabyBearExtElem>,
55    C: CircuitHal<H>,
56{
57    /// A [Hal] implementation.
58    pub hal: Rc<H>,
59
60    /// An [CircuitHal] implementation.
61    pub circuit_hal: Rc<C>,
62}
63
64// TODO: Automatically generate this constant from the circuit somehow without
65// messing up bootstrap dependencies.
66/// Size of the code group in the taps of the recursion circuit.
67const RECURSION_CODE_SIZE: usize = 23;
68
69#[derive(Clone, Debug, Serialize, Deserialize)]
70pub struct RecursionReceipt {
71    pub seal: Vec<u32>,
72    pub output: Vec<u32>,
73}
74
75impl RecursionReceipt {
76    /// Total number of bytes used by the seal of this receipt.
77    pub fn seal_size(&self) -> usize {
78        core::mem::size_of_val(self.seal.as_slice())
79    }
80}
81
82/// Prover for the recursion circuit.
83pub struct Prover {
84    program: Program,
85    hashfn: String,
86    input: VecDeque<u32>,
87    split_points: Vec<usize>,
88    output: Vec<u32>,
89}
90
91#[cfg(feature = "cuda")]
92mod cuda {
93    pub use crate::cuda::{CudaCircuitHalPoseidon2, CudaCircuitHalSha256};
94    pub use risc0_zkp::hal::cuda::{CudaHalPoseidon2, CudaHalSha256};
95
96    use super::{HalPair, Rc};
97
98    pub fn sha256_hal_pair() -> HalPair<CudaHalSha256, CudaCircuitHalSha256> {
99        let hal = Rc::new(CudaHalSha256::new());
100        let circuit_hal = Rc::new(CudaCircuitHalSha256::new(hal.clone()));
101        HalPair { hal, circuit_hal }
102    }
103
104    pub fn poseidon2_hal_pair() -> HalPair<CudaHalPoseidon2, CudaCircuitHalPoseidon2> {
105        let hal = Rc::new(CudaHalPoseidon2::new());
106        let circuit_hal = Rc::new(CudaCircuitHalPoseidon2::new(hal.clone()));
107        HalPair { hal, circuit_hal }
108    }
109}
110
111#[cfg(any(all(target_os = "macos", target_arch = "aarch64"), target_os = "ios"))]
112mod metal {
113    pub use crate::metal::MetalCircuitHal;
114    pub use risc0_zkp::hal::metal::{
115        MetalHalPoseidon2, MetalHalSha256, MetalHashPoseidon2, MetalHashSha256,
116    };
117
118    use super::{HalPair, Rc};
119
120    pub fn sha256_hal_pair() -> HalPair<MetalHalSha256, MetalCircuitHal<MetalHashSha256>> {
121        let hal = Rc::new(MetalHalSha256::new());
122        let circuit_hal = Rc::new(MetalCircuitHal::<MetalHashSha256>::new(hal.clone()));
123        HalPair { hal, circuit_hal }
124    }
125
126    pub fn poseidon2_hal_pair() -> HalPair<MetalHalPoseidon2, MetalCircuitHal<MetalHashPoseidon2>> {
127        let hal = Rc::new(MetalHalPoseidon2::new());
128        let circuit_hal = Rc::new(MetalCircuitHal::<MetalHashPoseidon2>::new(hal.clone()));
129        HalPair { hal, circuit_hal }
130    }
131}
132
133mod cpu {
134    use risc0_zkp::core::hash::{poseidon_254::Poseidon254HashSuite, sha::Sha256HashSuite};
135
136    use super::{
137        BabyBear, CircuitImpl, CpuCircuitHal, CpuHal, HalPair, Poseidon2HashSuite, Rc, CIRCUIT,
138    };
139
140    #[allow(dead_code)]
141    pub fn sha256_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>> {
142        let hal = Rc::new(CpuHal::new(Sha256HashSuite::new_suite()));
143        let circuit_hal = Rc::new(CpuCircuitHal::new(&CIRCUIT));
144        HalPair { hal, circuit_hal }
145    }
146
147    #[allow(dead_code)]
148    pub fn poseidon2_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>> {
149        let hal = Rc::new(CpuHal::new(Poseidon2HashSuite::new_suite()));
150        let circuit_hal = Rc::new(CpuCircuitHal::new(&CIRCUIT));
151        HalPair { hal, circuit_hal }
152    }
153
154    #[allow(dead_code)]
155    pub fn poseidon254_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>>
156    {
157        let hal = Rc::new(CpuHal::new(Poseidon254HashSuite::new_suite()));
158        let circuit_hal = Rc::new(CpuCircuitHal::new(&CIRCUIT));
159        HalPair { hal, circuit_hal }
160    }
161}
162
163cfg_if::cfg_if! {
164    if #[cfg(feature = "cuda")] {
165        /// TODO
166        #[allow(dead_code)]
167        pub fn sha256_hal_pair() -> HalPair<cuda::CudaHalSha256, cuda::CudaCircuitHalSha256> {
168            cuda::sha256_hal_pair()
169        }
170
171        /// TODO
172        #[allow(dead_code)]
173        pub fn poseidon2_hal_pair() -> HalPair<cuda::CudaHalPoseidon2, cuda::CudaCircuitHalPoseidon2> {
174            cuda::poseidon2_hal_pair()
175        }
176
177        /// TODO
178        #[allow(dead_code)]
179        pub fn poseidon254_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>> {
180            cpu::poseidon254_hal_pair()
181        }
182    } else if #[cfg(any(all(target_os = "macos", target_arch = "aarch64"), target_os = "ios"))] {
183        /// TODO
184        #[allow(dead_code)]
185        pub fn sha256_hal_pair() -> HalPair<metal::MetalHalSha256, metal::MetalCircuitHal<metal::MetalHashSha256>> {
186            metal::sha256_hal_pair()
187        }
188
189        /// TODO
190        #[allow(dead_code)]
191        pub fn poseidon2_hal_pair() -> HalPair<metal::MetalHalPoseidon2, metal::MetalCircuitHal<metal::MetalHashPoseidon2>> {
192            metal::poseidon2_hal_pair()
193        }
194
195        /// TODO
196        #[allow(dead_code)]
197        pub fn poseidon254_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>> {
198            cpu::poseidon254_hal_pair()
199        }
200    } else {
201        /// TODO
202        #[allow(dead_code)]
203        pub fn sha256_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>> {
204            cpu::sha256_hal_pair()
205        }
206
207        /// TODO
208        #[allow(dead_code)]
209        pub fn poseidon2_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>> {
210            cpu::poseidon2_hal_pair()
211        }
212
213        /// TODO
214        #[allow(dead_code)]
215        pub fn poseidon254_hal_pair() -> HalPair<CpuHal<BabyBear>, CpuCircuitHal<'static, CircuitImpl>> {
216            cpu::poseidon254_hal_pair()
217        }
218    }
219}
220
221/// Kinds of digests recognized by the recursion program language.
222// NOTE: Default is additionally a recognized type in the recursion program language. It's not
223// yet supported here because some of the code in this module assumes Poseidon2 is Default.
224#[non_exhaustive]
225pub enum DigestKind {
226    Poseidon2,
227    Sha256,
228}
229
230impl Prover {
231    /// Creates a new prover with the given recursion program.  This
232    /// is a low-level interface; for the zkVM, prefer to use
233    /// risc0_zkvm::host::recursion::prove::Prover.
234    pub fn new(program: Program, hashfn: &str) -> Self {
235        Self {
236            program,
237            hashfn: hashfn.to_string(),
238            input: VecDeque::new(),
239            split_points: Vec::new(),
240            output: Vec::new(),
241        }
242    }
243
244    /// Add a set of u32s to the input for the recursion program.
245    pub fn add_input(&mut self, input: &[u32]) {
246        self.input.extend(input);
247    }
248
249    /// Add a digest to the input for the recursion program.
250    pub fn add_input_digest(&mut self, digest: &Digest, kind: DigestKind) {
251        match kind {
252            // Poseidon2 digests consist of  BabyBear field elems and do not need to be split.
253            DigestKind::Poseidon2 => self.add_input(digest.as_words()),
254            // SHA-256 digests need to be split into 16-bit half words to avoid overflowing.
255            DigestKind::Sha256 => self.add_input(bytemuck::cast_slice(
256                &digest
257                    .as_words()
258                    .iter()
259                    .copied()
260                    .flat_map(|x| [x & 0xffff, x >> 16])
261                    .map(BabyBearElem::new)
262                    .collect::<Vec<_>>(),
263            )),
264        }
265    }
266
267    /// Run the prover, producing a receipt of execution for the recursion circuit over the loaded
268    /// program and input.
269    pub fn run(&mut self) -> Result<RecursionReceipt> {
270        // NOTE: Code is repeated across match arms to satisfy generics.
271        match self.hashfn.as_ref() {
272            "poseidon2" => {
273                let hal_pair = poseidon2_hal_pair();
274                let (hal, circuit_hal) = (hal_pair.hal.as_ref(), hal_pair.circuit_hal.as_ref());
275                self.run_with_hal(hal, circuit_hal)
276            }
277            "poseidon_254" => {
278                let hal_pair = poseidon254_hal_pair();
279                let (hal, circuit_hal) = (hal_pair.hal.as_ref(), hal_pair.circuit_hal.as_ref());
280                self.run_with_hal(hal, circuit_hal)
281            }
282            "sha-256" => {
283                let hal_pair = sha256_hal_pair();
284                let (hal, circuit_hal) = (hal_pair.hal.as_ref(), hal_pair.circuit_hal.as_ref());
285                self.run_with_hal(hal, circuit_hal)
286            }
287            _ => bail!("no hal found for {}", self.hashfn),
288        }
289    }
290
291    /// Run the prover, producing a receipt of execution for the recursion circuit over the loaded
292    /// program and input, using the specified HAL.
293    pub fn run_with_hal<H, C>(&mut self, hal: &H, circuit_hal: &C) -> Result<RecursionReceipt>
294    where
295        H: Hal<Field = BabyBear, Elem = BabyBearElem, ExtElem = BabyBearExtElem>,
296        C: CircuitHal<H>,
297    {
298        scope!("run_with_hal");
299
300        let machine_ctx = self.preflight()?;
301
302        let split_points = core::mem::take(&mut self.split_points);
303
304        let mut executor = scope!("witgen", {
305            let mut executor =
306                RecursionExecutor::new(&CIRCUIT, &self.program, machine_ctx, split_points);
307            executor.run()?;
308            Result::<RecursionExecutor, anyhow::Error>::Ok(executor)
309        })?;
310
311        let seal = scope!("prove", {
312            let mut adapter = ProveAdapter::new(&mut executor.executor);
313            let mut prover = risc0_zkp::prove::Prover::new(hal, CIRCUIT.get_taps());
314            let hashfn = Rc::clone(&hal.get_hash_suite().hashfn);
315
316            let (mix, io) = scope!("main", {
317                // At the start of the protocol, seed the Fiat-Shamir transcript with context information
318                // about the proof system and circuit.
319                prover
320                    .iop()
321                    .commit(&hashfn.hash_elem_slice(&PROOF_SYSTEM_INFO.encode()));
322                prover
323                    .iop()
324                    .commit(&hashfn.hash_elem_slice(&CircuitImpl::CIRCUIT_INFO.encode()));
325
326                adapter.execute(prover.iop(), hal);
327
328                prover.set_po2(adapter.po2() as usize);
329
330                let ctrl = hal.copy_from_elem("ctrl", &adapter.get_code().as_slice());
331                prover.commit_group(REGISTER_GROUP_CTRL, &ctrl);
332
333                let data = hal.copy_from_elem("data", &adapter.get_data().as_slice());
334                prover.commit_group(REGISTER_GROUP_DATA, &data);
335
336                // Make the mixing values
337                let mix = scope!("alloc+copy(mix)", {
338                    let mix: Vec<_> = (0..CircuitImpl::MIX_SIZE)
339                        .map(|_| prover.iop().random_elem())
340                        .collect();
341                    hal.copy_from_elem("mix", mix.as_slice())
342                });
343
344                let steps = adapter.get_steps();
345
346                let accum = scope!(
347                    "alloc(accum)",
348                    hal.alloc_elem_init(
349                        "accum",
350                        steps * CIRCUIT.accum_size(),
351                        BabyBearElem::INVALID,
352                    )
353                );
354
355                // Add random noise to end of accum
356                scope!("noise(accum)", {
357                    let mut rng = thread_rng();
358                    let noise =
359                        vec![BabyBearElem::random(&mut rng); ZK_CYCLES * CIRCUIT.accum_size()];
360                    hal.eltwise_copy_elem_slice(
361                        &accum,
362                        &noise,
363                        CIRCUIT.accum_size(), // from_rows
364                        ZK_CYCLES,            // from_cols
365                        0,                    // from_offset
366                        ZK_CYCLES,            // from_stride
367                        steps - ZK_CYCLES,    // into_offset
368                        steps,                // into_stride
369                    );
370                });
371
372                let io = scope!(
373                    "copy(io)",
374                    hal.copy_from_elem("io", &adapter.get_io().as_slice())
375                );
376
377                // The recursion circuit doesn't make use of the preflight.
378                let preflight = AccumPreflight::default();
379                circuit_hal.accumulate(&preflight, &ctrl, &io, &data, &mix, &accum, steps);
380
381                prover.commit_group(REGISTER_GROUP_ACCUM, &accum);
382
383                (mix, io)
384            });
385
386            prover.finalize(&[&mix, &io], circuit_hal)
387        });
388
389        Ok(RecursionReceipt {
390            seal,
391            output: self.output.clone(),
392        })
393    }
394
395    fn preflight(&mut self) -> Result<exec::MachineContext> {
396        scope!("preflight");
397
398        let mut machine = exec::MachineContext::new(take(&mut self.input));
399        let mut preflight = preflight::Preflight::new(&mut machine);
400        let size = (1 << self.program.po2) - ZK_CYCLES;
401
402        for (cycle, row) in self.program.code_by_row().enumerate() {
403            let ctx = CircuitStepContext { cycle, size };
404
405            preflight.set_top(&ctx, row)?
406        }
407
408        // TODO: is this necessary?
409        let zero_row = vec![BabyBearElem::ZERO; self.program.code_size];
410        for cycle in self.program.code_rows()..size {
411            let ctx = CircuitStepContext { cycle, size };
412
413            preflight.set_top(&ctx, &zero_row)?
414        }
415
416        self.split_points = preflight.split_points;
417        self.split_points.push(size);
418        self.output = preflight.output;
419        (machine.iop_reads, machine.byte_reads) = (preflight.iop_reads, preflight.byte_reads);
420        Ok(machine)
421    }
422}