risc0_circuit_recursion/prove/
program.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::Result;
16
17use risc0_zkp::{
18    core::{digest::Digest, hash::HashSuite},
19    field::baby_bear::{BabyBear, BabyBearElem},
20    hal::{self, Hal},
21    prove::poly_group::PolyGroup,
22    ZK_CYCLES,
23};
24
25use super::RECURSION_CODE_SIZE;
26
27/// A Program for the recursion circuit (e.g. lift_20 or join).
28///
29/// The recursion circuit is an application specific virtual machine with a limited instruction
30/// set, no control flow operations, and a write-once memory tape. Although it is not general
31/// purpose, it can load and execute a program, similar to the rv32im zkVM.
32///
33/// Programs for the recursion circuit are loaded into the control columns, which is a set of
34/// public columns in the witness. Programs are therefore identified by their control ID, which is
35/// similar but not the same as the image ID used to identify rv32im programs.
36#[derive(Clone)]
37pub struct Program {
38    /// The code of the program, encoded as Baby Bear field elements.
39    pub code: Vec<BabyBearElem>,
40
41    /// The number of code columns.
42    pub code_size: usize,
43
44    /// 1 << po2 is the number of cycles executed.
45    pub po2: usize,
46}
47
48impl Program {
49    /// Create a [Program] from a stream of data encoded by Zirgen.
50    pub fn from_encoded(encoded: &[u32], po2: usize) -> Self {
51        let prog = Self {
52            code: encoded.iter().copied().map(BabyBearElem::from).collect(),
53            code_size: RECURSION_CODE_SIZE,
54            po2,
55        };
56        assert_eq!(prog.code.len() % RECURSION_CODE_SIZE, 0);
57        assert!(prog.code.len() <= (RECURSION_CODE_SIZE * ((1 << po2) - ZK_CYCLES)));
58        prog
59    }
60
61    /// Total number of rows in the code group for this program.
62    pub fn code_rows(&self) -> usize {
63        self.code.len() / self.code_size
64    }
65
66    /// An iterator over the rows of the code group.
67    pub fn code_by_row(&self) -> impl Iterator<Item = &[BabyBearElem]> {
68        self.code.as_slice().chunks(self.code_size)
69    }
70
71    /// Given a [Program] for the recursion circuit, compute the control ID as the FRI Merkle root
72    /// of the code group. This uniquely identifies the program running on the recursion circuit
73    /// (e.g. lift_20 or join)
74    pub fn compute_control_id(&self, hash_suite: HashSuite<BabyBear>) -> Result<Digest> {
75        #[cfg(feature = "cuda")]
76        let digest =
77            self.compute_control_id_inner(&hal::cuda::CudaHal::new_from_hash_suite(hash_suite)?);
78
79        #[cfg(not(feature = "cuda"))]
80        let digest = self.compute_control_id_inner(&hal::cpu::CpuHal::new(hash_suite));
81
82        Ok(digest)
83    }
84
85    fn compute_control_id_inner(&self, hal: &impl Hal<Elem = BabyBearElem>) -> Digest {
86        let cycles = 1 << self.po2;
87
88        let mut code = vec![BabyBearElem::default(); cycles * self.code_size];
89
90        for (cycle, row) in self.code_by_row().enumerate() {
91            for (i, elem) in row.iter().enumerate() {
92                code[cycles * i + cycle] = *elem;
93            }
94        }
95        let coeffs = hal.copy_from_elem("coeffs", &code);
96        // Do interpolate & shift
97        hal.batch_interpolate_ntt(&coeffs, self.code_size);
98        hal.zk_shift(&coeffs, self.code_size);
99        // Make the poly-group & extract the root
100        let code_group = PolyGroup::new(hal, coeffs, self.code_size, cycles, "code");
101        let root = *code_group.merkle.root();
102        tracing::trace!("Computed recursion code: {root:?}");
103        root
104    }
105}