use air::Air;
use math::{fft, StarkField};
use utils::{
collections::{BTreeMap, Vec},
uninit_vector,
};
pub struct PeriodicValueTable<B: StarkField> {
values: Vec<B>,
length: usize,
width: usize,
}
impl<B: StarkField> PeriodicValueTable<B> {
pub fn new<A: Air<BaseField = B>>(air: &A) -> PeriodicValueTable<B> {
let polys = air.get_periodic_column_polys();
if polys.is_empty() {
return PeriodicValueTable {
values: Vec::new(),
length: 0,
width: 0,
};
}
let max_poly_size = polys.iter().max_by_key(|p| p.len()).unwrap().len();
let mut twiddle_map = BTreeMap::new();
let evaluations = polys
.iter()
.map(|poly| {
let poly_size = poly.len();
let num_cycles = (air.trace_length() / poly_size) as u64;
let offset = air.domain_offset().exp(num_cycles.into());
let twiddles = twiddle_map
.entry(poly_size)
.or_insert_with(|| fft::get_twiddles(poly_size));
fft::evaluate_poly_with_offset(poly, twiddles, offset, air.ce_blowup_factor())
})
.collect::<Vec<_>>();
let row_width = polys.len();
let column_length = max_poly_size * air.ce_blowup_factor();
let mut values = unsafe { uninit_vector(row_width * column_length) };
for i in 0..column_length {
for (j, column) in evaluations.iter().enumerate() {
values[i * row_width + j] = column[i % column.len()];
}
}
PeriodicValueTable {
values,
length: column_length,
width: row_width,
}
}
pub fn is_empty(&self) -> bool {
self.width == 0
}
pub fn get_row(&self, ce_step: usize) -> &[B] {
if self.is_empty() {
&[]
} else {
let start = (ce_step % self.length) * self.width;
&self.values[start..start + self.width]
}
}
}
#[cfg(test)]
mod tests {
use crate::tests::MockAir;
use air::Air;
use math::{
fields::f128::BaseElement, get_power_series_with_offset, log2, polynom, FieldElement,
StarkField,
};
use utils::collections::Vec;
#[test]
fn periodic_value_table() {
let trace_length = 32;
let col1 = vec![1u128, 2]
.into_iter()
.map(BaseElement::new)
.collect::<Vec<_>>();
let col2 = vec![3u128, 4, 5, 6]
.into_iter()
.map(BaseElement::new)
.collect::<Vec<_>>();
let air = MockAir::with_periodic_columns(vec![col1, col2], trace_length);
let table = super::PeriodicValueTable::new(&air);
assert_eq!(2, table.width);
assert_eq!(4 * air.ce_blowup_factor(), table.length);
let polys = air.get_periodic_column_polys();
let domain = build_ce_domain(air.ce_domain_size(), air.domain_offset());
let expected = polys
.iter()
.map(|poly| {
let num_cycles = trace_length / poly.len();
domain
.iter()
.map(|&x| {
let x = x.exp((num_cycles as u32).into());
polynom::eval(poly, x)
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let mut actual = vec![Vec::new(), Vec::new()];
for i in 0..air.ce_domain_size() {
let row = table.get_row(i);
actual[0].push(row[0]);
actual[1].push(row[1]);
}
assert_eq!(expected, actual);
}
fn build_ce_domain(domain_size: usize, domain_offset: BaseElement) -> Vec<BaseElement> {
let g = BaseElement::get_root_of_unity(log2(domain_size));
get_power_series_with_offset(g, domain_offset, domain_size)
}
}