use super::{CompositionPoly, ConstraintDivisor, ProverError, StarkDomain};
use math::{batch_inversion, fft, FieldElement, StarkField};
use utils::{batch_iter_mut, collections::Vec, iter_mut, uninit_vector};
#[cfg(debug_assertions)]
use air::TransitionConstraints;
#[cfg(feature = "concurrent")]
use utils::iterators::*;
const MIN_FRAGMENT_SIZE: usize = 16;
pub struct ConstraintEvaluationTable<'a, E: FieldElement> {
evaluations: Vec<Vec<E>>,
divisors: Vec<ConstraintDivisor<E::BaseField>>,
domain: &'a StarkDomain<E::BaseField>,
#[cfg(debug_assertions)]
main_transition_evaluations: Vec<Vec<E::BaseField>>,
#[cfg(debug_assertions)]
aux_transition_evaluations: Vec<Vec<E>>,
#[cfg(debug_assertions)]
expected_transition_degrees: Vec<usize>,
}
impl<'a, E: FieldElement> ConstraintEvaluationTable<'a, E> {
#[cfg(not(debug_assertions))]
pub fn new(
domain: &'a StarkDomain<E::BaseField>,
divisors: Vec<ConstraintDivisor<E::BaseField>>,
) -> Self {
let num_columns = divisors.len();
let num_rows = domain.ce_domain_size();
ConstraintEvaluationTable {
evaluations: uninit_matrix(num_columns, num_rows),
divisors,
domain,
}
}
#[cfg(debug_assertions)]
pub fn new(
domain: &'a StarkDomain<E::BaseField>,
divisors: Vec<ConstraintDivisor<E::BaseField>>,
transition_constraints: &TransitionConstraints<E>,
) -> Self {
let num_columns = divisors.len();
let num_rows = domain.ce_domain_size();
let num_tm_columns = transition_constraints.num_main_constraints();
let num_ta_columns = transition_constraints.num_aux_constraints();
let expected_transition_degrees =
build_transition_constraint_degrees(transition_constraints, domain.trace_length());
ConstraintEvaluationTable {
evaluations: uninit_matrix(num_columns, num_rows),
divisors,
domain,
main_transition_evaluations: uninit_matrix(num_tm_columns, num_rows),
aux_transition_evaluations: uninit_matrix(num_ta_columns, num_rows),
expected_transition_degrees,
}
}
pub fn num_rows(&self) -> usize {
self.evaluations[0].len()
}
#[allow(dead_code)]
pub fn num_columns(&self) -> usize {
self.evaluations.len()
}
pub fn fragments(&mut self, num_fragments: usize) -> Vec<EvaluationTableFragment<E>> {
let fragment_size = self.num_rows() / num_fragments;
assert!(
fragment_size >= MIN_FRAGMENT_SIZE,
"fragment size must be at least {}, but was {}",
MIN_FRAGMENT_SIZE,
fragment_size
);
let evaluation_data = make_fragments(&mut self.evaluations, num_fragments);
#[cfg(debug_assertions)]
let result = {
let tm_evaluation_data =
make_fragments(&mut self.main_transition_evaluations, num_fragments);
let ta_evaluation_data =
make_fragments(&mut self.aux_transition_evaluations, num_fragments);
evaluation_data
.into_iter()
.zip(tm_evaluation_data)
.zip(ta_evaluation_data)
.enumerate()
.map(|(i, ((evaluations, tm_evaluations), ta_evaluations))| {
EvaluationTableFragment {
offset: i * fragment_size,
evaluations,
tm_evaluations,
ta_evaluations,
}
})
.collect()
};
#[cfg(not(debug_assertions))]
let result = evaluation_data
.into_iter()
.enumerate()
.map(|(i, evaluations)| EvaluationTableFragment {
offset: i * fragment_size,
evaluations,
})
.collect();
result
}
pub fn into_poly(self) -> Result<CompositionPoly<E>, ProverError> {
let mut combined_poly = E::zeroed_vector(self.num_rows());
for (column, divisor) in self.evaluations.into_iter().zip(self.divisors.iter()) {
#[cfg(debug_assertions)]
validate_column_degree(&column, divisor, self.domain, column.len() - 1)?;
acc_column(column, divisor, self.domain, &mut combined_poly);
}
let inv_twiddles = fft::get_inv_twiddles::<E::BaseField>(combined_poly.len());
fft::interpolate_poly_with_offset(&mut combined_poly, &inv_twiddles, self.domain.offset());
let trace_length = self.domain.trace_length();
Ok(CompositionPoly::new(combined_poly, trace_length))
}
#[cfg(debug_assertions)]
pub fn validate_transition_degrees(&mut self) {
let div_values = evaluate_divisor::<E::BaseField>(
&self.divisors[0],
self.num_rows(),
self.domain.offset(),
);
let mut actual_degrees = Vec::with_capacity(self.expected_transition_degrees.len());
let mut max_degree = 0;
let inv_twiddles = fft::get_inv_twiddles::<E::BaseField>(self.num_rows());
for evaluations in self.main_transition_evaluations.iter() {
let degree = get_transition_poly_degree(evaluations, &inv_twiddles, &div_values);
actual_degrees.push(degree);
max_degree = core::cmp::max(max_degree, degree);
}
for evaluations in self.aux_transition_evaluations.iter() {
let degree = get_transition_poly_degree(evaluations, &inv_twiddles, &div_values);
actual_degrees.push(degree);
max_degree = core::cmp::max(max_degree, degree);
}
assert_eq!(
self.expected_transition_degrees, actual_degrees,
"transition constraint degrees didn't match\nexpected: {:>3?}\nactual: {:>3?}",
self.expected_transition_degrees, actual_degrees
);
let expected_domain_size =
core::cmp::max(max_degree, self.domain.trace_length() + 1).next_power_of_two();
assert_eq!(
expected_domain_size,
self.num_rows(),
"incorrect constraint evaluation domain size; expected {}, but was {}",
expected_domain_size,
self.num_rows()
);
}
}
pub struct EvaluationTableFragment<'a, E: FieldElement> {
offset: usize,
evaluations: Vec<&'a mut [E]>,
#[cfg(debug_assertions)]
tm_evaluations: Vec<&'a mut [E::BaseField]>,
#[cfg(debug_assertions)]
ta_evaluations: Vec<&'a mut [E]>,
}
impl<'a, E: FieldElement> EvaluationTableFragment<'a, E> {
pub fn offset(&self) -> usize {
self.offset
}
pub fn num_rows(&self) -> usize {
self.evaluations[0].len()
}
pub fn num_columns(&self) -> usize {
self.evaluations.len()
}
pub fn update_row(&mut self, row_idx: usize, row_data: &[E]) {
for (column, &value) in self.evaluations.iter_mut().zip(row_data) {
column[row_idx] = value;
}
}
#[cfg(debug_assertions)]
pub fn update_transition_evaluations(
&mut self,
row_idx: usize,
main_evaluations: &[E::BaseField],
aux_evaluations: &[E],
) {
for (column, &value) in self.tm_evaluations.iter_mut().zip(main_evaluations) {
column[row_idx] = value;
}
for (column, &value) in self.ta_evaluations.iter_mut().zip(aux_evaluations) {
column[row_idx] = value;
}
}
}
fn uninit_matrix<E: FieldElement>(num_cols: usize, num_rows: usize) -> Vec<Vec<E>> {
unsafe { (0..num_cols).map(|_| uninit_vector(num_rows)).collect() }
}
fn make_fragments<E: FieldElement>(
source: &mut [Vec<E>],
num_fragments: usize,
) -> Vec<Vec<&mut [E]>> {
let mut result = (0..num_fragments).map(|_| Vec::new()).collect::<Vec<_>>();
if !source.is_empty() {
let fragment_size = source[0].len() / num_fragments;
source.iter_mut().for_each(|column| {
for (i, fragment) in column.chunks_mut(fragment_size).enumerate() {
result[i].push(fragment);
}
});
}
result
}
#[allow(clippy::many_single_char_names)]
fn acc_column<E: FieldElement>(
column: Vec<E>,
divisor: &ConstraintDivisor<E::BaseField>,
domain: &StarkDomain<E::BaseField>,
result: &mut [E],
) {
let numerator = divisor.numerator();
assert_eq!(numerator.len(), 1, "complex divisors are not yet supported");
let z = get_inv_evaluation(divisor, domain);
if divisor.exemptions().is_empty() {
iter_mut!(result, 1024)
.zip(column)
.enumerate()
.for_each(|(i, (acc_value, value))| {
let z = z[i % z.len()];
*acc_value += value.mul_base(z);
});
} else {
batch_iter_mut!(
result,
128, |batch: &mut [E], batch_offset: usize| {
for (i, acc_value) in batch.iter_mut().enumerate() {
let x = domain.get_ce_x_at(batch_offset + i);
let e = divisor.evaluate_exemptions_at(x);
let z = z[i % z.len()];
*acc_value += column[batch_offset + i].mul_base(z * e);
}
}
);
}
}
fn get_inv_evaluation<B: StarkField>(
divisor: &ConstraintDivisor<B>,
domain: &StarkDomain<B>,
) -> Vec<B> {
let numerator = divisor.numerator();
let a = numerator[0].0 as u64; let b = numerator[0].1;
assert!(
a <= u32::MAX as u64,
"constraint divisor numerator degree cannot exceed {}, but was {}",
u32::MAX,
a
);
let n = domain.ce_domain_size() / a as usize;
let domain_offset_exp = domain.offset().exp(a.into());
let mut evaluations = unsafe { uninit_vector(n) };
batch_iter_mut!(
&mut evaluations,
128, |batch: &mut [B], batch_offset: usize| {
for (i, evaluation) in batch.iter_mut().enumerate() {
let x = domain.get_ce_x_power_at(batch_offset + i, a, domain_offset_exp);
*evaluation = x - b;
}
}
);
batch_inversion(&evaluations)
}
#[cfg(debug_assertions)]
fn build_transition_constraint_degrees<E: FieldElement>(
constraints: &TransitionConstraints<E>,
trace_length: usize,
) -> Vec<usize> {
let mut result = Vec::new();
for degree in constraints.main_constraint_degrees() {
result.push(degree.get_evaluation_degree(trace_length) - constraints.divisor().degree())
}
for degree in constraints.aux_constraint_degrees() {
result.push(degree.get_evaluation_degree(trace_length) - constraints.divisor().degree())
}
result
}
#[cfg(debug_assertions)]
fn get_transition_poly_degree<E: FieldElement>(
evaluations: &[E],
inv_twiddles: &[E::BaseField],
div_values: &[E::BaseField],
) -> usize {
let mut evaluations = evaluations
.iter()
.zip(div_values)
.map(|(&c, &d)| c / E::from(d))
.collect::<Vec<_>>();
fft::interpolate_poly(&mut evaluations, inv_twiddles);
math::polynom::degree_of(&evaluations)
}
#[cfg(debug_assertions)]
fn validate_column_degree<B: StarkField, E: FieldElement<BaseField = B>>(
column: &[E],
divisor: &ConstraintDivisor<B>,
domain: &StarkDomain<B>,
expected_degree: usize,
) -> Result<(), ProverError> {
let div_values = evaluate_divisor(divisor, column.len(), domain.offset());
let mut evaluations = column
.iter()
.zip(div_values)
.map(|(&c, d)| c / d)
.collect::<Vec<_>>();
let inv_twiddles = fft::get_inv_twiddles::<B>(evaluations.len());
fft::interpolate_poly_with_offset(&mut evaluations, &inv_twiddles, domain.offset());
let poly = evaluations;
if expected_degree != math::polynom::degree_of(&poly) {
return Err(ProverError::MismatchedConstraintPolynomialDegree(
expected_degree,
math::polynom::degree_of(&poly),
));
}
Ok(())
}
#[cfg(debug_assertions)]
fn evaluate_divisor<E: FieldElement>(
divisor: &ConstraintDivisor<E::BaseField>,
domain_size: usize,
domain_offset: E::BaseField,
) -> Vec<E> {
let g = E::BaseField::get_root_of_unity(domain_size.trailing_zeros());
let domain = math::get_power_series_with_offset(g, domain_offset, domain_size);
domain
.into_iter()
.map(|x| E::from(divisor.evaluate_at(x)))
.collect()
}