use alloc::{collections::BTreeMap, vec::Vec};
use core::mem::MaybeUninit;
use miden_air::trace::{
Challenges, MainTrace, RowIndex,
range::{M_COL_IDX, V_COL_IDX},
};
use crate::{
Felt, ZERO,
field::ExtensionField,
utils::{assume_init_vec, uninit_vector},
};
#[derive(Debug, Clone)]
pub struct AuxTraceBuilder {
lookup_values: Vec<u16>,
cycle_lookups: BTreeMap<RowIndex, Vec<u16>>,
values_start: usize,
}
impl AuxTraceBuilder {
pub fn new(
lookup_values: Vec<u16>,
cycle_lookups: BTreeMap<RowIndex, Vec<u16>>,
values_start: usize,
) -> Self {
Self {
lookup_values,
cycle_lookups,
values_start,
}
}
pub fn build_aux_columns<E: ExtensionField<Felt>>(
&self,
main_trace: &MainTrace,
challenges: &Challenges<E>,
) -> Vec<Vec<E>> {
let b_range = self.build_aux_col_b_range(main_trace, challenges);
vec![b_range]
}
fn build_aux_col_b_range<E: ExtensionField<Felt>>(
&self,
main_trace: &MainTrace,
challenges: &Challenges<E>,
) -> Vec<E> {
let divisors = get_divisors(&self.lookup_values, challenges.alpha);
let mut b_range: Vec<MaybeUninit<E>> = uninit_vector(main_trace.num_rows());
b_range[0].write(E::ZERO);
let mut b_range_idx = 0_usize;
let mut current_value = E::ZERO;
for (clk, range_checks) in
self.cycle_lookups.range(RowIndex::from(0)..RowIndex::from(self.values_start))
{
let clk: usize = (*clk).into();
if b_range_idx < clk {
b_range[(b_range_idx + 1)..=clk].fill(MaybeUninit::new(current_value));
}
b_range_idx = clk + 1;
let mut new_value = current_value;
for lookup in range_checks.iter() {
let value = divisors.get(lookup).expect("invalid lookup value");
new_value -= *value;
}
b_range[b_range_idx].write(new_value);
current_value = new_value;
}
if b_range_idx < self.values_start {
b_range[(b_range_idx + 1)..=self.values_start].fill(MaybeUninit::new(current_value));
}
for (row_idx, (multiplicity, lookup)) in main_trace
.get_column(M_COL_IDX)
.iter()
.zip(main_trace.get_column(V_COL_IDX).iter())
.enumerate()
.take(main_trace.num_rows() - 1)
.skip(self.values_start)
{
b_range_idx = row_idx + 1;
let mut new_value = current_value;
if *multiplicity != ZERO {
let value = divisors
.get(&(lookup.as_canonical_u64() as u16))
.expect("invalid lookup value");
new_value = current_value + *value * *multiplicity;
}
if let Some(range_checks) = self.cycle_lookups.get(&(row_idx as u32).into()) {
for lookup in range_checks.iter() {
let value = divisors.get(lookup).expect("invalid lookup value");
new_value -= *value;
}
}
b_range[b_range_idx].write(new_value);
current_value = new_value;
}
assert_eq!(current_value, E::ZERO);
if b_range_idx < b_range.len() - 1 {
b_range[(b_range_idx + 1)..].fill(MaybeUninit::new(E::ZERO));
}
unsafe { assume_init_vec(b_range) }
}
}
fn get_divisors<E: ExtensionField<Felt>>(lookup_values: &[u16], alpha: E) -> BTreeMap<u16, E> {
let mut values: Vec<MaybeUninit<E>> = uninit_vector(lookup_values.len());
let mut inv_values: Vec<MaybeUninit<E>> = uninit_vector(lookup_values.len());
let mut acc = E::ONE;
for (i, (value, inv_value)) in values.iter_mut().zip(inv_values.iter_mut()).enumerate() {
inv_value.write(acc);
let v = alpha + E::from_u16(lookup_values[i]);
value.write(v);
acc *= v;
}
let values = unsafe { assume_init_vec(values) };
let mut inv_values = unsafe { assume_init_vec(inv_values) };
acc = acc.inverse();
let mut log_values = BTreeMap::new();
for i in (0..lookup_values.len()).rev() {
inv_values[i] *= acc;
acc *= values[i];
log_values.insert(lookup_values[i], inv_values[i]);
}
log_values
}