use alloc::vec::Vec;
use air::{EvaluationFrame, TraceInfo};
use math::StarkField;
use utils::uninit_vector;
#[cfg(feature = "concurrent")]
use utils::{iterators::*, rayon};
use super::{ColMatrix, Trace};
const MIN_FRAGMENT_LENGTH: usize = 2;
#[derive(Debug, Clone)]
pub struct TraceTable<B: StarkField> {
info: TraceInfo,
trace: ColMatrix<B>,
}
impl<B: StarkField> TraceTable<B> {
pub fn new(width: usize, length: usize) -> Self {
Self::with_meta(width, length, vec![])
}
pub fn with_meta(width: usize, length: usize, meta: Vec<u8>) -> Self {
let info = TraceInfo::with_meta(width, length, meta);
assert!(
length.ilog2() <= B::TWO_ADICITY,
"execution trace length cannot exceed 2^{} steps, but was 2^{}",
B::TWO_ADICITY,
length.ilog2()
);
let columns = unsafe { (0..width).map(|_| uninit_vector(length)).collect() };
Self { info, trace: ColMatrix::new(columns) }
}
pub fn init(columns: Vec<Vec<B>>) -> Self {
assert!(!columns.is_empty(), "execution trace must consist of at least one column");
let trace_length = columns[0].len();
let info = TraceInfo::with_meta(columns.len(), trace_length, Vec::new());
assert!(
trace_length.ilog2() <= B::TWO_ADICITY,
"execution trace length cannot exceed 2^{} steps, but was 2^{}",
B::TWO_ADICITY,
trace_length.ilog2()
);
for column in columns.iter().skip(1) {
assert_eq!(column.len(), trace_length, "all columns traces must have the same length");
}
Self { info, trace: ColMatrix::new(columns) }
}
pub fn set(&mut self, column: usize, step: usize, value: B) {
self.trace.set(column, step, value)
}
pub fn fill<I, U>(&mut self, init: I, mut update: U)
where
I: FnOnce(&mut [B]),
U: FnMut(usize, &mut [B]),
{
let mut state = vec![B::ZERO; self.info.main_trace_width()];
init(&mut state);
self.update_row(0, &state);
for i in 0..self.info.length() - 1 {
update(i, &mut state);
self.update_row(i + 1, &state);
}
}
pub fn update_row(&mut self, step: usize, state: &[B]) {
self.trace.update_row(step, state);
}
#[cfg(not(feature = "concurrent"))]
pub fn fragments(
&mut self,
fragment_length: usize,
) -> alloc::vec::IntoIter<TraceTableFragment<B>> {
self.build_fragments(fragment_length).into_iter()
}
#[cfg(feature = "concurrent")]
pub fn fragments(
&mut self,
fragment_length: usize,
) -> rayon::vec::IntoIter<TraceTableFragment<'_, B>> {
self.build_fragments(fragment_length).into_par_iter()
}
fn build_fragments(&mut self, fragment_length: usize) -> Vec<TraceTableFragment<'_, B>> {
assert!(
fragment_length >= MIN_FRAGMENT_LENGTH,
"fragment length must be at least {MIN_FRAGMENT_LENGTH}, but was {fragment_length}"
);
assert!(
fragment_length <= self.info.length(),
"length of a fragment cannot exceed {}, but was {}",
self.info.length(),
fragment_length
);
assert!(fragment_length.is_power_of_two(), "fragment length must be a power of 2");
let num_fragments = self.info.length() / fragment_length;
let mut fragment_data = (0..num_fragments).map(|_| Vec::new()).collect::<Vec<_>>();
self.trace.columns_mut().for_each(|column| {
for (i, fragment) in column.chunks_mut(fragment_length).enumerate() {
fragment_data[i].push(fragment);
}
});
fragment_data
.into_iter()
.enumerate()
.map(|(i, data)| TraceTableFragment {
index: i,
offset: i * fragment_length,
data,
})
.collect()
}
pub fn width(&self) -> usize {
self.info.main_trace_width()
}
pub fn get_column(&self, col_idx: usize) -> &[B] {
self.trace.get_column(col_idx)
}
pub fn get(&self, column: usize, step: usize) -> B {
self.trace.get(column, step)
}
pub fn read_row_into(&self, step: usize, target: &mut [B]) {
self.trace.read_row_into(step, target);
}
}
impl<B: StarkField> Trace for TraceTable<B> {
type BaseField = B;
fn info(&self) -> &TraceInfo {
&self.info
}
fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame<Self::BaseField>) {
let next_row_idx = (row_idx + 1) % self.info.length();
self.trace.read_row_into(row_idx, frame.current_mut());
self.trace.read_row_into(next_row_idx, frame.next_mut());
}
fn main_segment(&self) -> &ColMatrix<B> {
&self.trace
}
}
pub struct TraceTableFragment<'a, B: StarkField> {
index: usize,
offset: usize,
data: Vec<&'a mut [B]>,
}
impl<B: StarkField> TraceTableFragment<'_, B> {
pub fn index(&self) -> usize {
self.index
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn length(&self) -> usize {
self.data[0].len()
}
pub fn width(&self) -> usize {
self.data.len()
}
pub fn fill<I, T>(&mut self, init_state: I, mut update_state: T)
where
I: FnOnce(&mut [B]),
T: FnMut(usize, &mut [B]),
{
let mut state = vec![B::ZERO; self.width()];
init_state(&mut state);
self.update_row(0, &state);
for i in 0..self.length() - 1 {
update_state(i, &mut state);
self.update_row(i + 1, &state);
}
}
pub fn update_row(&mut self, row_idx: usize, row_data: &[B]) {
for (column, &value) in self.data.iter_mut().zip(row_data) {
column[row_idx] = value;
}
}
}