use alloc::vec::Vec;
use air::proof::TraceOodFrame;
use math::{FieldElement, StarkField};
use crate::{matrix::ColumnIter, ColMatrix};
pub struct TracePolyTable<E: FieldElement> {
main_trace_polys: ColMatrix<E::BaseField>,
aux_trace_polys: Option<ColMatrix<E>>,
}
impl<E: FieldElement> TracePolyTable<E> {
pub fn new(main_trace_polys: ColMatrix<E::BaseField>) -> Self {
Self { main_trace_polys, aux_trace_polys: None }
}
pub fn add_aux_segment(&mut self, aux_trace_polys: ColMatrix<E>) {
assert!(self.aux_trace_polys.is_none());
assert_eq!(
self.main_trace_polys.num_rows(),
aux_trace_polys.num_rows(),
"polynomials in auxiliary segment must be of the same size as in the main segment"
);
self.aux_trace_polys = Some(aux_trace_polys);
}
pub fn poly_size(&self) -> usize {
self.main_trace_polys.num_rows()
}
pub fn evaluate_at(&self, x: E) -> Vec<E> {
let mut result = self.main_trace_polys.evaluate_columns_at(x);
for aux_polys in self.aux_trace_polys.iter() {
result.append(&mut aux_polys.evaluate_columns_at(x));
}
result
}
pub fn get_ood_frame(&self, z: E) -> TraceOodFrame<E> {
let log_trace_len = self.poly_size().ilog2();
let g = E::from(E::BaseField::get_root_of_unity(log_trace_len));
let current_row = self.evaluate_at(z);
let next_row = self.evaluate_at(z * g);
let main_trace_width = self.main_trace_polys.num_cols();
TraceOodFrame::new(current_row, next_row, main_trace_width)
}
pub fn main_trace_polys(&self) -> impl Iterator<Item = &[E::BaseField]> {
self.main_trace_polys.columns()
}
pub fn aux_trace_polys(&self) -> impl Iterator<Item = &[E]> {
match self.aux_trace_polys {
Some(ref aux_segment_polys) => aux_segment_polys.columns(),
None => ColumnIter::empty(),
}
}
#[cfg(test)]
pub fn num_main_trace_polys(&self) -> usize {
self.main_trace_polys.num_cols()
}
#[cfg(test)]
pub fn get_main_trace_poly(&self, idx: usize) -> &[E::BaseField] {
self.main_trace_polys.get_column(idx)
}
}