use super::{BTreeMap, Felt, FieldElement, Vec, ONE, ZERO};
use crate::RangeCheckTrace;
use vm_core::utils::uninit_vector;
mod aux_trace;
pub use aux_trace::AuxTraceBuilder;
mod request;
use request::CycleRangeChecks;
#[cfg(test)]
mod tests;
pub const RANGE_CHECK_TRACE_TABLE_WIDTH: usize = 256;
pub struct RangeCheckTraceTable {
pub lookups_8bit: [usize; RANGE_CHECK_TRACE_TABLE_WIDTH],
pub len: usize,
}
pub struct RangeChecker {
lookups: BTreeMap<u16, usize>,
cycle_range_checks: BTreeMap<u32, CycleRangeChecks>,
}
impl RangeChecker {
pub fn new() -> Self {
let mut lookups = BTreeMap::new();
lookups.insert(0, 0);
lookups.insert(u16::MAX, 0);
Self {
lookups,
cycle_range_checks: BTreeMap::new(),
}
}
pub fn add_value(&mut self, value: u16) {
self.lookups
.entry(value)
.and_modify(|v| *v += 1)
.or_insert(1);
}
pub fn add_stack_checks(&mut self, clk: u32, values: &[u16; 4]) {
self.add_value(values[0]);
self.add_value(values[1]);
self.add_value(values[2]);
self.add_value(values[3]);
self.cycle_range_checks
.insert(clk, CycleRangeChecks::new_from_stack(values));
}
pub fn add_mem_checks(&mut self, clk: u32, values: &[u16; 2]) {
self.add_value(values[0]);
self.add_value(values[1]);
self.cycle_range_checks
.entry(clk)
.and_modify(|entry| entry.add_memory_checks(values))
.or_insert_with(|| CycleRangeChecks::new_from_memory(values));
}
pub fn into_trace_with_table(
self,
table: RangeCheckTraceTable,
target_len: usize,
num_rand_rows: usize,
) -> RangeCheckTrace {
assert!(
target_len.is_power_of_two(),
"target trace length is not a power of two"
);
let trace_len = table.len;
assert!(
trace_len + num_rand_rows <= target_len,
"target trace length too small"
);
let mut trace = unsafe {
[
uninit_vector(target_len),
uninit_vector(target_len),
uninit_vector(target_len),
uninit_vector(target_len),
]
};
let mut row_flags = unsafe { uninit_vector(target_len) };
let num_padding_rows = target_len - trace_len - num_rand_rows;
trace[1][..num_padding_rows].fill(ZERO);
trace[2][..num_padding_rows].fill(ZERO);
trace[3][..num_padding_rows].fill(ZERO);
row_flags[..num_padding_rows].fill(RangeCheckFlag::F0);
let mut i = num_padding_rows;
for (value, num_lookups) in table.lookups_8bit.into_iter().enumerate() {
write_value(
&mut trace,
&mut i,
num_lookups,
value as u64,
&mut row_flags,
);
}
trace[0][..i].fill(ZERO);
trace[0][i..].fill(ONE);
let start_16bit = i;
let mut prev_value = 0u16;
for (&value, &num_lookups) in self.lookups.iter() {
for value in (prev_value..value).step_by(255).skip(1) {
write_value(&mut trace, &mut i, 0, value as u64, &mut row_flags);
}
write_value(
&mut trace,
&mut i,
num_lookups,
value as u64,
&mut row_flags,
);
prev_value = value;
}
write_value(&mut trace, &mut i, 0, (u16::MAX).into(), &mut row_flags);
RangeCheckTrace {
trace,
aux_builder: AuxTraceBuilder::new(self.cycle_range_checks, row_flags, start_16bit),
}
}
pub fn build_8bit_lookup(&self) -> RangeCheckTraceTable {
let mut lookups_8bit = [0; 256];
let mut num_16bit_rows = 1;
lookups_8bit[0] = 1;
let mut prev_value = 0u16;
for (&value, &num_lookups) in self.lookups.iter() {
let num_rows = lookups_to_rows(num_lookups);
lookups_8bit[0] += num_rows - 1;
num_16bit_rows += num_rows;
let delta = value - prev_value;
let (delta_q, delta_r) = div_rem(delta as usize, 255);
if delta_q != 0 {
lookups_8bit[255] += delta_q;
let num_bridge_rows = if delta_r == 0 { delta_q - 1 } else { delta_q };
num_16bit_rows += num_bridge_rows;
}
if delta_r != 0 {
lookups_8bit[delta_r] += 1;
}
prev_value = value;
}
let num_8bit_rows = get_num_8bit_rows(&lookups_8bit);
let len = num_8bit_rows + num_16bit_rows;
RangeCheckTraceTable { lookups_8bit, len }
}
#[cfg(test)]
pub fn trace_len(&self) -> usize {
self.build_8bit_lookup().len
}
#[cfg(test)]
pub fn into_trace(self, target_len: usize, num_rand_rows: usize) -> RangeCheckTrace {
let table = self.build_8bit_lookup();
self.into_trace_with_table(table, target_len, num_rand_rows)
}
}
impl Default for RangeChecker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum RangeCheckFlag {
F0,
F1,
F2,
F3,
}
impl RangeCheckFlag {
pub fn to_value<E: FieldElement<BaseField = Felt>>(&self, value: Felt, alphas: &[E]) -> E {
let alpha: E = alphas[0];
match self {
RangeCheckFlag::F0 => E::ONE,
RangeCheckFlag::F1 => alpha + value.into(),
RangeCheckFlag::F2 => (alpha + value.into()).square(),
RangeCheckFlag::F3 => ((alpha + value.into()).square()).square(),
}
}
}
fn lookups_to_rows(num_lookups: usize) -> usize {
if num_lookups == 0 {
1
} else {
let (num_rows4, num_lookups) = div_rem(num_lookups, 4);
let (num_rows2, num_rows1) = div_rem(num_lookups, 2);
num_rows4 + num_rows2 + num_rows1
}
}
fn get_num_8bit_rows(lookups: &[usize; 256]) -> usize {
let mut result = 0;
for &num_lookups in lookups.iter() {
result += lookups_to_rows(num_lookups);
}
result
}
fn write_value(
trace: &mut [Vec<Felt>],
step: &mut usize,
num_lookups: usize,
value: u64,
row_flags: &mut [RangeCheckFlag],
) {
if num_lookups == 0 {
row_flags[*step] = RangeCheckFlag::F0;
write_trace_row(trace, step, ZERO, ZERO, value);
return;
}
let (num_rows, num_lookups) = div_rem(num_lookups, 4);
for _ in 0..num_rows {
row_flags[*step] = RangeCheckFlag::F3;
write_trace_row(trace, step, ONE, ONE, value);
}
let (num_rows, num_lookups) = div_rem(num_lookups, 2);
for _ in 0..num_rows {
row_flags[*step] = RangeCheckFlag::F2;
write_trace_row(trace, step, ZERO, ONE, value);
}
for _ in 0..num_lookups {
row_flags[*step] = RangeCheckFlag::F1;
write_trace_row(trace, step, ONE, ZERO, value);
}
}
fn write_trace_row(trace: &mut [Vec<Felt>], step: &mut usize, s0: Felt, s1: Felt, value: u64) {
trace[1][*step] = s0;
trace[2][*step] = s1;
trace[3][*step] = Felt::new(value);
*step += 1;
}
fn div_rem(value: usize, divisor: usize) -> (usize, usize) {
let q = value / divisor;
let r = value % divisor;
(q, r)
}