Skip to main content

provekit_prover/
lib.rs

1#[cfg(test)]
2use crate::r1cs::R1CSSolver;
3use {
4    crate::{
5        r1cs::{CompressedLayers, CompressedR1CS},
6        whir_r1cs::WhirR1CSProver,
7    },
8    acir::native_types::{Witness, WitnessMap},
9    anyhow::{Context, Result},
10    provekit_common::{
11        utils::noir_to_native, FieldElement, NoirElement, NoirProof, Prover, PublicInputs,
12        TranscriptSponge,
13    },
14    std::mem::size_of,
15    tracing::{debug, info_span, instrument},
16    whir::transcript::ProverState,
17};
18#[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))]
19use {
20    bn254_blackbox_solver::Bn254BlackBoxSolver,
21    nargo::foreign_calls::DefaultForeignCallBuilder,
22    noir_artifact_cli::fs::inputs::read_inputs_from_file,
23    provekit_common::{Format, InputMap},
24    std::path::Path,
25};
26
27mod r1cs;
28mod whir_r1cs;
29mod witness;
30
31pub trait Prove {
32    #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))]
33    fn generate_witness(&mut self, input_map: InputMap) -> Result<WitnessMap<NoirElement>>;
34
35    #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))]
36    fn prove(self, input_map: InputMap) -> Result<NoirProof>;
37
38    #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))]
39    fn prove_with_toml(self, prover_toml: impl AsRef<Path>) -> Result<NoirProof>;
40
41    #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))]
42    fn prove_with_inputs(self, inputs: &str, format: Format) -> Result<NoirProof>;
43
44    fn prove_with_witness(self, witness: WitnessMap<NoirElement>) -> Result<NoirProof>;
45}
46
47impl Prove for Prover {
48    #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))]
49    #[instrument(skip_all)]
50    fn generate_witness(&mut self, input_map: InputMap) -> Result<WitnessMap<NoirElement>> {
51        let solver = Bn254BlackBoxSolver::default();
52        let mut output_buffer = Vec::new();
53        let mut foreign_call_executor = DefaultForeignCallBuilder {
54            output:       &mut output_buffer,
55            enable_mocks: false,
56            resolver_url: None,
57            root_path:    None,
58            package_name: None,
59        }
60        .build();
61
62        let initial_witness = self.witness_generator.abi().encode(&input_map, None)?;
63
64        let mut witness_stack = nargo::ops::execute_program(
65            &self.program,
66            initial_witness,
67            &solver,
68            &mut foreign_call_executor,
69        )?;
70
71        Ok(witness_stack
72            .pop()
73            .context("Missing witness results")?
74            .witness)
75    }
76
77    #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))]
78    #[instrument(skip_all)]
79    fn prove(mut self, input_map: InputMap) -> Result<NoirProof> {
80        let witness = self.generate_witness(input_map)?;
81        self.prove_with_witness(witness)
82    }
83
84    #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))]
85    #[instrument(skip_all)]
86    fn prove_with_toml(self, prover_toml: impl AsRef<Path>) -> Result<NoirProof> {
87        let (input_map, _expected_return) =
88            read_inputs_from_file(prover_toml.as_ref(), self.witness_generator.abi())?;
89        self.prove(input_map)
90    }
91
92    #[cfg(all(feature = "witness-generation", not(target_arch = "wasm32")))]
93    #[instrument(skip_all)]
94    fn prove_with_inputs(self, inputs: &str, format: Format) -> Result<NoirProof> {
95        let input_map = format
96            .parse(inputs, self.witness_generator.abi())
97            .with_context(|| format!("While parsing {} inputs", format.ext()))?;
98        self.prove(input_map)
99    }
100
101    #[instrument(skip_all)]
102    fn prove_with_witness(
103        self,
104        acir_witness_idx_to_value_map: WitnessMap<NoirElement>,
105    ) -> Result<NoirProof> {
106        provekit_common::register_ntt();
107
108        let mut public_input_indices = self.program.functions[0].public_inputs().indices();
109        public_input_indices.sort_unstable();
110        let public_inputs = PublicInputs::from_vec(
111            public_input_indices
112                .iter()
113                .map(|&idx| {
114                    acir_witness_idx_to_value_map
115                        .get(&Witness::from(idx))
116                        .map(|v| noir_to_native(*v))
117                        .ok_or_else(|| anyhow::anyhow!("Missing public input at index {idx}"))
118                })
119                .collect::<Result<Vec<_>>>()?,
120        );
121
122        drop(self.program);
123        drop(self.witness_generator);
124
125        // R1CS matrices are only needed at sumcheck; compress to free memory during
126        // commits.
127        let compressed_r1cs =
128            CompressedR1CS::compress(self.r1cs).context("While compressing R1CS")?;
129        let num_witnesses = compressed_r1cs.num_witnesses();
130        let num_constraints = compressed_r1cs.num_constraints();
131
132        // Set up transcript with public inputs bound to the instance.
133        let instance = public_inputs.hash_bytes();
134        let ds = self
135            .whir_for_witness
136            .create_domain_separator()
137            .instance(&instance);
138        let mut merlin = ProverState::new(&ds, TranscriptSponge::default());
139
140        let mut witness: Vec<Option<FieldElement>> = vec![None; num_witnesses];
141
142        // Solve w1 (or all witnesses if no challenges).
143        // Outer span captures memory AFTER w1_layers parameter is freed
144        // (parameter drop happens before outer span close).
145        {
146            let _s = info_span!("solve_w1").entered();
147            crate::r1cs::solve_witness_vec(
148                &mut witness,
149                self.split_witness_builders.w1_layers,
150                &acir_witness_idx_to_value_map,
151                &mut merlin,
152            )
153            .context("While solving w1 witnesses")?;
154        }
155
156        // Compress w2 layers to free memory during w1 commit (only when
157        // challenges exist; otherwise just drop them).
158        let has_challenges = self.whir_for_witness.num_challenges > 0;
159        let compressed_w2_layers = if has_challenges {
160            Some(
161                CompressedLayers::compress(self.split_witness_builders.w2_layers)
162                    .context("While compressing w2 layers")?,
163            )
164        } else {
165            drop(self.split_witness_builders.w2_layers);
166            None
167        };
168
169        debug!(
170            witness_heap_bytes = witness.capacity() * size_of::<Option<FieldElement>>(),
171            compressed_r1cs_blob_bytes = compressed_r1cs.blob_len(),
172            "component sizes after solve_w1"
173        );
174
175        // Verify that ACIR-derived public inputs match the R1CS witness layout.
176        debug_assert!(
177            {
178                let n = public_inputs.0.len();
179                n == 0
180                    || witness[1..=n]
181                        .iter()
182                        .zip(public_inputs.0.iter())
183                        .all(|(w, pi)| w.map_or(false, |v| v == *pi))
184            },
185            "Public inputs from ACIR witness map do not match witness[1..=N]"
186        );
187
188        let w1 = {
189            let _s = info_span!("allocate_w1").entered();
190            witness[..self.whir_for_witness.w1_size]
191                .iter()
192                .map(|w| w.ok_or_else(|| anyhow::anyhow!("Some witnesses in w1 are missing")))
193                .collect::<Result<Vec<_>>>()?
194        };
195
196        let commitment_1 = self
197            .whir_for_witness
198            .commit(&mut merlin, num_witnesses, num_constraints, w1, true)
199            .context("While committing to w1")?;
200
201        let commitments = if has_challenges {
202            let w2_layers = compressed_w2_layers
203                .unwrap()
204                .decompress()
205                .context("While decompressing w2 layers")?;
206            {
207                let _s = info_span!("solve_w2").entered();
208                crate::r1cs::solve_witness_vec(
209                    &mut witness,
210                    w2_layers,
211                    &acir_witness_idx_to_value_map,
212                    &mut merlin,
213                )
214                .context("While solving w2 witnesses")?;
215            }
216            drop(acir_witness_idx_to_value_map);
217
218            let w2 = {
219                let _s = info_span!("allocate_w2").entered();
220                witness[self.whir_for_witness.w1_size..]
221                    .iter()
222                    .map(|w| w.ok_or_else(|| anyhow::anyhow!("Some witnesses in w2 are missing")))
223                    .collect::<Result<Vec<_>>>()?
224            };
225
226            let commitment_2 = self
227                .whir_for_witness
228                .commit(&mut merlin, num_witnesses, num_constraints, w2, false)
229                .context("While committing to w2")?;
230
231            vec![commitment_1, commitment_2]
232        } else {
233            drop(acir_witness_idx_to_value_map);
234            vec![commitment_1]
235        };
236
237        // Decompress R1CS for the sumcheck and matrix operations.
238        let r1cs = compressed_r1cs
239            .decompress()
240            .context("While decompressing R1CS")?;
241
242        #[cfg(test)]
243        r1cs.test_witness_satisfaction(&witness.iter().map(|w| w.unwrap()).collect::<Vec<_>>())
244            .context("While verifying R1CS instance")?;
245
246        let full_witness: Vec<FieldElement> = witness
247            .into_iter()
248            .enumerate()
249            .map(|(i, w)| w.ok_or_else(|| anyhow::anyhow!("Witness {i} unsolved after solving")))
250            .collect::<Result<Vec<_>>>()?;
251
252        let whir_r1cs_proof = self
253            .whir_for_witness
254            .prove(merlin, r1cs, commitments, full_witness, &public_inputs)
255            .context("While proving R1CS instance")?;
256
257        Ok(NoirProof {
258            public_inputs,
259            whir_r1cs_proof,
260        })
261    }
262}