use alloc::vec::Vec;
use air::proof::QuotientOodFrame;
use math::{fft, polynom::degree_of, FieldElement, StarkField};
use super::{ColMatrix, StarkDomain};
pub struct CompositionPolyTrace<E>(Vec<E>);
impl<E: FieldElement> CompositionPolyTrace<E> {
pub fn new(evaluations: Vec<E>) -> Self {
assert!(
evaluations.len().is_power_of_two(),
"length of composition polynomial trace must be a power of 2, but was {}",
evaluations.len(),
);
Self(evaluations)
}
pub fn num_rows(&self) -> usize {
self.0.len()
}
pub fn into_inner(self) -> Vec<E> {
self.0
}
}
pub struct CompositionPoly<E: FieldElement> {
data: ColMatrix<E>,
}
impl<E: FieldElement> CompositionPoly<E> {
pub fn new(
composition_trace: CompositionPolyTrace<E>,
domain: &StarkDomain<E::BaseField>,
num_cols: usize,
) -> Self {
assert!(
domain.trace_length() < composition_trace.num_rows(),
"trace length must be smaller than length of composition polynomial trace"
);
let mut trace = composition_trace.into_inner();
let inv_twiddles = fft::get_inv_twiddles::<E::BaseField>(trace.len());
fft::interpolate_poly_with_offset(&mut trace, &inv_twiddles, domain.offset());
let polys = segment(trace, domain.trace_length(), num_cols);
CompositionPoly { data: ColMatrix::new(polys) }
}
pub fn num_columns(&self) -> usize {
self.data.num_cols()
}
pub fn column_len(&self) -> usize {
self.data.num_rows()
}
#[allow(unused)]
pub fn column_degree(&self) -> usize {
self.column_len() - 1
}
pub fn get_ood_frame(&self, z: E) -> QuotientOodFrame<E> {
let log_trace_len = self.column_len().ilog2();
let g = E::from(E::BaseField::get_root_of_unity(log_trace_len));
let current_row = self.data.evaluate_columns_at(z);
let next_row = self.data.evaluate_columns_at(z * g);
QuotientOodFrame::new(current_row, next_row)
}
pub fn data(&self) -> &ColMatrix<E> {
&self.data
}
pub fn into_columns(self) -> Vec<Vec<E>> {
self.data.into_columns()
}
}
fn segment<E: FieldElement>(
coefficients: Vec<E>,
trace_len: usize,
num_cols: usize,
) -> Vec<Vec<E>> {
debug_assert!(degree_of(&coefficients) < trace_len * num_cols);
coefficients
.chunks(trace_len)
.take(num_cols)
.map(|slice| slice.to_vec())
.collect()
}
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
use math::fields::f128::BaseElement;
#[test]
fn segment() {
let values = (0u128..16).map(BaseElement::new).collect::<Vec<_>>();
let actual = super::segment(values, 4, 4);
#[rustfmt::skip]
let expected = vec![
vec![BaseElement::new(0), BaseElement::new(1), BaseElement::new(2), BaseElement::new(3)],
vec![BaseElement::new(4), BaseElement::new(5), BaseElement::new(6), BaseElement::new(7)],
vec![BaseElement::new(8), BaseElement::new(9), BaseElement::new(10), BaseElement::new(11)],
vec![BaseElement::new(12), BaseElement::new(13), BaseElement::new(14), BaseElement::new(15)],
];
assert_eq!(expected, actual)
}
}