use alloc::{collections::BTreeMap, vec::Vec};
use miden_air::{
RowIndex,
trace::{
main_trace::MainTrace,
range::{M_COL_IDX, V_COL_IDX},
},
};
use miden_core::ZERO;
use super::{Felt, FieldElement, NUM_RAND_ROWS, uninit_vector};
#[derive(Debug)]
pub struct AuxTraceBuilder {
lookup_values: Vec<u16>,
cycle_lookups: BTreeMap<RowIndex, Vec<u16>>,
values_start: usize,
}
impl AuxTraceBuilder {
pub fn new(
lookup_values: Vec<u16>,
cycle_lookups: BTreeMap<RowIndex, Vec<u16>>,
values_start: usize,
) -> Self {
Self {
lookup_values,
cycle_lookups,
values_start,
}
}
pub fn build_aux_columns<E: FieldElement<BaseField = Felt>>(
&self,
main_trace: &MainTrace,
rand_elements: &[E],
) -> Vec<Vec<E>> {
let b_range = self.build_aux_col_b_range(main_trace, rand_elements);
vec![b_range]
}
fn build_aux_col_b_range<E: FieldElement<BaseField = Felt>>(
&self,
main_trace: &MainTrace,
rand_elements: &[E],
) -> Vec<E> {
let divisors = get_divisors(&self.lookup_values, rand_elements[0]);
let mut b_range = unsafe { uninit_vector(main_trace.num_rows()) };
b_range[0] = E::ZERO;
let mut b_range_idx = 0_usize;
for (clk, range_checks) in
self.cycle_lookups.range(RowIndex::from(0)..RowIndex::from(self.values_start))
{
let clk: usize = (*clk).into();
if b_range_idx < clk {
let last_value = b_range[b_range_idx];
b_range[(b_range_idx + 1)..=clk].fill(last_value);
}
b_range_idx = clk + 1;
b_range[b_range_idx] = b_range[clk];
for lookup in range_checks.iter() {
let value = divisors.get(lookup).expect("invalid lookup value");
b_range[b_range_idx] -= *value;
}
}
if b_range_idx < self.values_start {
let last_value = b_range[b_range_idx];
b_range[(b_range_idx + 1)..=self.values_start].fill(last_value);
}
for (row_idx, (multiplicity, lookup)) in main_trace
.get_column(M_COL_IDX)
.iter()
.zip(main_trace.get_column(V_COL_IDX).iter())
.enumerate()
.take(main_trace.num_rows() - NUM_RAND_ROWS)
.skip(self.values_start)
{
b_range_idx = row_idx + 1;
if *multiplicity != ZERO {
let value = divisors.get(&(lookup.as_int() as u16)).expect("invalid lookup value");
b_range[b_range_idx] = b_range[row_idx] + value.mul_base(*multiplicity);
} else {
b_range[b_range_idx] = b_range[row_idx];
}
if let Some(range_checks) = self.cycle_lookups.get(&(row_idx as u32).into()) {
for lookup in range_checks.iter() {
let value = divisors.get(lookup).expect("invalid lookup value");
b_range[b_range_idx] -= *value;
}
}
}
assert_eq!(b_range[b_range_idx], E::ZERO);
if b_range_idx < b_range.len() - 1 {
b_range[(b_range_idx + 1)..].fill(E::ZERO);
}
b_range
}
}
fn get_divisors<E: FieldElement<BaseField = Felt>>(
lookup_values: &[u16],
alpha: E,
) -> BTreeMap<u16, E> {
let mut values = unsafe { uninit_vector(lookup_values.len()) };
let mut inv_values = unsafe { uninit_vector(lookup_values.len()) };
let mut log_values = BTreeMap::new();
let mut acc = E::ONE;
for (i, (value, inv_value)) in values.iter_mut().zip(inv_values.iter_mut()).enumerate() {
*inv_value = acc;
*value = alpha + E::from(lookup_values[i]);
acc *= *value;
}
acc = acc.inv();
for i in (0..lookup_values.len()).rev() {
inv_values[i] *= acc;
acc *= values[i];
log_values.insert(lookup_values[i], inv_values[i]);
}
log_values
}