use super::{trace::NUM_RAND_ROWS, Felt, FieldElement, RangeCheckTrace, ZERO};
use crate::utils::uninit_vector;
use alloc::collections::BTreeMap;
use alloc::vec::Vec;
mod aux_trace;
pub use aux_trace::AuxTraceBuilder;
#[cfg(test)]
mod tests;
pub struct RangeChecker {
lookups: BTreeMap<u16, usize>,
cycle_lookups: BTreeMap<u32, Vec<u16>>,
}
impl RangeChecker {
pub fn new() -> Self {
let mut lookups = BTreeMap::new();
lookups.insert(0, 0);
lookups.insert(u16::MAX, 0);
Self {
lookups,
cycle_lookups: BTreeMap::new(),
}
}
pub fn add_value(&mut self, value: u16) {
self.lookups.entry(value).and_modify(|v| *v += 1).or_insert(1);
}
pub fn add_range_checks(&mut self, clk: u32, values: &[u16]) {
debug_assert!(values.len() == 2 || values.len() == 4);
for value in values.iter() {
self.add_value(*value);
}
self.cycle_lookups
.entry(clk)
.and_modify(|entry| entry.append(&mut values.to_vec()))
.or_insert_with(|| values.to_vec());
}
pub fn into_trace_with_table(
self,
trace_len: usize,
target_len: usize,
num_rand_rows: usize,
) -> RangeCheckTrace {
assert!(target_len.is_power_of_two(), "target trace length is not a power of two");
assert!(trace_len + num_rand_rows <= target_len, "target trace length too small");
let mut trace = unsafe { [uninit_vector(target_len), uninit_vector(target_len)] };
let num_padding_rows = target_len - trace_len - num_rand_rows;
trace[0][..num_padding_rows].fill(ZERO);
trace[1][..num_padding_rows].fill(ZERO);
let mut i = num_padding_rows;
let mut prev_value = 0u16;
for (&value, &num_lookups) in self.lookups.iter() {
write_rows(&mut trace, &mut i, num_lookups, value, prev_value);
prev_value = value;
}
write_trace_row(&mut trace, &mut i, 0, (u16::MAX).into());
RangeCheckTrace {
trace,
aux_builder: AuxTraceBuilder::new(
self.lookups.keys().cloned().collect(),
self.cycle_lookups,
num_padding_rows,
),
}
}
pub fn get_number_range_checker_rows(&self) -> usize {
let mut num_rows = 1;
let mut prev_value = 0u16;
for value in self.lookups.keys() {
num_rows += 1;
let delta = value - prev_value;
num_rows += get_num_bridge_rows(delta);
prev_value = *value;
}
num_rows
}
#[cfg(test)]
pub fn trace_len(&self) -> usize {
self.get_number_range_checker_rows()
}
#[cfg(test)]
pub fn into_trace(self, target_len: usize, num_rand_rows: usize) -> RangeCheckTrace {
let table_len = self.get_number_range_checker_rows();
self.into_trace_with_table(table_len, target_len, num_rand_rows)
}
}
impl Default for RangeChecker {
fn default() -> Self {
Self::new()
}
}
pub fn get_num_bridge_rows(delta: u16) -> usize {
let mut gap = delta;
let mut bridge_rows = 0_usize;
let mut stride = 3_u16.pow(7);
while gap != stride {
if gap > stride {
bridge_rows += 1;
gap -= stride;
} else {
stride /= 3;
}
}
bridge_rows
}
fn write_rows(
trace: &mut [Vec<Felt>],
step: &mut usize,
num_lookups: usize,
value: u16,
prev_value: u16,
) {
let mut gap = value - prev_value;
let mut prev_val = prev_value;
let mut stride = 3_u16.pow(7);
while gap != stride {
if gap > stride {
gap -= stride;
prev_val += stride;
write_trace_row(trace, step, 0, prev_val as u64);
} else {
stride /= 3;
}
}
write_trace_row(trace, step, num_lookups, value as u64);
}
fn write_trace_row(trace: &mut [Vec<Felt>], step: &mut usize, num_lookups: usize, value: u64) {
trace[0][*step] = Felt::new(num_lookups as u64);
trace[1][*step] = Felt::new(value);
*step += 1;
}