use std::collections::BTreeMap;
use crate::{
fft::{DensePolynomial, EvaluationDomain, Evaluations as EvaluationsOnDomain, polynomial::PolyMultiplier},
polycommit::sonic_pc::{LabeledPolynomial, PolynomialInfo, PolynomialLabel},
snark::varuna::{
Circuit,
CircuitId,
SNARKMode,
ahp::{AHPForR1CS, verifier},
prover,
selectors::apply_randomized_selector,
witness_label,
},
};
use anyhow::Result;
use rand::RngCore;
use snarkvm_fields::PrimeField;
use snarkvm_utilities::{ExecutionPool, cfg_into_iter, cfg_reduce};
#[cfg(not(feature = "serial"))]
use rayon::prelude::*;
impl<F: PrimeField, SM: SNARKMode> AHPForR1CS<F, SM> {
pub const fn num_second_round_oracles() -> usize {
1
}
pub fn second_round_polynomial_info() -> BTreeMap<PolynomialLabel, PolynomialInfo> {
[PolynomialInfo::new("h_0".into(), None, None)].into_iter().map(|info| (info.label().into(), info)).collect()
}
pub fn prover_second_round<'a, R: RngCore>(
verifier_message: &verifier::FirstMessage<F>,
mut state: prover::State<'a, F, SM>,
_r: &mut R,
) -> Result<(prover::SecondOracles<F>, prover::State<'a, F, SM>)> {
let round_time = start_timer!(|| "AHP::Prover::SecondRound");
let zk_bound = Self::zk_bound();
let max_constraint_domain = state.max_constraint_domain;
let verifier::FirstMessage { first_round_batch_combiners, .. } = verifier_message;
let h_0 = Self::calculate_rowcheck_witness(&mut state, first_round_batch_combiners)?;
assert!(h_0.degree() <= 2 * max_constraint_domain.size() + 2 * zk_bound.unwrap_or(0) - 2);
let oracles = prover::SecondOracles { h_0: LabeledPolynomial::new("h_0", h_0, None, None) };
assert!(oracles.matches_info(&Self::second_round_polynomial_info()));
end_timer!(round_time);
Ok((oracles, state))
}
fn calculate_rowcheck_witness(
state: &mut prover::State<F, SM>,
batch_combiners: &BTreeMap<CircuitId, verifier::BatchCombiners<F>>,
) -> Result<DensePolynomial<F>> {
let mut job_pool = ExecutionPool::with_capacity(state.circuit_specific_states.len());
let max_constraint_domain = state.max_constraint_domain;
for (circuit, circuit_specific_state) in state.circuit_specific_states.iter_mut() {
let z_a = circuit_specific_state.z_a.take().unwrap();
let z_b = circuit_specific_state.z_b.take().unwrap();
let z_c = circuit_specific_state.z_c.take().unwrap();
let circuit_combiner = batch_combiners[&circuit.id].circuit_combiner;
let instance_combiners = batch_combiners[&circuit.id].instance_combiners.clone();
let constraint_domain = circuit_specific_state.constraint_domain;
let fft_precomputation = &circuit.fft_precomputation;
let ifft_precomputation = &circuit.ifft_precomputation;
let _circuit_id = &circuit.id;
for (j, (instance_combiner, z_a, z_b, z_c)) in
itertools::izip!(instance_combiners, z_a, z_b, z_c).enumerate()
{
job_pool.add_job(move || {
let mut instance_lhs = DensePolynomial::zero();
let za_label = witness_label(circuit.id, "z_a", j);
let zb_label = witness_label(circuit.id, "z_b", j);
let zc_label = witness_label(circuit.id, "z_c", j);
let z_a = Self::calculate_z_m(za_label, z_a, constraint_domain, circuit);
let z_b = Self::calculate_z_m(zb_label, z_b, constraint_domain, circuit);
let z_c = Self::calculate_z_m(zc_label, z_c, constraint_domain, circuit);
let mut multiplier_2 = PolyMultiplier::new();
multiplier_2.add_precomputation(fft_precomputation, ifft_precomputation);
multiplier_2.add_polynomial(z_a, "z_a");
multiplier_2.add_polynomial(z_b, "z_b");
let mut rowcheck = multiplier_2.multiply().unwrap();
rowcheck.coeffs.iter_mut().zip(&z_c.coeffs).for_each(|(ab, c)| *ab -= c);
instance_lhs += &(&rowcheck * instance_combiner);
let (h_0_i, remainder) = apply_randomized_selector(
&mut instance_lhs,
circuit_combiner,
&max_constraint_domain,
&constraint_domain,
false,
)?;
assert!(remainder.is_none());
Ok::<_, anyhow::Error>(h_0_i)
});
}
}
let h_sum_time = start_timer!(|| "AHP::Prover::SecondRound h_sum");
let h_sum: DensePolynomial<F> =
cfg_reduce!(cfg_into_iter!(job_pool.execute_all()), || Ok(DensePolynomial::zero()), |a, b| {
a.and_then(|a| {
b.map(|mut b| {
b += &a;
b
})
})
})?;
end_timer!(h_sum_time);
Ok(h_sum)
}
fn calculate_z_m(
label: impl ToString,
evaluations: Vec<F>,
constraint_domain: EvaluationDomain<F>,
circuit: &Circuit<F, SM>,
) -> DensePolynomial<F> {
let label = label.to_string();
let poly_time = start_timer!(|| format!("Computing {label}"));
let evals = EvaluationsOnDomain::from_vec_and_domain(evaluations, constraint_domain);
let poly = evals.interpolate_with_pc_by_ref(&circuit.ifft_precomputation);
debug_assert!(
poly.evaluate_over_domain_by_ref(constraint_domain)
.evaluations
.into_iter()
.zip(&evals.evaluations)
.all(|(z, e)| *e == z),
"Label: {label}\n1: {:#?}\n2: {:#?}",
poly.evaluate_over_domain_by_ref(constraint_domain).evaluations,
&evals.evaluations,
);
end_timer!(poly_time);
poly
}
}