p3-circle 0.5.3

A STARK proof system built around the unit circle of a finite field, based on the Circle STARKs paper.
Documentation
use alloc::vec;
use alloc::vec::Vec;
use core::iter;

use itertools::{Itertools, izip};
use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
use p3_commit::Mmcs;
use p3_field::{ExtensionField, Field};
use p3_fri::{FriFoldingStrategy, FriParameters, compute_log_arity_for_round};
use p3_matrix::dense::RowMajorMatrix;
use p3_util::log2_strict_usize;
use tracing::{info_span, instrument};

use crate::{CircleCommitPhaseProofStep, CircleFriProof, CircleQueryProof};

#[instrument(name = "FRI prover", skip_all)]
pub fn prove<Folding, Val, Challenge, M, Challenger>(
    folding: &Folding,
    params: &FriParameters<M>,
    inputs: Vec<Vec<Challenge>>,
    challenger: &mut Challenger,
    open_input: impl Fn(usize) -> Folding::InputProof,
) -> CircleFriProof<Challenge, M, Challenger::Witness, Folding::InputProof>
where
    Val: Field,
    Challenge: ExtensionField<Val>,
    M: Mmcs<Challenge>,
    Challenger: FieldChallenger<Val> + GrindingChallenger + CanObserve<M::Commitment>,
    Folding: FriFoldingStrategy<Val, Challenge>,
{
    // check sorted descending
    assert!(
        inputs
            .iter()
            .tuple_windows()
            .all(|(l, r)| l.len() >= r.len())
    );

    let log_max_height = log2_strict_usize(inputs[0].len());

    let commit_phase_result = commit_phase(folding, params, inputs, challenger);

    let pow_witness = challenger.grind(params.query_proof_of_work_bits);

    let query_proofs = info_span!("query phase").in_scope(|| {
        iter::repeat_with(|| {
            let index = challenger.sample_bits(log_max_height + folding.extra_query_index_bits());
            // For each index, create a proof that the folding operations along the chain are correct.
            CircleQueryProof {
                input_proof: open_input(index),
                commit_phase_openings: answer_query(
                    params,
                    &commit_phase_result.log_arities,
                    &commit_phase_result.data,
                    index >> folding.extra_query_index_bits(),
                ),
            }
        })
        .take(params.num_queries)
        .collect()
    });

    CircleFriProof {
        commit_phase_commits: commit_phase_result.commits,
        query_proofs,
        final_poly: commit_phase_result.final_poly,
        pow_witness,
    }
}

struct CommitPhaseResult<F: Field, M: Mmcs<F>> {
    commits: Vec<M::Commitment>,
    data: Vec<M::ProverData<RowMajorMatrix<F>>>,
    log_arities: Vec<usize>,
    final_poly: F,
}

#[instrument(name = "commit phase", skip_all)]
fn commit_phase<Folding, Val, Challenge, M, Challenger>(
    folding: &Folding,
    params: &FriParameters<M>,
    inputs: Vec<Vec<Challenge>>,
    challenger: &mut Challenger,
) -> CommitPhaseResult<Challenge, M>
where
    Val: Field,
    Challenge: ExtensionField<Val>,
    M: Mmcs<Challenge>,
    Challenger: FieldChallenger<Val> + CanObserve<M::Commitment>,
    Folding: FriFoldingStrategy<Val, Challenge>,
{
    let mut inputs_iter = inputs.into_iter().peekable();
    let mut folded = inputs_iter.next().unwrap();
    let mut commits = vec![];
    let mut data = vec![];
    let mut log_arities = vec![];

    // For Circle, we fold down to blowup elements (no separate final_poly_len)
    let log_final_height = params.log_blowup;

    while folded.len() > params.blowup() {
        let log_current_height = log2_strict_usize(folded.len());
        let next_input_log_height = inputs_iter.peek().map(|v| log2_strict_usize(v.len()));

        // Compute the arity for this round
        let log_arity = compute_log_arity_for_round(
            log_current_height,
            next_input_log_height,
            log_final_height,
            params.max_log_arity,
        );
        let arity = 1 << log_arity;
        log_arities.push(log_arity);

        let leaves = RowMajorMatrix::new(folded, arity);
        let (commit, prover_data) = params.mmcs.commit_matrix(leaves);
        challenger.observe(commit.clone());

        let beta: Challenge = challenger.sample_algebra_element();
        // We passed ownership of `current` to the MMCS, so get a reference to it
        let leaves = params.mmcs.get_matrices(&prover_data).pop().unwrap();
        folded = folding.fold_matrix(beta, log_arity, leaves.as_view());

        commits.push(commit);
        data.push(prover_data);

        if let Some(v) = inputs_iter.next_if(|v| v.len() == folded.len()) {
            izip!(&mut folded, v).for_each(|(c, x)| *c += x);
        }
    }

    // We should be left with `blowup` evaluations of a constant polynomial.
    assert_eq!(folded.len(), params.blowup());
    let final_poly = folded[0];
    for x in folded {
        assert_eq!(x, final_poly);
    }
    challenger.observe_algebra_element(final_poly);

    CommitPhaseResult {
        commits,
        data,
        log_arities,
        final_poly,
    }
}

fn answer_query<F, M>(
    params: &FriParameters<M>,
    log_arities: &[usize],
    commit_phase_commits: &[M::ProverData<RowMajorMatrix<F>>],
    start_index: usize,
) -> Vec<CircleCommitPhaseProofStep<F, M>>
where
    F: Field,
    M: Mmcs<F>,
{
    let mut current_index = start_index;

    commit_phase_commits
        .iter()
        .enumerate()
        .map(|(i, commit)| {
            let log_arity = log_arities[i];
            let arity = 1 << log_arity;

            // Index of this element within its group
            let index_in_group = current_index % arity;
            // Index of the group (row in the committed matrix)
            let group_index = current_index >> log_arity;

            let (mut opened_rows, opening_proof) =
                params.mmcs.open_batch(group_index, commit).unpack();
            assert_eq!(opened_rows.len(), 1);
            let opened_row = opened_rows.pop().unwrap();
            assert_eq!(
                opened_row.len(),
                arity,
                "Committed data should have arity {} elements",
                arity
            );

            // Get all siblings (exclude self)
            let sibling_values: Vec<_> = opened_row
                .into_iter()
                .enumerate()
                .filter(|(j, _)| *j != index_in_group)
                .map(|(_, v)| v)
                .collect();

            // Update current_index for the next round
            current_index = group_index;

            CircleCommitPhaseProofStep {
                log_arity: log_arity as u8,
                sibling_values,
                opening_proof,
            }
        })
        .collect()
}