use std::{collections::BTreeMap, sync::Arc};
use crate::{
fft::{DensePolynomial, EvaluationDomain, Evaluations as EvaluationsOnDomain},
r1cs::SynthesisResult,
snark::marlin::{ahp::verifier, AHPError, AHPForR1CS, Circuit, MarlinMode},
};
use snarkvm_fields::PrimeField;
pub struct CircuitSpecificState<F: PrimeField> {
pub(super) input_domain: EvaluationDomain<F>,
pub(super) constraint_domain: EvaluationDomain<F>,
pub(super) non_zero_a_domain: EvaluationDomain<F>,
pub(super) non_zero_b_domain: EvaluationDomain<F>,
pub(super) non_zero_c_domain: EvaluationDomain<F>,
pub(in crate::snark) batch_size: usize,
pub(super) padded_public_variables: Vec<Vec<F>>,
pub(super) private_variables: Vec<Vec<F>>,
pub(super) z_a: Option<Vec<Vec<F>>>,
pub(super) z_b: Option<Vec<Vec<F>>>,
pub(super) x_polys: Vec<DensePolynomial<F>>,
pub(super) mz_poly_randomizer: Option<Vec<F>>,
pub(super) lhs_polynomials: Option<[DensePolynomial<F>; 3]>,
}
pub struct State<'a, F: PrimeField, MM: MarlinMode> {
pub(super) circuit_specific_states: BTreeMap<&'a Circuit<F, MM>, CircuitSpecificState<F>>,
pub(super) verifier_first_message: Option<verifier::FirstMessage<F>>,
pub(in crate::snark) first_round_oracles: Option<Arc<super::FirstOracles<F>>>,
pub(in crate::snark) max_non_zero_domain: EvaluationDomain<F>,
pub(in crate::snark) max_constraint_domain: EvaluationDomain<F>,
pub(in crate::snark) total_instances: usize,
}
type PaddedPubInputs<F> = Vec<F>;
type PrivateInputs<F> = Vec<F>;
type Za<F> = Vec<F>;
type Zb<F> = Vec<F>;
pub(super) struct Assignments<F>(
pub(super) PaddedPubInputs<F>,
pub(super) PrivateInputs<F>,
pub(super) Za<F>,
pub(super) Zb<F>,
);
impl<'a, F: PrimeField, MM: MarlinMode> State<'a, F, MM> {
pub(super) fn initialize(
indices_and_assignments: BTreeMap<&'a Circuit<F, MM>, Vec<Assignments<F>>>,
) -> Result<Self, AHPError> {
let mut max_constraint_domain: Option<EvaluationDomain<F>> = None;
let mut max_non_zero_domain: Option<EvaluationDomain<F>> = None;
let mut total_instances = 0;
let circuit_specific_states = indices_and_assignments
.into_iter()
.map(|(circuit, variable_assignments)| {
let index_info = &circuit.index_info;
let constraint_domains = AHPForR1CS::<_, MM>::max_constraint_domain(index_info, max_constraint_domain)?;
max_constraint_domain = constraint_domains.max_constraint_domain;
let non_zero_domains = AHPForR1CS::<_, MM>::max_non_zero_domain(index_info, max_non_zero_domain)?;
max_non_zero_domain = non_zero_domains.max_non_zero_domain;
let first_padded_public_inputs = &variable_assignments[0].0;
let input_domain = EvaluationDomain::new(first_padded_public_inputs.len()).unwrap();
let batch_size = variable_assignments.len();
total_instances += batch_size;
let mut z_as = Vec::with_capacity(batch_size);
let mut z_bs = Vec::with_capacity(batch_size);
let mut x_polys = Vec::with_capacity(batch_size);
let mut padded_public_variables = Vec::with_capacity(batch_size);
let mut private_variables = Vec::with_capacity(batch_size);
for Assignments(padded_public_input, private_input, z_a, z_b) in variable_assignments {
z_as.push(z_a);
z_bs.push(z_b);
let x_poly = EvaluationsOnDomain::from_vec_and_domain(padded_public_input.clone(), input_domain)
.interpolate();
x_polys.push(x_poly);
padded_public_variables.push(padded_public_input);
private_variables.push(private_input);
}
let state = CircuitSpecificState {
input_domain,
constraint_domain: constraint_domains.constraint_domain,
non_zero_a_domain: non_zero_domains.domain_a,
non_zero_b_domain: non_zero_domains.domain_b,
non_zero_c_domain: non_zero_domains.domain_c,
batch_size,
padded_public_variables,
x_polys,
private_variables,
z_a: Some(z_as),
z_b: Some(z_bs),
mz_poly_randomizer: None,
lhs_polynomials: None,
};
Ok((circuit, state))
})
.collect::<SynthesisResult<BTreeMap<_, _>>>()?;
let max_constraint_domain = max_constraint_domain.ok_or(AHPError::BatchSizeIsZero)?;
let max_non_zero_domain = max_non_zero_domain.ok_or(AHPError::BatchSizeIsZero)?;
Ok(Self {
max_constraint_domain,
max_non_zero_domain,
circuit_specific_states,
total_instances,
first_round_oracles: None,
verifier_first_message: None,
})
}
pub fn batch_size(&self, circuit: &Circuit<F, MM>) -> Option<usize> {
self.circuit_specific_states.get(circuit).map(|s| s.batch_size)
}
pub fn public_inputs(&self, circuit: &Circuit<F, MM>) -> Option<Vec<Vec<F>>> {
self.circuit_specific_states.get(circuit).map(|s| {
s.padded_public_variables.iter().map(|v| super::ConstraintSystem::unformat_public_input(v)).collect()
})
}
pub fn padded_public_inputs(&self, circuit: &Circuit<F, MM>) -> Option<&[Vec<F>]> {
self.circuit_specific_states.get(circuit).map(|s| s.padded_public_variables.as_slice())
}
pub fn lhs_polys_into_iter(self) -> impl Iterator<Item = DensePolynomial<F>> + 'a {
self.circuit_specific_states.into_values().flat_map(|s| s.lhs_polynomials.unwrap().into_iter())
}
}