use air::{
Air, AuxRandElements, ConstraintCompositionCoefficients, EvaluationFrame, TransitionConstraints,
};
use math::FieldElement;
use tracing::instrument;
use utils::iter_mut;
#[cfg(feature = "concurrent")]
use utils::{iterators::*, rayon};
use super::{
super::EvaluationTableFragment, BoundaryConstraints, CompositionPolyTrace,
ConstraintEvaluationTable, ConstraintEvaluator, PeriodicValueTable, StarkDomain, TraceLde,
};
#[cfg(feature = "concurrent")]
const MIN_CONCURRENT_DOMAIN_SIZE: usize = 8192;
pub struct DefaultConstraintEvaluator<'a, A: Air, E: FieldElement<BaseField = A::BaseField>> {
air: &'a A,
boundary_constraints: BoundaryConstraints<E>,
transition_constraints: TransitionConstraints<E>,
aux_rand_elements: Option<AuxRandElements<E>>,
periodic_values: PeriodicValueTable<E::BaseField>,
}
impl<A, E> ConstraintEvaluator<E> for DefaultConstraintEvaluator<'_, A, E>
where
A: Air,
E: FieldElement<BaseField = A::BaseField>,
{
type Air = A;
#[instrument(
skip_all,
name = "evaluate_constraints",
fields(
ce_domain_size = %domain.ce_domain_size()
)
)]
fn evaluate<T: TraceLde<E>>(
self,
trace: &T,
domain: &StarkDomain<<E as FieldElement>::BaseField>,
) -> CompositionPolyTrace<E> {
assert_eq!(
trace.trace_len(),
domain.lde_domain_size(),
"extended trace length is not consistent with evaluation domain"
);
let mut divisors = vec![self.transition_constraints.divisor().clone()];
divisors.append(&mut self.boundary_constraints.get_divisors());
#[cfg(not(debug_assertions))]
let mut evaluation_table = ConstraintEvaluationTable::<E>::new(domain, divisors);
#[cfg(debug_assertions)]
let mut evaluation_table =
ConstraintEvaluationTable::<E>::new(domain, divisors, &self.transition_constraints);
#[cfg(not(feature = "concurrent"))]
let num_fragments = 1;
#[cfg(feature = "concurrent")]
let num_fragments = if domain.ce_domain_size() >= MIN_CONCURRENT_DOMAIN_SIZE {
rayon::current_num_threads().next_power_of_two()
} else {
1
};
let mut fragments = evaluation_table.fragments(num_fragments);
iter_mut!(fragments).for_each(|fragment| {
if self.air.trace_info().is_multi_segment() {
self.evaluate_fragment_full(trace, domain, fragment);
} else {
self.evaluate_fragment_main(trace, domain, fragment);
}
});
#[cfg(debug_assertions)]
evaluation_table.validate_transition_degrees();
CompositionPolyTrace::new(evaluation_table.combine())
}
}
impl<'a, A, E> DefaultConstraintEvaluator<'a, A, E>
where
A: Air,
E: FieldElement<BaseField = A::BaseField>,
{
pub fn new(
air: &'a A,
aux_rand_elements: Option<AuxRandElements<E>>,
composition_coefficients: ConstraintCompositionCoefficients<E>,
) -> Self {
let transition_constraints =
air.get_transition_constraints(&composition_coefficients.transition);
let periodic_values = PeriodicValueTable::new(air);
let boundary_constraints = BoundaryConstraints::new(
air,
aux_rand_elements.as_ref(),
&composition_coefficients.boundary,
);
DefaultConstraintEvaluator {
air,
boundary_constraints,
transition_constraints,
aux_rand_elements,
periodic_values,
}
}
fn evaluate_fragment_main<T: TraceLde<E>>(
&self,
trace: &T,
domain: &StarkDomain<A::BaseField>,
fragment: &mut EvaluationTableFragment<E>,
) {
let mut main_frame = EvaluationFrame::new(trace.trace_info().main_trace_width());
let mut evaluations = vec![E::ZERO; fragment.num_columns()];
let mut t_evaluations = vec![E::BaseField::ZERO; self.num_main_transition_constraints()];
let lde_shift = domain.ce_to_lde_blowup().trailing_zeros();
for i in 0..fragment.num_rows() {
let step = i + fragment.offset();
trace.read_main_trace_frame_into(step << lde_shift, &mut main_frame);
evaluations[0] = self.evaluate_main_transition(&main_frame, step, &mut t_evaluations);
#[cfg(debug_assertions)]
fragment.update_transition_evaluations(i, &t_evaluations, &[]);
let main_state = main_frame.current();
self.boundary_constraints.evaluate_main(
main_state,
domain,
step,
&mut evaluations[1..],
);
fragment.update_row(i, &evaluations);
}
}
fn evaluate_fragment_full<T: TraceLde<E>>(
&self,
trace: &T,
domain: &StarkDomain<A::BaseField>,
fragment: &mut EvaluationTableFragment<E>,
) {
let mut main_frame = EvaluationFrame::new(trace.trace_info().main_trace_width());
let mut aux_frame = EvaluationFrame::new(trace.trace_info().aux_segment_width());
let mut tm_evaluations = vec![E::BaseField::ZERO; self.num_main_transition_constraints()];
let mut ta_evaluations = vec![E::ZERO; self.num_aux_transition_constraints()];
let mut evaluations = vec![E::ZERO; fragment.num_columns()];
let lde_shift = domain.ce_to_lde_blowup().trailing_zeros();
for i in 0..fragment.num_rows() {
let step = i + fragment.offset();
trace.read_main_trace_frame_into(step << lde_shift, &mut main_frame);
trace.read_aux_trace_frame_into(step << lde_shift, &mut aux_frame);
evaluations[0] = self.evaluate_main_transition(&main_frame, step, &mut tm_evaluations);
evaluations[0] +=
self.evaluate_aux_transition(&main_frame, &aux_frame, step, &mut ta_evaluations);
#[cfg(debug_assertions)]
fragment.update_transition_evaluations(i, &tm_evaluations, &ta_evaluations);
let main_state = main_frame.current();
let aux_state = aux_frame.current();
self.boundary_constraints.evaluate_all(
main_state,
aux_state,
domain,
step,
&mut evaluations[1..],
);
fragment.update_row(i, &evaluations);
}
}
fn evaluate_main_transition(
&self,
main_frame: &EvaluationFrame<E::BaseField>,
step: usize,
evaluations: &mut [E::BaseField],
) -> E {
evaluations.fill(E::BaseField::ZERO);
let periodic_values = self.periodic_values.get_row(step);
self.air.evaluate_transition(main_frame, periodic_values, evaluations);
evaluations
.iter()
.zip(self.transition_constraints.main_constraint_coef().iter())
.fold(E::ZERO, |acc, (&const_eval, &coef)| acc + coef.mul_base(const_eval))
}
fn evaluate_aux_transition(
&self,
main_frame: &EvaluationFrame<E::BaseField>,
aux_frame: &EvaluationFrame<E>,
step: usize,
evaluations: &mut [E],
) -> E {
evaluations.fill(E::ZERO);
let periodic_values = self.periodic_values.get_row(step);
self.air.evaluate_aux_transition(
main_frame,
aux_frame,
periodic_values,
self.aux_rand_elements
.as_ref()
.expect("expected aux rand elements to be present"),
evaluations,
);
evaluations
.iter()
.zip(self.transition_constraints.aux_constraint_coef().iter())
.fold(E::ZERO, |acc, (&const_eval, &coef)| acc + coef * const_eval)
}
fn num_main_transition_constraints(&self) -> usize {
self.transition_constraints.num_main_constraints()
}
fn num_aux_transition_constraints(&self) -> usize {
self.transition_constraints.num_aux_constraints()
}
}