use super::{uninit_vector, Felt, FieldElement, NUM_RAND_ROWS};
use alloc::collections::BTreeMap;
use alloc::vec::Vec;
use miden_air::trace::main_trace::MainTrace;
use miden_air::trace::range::{M_COL_IDX, V_COL_IDX};
pub struct AuxTraceBuilder {
lookup_values: Vec<u16>,
cycle_lookups: BTreeMap<u32, Vec<u16>>,
values_start: usize,
}
impl AuxTraceBuilder {
pub fn new(
lookup_values: Vec<u16>,
cycle_lookups: BTreeMap<u32, Vec<u16>>,
values_start: usize,
) -> Self {
Self {
lookup_values,
cycle_lookups,
values_start,
}
}
pub fn build_aux_columns<E: FieldElement<BaseField = Felt>>(
&self,
main_trace: &MainTrace,
rand_elements: &[E],
) -> Vec<Vec<E>> {
let b_range = self.build_aux_col_b_range(main_trace, rand_elements);
vec![b_range]
}
fn build_aux_col_b_range<E: FieldElement<BaseField = Felt>>(
&self,
main_trace: &MainTrace,
rand_elements: &[E],
) -> Vec<E> {
let divisors = get_divisors(&self.lookup_values, rand_elements[0]);
let mut b_range = unsafe { uninit_vector(main_trace.num_rows()) };
b_range[0] = E::ONE;
let mut b_range_idx = 0_usize;
for (clk, range_checks) in self.cycle_lookups.range(0..self.values_start as u32) {
let clk = *clk as usize;
if b_range_idx < clk {
let last_value = b_range[b_range_idx];
b_range[(b_range_idx + 1)..=clk].fill(last_value);
}
b_range_idx = clk + 1;
b_range[b_range_idx] = b_range[clk];
for lookup in range_checks.iter() {
let value = divisors.get(lookup).expect("invalid lookup value {}");
b_range[b_range_idx] -= *value;
}
}
if b_range_idx < self.values_start {
let last_value = b_range[b_range_idx];
b_range[(b_range_idx + 1)..=self.values_start].fill(last_value);
}
for (row_idx, (multiplicity, lookup)) in main_trace
.get_column(M_COL_IDX)
.iter()
.zip(main_trace.get_column(V_COL_IDX).iter())
.enumerate()
.take(main_trace.num_rows() - NUM_RAND_ROWS)
.skip(self.values_start)
{
b_range_idx = row_idx + 1;
if multiplicity.as_int() != 0 {
let value = divisors.get(&(lookup.as_int() as u16)).expect("invalid lookup value");
b_range[b_range_idx] = b_range[row_idx] + value.mul_base(*multiplicity);
} else {
b_range[b_range_idx] = b_range[row_idx];
}
if let Some(range_checks) = self.cycle_lookups.get(&(row_idx as u32)) {
for lookup in range_checks.iter() {
let value = divisors.get(lookup).expect("invalid lookup value");
b_range[b_range_idx] -= *value;
}
}
}
assert_eq!(b_range[b_range_idx], E::ONE);
if b_range_idx < b_range.len() - 1 {
b_range[(b_range_idx + 1)..].fill(E::ONE);
}
b_range
}
}
fn get_divisors<E: FieldElement<BaseField = Felt>>(
lookup_values: &[u16],
alpha: E,
) -> BTreeMap<u16, E> {
let mut values = unsafe { uninit_vector(lookup_values.len()) };
let mut inv_values = unsafe { uninit_vector(lookup_values.len()) };
let mut log_values = BTreeMap::new();
let mut acc = E::ONE;
for (i, (value, inv_value)) in values.iter_mut().zip(inv_values.iter_mut()).enumerate() {
*inv_value = acc;
*value = alpha - E::from(lookup_values[i]);
acc *= *value;
}
acc = acc.inv();
for i in (0..lookup_values.len()).rev() {
inv_values[i] *= acc;
acc *= values[i];
log_values.insert(lookup_values[i], inv_values[i]);
}
log_values
}