use super::{utils::build_lookup_table_row_values, Felt, FieldElement, Matrix, Vec};
use crate::decoder::{AuxTraceHints, BlockTableUpdate, OpGroupTableUpdate};
use vm_core::{utils::uninit_vector, DECODER_TRACE_OFFSET};
#[cfg(test)]
mod tests;
const ADDR_COL_IDX: usize = DECODER_TRACE_OFFSET + vm_core::decoder::ADDR_COL_IDX;
pub fn build_aux_columns<E: FieldElement<BaseField = Felt>>(
main_trace: &Matrix<Felt>,
aux_trace_hints: &AuxTraceHints,
rand_elements: &[E],
) -> Vec<Vec<E>> {
let p1 = build_aux_col_p1(main_trace, aux_trace_hints, rand_elements);
let p2 = build_aux_col_p2(main_trace, aux_trace_hints, rand_elements);
let p3 = build_aux_col_p3(
main_trace,
main_trace.num_rows(),
aux_trace_hints,
rand_elements,
);
vec![p1, p2, p3]
}
fn build_aux_col_p1<E: FieldElement<BaseField = Felt>>(
main_trace: &Matrix<Felt>,
aux_trace_hints: &AuxTraceHints,
alphas: &[E],
) -> Vec<E> {
let table_rows = aux_trace_hints.block_stack_table_rows();
let (row_values, inv_row_values) =
build_lookup_table_row_values(table_rows, main_trace, alphas);
let mut result = unsafe { uninit_vector(main_trace.num_rows()) };
result[0] = E::ONE;
let mut started_block_idx = 0;
let mut result_idx = 0_usize;
for (clk, update) in aux_trace_hints.block_exec_hints() {
let clk = *clk as usize;
if result_idx < clk {
let last_value = result[result_idx];
result[(result_idx + 1)..=clk].fill(last_value);
}
result_idx = clk + 1;
match update {
BlockTableUpdate::BlockStarted(_) => {
result[result_idx] = result[clk] * row_values[started_block_idx];
started_block_idx += 1;
}
BlockTableUpdate::SpanExtended => {
let old_row_value_inv = inv_row_values[started_block_idx - 1];
let new_row_value = row_values[started_block_idx];
result[result_idx] = result[clk] * old_row_value_inv * new_row_value;
started_block_idx += 1;
}
BlockTableUpdate::BlockEnded(_) => {
let block_id = get_block_addr(main_trace, clk as u32);
let row_idx = aux_trace_hints
.get_block_stack_row_idx(block_id)
.expect("block stack row not found");
result[result_idx] = result[clk] * inv_row_values[row_idx];
}
BlockTableUpdate::LoopRepeated => result[result_idx] = result[clk],
}
}
let last_value = result[result_idx];
assert_eq!(last_value, E::ONE);
if result_idx < result.len() - 1 {
result[(result_idx + 1)..].fill(E::ONE);
}
result
}
fn build_aux_col_p2<E: FieldElement<BaseField = Felt>>(
main_trace: &Matrix<Felt>,
aux_trace_hints: &AuxTraceHints,
alphas: &[E],
) -> Vec<E> {
let table_rows = aux_trace_hints.block_hash_table_rows();
let (row_values, inv_row_values) =
build_lookup_table_row_values(table_rows, main_trace, alphas);
let mut result = unsafe { uninit_vector(main_trace.num_rows()) };
result[0] = row_values[0];
let mut started_block_idx = 1;
let mut result_idx = 0_usize;
for (clk, update) in aux_trace_hints.block_exec_hints() {
let clk = *clk as usize;
if result_idx < clk {
let last_value = result[result_idx];
result[(result_idx + 1)..=clk].fill(last_value);
}
result_idx = clk + 1;
match update {
BlockTableUpdate::BlockStarted(num_children) => {
match *num_children {
0 => result[result_idx] = result[clk],
1 => {
debug_assert!(!table_rows[started_block_idx].is_first_child());
result[result_idx] = result[clk] * row_values[started_block_idx];
}
2 => {
debug_assert!(table_rows[started_block_idx].is_first_child());
debug_assert!(!table_rows[started_block_idx + 1].is_first_child());
result[result_idx] = result[clk]
* row_values[started_block_idx]
* row_values[started_block_idx + 1];
}
_ => panic!("invalid number of children for a block"),
}
started_block_idx += *num_children as usize;
}
BlockTableUpdate::LoopRepeated => {
let parent_id = get_block_addr(main_trace, result_idx as u32);
let row_idx = aux_trace_hints
.get_block_hash_row_idx(parent_id, false)
.expect("block hash row not found");
result[result_idx] = result[clk] * row_values[row_idx];
}
BlockTableUpdate::BlockEnded(is_first_child) => {
let parent_id = get_block_addr(main_trace, result_idx as u32);
let row_idx = aux_trace_hints
.get_block_hash_row_idx(parent_id, *is_first_child)
.expect("block hash row not found");
result[result_idx] = result[clk] * inv_row_values[row_idx];
}
BlockTableUpdate::SpanExtended => result[result_idx] = result[clk],
}
}
let last_value = result[result_idx];
assert_eq!(last_value, E::ONE);
if result_idx < result.len() - 1 {
result[(result_idx + 1)..].fill(E::ONE);
}
result
}
fn build_aux_col_p3<E: FieldElement<BaseField = Felt>>(
main_trace: &Matrix<Felt>,
trace_len: usize,
aux_trace_hints: &AuxTraceHints,
alphas: &[E],
) -> Vec<E> {
let mut result = unsafe { uninit_vector(trace_len) };
result[0] = E::ONE;
let (row_values, inv_row_values) =
build_lookup_table_row_values(aux_trace_hints.op_group_table_rows(), main_trace, alphas);
let mut inserted_group_idx = 0_usize;
let mut removed_group_idx = 0_usize;
let mut result_idx = 0_usize;
for (clk, update) in aux_trace_hints.op_group_table_hints() {
let clk = *clk as usize;
if result_idx < clk {
let last_value = result[result_idx];
result[(result_idx + 1)..=clk].fill(last_value);
}
result_idx = clk + 1;
match update {
OpGroupTableUpdate::InsertRows(num_op_groups) => {
let mut value = row_values[inserted_group_idx];
for i in 1..(*num_op_groups as usize) {
value *= row_values[inserted_group_idx + i];
}
result[result_idx] = result[clk] * value;
inserted_group_idx += *num_op_groups as usize;
}
OpGroupTableUpdate::RemoveRow => {
result[result_idx] = result[clk] * inv_row_values[removed_group_idx];
removed_group_idx += 1;
}
}
}
let last_value = result[result_idx];
assert_eq!(last_value, E::ONE);
if result_idx < result.len() - 1 {
result[(result_idx + 1)..].fill(E::ONE);
}
result
}
fn get_block_addr(main_trace: &Matrix<Felt>, row_idx: u32) -> Felt {
main_trace.get(ADDR_COL_IDX, row_idx as usize)
}