use alloc::vec;
use alloc::vec::Vec;
use core::iter;
use itertools::{Itertools, izip};
use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
use p3_commit::Mmcs;
use p3_field::{ExtensionField, Field};
use p3_fri::{FriFoldingStrategy, FriParameters, compute_log_arity_for_round};
use p3_matrix::dense::RowMajorMatrix;
use p3_util::log2_strict_usize;
use tracing::{info_span, instrument};
use crate::{CircleCommitPhaseProofStep, CircleFriProof, CircleQueryProof};
#[instrument(name = "FRI prover", skip_all)]
pub fn prove<Folding, Val, Challenge, M, Challenger>(
folding: &Folding,
params: &FriParameters<M>,
inputs: Vec<Vec<Challenge>>,
challenger: &mut Challenger,
open_input: impl Fn(usize) -> Folding::InputProof,
) -> CircleFriProof<Challenge, M, Challenger::Witness, Folding::InputProof>
where
Val: Field,
Challenge: ExtensionField<Val>,
M: Mmcs<Challenge>,
Challenger: FieldChallenger<Val> + GrindingChallenger + CanObserve<M::Commitment>,
Folding: FriFoldingStrategy<Val, Challenge>,
{
assert!(
inputs
.iter()
.tuple_windows()
.all(|(l, r)| l.len() >= r.len())
);
let log_max_height = log2_strict_usize(inputs[0].len());
let commit_phase_result = commit_phase(folding, params, inputs, challenger);
let pow_witness = challenger.grind(params.query_proof_of_work_bits);
let query_proofs = info_span!("query phase").in_scope(|| {
iter::repeat_with(|| {
let index = challenger.sample_bits(log_max_height + folding.extra_query_index_bits());
CircleQueryProof {
input_proof: open_input(index),
commit_phase_openings: answer_query(
params,
&commit_phase_result.log_arities,
&commit_phase_result.data,
index >> folding.extra_query_index_bits(),
),
}
})
.take(params.num_queries)
.collect()
});
CircleFriProof {
commit_phase_commits: commit_phase_result.commits,
query_proofs,
final_poly: commit_phase_result.final_poly,
pow_witness,
}
}
struct CommitPhaseResult<F: Field, M: Mmcs<F>> {
commits: Vec<M::Commitment>,
data: Vec<M::ProverData<RowMajorMatrix<F>>>,
log_arities: Vec<usize>,
final_poly: F,
}
#[instrument(name = "commit phase", skip_all)]
fn commit_phase<Folding, Val, Challenge, M, Challenger>(
folding: &Folding,
params: &FriParameters<M>,
inputs: Vec<Vec<Challenge>>,
challenger: &mut Challenger,
) -> CommitPhaseResult<Challenge, M>
where
Val: Field,
Challenge: ExtensionField<Val>,
M: Mmcs<Challenge>,
Challenger: FieldChallenger<Val> + CanObserve<M::Commitment>,
Folding: FriFoldingStrategy<Val, Challenge>,
{
let mut inputs_iter = inputs.into_iter().peekable();
let mut folded = inputs_iter.next().unwrap();
let mut commits = vec![];
let mut data = vec![];
let mut log_arities = vec![];
let log_final_height = params.log_blowup;
while folded.len() > params.blowup() {
let log_current_height = log2_strict_usize(folded.len());
let next_input_log_height = inputs_iter.peek().map(|v| log2_strict_usize(v.len()));
let log_arity = compute_log_arity_for_round(
log_current_height,
next_input_log_height,
log_final_height,
params.max_log_arity,
);
let arity = 1 << log_arity;
log_arities.push(log_arity);
let leaves = RowMajorMatrix::new(folded, arity);
let (commit, prover_data) = params.mmcs.commit_matrix(leaves);
challenger.observe(commit.clone());
let beta: Challenge = challenger.sample_algebra_element();
let leaves = params.mmcs.get_matrices(&prover_data).pop().unwrap();
folded = folding.fold_matrix(beta, log_arity, leaves.as_view());
commits.push(commit);
data.push(prover_data);
if let Some(v) = inputs_iter.next_if(|v| v.len() == folded.len()) {
izip!(&mut folded, v).for_each(|(c, x)| *c += x);
}
}
assert_eq!(folded.len(), params.blowup());
let final_poly = folded[0];
for x in folded {
assert_eq!(x, final_poly);
}
challenger.observe_algebra_element(final_poly);
CommitPhaseResult {
commits,
data,
log_arities,
final_poly,
}
}
fn answer_query<F, M>(
params: &FriParameters<M>,
log_arities: &[usize],
commit_phase_commits: &[M::ProverData<RowMajorMatrix<F>>],
start_index: usize,
) -> Vec<CircleCommitPhaseProofStep<F, M>>
where
F: Field,
M: Mmcs<F>,
{
let mut current_index = start_index;
commit_phase_commits
.iter()
.enumerate()
.map(|(i, commit)| {
let log_arity = log_arities[i];
let arity = 1 << log_arity;
let index_in_group = current_index % arity;
let group_index = current_index >> log_arity;
let (mut opened_rows, opening_proof) =
params.mmcs.open_batch(group_index, commit).unpack();
assert_eq!(opened_rows.len(), 1);
let opened_row = opened_rows.pop().unwrap();
assert_eq!(
opened_row.len(),
arity,
"Committed data should have arity {} elements",
arity
);
let sibling_values: Vec<_> = opened_row
.into_iter()
.enumerate()
.filter(|(j, _)| *j != index_in_group)
.map(|(_, v)| v)
.collect();
current_index = group_index;
CircleCommitPhaseProofStep {
log_arity: log_arity as u8,
sibling_values,
opening_proof,
}
})
.collect()
}