use air::{Air, AuxRandElements, EvaluationFrame, TraceInfo};
use math::{polynom, FieldElement, StarkField};
use super::ColMatrix;
mod trace_lde;
pub use trace_lde::{DefaultTraceLde, TraceLde};
mod poly_table;
pub use poly_table::TracePolyTable;
mod trace_table;
pub use trace_table::{TraceTable, TraceTableFragment};
#[cfg(test)]
mod tests;
pub struct AuxTraceWithMetadata<E: FieldElement> {
pub aux_trace: ColMatrix<E>,
pub aux_rand_elements: AuxRandElements<E>,
}
pub trait Trace: Sized {
type BaseField: StarkField;
fn info(&self) -> &TraceInfo;
fn main_segment(&self) -> &ColMatrix<Self::BaseField>;
fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame<Self::BaseField>);
fn length(&self) -> usize {
self.info().length()
}
fn main_trace_width(&self) -> usize {
self.info().main_trace_width()
}
fn aux_trace_width(&self) -> usize {
self.info().aux_segment_width()
}
fn validate<A, E>(&self, air: &A, aux_trace_with_metadata: Option<&AuxTraceWithMetadata<E>>)
where
A: Air<BaseField = Self::BaseField>,
E: FieldElement<BaseField = Self::BaseField>,
{
assert_eq!(
self.main_trace_width(),
air.trace_info().main_trace_width(),
"inconsistent trace width: expected {}, but was {}",
self.main_trace_width(),
air.trace_info().main_trace_width(),
);
for assertion in air.get_assertions() {
assertion.apply(self.length(), |step, value| {
assert!(
value == self.main_segment().get(assertion.column(), step),
"trace does not satisfy assertion main_trace({}, {}) == {}",
assertion.column(),
step,
value
);
});
}
if let Some(aux_trace_with_metadata) = aux_trace_with_metadata {
let aux_trace = &aux_trace_with_metadata.aux_trace;
let aux_rand_elements = &aux_trace_with_metadata.aux_rand_elements;
for assertion in air.get_aux_assertions(aux_rand_elements) {
assertion.apply(self.length(), |step, value| {
assert!(
value == aux_trace.get(assertion.column(), step),
"trace does not satisfy assertion aux_trace({}, {}) == {}",
assertion.column(),
step,
value
);
});
}
}
let g = air.trace_domain_generator();
let periodic_values_polys = air.get_periodic_column_polys();
let mut periodic_values = vec![Self::BaseField::ZERO; periodic_values_polys.len()];
let mut x = Self::BaseField::ONE;
let mut main_frame = EvaluationFrame::new(self.main_trace_width());
let mut aux_frame = if air.trace_info().is_multi_segment() {
Some(EvaluationFrame::<E>::new(self.aux_trace_width()))
} else {
None
};
let mut main_evaluations =
vec![Self::BaseField::ZERO; air.context().num_main_transition_constraints()];
let mut aux_evaluations = vec![E::ZERO; air.context().num_aux_transition_constraints()];
for step in 0..self.length() - air.context().num_transition_exemptions() {
for (p, v) in periodic_values_polys.iter().zip(periodic_values.iter_mut()) {
let num_cycles = air.trace_length() / p.len();
let x = x.exp((num_cycles as u32).into());
*v = polynom::eval(p, x);
}
self.read_main_frame(step, &mut main_frame);
air.evaluate_transition(&main_frame, &periodic_values, &mut main_evaluations);
for (i, &evaluation) in main_evaluations.iter().enumerate() {
assert!(
evaluation == Self::BaseField::ZERO,
"main transition constraint {i} did not evaluate to ZERO at step {step}"
);
}
if let Some(ref mut aux_frame) = aux_frame {
let aux_trace_with_metadata =
aux_trace_with_metadata.expect("expected aux trace to be present");
let aux_trace = &aux_trace_with_metadata.aux_trace;
let aux_rand_elements = &aux_trace_with_metadata.aux_rand_elements;
read_aux_frame(aux_trace, step, aux_frame);
air.evaluate_aux_transition(
&main_frame,
aux_frame,
&periodic_values,
aux_rand_elements,
&mut aux_evaluations,
);
for (i, &evaluation) in aux_evaluations.iter().enumerate() {
assert!(
evaluation == E::ZERO,
"auxiliary transition constraint {i} did not evaluate to ZERO at step {step}"
);
}
}
x *= g;
}
}
}
fn read_aux_frame<E>(aux_segment: &ColMatrix<E>, row_idx: usize, frame: &mut EvaluationFrame<E>)
where
E: FieldElement,
{
for (current_frame_cell, aux_segment_col) in
frame.current_mut().iter_mut().zip(aux_segment.columns())
{
*current_frame_cell = aux_segment_col[row_idx];
}
let next_row_idx = (row_idx + 1) % aux_segment.num_rows();
for (next_frame_cell, aux_segment_col) in frame.next_mut().iter_mut().zip(aux_segment.columns())
{
*next_frame_cell = aux_segment_col[next_row_idx];
}
}