use crate::{
fft::DensePolynomial,
snark::varuna::{
AHPError,
Matrix,
SNARKMode,
ahp::{AHPForR1CS, indexer::CircuitId, verifier},
prover::{self, MatrixSums, ThirdMessage},
},
};
use snarkvm_fields::PrimeField;
use snarkvm_utilities::ExecutionPool;
use anyhow::Result;
use itertools::Itertools;
use rand::RngCore;
use std::collections::{BTreeMap, VecDeque};
struct LinevalPrepInstance<F: PrimeField> {
z_m_at_alpha: DensePolynomial<F>,
sum: F,
}
impl<F: PrimeField, SM: SNARKMode> AHPForR1CS<F, SM> {
pub fn prover_prepare_third_round<'a, R: RngCore>(
verifier_message: &verifier::FirstMessage<F>,
verifier_second_message: &verifier::SecondMessage<F>,
mut state: prover::State<'a, F, SM>,
_r: &mut R,
) -> Result<(prover::ThirdMessage<F>, prover::State<'a, F, SM>), AHPError> {
let round_time = start_timer!(|| "AHP::Prover::ThirdRound");
let verifier::FirstMessage { first_round_batch_combiners } = verifier_message;
let verifier::SecondMessage { alpha, eta_b, eta_c } = verifier_second_message;
if eta_b.is_some() || eta_c.is_some() {
return Err(AHPError::AnyhowError(anyhow::anyhow!(
"Did not expect eta_b,c in SecondMessage in VarunaVersion::V2"
)));
}
let assignments = Self::calculate_assignments(&mut state)?;
let matrix_transposes = Self::calculate_matrix_transpose(&mut state)?;
let msg = Self::calculate_prep_lineval_sumcheck_witness(
&mut state,
first_round_batch_combiners,
assignments,
matrix_transposes,
alpha,
)?;
end_timer!(round_time);
Ok((msg, state))
}
fn calculate_prep_lineval_sumcheck_witness(
state: &mut prover::State<F, SM>,
first_round_batch_combiners: &BTreeMap<CircuitId, verifier::BatchCombiners<F>>,
assignments: BTreeMap<CircuitId, Vec<DensePolynomial<F>>>,
matrix_transposes: BTreeMap<CircuitId, BTreeMap<String, Matrix<F>>>,
alpha: &F,
) -> Result<ThirdMessage<F>> {
let num_instances = first_round_batch_combiners.values().map(|c| c.instance_combiners.len()).collect_vec();
let total_instances = num_instances.iter().sum::<usize>();
let matrix_labels = ["a", "b", "c"];
let fft_precomputations = state
.circuit_specific_states
.keys()
.map(|circuit| (circuit.fft_precomputation.clone(), circuit.ifft_precomputation.clone()))
.collect_vec();
let mut job_pool = ExecutionPool::with_capacity(total_instances * 3);
anyhow::ensure!(
state.circuit_specific_states.len() == fft_precomputations.len(),
"[calculate Prep Lineval Sumcheck Witness] Expected {} circuit specific states, but {} were provided.",
fft_precomputations.len(),
state.circuit_specific_states.len()
);
anyhow::ensure!(
state.circuit_specific_states.len() == assignments.len(),
"[calculate Prep Lineval Sumcheck Witness] Expected {} assignments, but {} were provided.",
assignments.len(),
state.circuit_specific_states.len()
);
anyhow::ensure!(
state.circuit_specific_states.len() == matrix_transposes.len(),
"[calculate Prep Lineval Sumcheck Witness] Expected {} matrix transposes, but {} were provided.",
matrix_transposes.len(),
state.circuit_specific_states.len()
);
for ((((&circuit, circuit_specific_state), precomp), assignments_i), matrix_transposes_i) in state
.circuit_specific_states
.iter()
.zip_eq(fft_precomputations.iter())
.zip_eq(assignments.values())
.zip_eq(matrix_transposes.values())
{
for assignment in assignments_i {
for label in matrix_labels {
let matrix_transpose = &matrix_transposes_i[label];
job_pool.add_job(move || {
let z_m_at_alpha = Self::calculate_lineval_sumcheck_instance_witness(
label,
&circuit_specific_state.constraint_domain,
&circuit_specific_state.variable_domain,
&precomp.0,
&precomp.1,
assignment,
matrix_transpose,
*alpha,
)?;
let sum = z_m_at_alpha
.evaluate_over_domain_by_ref(circuit_specific_state.variable_domain)
.evaluations
.into_iter()
.sum::<F>();
Ok((circuit, LinevalPrepInstance { z_m_at_alpha, sum }))
});
}
}
}
let mut sums = num_instances.iter().map(|n| Vec::with_capacity(*n)).collect_vec();
let mut circuit_index = 0;
let mut instances_seen = 0;
for (i, ((circuit_a, lineval_a), (circuit_b, lineval_b), (circuit_c, lineval_c))) in
job_pool.execute_all().into_iter().collect::<Result<Vec<_>>>()?.into_iter().tuples().enumerate()
{
assert_eq!(circuit_a, circuit_b);
assert_eq!(circuit_a, circuit_c);
sums[circuit_index].push(MatrixSums { sum_a: lineval_a.sum, sum_b: lineval_b.sum, sum_c: lineval_c.sum });
if 1 + i - instances_seen == num_instances[circuit_index] {
instances_seen += num_instances[circuit_index];
circuit_index += 1;
}
match &mut state.circuit_specific_states.get_mut(circuit_a).unwrap().z_m_at_alpha_polys {
None => {
let mut z_m_at_alpha_polys = VecDeque::new();
z_m_at_alpha_polys.push_back([
lineval_a.z_m_at_alpha,
lineval_b.z_m_at_alpha,
lineval_c.z_m_at_alpha,
]);
state.circuit_specific_states.get_mut(circuit_a).unwrap().z_m_at_alpha_polys =
Some(z_m_at_alpha_polys);
}
Some(z_m_at_alpha_polys) => {
z_m_at_alpha_polys.push_back([
lineval_a.z_m_at_alpha,
lineval_b.z_m_at_alpha,
lineval_c.z_m_at_alpha,
]);
}
}
}
let msg = ThirdMessage { sums };
Ok(msg)
}
}