use alloc::{collections::BTreeMap, vec::Vec};
use air::{Air, AuxRandElements, ConstraintDivisor};
use math::{fft, ExtensionOf, FieldElement};
use super::StarkDomain;
const SMALL_POLY_DEGREE: usize = 63;
pub struct BoundaryConstraints<E: FieldElement>(Vec<BoundaryConstraintGroup<E>>);
impl<E: FieldElement> BoundaryConstraints<E> {
pub fn new<A: Air<BaseField = E::BaseField>>(
air: &A,
aux_rand_elements: Option<&AuxRandElements<E>>,
composition_coefficients: &[E],
) -> Self {
let source = air.get_boundary_constraints(aux_rand_elements, composition_coefficients);
let mut twiddle_map = BTreeMap::new();
let mut result = source
.main_constraints()
.iter()
.map(|group| {
BoundaryConstraintGroup::from_main_constraints(group, air, &mut twiddle_map)
})
.collect::<Vec<BoundaryConstraintGroup<E>>>();
for group in source.aux_constraints() {
match result.iter_mut().find(|g| &g.divisor == group.divisor()) {
Some(x) => x.add_aux_constraints(group, air, &mut twiddle_map),
None => {
let group =
BoundaryConstraintGroup::from_aux_constraints(group, air, &mut twiddle_map);
result.push(group);
},
};
}
Self(result)
}
pub fn get_divisors(&self) -> Vec<ConstraintDivisor<E::BaseField>> {
self.0.iter().map(|g| g.divisor.clone()).collect()
}
pub fn evaluate_main(
&self,
main_state: &[E::BaseField],
domain: &StarkDomain<E::BaseField>,
step: usize,
result: &mut [E],
) {
let x = domain.get_ce_x_at(step);
for (group, result) in self.0.iter().zip(result.iter_mut()) {
*result = group.evaluate_main(main_state, step, x);
}
}
pub fn evaluate_all(
&self,
main_state: &[E::BaseField],
aux_state: &[E],
domain: &StarkDomain<E::BaseField>,
step: usize,
result: &mut [E],
) {
let x = domain.get_ce_x_at(step);
for (group, result) in self.0.iter().zip(result.iter_mut()) {
*result = group.evaluate_all(main_state, aux_state, step, x);
}
}
}
pub struct BoundaryConstraintGroup<E: FieldElement> {
divisor: ConstraintDivisor<E::BaseField>,
main_single_value: Vec<SingleValueConstraint<E::BaseField, E>>,
main_small_poly: Vec<SmallPolyConstraint<E::BaseField, E>>,
main_large_poly: Vec<LargePolyConstraint<E::BaseField, E>>,
aux_single_value: Vec<SingleValueConstraint<E, E>>,
aux_small_poly: Vec<SmallPolyConstraint<E, E>>,
aux_large_poly: Vec<LargePolyConstraint<E, E>>,
}
impl<E: FieldElement> BoundaryConstraintGroup<E> {
fn new(divisor: ConstraintDivisor<E::BaseField>) -> Self {
Self {
divisor,
main_single_value: Vec::new(),
main_small_poly: Vec::new(),
main_large_poly: Vec::new(),
aux_single_value: Vec::new(),
aux_small_poly: Vec::new(),
aux_large_poly: Vec::new(),
}
}
pub fn from_main_constraints<A: Air<BaseField = E::BaseField>>(
source: &air::BoundaryConstraintGroup<E::BaseField, E>,
air: &A,
twiddle_map: &mut BTreeMap<usize, Vec<E::BaseField>>,
) -> Self {
let mut result = Self::new(source.divisor().clone());
for constraint in source.constraints() {
if constraint.poly().len() == 1 {
let constraint = SingleValueConstraint::new(constraint);
result.main_single_value.push(constraint);
} else if constraint.poly().len() < SMALL_POLY_DEGREE {
let constraint = SmallPolyConstraint::new(constraint);
result.main_small_poly.push(constraint);
} else {
let constraint = LargePolyConstraint::new(constraint, air, twiddle_map);
result.main_large_poly.push(constraint);
}
}
result
}
pub fn from_aux_constraints<A: Air<BaseField = E::BaseField>>(
group: &air::BoundaryConstraintGroup<E, E>,
air: &A,
twiddle_map: &mut BTreeMap<usize, Vec<E::BaseField>>,
) -> Self {
let mut result = Self::new(group.divisor().clone());
result.add_aux_constraints(group, air, twiddle_map);
result
}
pub fn add_aux_constraints<A: Air<BaseField = E::BaseField>>(
&mut self,
group: &air::BoundaryConstraintGroup<E, E>,
air: &A,
twiddle_map: &mut BTreeMap<usize, Vec<E::BaseField>>,
) {
assert_eq!(group.divisor(), &self.divisor, "inconsistent constraint divisor");
for constraint in group.constraints() {
if constraint.poly().len() == 1 {
let constraint = SingleValueConstraint::new(constraint);
self.aux_single_value.push(constraint);
} else if constraint.poly().len() < SMALL_POLY_DEGREE {
let constraint = SmallPolyConstraint::new(constraint);
self.aux_small_poly.push(constraint);
} else {
let constraint = LargePolyConstraint::new(constraint, air, twiddle_map);
self.aux_large_poly.push(constraint);
}
}
}
pub fn evaluate_main(&self, state: &[E::BaseField], ce_step: usize, x: E::BaseField) -> E {
let mut result = E::ZERO;
for constraint in self.main_single_value.iter() {
result += constraint.evaluate(state);
}
for constraint in self.main_small_poly.iter() {
result += constraint.evaluate(state, x);
}
for constraint in self.main_large_poly.iter() {
result += constraint.evaluate(state, ce_step);
}
result
}
pub fn evaluate_all(
&self,
main_state: &[E::BaseField],
aux_state: &[E],
ce_step: usize,
x: E::BaseField,
) -> E {
let mut result = self.evaluate_main(main_state, ce_step, x);
for constraint in self.aux_single_value.iter() {
result += constraint.evaluate(aux_state);
}
for constraint in self.aux_small_poly.iter() {
result += constraint.evaluate(aux_state, x);
}
for constraint in self.aux_large_poly.iter() {
result += constraint.evaluate(aux_state, ce_step);
}
result
}
}
struct SingleValueConstraint<F, E>
where
F: FieldElement,
E: FieldElement<BaseField = F::BaseField> + ExtensionOf<F>,
{
column: usize,
value: F,
coefficients: E,
}
impl<F, E> SingleValueConstraint<F, E>
where
F: FieldElement,
E: FieldElement<BaseField = F::BaseField> + ExtensionOf<F>,
{
pub fn new(source: &air::BoundaryConstraint<F, E>) -> Self {
debug_assert!(source.poly().len() == 1, "not a single constraint");
Self {
column: source.column(),
value: source.poly()[0],
coefficients: *source.cc(),
}
}
pub fn evaluate(&self, state: &[F]) -> E {
let evaluation = state[self.column] - self.value;
self.coefficients.mul_base(evaluation)
}
}
struct SmallPolyConstraint<F, E>
where
F: FieldElement,
E: FieldElement<BaseField = F::BaseField> + ExtensionOf<F>,
{
column: usize,
poly: Vec<F>,
x_offset: F::BaseField,
coefficients: E,
}
impl<F, E> SmallPolyConstraint<F, E>
where
F: FieldElement,
E: FieldElement<BaseField = F::BaseField> + ExtensionOf<F>,
{
pub fn new(source: &air::BoundaryConstraint<F, E>) -> Self {
debug_assert!(
source.poly().len() > 1 && source.poly().len() < SMALL_POLY_DEGREE,
"not a small poly constraint"
);
Self {
column: source.column(),
poly: source.poly().to_vec(),
x_offset: source.poly_offset().1,
coefficients: *source.cc(),
}
}
pub fn evaluate(&self, state: &[F], x: F::BaseField) -> E {
let x = x * self.x_offset;
let assertion_value =
self.poly.iter().rev().fold(F::ZERO, |acc, &coeff| acc.mul_base(x) + coeff);
let evaluation = state[self.column] - assertion_value;
self.coefficients.mul_base(evaluation)
}
}
struct LargePolyConstraint<F, E>
where
F: FieldElement,
E: FieldElement<BaseField = F::BaseField> + ExtensionOf<F>,
{
column: usize,
values: Vec<F>,
step_offset: usize,
coefficients: E,
}
impl<F, E> LargePolyConstraint<F, E>
where
F: FieldElement,
E: FieldElement<BaseField = F::BaseField> + ExtensionOf<F>,
{
pub fn new<A: Air<BaseField = F::BaseField>>(
source: &air::BoundaryConstraint<F, E>,
air: &A,
twiddle_map: &mut BTreeMap<usize, Vec<F::BaseField>>,
) -> Self {
debug_assert!(source.poly().len() >= SMALL_POLY_DEGREE, "not a large poly constraint");
let poly_length = source.poly().len();
let twiddles =
twiddle_map.entry(poly_length).or_insert_with(|| fft::get_twiddles(poly_length));
let values = fft::evaluate_poly_with_offset(
source.poly(),
twiddles,
air.domain_offset(),
air.ce_domain_size() / poly_length,
);
LargePolyConstraint {
column: source.column(),
values,
step_offset: source.poly_offset().0 * air.ce_blowup_factor(),
coefficients: *source.cc(),
}
}
pub fn evaluate(&self, state: &[F], ce_step: usize) -> E {
let value_index = if self.step_offset > 0 {
if self.step_offset > ce_step {
self.values.len() + ce_step - self.step_offset
} else {
ce_step - self.step_offset
}
} else {
ce_step
};
let evaluation = state[self.column] - self.values[value_index];
(self.coefficients).mul_base(evaluation)
}
}