use alloc::{collections::BTreeMap, vec::Vec};
use core::mem::MaybeUninit;
use miden_air::trace::RANGE_CHECK_TRACE_WIDTH;
use super::RowIndex;
use crate::{
Felt, ZERO,
utils::{assume_init_vec, uninit_vector},
};
mod aux_trace;
pub use aux_trace::AuxTraceBuilder;
#[cfg(test)]
mod tests;
pub struct RangeCheckTrace {
pub(crate) trace: [Vec<Felt>; RANGE_CHECK_TRACE_WIDTH],
pub(crate) aux_builder: AuxTraceBuilder,
}
pub struct RangeChecker {
lookups: BTreeMap<u16, usize>,
cycle_lookups: BTreeMap<RowIndex, Vec<u16>>,
}
impl RangeChecker {
pub fn new() -> Self {
let mut lookups = BTreeMap::new();
lookups.insert(0, 0);
lookups.insert(u16::MAX, 0);
Self { lookups, cycle_lookups: BTreeMap::new() }
}
pub fn add_value(&mut self, value: u16) {
self.lookups.entry(value).and_modify(|v| *v += 1).or_insert(1);
}
pub fn add_range_checks(&mut self, clk: RowIndex, values: &[u16]) {
debug_assert!(values.len() == 2 || values.len() == 4);
for value in values.iter() {
self.add_value(*value);
}
self.cycle_lookups
.entry(clk)
.and_modify(|entry| entry.append(&mut values.to_vec()))
.or_insert_with(|| values.to_vec());
}
pub fn into_trace_with_table(self, trace_len: usize, target_len: usize) -> RangeCheckTrace {
assert!(target_len.is_power_of_two(), "target trace length is not a power of two");
assert!(trace_len <= target_len, "target trace length too small");
let mut trace = [uninit_vector(target_len), uninit_vector(target_len)];
let num_padding_rows = target_len - trace_len;
trace[0][..num_padding_rows].fill(MaybeUninit::new(ZERO));
trace[1][..num_padding_rows].fill(MaybeUninit::new(ZERO));
let mut i = num_padding_rows;
let mut prev_value = 0u16;
for (&value, &num_lookups) in self.lookups.iter() {
write_rows(&mut trace, &mut i, num_lookups, value, prev_value);
prev_value = value;
}
write_trace_row(&mut trace, &mut i, 0, (u16::MAX).into());
assert_eq!(i, target_len, "range checker trace not fully initialized; trace_len mismatch");
let [t0, t1] = trace;
let trace = unsafe { [assume_init_vec(t0), assume_init_vec(t1)] };
RangeCheckTrace {
trace,
aux_builder: AuxTraceBuilder::new(
self.lookups.keys().cloned().collect(),
self.cycle_lookups,
num_padding_rows,
),
}
}
pub fn get_number_range_checker_rows(&self) -> usize {
let mut num_rows = 1;
let mut prev_value = 0u16;
for value in self.lookups.keys() {
num_rows += 1;
let delta = value - prev_value;
num_rows += get_num_bridge_rows(delta);
prev_value = *value;
}
num_rows
}
#[cfg(test)]
pub fn trace_len(&self) -> usize {
self.get_number_range_checker_rows()
}
#[cfg(test)]
pub fn into_trace(self, target_len: usize) -> RangeCheckTrace {
let table_len = self.get_number_range_checker_rows();
self.into_trace_with_table(table_len, target_len)
}
}
impl Default for RangeChecker {
fn default() -> Self {
Self::new()
}
}
pub fn get_num_bridge_rows(delta: u16) -> usize {
let mut gap = delta;
let mut bridge_rows = 0_usize;
let mut stride = 3_u16.pow(7);
while gap != stride {
if gap > stride {
bridge_rows += 1;
gap -= stride;
} else {
stride /= 3;
}
}
bridge_rows
}
fn write_rows(
trace: &mut [Vec<MaybeUninit<Felt>>],
step: &mut usize,
num_lookups: usize,
value: u16,
prev_value: u16,
) {
let mut gap = value - prev_value;
let mut prev_val = prev_value;
let mut stride = 3_u16.pow(7);
while gap != stride {
if gap > stride {
gap -= stride;
prev_val += stride;
write_trace_row(trace, step, 0, prev_val as u64);
} else {
stride /= 3;
}
}
write_trace_row(trace, step, num_lookups, value as u64);
}
fn write_trace_row(
trace: &mut [Vec<MaybeUninit<Felt>>],
step: &mut usize,
num_lookups: usize,
value: u64,
) {
trace[0][*step].write(Felt::new(num_lookups as u64));
trace[1][*step].write(Felt::new(value));
*step += 1;
}