use super::{constraints::CompositionPoly, StarkDomain, TracePolyTable};
use air::{Air, DeepCompositionCoefficients};
use math::{add_in_place, fft, log2, mul_acc, polynom, ExtensionOf, FieldElement, StarkField};
use utils::{collections::Vec, iter_mut};
#[cfg(feature = "concurrent")]
use utils::iterators::*;
pub struct DeepCompositionPoly<E: FieldElement> {
coefficients: Vec<E>,
cc: DeepCompositionCoefficients<E>,
z: E,
field_extension: bool,
}
impl<E: FieldElement> DeepCompositionPoly<E> {
pub fn new<A>(air: &A, z: E, cc: DeepCompositionCoefficients<E>) -> Self
where
A: Air<BaseField = E::BaseField>,
{
DeepCompositionPoly {
coefficients: vec![],
cc,
z,
field_extension: !air.options().field_extension().is_none(),
}
}
pub fn poly_size(&self) -> usize {
self.coefficients.len()
}
pub fn degree(&self) -> usize {
polynom::degree_of(&self.coefficients)
}
pub fn add_trace_polys(
&mut self,
trace_polys: TracePolyTable<E>,
ood_trace_states: Vec<Vec<E>>,
) {
assert!(self.coefficients.is_empty());
let trace_length = trace_polys.poly_size();
let g = E::from(E::BaseField::get_root_of_unity(log2(trace_length)));
let next_z = self.z * g;
let mut t1_composition = E::zeroed_vector(trace_length);
let mut t2_composition = E::zeroed_vector(trace_length);
let mut t3_composition = if self.field_extension {
E::zeroed_vector(trace_length)
} else {
Vec::new()
};
let mut i = 0;
for poly in trace_polys.main_trace_polys() {
acc_trace_poly::<E::BaseField, E>(
&mut t1_composition,
poly,
ood_trace_states[0][i],
self.cc.trace[i].0,
);
acc_trace_poly::<E::BaseField, E>(
&mut t2_composition,
poly,
ood_trace_states[1][i],
self.cc.trace[i].1,
);
if self.field_extension {
acc_trace_poly::<E::BaseField, E>(
&mut t3_composition,
poly,
ood_trace_states[0][i].conjugate(),
self.cc.trace[i].2,
);
}
i += 1;
}
for poly in trace_polys.aux_trace_polys() {
acc_trace_poly::<E, E>(
&mut t1_composition,
poly,
ood_trace_states[0][i],
self.cc.trace[i].0,
);
acc_trace_poly::<E, E>(
&mut t2_composition,
poly,
ood_trace_states[1][i],
self.cc.trace[i].1,
);
i += 1;
}
let trace_poly = merge_trace_compositions(
vec![t1_composition, t2_composition, t3_composition],
vec![self.z, next_z, self.z.conjugate()],
);
self.coefficients = trace_poly;
assert_eq!(self.poly_size() - 2, self.degree());
}
pub fn add_composition_poly(
&mut self,
composition_poly: CompositionPoly<E>,
ood_evaluations: Vec<E>,
) {
assert!(!self.coefficients.is_empty());
let num_columns = composition_poly.num_columns() as u32;
let z_m = self.z.exp(num_columns.into());
let mut column_polys = composition_poly.into_columns();
iter_mut!(column_polys)
.zip(ood_evaluations)
.for_each(|(poly, value_at_z_m)| {
poly[0] -= value_at_z_m;
polynom::syn_div_in_place(poly, 1, z_m);
});
for (i, poly) in column_polys.into_iter().enumerate() {
mul_acc::<E, E>(&mut self.coefficients, &poly, self.cc.constraints[i]);
}
assert_eq!(self.poly_size() - 2, self.degree());
}
pub fn adjust_degree(&mut self) {
assert_eq!(self.poly_size() - 2, self.degree());
let mut result = E::zeroed_vector(self.coefficients.len());
mul_acc::<E, E>(&mut result, &self.coefficients, self.cc.degree.0);
mul_acc::<E, E>(
&mut result[1..],
&self.coefficients[..(self.coefficients.len() - 1)],
self.cc.degree.1,
);
self.coefficients = result;
assert_eq!(self.poly_size() - 1, self.degree());
}
pub fn evaluate(self, domain: &StarkDomain<E::BaseField>) -> Vec<E> {
fft::evaluate_poly_with_offset(
&self.coefficients,
domain.trace_twiddles(),
domain.offset(),
domain.trace_to_lde_blowup(),
)
}
}
fn merge_trace_compositions<E: FieldElement>(mut polys: Vec<Vec<E>>, divisors: Vec<E>) -> Vec<E> {
iter_mut!(polys).zip(divisors).for_each(|(poly, divisor)| {
if !poly.is_empty() {
polynom::syn_div_in_place(poly, 1, divisor);
}
});
let mut result = polys.remove(0);
for poly in polys.iter() {
if !poly.is_empty() {
add_in_place(&mut result, poly);
}
}
result
}
fn acc_trace_poly<F, E>(accumulator: &mut [E], poly: &[F], value: E, k: E)
where
F: FieldElement,
E: FieldElement<BaseField = F::BaseField> + ExtensionOf<F>,
{
mul_acc(accumulator, poly, k);
let adjusted_tz = value * k;
accumulator[0] -= adjusted_tz;
}