Skip to main content

miden_lifted_stark/prover/
mod.rs

1//! Lifted STARK prover.
2//!
3//! This module provides:
4//! - [`prove_single`]: Prove a single AIR instance.
5//! - [`prove_multi`]: Prove multiple AIR instances with traces of different heights.
6//!
7//! These functions write the proof into a [`miden_stark_transcript::ProverChannel`]
8//! (commitments, grinding witnesses, and openings).
9//!
10//! # Fiat-Shamir / transcript binding (initial challenger state)
11//!
12//! This crate does **not** prescribe the *initial* transcript state. The caller
13//! must bind the full statement into the Fiat-Shamir challenger before calling
14//! [`prove_multi`]. Both prover and verifier must produce identical challenger
15//! states. Concretely, the caller **MUST** observe:
16//!
17//! 1. **Protocol parameters** — e.g. the STARK configuration, blowup factor, and any
18//!    application-level domain separator.
19//!
20//! 2. **Public values and variable-length inputs** — `public_values` and `var_len_public_inputs`
21//!    for every instance. Without this, Fiat-Shamir challenges are independent of the statement.
22//!
23//! 3. **AIR configurations and `air_order`** — The proof defines an ordering of AIR instances
24//!    (`air_order()[j]` is the caller's original index at proof position `j`), queryable via
25//!    [`InstanceShapes::air_order`]. The ordering is deterministic: instances are sorted by
26//!    `(log_trace_height, caller_index)`. Neither the AIR configurations nor `air_order` are
27//!    absorbed into the transcript, so the caller must bind both into the challenger. How this is
28//!    done is up to the caller — see the examples below. The prover can precompute `air_order` via
29//!    [`InstanceShapes::from_trace_heights`]; the verifier reads it from the proof.
30//!
31//! ## Recommended pattern
32//!
33//! Pre-seed the challenger so statement data stays out of the proof:
34//!
35//! ```ignore
36//! // --- Bind statement into Fiat-Shamir ---
37//! let mut ch = Challenger::new(perm.clone());
38//! ch.observe_slice(&b"MY_APP_V1".map(|b| F::from_u8(b)));  // domain separator
39//! ch.observe(F::from_u8(config.pcs().log_blowup()));        // protocol parameters
40//! // ... observe remaining protocol parameters ...
41//! ch.observe_slice(&public_values);
42//! for vl in &var_len_public_inputs {
43//!     ch.observe_slice(vl);
44//! }
45//! // For multi-AIR: bind AIR configurations and air_order (see below).
46//!
47//! // --- Prove ---
48//! let output = prove_multi(&config, &instances, ch)?;
49//!
50//! // --- Verify (identical binding) ---
51//! let mut ch = Challenger::new(perm);
52//! ch.observe_slice(&b"MY_APP_V1".map(|b| F::from_u8(b)));
53//! ch.observe(F::from_u8(config.pcs().log_blowup()));
54//! // ... observe remaining protocol parameters ...
55//! ch.observe_slice(&public_values);
56//! for vl in &var_len_public_inputs {
57//!     ch.observe_slice(vl);
58//! }
59//! let verifier_digest = verify_multi(&config, &verifier_instances, &output.proof, ch)?;
60//! assert_eq!(output.digest, verifier_digest);
61//! ```
62//!
63//! ## Multi-AIR binding examples
64//!
65//! ```text
66//! // Prover: precompute air_order before building the challenger.
67//! let shapes = InstanceShapes::from_trace_heights(trace_heights)?;
68//! let air_order = shapes.air_order();
69//!
70//! // Verifier: read air_order from the proof.
71//! let air_order = proof.air_order();
72//!
73//! // Option A: reorder AIRs to proof order and commit — the ordering is
74//! // implicit in the commitment.
75//! let ordered_airs: Vec<_> = air_order.iter().map(|&idx| &airs[idx as usize]).collect();
76//! let circuit = Circuit::from_airs(&ordered_airs);
77//! challenger.observe(circuit.commitment());
78//!
79//! // Option B: commit to AIRs in their natural order, then observe
80//! // air_order to bind the ordering explicitly.
81//! for air in &airs {
82//!     challenger.observe(air.commitment());
83//! }
84//! challenger.observe_slice(air_order);
85//! ```
86
87extern crate alloc;
88
89pub mod commit;
90pub mod constraints;
91pub mod periodic;
92pub mod quotient;
93
94use alloc::{vec, vec::Vec};
95
96use commit::commit_traces;
97use constraints::{evaluate_constraints_into, layout::get_constraint_layout};
98use miden_lifted_air::{AuxBuilder, LiftedAir, VarLenPublicInputs, log2_strict_u8};
99use miden_stark_transcript::{Channel, ProverChannel, ProverTranscript};
100use p3_field::{BasedVectorSpace, ExtensionField, TwoAdicField};
101use p3_matrix::{Matrix, dense::RowMajorMatrix};
102use periodic::PeriodicLde;
103use thiserror::Error;
104use tracing::{info_span, instrument};
105
106use crate::{
107    StarkConfig,
108    coset::LiftedCoset,
109    instance::{AirWitness, InstanceShapes, InstanceValidationError, validate_inputs},
110    pcs::prover::open_with_channel,
111    proof::{StarkOutput, StarkProof},
112};
113
114/// Errors that can occur during proving.
115#[derive(Debug, Error)]
116pub enum ProverError {
117    #[error("instance validation failed: {0}")]
118    Instance(#[from] InstanceValidationError),
119    #[error(
120        "constraint degree exceeds blowup: \
121         log_quotient_degree {log_quotient_degree} > log_blowup {log_blowup}"
122    )]
123    ConstraintDegreeTooHigh { log_quotient_degree: u8, log_blowup: u8 },
124}
125
126/// Prove a single AIR.
127///
128/// The caller's challenger must already be bound to the full statement
129/// (protocol parameters, AIR configuration, public values, and
130/// variable-length inputs) — see the module-level docs.
131///
132/// This is a convenience wrapper around [`prove_multi`] for the single-AIR case.
133///
134/// # Returns
135/// `Ok(StarkOutput { digest, proof })` on success, or a `ProverError` if validation fails.
136pub fn prove_single<F, EF, A, B, SC>(
137    config: &SC,
138    air: &A,
139    trace: &RowMajorMatrix<F>,
140    public_values: &[F],
141    var_len_public_inputs: VarLenPublicInputs<'_, F>,
142    aux_builder: &B,
143    challenger: SC::Challenger,
144) -> Result<StarkOutput<F, EF, SC>, ProverError>
145where
146    F: TwoAdicField,
147    EF: ExtensionField<F>,
148    SC: StarkConfig<F, EF>,
149    A: LiftedAir<F, EF>,
150    B: AuxBuilder<F, EF>,
151{
152    let witness = AirWitness::new(trace, public_values, var_len_public_inputs);
153    prove_multi(config, &[(air, witness, aux_builder)], challenger)
154}
155
156/// Prove multiple AIRs with traces of different heights.
157///
158/// The caller's challenger must already be bound to the full statement
159/// (protocol parameters, AIR configurations, AIR ordering, and public
160/// inputs — both fixed and variable-length) — see the module-level docs.
161///
162/// # Arguments
163/// - `config`: STARK configuration (PCS params, LMCS, DFT)
164/// - `instances`: Pairs of (AIR, witness, aux_builder)
165/// - `challenger`: Fiat-Shamir challenger (heights are observed before use)
166///
167/// # Returns
168/// `Ok(StarkOutput { digest, proof })` on success, or a `ProverError` if validation fails.
169#[instrument(name = "prove", skip_all)]
170pub fn prove_multi<F, EF, A, B, SC>(
171    config: &SC,
172    instances: &[(&A, AirWitness<'_, F>, &B)],
173    mut challenger: SC::Challenger,
174) -> Result<StarkOutput<F, EF, SC>, ProverError>
175where
176    F: TwoAdicField,
177    EF: ExtensionField<F>,
178    SC: StarkConfig<F, EF>,
179    A: LiftedAir<F, EF>,
180    B: AuxBuilder<F, EF>,
181{
182    let trace_heights: Vec<usize> = instances.iter().map(|(_, w, _)| w.trace.height()).collect();
183    let instance_shapes = InstanceShapes::from_trace_heights(trace_heights)?;
184
185    // Reorder instances to the proof's AIR ordering.
186    let instances = instance_shapes.reorder(instances.to_vec())?;
187
188    let verifier_instances: Vec<_> =
189        instances.iter().map(|(air, w, _)| (*air, w.to_instance())).collect();
190
191    let log_blowup = config.pcs().log_blowup();
192
193    // Validate AIR structure, instance dimensions, heights, and trace widths.
194    let log_max_trace_height = validate_inputs(&verifier_instances, &instance_shapes, log_blowup)?;
195    for &(air, w, _) in &instances {
196        if w.trace.width() != air.width() {
197            return Err(InstanceValidationError::WidthMismatch {
198                expected: air.width(),
199                actual: w.trace.width(),
200            }
201            .into());
202        }
203    }
204
205    // Observe shape metadata before creating the transcript.
206    instance_shapes.observe_heights::<F, _>(&mut challenger);
207
208    let mut channel = ProverTranscript::new(challenger);
209
210    // Clear the challenger's absorb buffer after observing instance shapes by
211    // squeezing a throwaway extension element. This guarantees later sampled
212    // challenges depend on all prior inputs regardless of sponge state.
213    let _instance_challenge: EF = channel.sample_algebra_element::<EF>();
214
215    // Infer constraint degree from symbolic AIR analysis (max across all AIRs)
216    let log_constraint_degree =
217        instances.iter().map(|(air, ..)| air.log_quotient_degree()).max().unwrap_or(1) as u8;
218
219    if log_constraint_degree > log_blowup {
220        return Err(ProverError::ConstraintDegreeTooHigh {
221            log_quotient_degree: log_constraint_degree,
222            log_blowup,
223        });
224    }
225
226    let log_lde_height = log_max_trace_height + log_blowup;
227
228    // Max LDE coset (for the largest trace, no lifting)
229    let max_lde_coset = LiftedCoset::unlifted(log_max_trace_height, log_blowup);
230    let max_quotient_coset = max_lde_coset.quotient_domain(log_constraint_degree);
231    let max_quotient_height = max_quotient_coset.lde_height();
232
233    // 1. Commit all main traces (trace order — ascending height).
234    //
235    // Clone with blowup × capacity so the DFT resize doesn't reallocate.
236    let blowup = 1 << log_blowup as usize;
237    let main_traces: Vec<_> = instances
238        .iter()
239        .map(|(_, w, _)| {
240            let src = &w.trace.values;
241            let mut values = Vec::with_capacity(src.len() * blowup);
242            values.extend_from_slice(src);
243            RowMajorMatrix::new(values, w.trace.width())
244        })
245        .collect();
246    let main_committed =
247        info_span!("commit to main traces").in_scope(|| commit_traces(config, main_traces));
248    channel.send_commitment(main_committed.root());
249
250    // 2. Sample randomness and build aux traces for all AIRs
251    let max_num_randomness =
252        instances.iter().map(|(air, ..)| air.num_randomness()).max().unwrap_or(0);
253
254    let randomness: Vec<EF> = (0..max_num_randomness)
255        .map(|_| channel.sample_algebra_element::<EF>())
256        .collect();
257
258    // Build aux traces via AuxBuilder
259    let (aux_traces_ef, all_aux_values): (Vec<RowMajorMatrix<EF>>, Vec<Vec<EF>>) =
260        info_span!("build aux traces").in_scope(|| {
261            let mut traces = Vec::with_capacity(instances.len());
262            let mut values = Vec::with_capacity(instances.len());
263            for (air, w, aux_builder) in &instances {
264                let num_rand = air.num_randomness();
265                let (aux, aux_vals) = aux_builder.build_aux_trace(w.trace, &randomness[..num_rand]);
266
267                assert_eq!(aux.width(), air.aux_width(), "aux trace width mismatch");
268                assert_eq!(
269                    aux_vals.len(),
270                    air.num_aux_values(),
271                    "aux values length mismatch: build_aux_trace returned {} values, \
272                     but num_aux_values() is {}",
273                    aux_vals.len(),
274                    air.num_aux_values()
275                );
276                assert_eq!(aux.height(), w.trace.height());
277                traces.push(aux);
278                values.push(aux_vals);
279            }
280            (traces, values)
281        });
282
283    // Flatten EF -> F and commit aux traces
284    let aux_traces: Vec<RowMajorMatrix<F>> = aux_traces_ef
285        .into_iter()
286        .map(|aux| {
287            let base_width = aux.width() * EF::DIMENSION;
288            let base_values = <EF as BasedVectorSpace<F>>::flatten_to_base(aux.values);
289            RowMajorMatrix::new(base_values, base_width)
290        })
291        .collect();
292
293    let aux_committed =
294        info_span!("commit to aux traces").in_scope(|| commit_traces(config, aux_traces));
295    channel.send_commitment(aux_committed.root());
296
297    // Observe aux values into the transcript (binds to Fiat-Shamir state).
298    // When no AIR has aux columns, each entry is empty so nothing is sent.
299    for vals in &all_aux_values {
300        for &val in vals {
301            channel.send_algebra_element(val);
302        }
303    }
304
305    // 4. Sample constraint folding alpha and accumulation beta
306    let alpha: EF = channel.sample_algebra_element::<EF>();
307    let beta: EF = channel.sample_algebra_element::<EF>();
308
309    // 5. Evaluate constraints and accumulate with beta folding.
310    //
311    // Single accumulator, processed in trace order (ascending height):
312    //   1. Cyclically extend accumulator to the next quotient height
313    //   2. Multiply every element by beta (Horner)
314    //   3. Add constraint evaluations in-place: acc[i] += eval(i)
315    //
316    // Pre-allocate with LDE capacity so commit_quotient's resize doesn't reallocate.
317    let constraint_degree = 1 << log_constraint_degree as usize;
318    let mut accumulator: Vec<EF> = Vec::with_capacity(max_quotient_height * blowup);
319
320    // Pre-compute constraint layouts for each AIR (base/ext index mapping)
321    let layouts: Vec<_> = instances
322        .iter()
323        .map(|(air, ..)| get_constraint_layout::<F, EF, A>(*air))
324        .collect();
325
326    info_span!("evaluate constraints").in_scope(|| {
327        for (i, (air, w, _)) in instances.iter().enumerate() {
328            let trace_height = w.trace.height();
329            let log_trace_height = log2_strict_u8(trace_height);
330
331            // Create LiftedCoset for this trace (may be lifted relative to max)
332            let this_lde_coset =
333                LiftedCoset::new(log_trace_height, log_blowup, log_max_trace_height);
334            let this_quotient_coset = this_lde_coset.quotient_domain(log_constraint_degree);
335            let this_quotient_height = this_quotient_coset.lde_height();
336
337            // Truncate the committed LDE to the quotient evaluation domain gJ (size N·D).
338            // Since B ≥ D, the committed LDE on gK (size N·B) contains gJ as a prefix in
339            // bit-reversed storage, so this is a zero-copy view.
340            let main_on_gj = main_committed.evals_on_quotient_domain(i, constraint_degree);
341            let aux_on_gj = aux_committed.evals_on_quotient_domain(i, constraint_degree);
342
343            // Build periodic LDE for this trace via coset method
344            let periodic_lde =
345                PeriodicLde::build(&this_quotient_coset, air.periodic_columns_matrix());
346
347            // Cyclically extend accumulator to this quotient height and scale by beta.
348            // On the first iteration the accumulator is empty, so this is a no-op
349            // and evaluate_constraints_into writes into a zero-filled buffer.
350            tracing::debug_span!(
351                "cyclic_extend",
352                acc_len = accumulator.len(),
353                target = this_quotient_height
354            )
355            .in_scope(|| {
356                quotient::cyclic_extend_and_scale(&mut accumulator, this_quotient_height, beta);
357            });
358
359            let aux_values_i = &all_aux_values[i];
360
361            // Add constraint evaluations in-place: accumulator[i] += eval(i)
362            info_span!("eval_instance", instance = i, height = this_quotient_height).in_scope(
363                || {
364                    evaluate_constraints_into::<F, EF, A>(
365                        &mut accumulator,
366                        *air,
367                        &main_on_gj,
368                        &aux_on_gj,
369                        &this_quotient_coset,
370                        alpha,
371                        &randomness[..air.num_randomness()],
372                        w.public_values,
373                        &periodic_lde,
374                        &layouts[i],
375                        aux_values_i,
376                    );
377                },
378            );
379        }
380    });
381
382    // Verify we have the expected size (max quotient domain)
383    assert_eq!(accumulator.len(), max_quotient_height);
384
385    // 6. Divide by vanishing polynomial once on full gJ (in-place)
386    tracing::debug_span!("divide_by_vanishing", height = max_quotient_height).in_scope(|| {
387        quotient::divide_by_vanishing_in_place::<F, EF>(&mut accumulator, &max_quotient_coset);
388    });
389
390    // 7. Commit quotient
391    let quotient_committed = info_span!("commit to quotient poly chunks")
392        .in_scope(|| quotient::commit_quotient(config, accumulator, &max_lde_coset));
393    channel.send_commitment(quotient_committed.root());
394
395    // 8. Sample OOD point (outside H and gK)
396    let z: EF = max_lde_coset.sample_ood_point(&mut channel);
397    let h = F::two_adic_generator(log_max_trace_height.into());
398    let z_next = z * h;
399
400    // 9. Open via PCS
401    let trees = vec![main_committed.tree(), aux_committed.tree(), quotient_committed.tree()];
402
403    info_span!("open").in_scope(|| {
404        open_with_channel::<F, EF, SC::Lmcs, RowMajorMatrix<F>, _, 2>(
405            config.pcs(),
406            config.lmcs(),
407            log_lde_height,
408            [z, z_next],
409            &trees,
410            &mut channel,
411        )
412    });
413
414    let (digest, transcript) = channel.finalize();
415    let proof = StarkProof { instance_shapes, transcript };
416    Ok(StarkOutput { digest, proof })
417}