use crate::{
base::{if_rayon, scalar::Scalar},
proof_primitive::sumcheck::ProverState,
utils::log,
};
use alloc::{vec, vec::Vec};
#[cfg(feature = "rayon")]
use rayon::prelude::*;
#[expect(clippy::ref_option, clippy::missing_panics_doc)]
#[tracing::instrument(level = "debug", skip_all)]
pub fn prove_round<S: Scalar>(prover_state: &mut ProverState<S>, r_maybe: &Option<S>) -> Vec<S> {
log::log_memory_usage("Start");
if let Some(r) = r_maybe {
assert!(
prover_state.round != 0,
"first round should be prover first."
);
if_rayon!(
prover_state.flattened_ml_extensions.par_iter_mut(),
prover_state.flattened_ml_extensions.iter_mut()
)
.for_each(|multiplicand| {
in_place_fix_variable(multiplicand, *r, prover_state.num_vars - prover_state.round);
});
} else if prover_state.round > 0 {
panic!("verifier message is empty");
}
prover_state.round += 1;
assert!(
prover_state.round <= prover_state.num_vars,
"Prover is not active"
);
let degree = prover_state.max_multiplicands; let round_length = 1usize << (prover_state.num_vars - prover_state.round);
let sums_iter = if_rayon!(
prover_state.list_of_products.par_iter(),
prover_state.list_of_products.iter()
)
.map(|(coefficient, multiplicand_indices)| {
let products_iter =
if_rayon!((0..round_length).into_par_iter(), 0..round_length).map(|b| {
let mut products = vec![*coefficient; degree + 1];
for &multiplicand_index in multiplicand_indices {
let table = &prover_state.flattened_ml_extensions[multiplicand_index];
let mut start = table[b << 1];
let step = table[(b << 1) + 1] - start;
products.iter_mut().take(degree).for_each(|product| {
*product *= start;
start += step;
});
products[degree] *= start;
}
products
});
if_rayon!(
products_iter.reduce(|| vec![S::zero(); degree + 1], vec_elementwise_add),
products_iter.fold(vec![S::zero(); degree + 1], vec_elementwise_add)
)
});
let res = if_rayon!(
sums_iter.reduce(|| vec![S::zero(); degree + 1], vec_elementwise_add),
sums_iter.fold(vec![S::zero(); degree + 1], vec_elementwise_add)
);
log::log_memory_usage("End");
res
}
fn in_place_fix_variable<S: Scalar>(multiplicand: &mut [S], r_as_field: S, num_vars: usize) {
assert!(num_vars > 0, "invalid size of partial point");
for b in 0..(1 << num_vars) {
let left: S = multiplicand[b << 1];
let right: S = multiplicand[(b << 1) + 1];
multiplicand[b] = left + r_as_field * (right - left);
}
}
fn vec_elementwise_add<S: Scalar>(a: Vec<S>, b: Vec<S>) -> Vec<S> {
a.into_iter().zip(b).map(|(x, y)| x + y).collect::<Vec<S>>()
}