use {
crate::{
digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder},
noir_to_r1cs::NoirToR1CSCompiler,
},
ark_std::{One, Zero},
provekit_common::{
witness::{ProductLinearTerm, WitnessBuilder, WitnessCoefficient},
FieldElement,
},
std::{
collections::{BTreeMap, HashSet},
ops::Neg,
},
};
const MIN_BASE_WIDTH: u32 = 2;
const MAX_BASE_WIDTH: u32 = 17;
struct RangeCheckRequest {
witness_idx: usize,
bits: u32,
}
fn should_use_logup(num_bits: u32, count: usize) -> bool {
let table_size = 1usize << num_bits;
let logup_cost = table_size
.saturating_mul(3)
.saturating_add(count)
.saturating_add(1);
let naive_cost = count.saturating_mul(table_size.saturating_sub(2));
logup_cost < naive_cost
}
fn bucket_cost(num_bits: u32, count: usize) -> usize {
if count == 0 || num_bits == 0 {
return 0;
}
if num_bits >= (usize::BITS - 1) {
return usize::MAX;
}
let table_size = 1usize << num_bits;
let logup_cost = table_size
.saturating_mul(3)
.saturating_add(count)
.saturating_add(1);
let naive_cost = count.saturating_mul(table_size.saturating_sub(2));
if should_use_logup(num_bits, count) {
logup_cost
} else {
naive_cost
}
}
fn calculate_witness_cost(base_width: u32, collected: &[RangeCheckRequest]) -> usize {
let mut decomposition_witnesses: usize = 0;
let mut atomic_buckets: BTreeMap<u32, usize> = BTreeMap::new();
for check in collected {
if check.bits <= base_width {
*atomic_buckets.entry(check.bits).or_default() += 1;
} else {
let num_full_digits = check.bits / base_width;
let remainder = check.bits % base_width;
let num_digits = num_full_digits as usize + if remainder > 0 { 1 } else { 0 };
decomposition_witnesses += num_digits;
*atomic_buckets.entry(base_width).or_default() += num_full_digits as usize;
if remainder > 0 {
*atomic_buckets.entry(remainder).or_default() += 1;
}
}
}
let mut total = decomposition_witnesses;
for (&num_bits, &count) in &atomic_buckets {
total = total.saturating_add(bucket_cost(num_bits, count));
}
total
}
fn get_optimal_base_width(collected: &[RangeCheckRequest]) -> u32 {
let mut min_cost = usize::MAX;
let mut optimal_width = 8u32;
for base_width in MIN_BASE_WIDTH..=MAX_BASE_WIDTH {
let cost = calculate_witness_cost(base_width, collected);
if cost < min_cost {
min_cost = cost;
optimal_width = base_width;
}
}
optimal_width
}
pub(crate) fn add_range_checks(
r1cs: &mut NoirToR1CSCompiler,
range_checks: BTreeMap<u32, Vec<usize>>,
) -> Option<u32> {
if range_checks.is_empty() {
return None;
}
let collected: Vec<RangeCheckRequest> = range_checks
.into_iter()
.flat_map(|(num_bits, values)| {
let mut seen = HashSet::new();
values
.into_iter()
.filter(move |v| seen.insert(*v))
.map(move |witness_idx| RangeCheckRequest {
witness_idx,
bits: num_bits,
})
})
.collect();
if collected.is_empty() {
return None;
}
let base_width = get_optimal_base_width(&collected);
let max_bucket = base_width as usize + 1;
let mut atomic_range_checks: Vec<Vec<Vec<usize>>> = vec![vec![vec![]]; max_bucket];
let mut by_bits: BTreeMap<u32, Vec<usize>> = BTreeMap::new();
for req in &collected {
by_bits.entry(req.bits).or_default().push(req.witness_idx);
}
for (num_bits, values_to_lookup) in by_bits {
if num_bits > base_width {
let num_full_digits = num_bits / base_width;
let remainder = num_bits % base_width;
let mut log_bases = vec![base_width as usize; num_full_digits as usize];
if remainder > 0 {
log_bases.push(remainder as usize);
}
let dd_struct = add_digital_decomposition(r1cs, log_bases, values_to_lookup);
dd_struct
.log_bases
.iter()
.enumerate()
.map(|(digit_place, log_base)| {
(
*log_base as u32,
(0..dd_struct.num_witnesses_to_decompose)
.map(|i| dd_struct.get_digit_witness_index(digit_place, i))
.collect::<Vec<_>>(),
)
})
.for_each(|(log_base, digit_witnesses)| {
atomic_range_checks[log_base as usize].push(digit_witnesses);
});
} else {
atomic_range_checks[num_bits as usize].push(values_to_lookup);
}
}
atomic_range_checks
.iter()
.enumerate()
.for_each(|(num_bits, all_values_to_lookup)| {
let values_to_lookup: Vec<usize> = {
let mut seen = HashSet::new();
all_values_to_lookup
.iter()
.flat_map(|v| v.iter())
.copied()
.filter(|v| seen.insert(*v))
.collect()
};
if values_to_lookup.is_empty() {
return;
}
let num_bits = num_bits as u32;
if should_use_logup(num_bits, values_to_lookup.len()) {
add_range_check_via_lookup(r1cs, num_bits, &values_to_lookup);
} else {
values_to_lookup.iter().for_each(|value| {
add_naive_range_check(r1cs, num_bits, *value);
})
}
});
Some(base_width)
}
fn add_range_check_via_lookup(
r1cs_compiler: &mut NoirToR1CSCompiler,
num_bits: u32,
values_to_lookup: &[usize],
) {
let wb = WitnessBuilder::MultiplicitiesForRange(
r1cs_compiler.num_witnesses(),
1 << num_bits,
values_to_lookup.into(),
);
let multiplicities_first_witness = r1cs_compiler.add_witness_builder(wb);
let sz_challenge =
r1cs_compiler.add_witness_builder(WitnessBuilder::Challenge(r1cs_compiler.num_witnesses()));
let mut logup_summands: Vec<(FieldElement, usize)> = (0..(1 << num_bits))
.map(|table_value| {
let multiplicity_witness = multiplicities_first_witness + table_value;
(
FieldElement::one(),
add_range_table_entry_quotient(
r1cs_compiler,
sz_challenge,
table_value as u64,
multiplicity_witness,
),
)
})
.collect();
for value in values_to_lookup {
let witness_idx =
add_lookup_factor(r1cs_compiler, sz_challenge, FieldElement::one(), *value);
logup_summands.push((FieldElement::one().neg(), witness_idx));
}
r1cs_compiler.r1cs.add_constraint(
&logup_summands,
&[(FieldElement::one(), r1cs_compiler.witness_one())],
&[(FieldElement::zero(), r1cs_compiler.witness_one())],
);
}
pub(crate) fn add_lookup_factor(
r1cs_compiler: &mut NoirToR1CSCompiler,
sz_challenge: usize,
value_coeff: FieldElement,
value_witness: usize,
) -> usize {
let inverse = r1cs_compiler.add_witness_builder(WitnessBuilder::LogUpInverse(
r1cs_compiler.num_witnesses(),
sz_challenge,
WitnessCoefficient(value_coeff, value_witness),
));
r1cs_compiler.r1cs.add_constraint(
&[
(FieldElement::one(), sz_challenge),
(value_coeff.neg(), value_witness),
],
&[(FieldElement::one(), inverse)],
&[(FieldElement::one(), r1cs_compiler.witness_one())],
);
inverse
}
fn add_naive_range_check(
r1cs_compiler: &mut NoirToR1CSCompiler,
num_bits: u32,
index_witness: usize,
) {
let mut current_product_witness = index_witness;
(1..(1 << num_bits) - 1).for_each(|index: u32| {
let next_product_witness =
r1cs_compiler.add_witness_builder(WitnessBuilder::ProductLinearOperation(
r1cs_compiler.num_witnesses(),
ProductLinearTerm(
current_product_witness,
FieldElement::one(),
FieldElement::zero(),
),
ProductLinearTerm(
index_witness,
FieldElement::one(),
FieldElement::from(index).neg(),
),
));
r1cs_compiler.r1cs.add_constraint(
&[(FieldElement::one(), current_product_witness)],
&[
(FieldElement::one(), index_witness),
(FieldElement::from(index).neg(), r1cs_compiler.witness_one()),
],
&[(FieldElement::one(), next_product_witness)],
);
current_product_witness = next_product_witness;
});
r1cs_compiler.r1cs.add_constraint(
&[(FieldElement::one(), current_product_witness)],
&[
(FieldElement::one(), index_witness),
(
FieldElement::from((1 << num_bits) - 1_u32).neg(),
r1cs_compiler.witness_one(),
),
],
&[(FieldElement::zero(), r1cs_compiler.witness_one())],
);
}
fn add_range_table_entry_quotient(
r1cs_compiler: &mut NoirToR1CSCompiler,
sz_challenge: usize,
table_value: u64,
multiplicity_witness: usize,
) -> usize {
let inverse = r1cs_compiler.add_witness_builder(WitnessBuilder::LogUpInverse(
r1cs_compiler.num_witnesses(),
sz_challenge,
WitnessCoefficient(FieldElement::from(table_value), r1cs_compiler.witness_one()),
));
let quotient = r1cs_compiler.add_witness_builder(WitnessBuilder::Product(
r1cs_compiler.num_witnesses(),
multiplicity_witness,
inverse,
));
r1cs_compiler.r1cs.add_constraint(
&[
(FieldElement::one(), sz_challenge),
(
FieldElement::from(table_value).neg(),
r1cs_compiler.witness_one(),
),
],
&[(FieldElement::one(), quotient)],
&[(FieldElement::one(), multiplicity_witness)],
);
quotient
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bucket_cost_zero_cases() {
assert_eq!(bucket_cost(0, 100), 0);
assert_eq!(bucket_cost(5, 0), 0);
}
#[test]
fn bucket_cost_overflow_guard() {
assert_eq!(bucket_cost(63, 1), usize::MAX);
}
#[test]
fn should_use_logup_decision() {
assert!(!should_use_logup(1, 1));
assert!(should_use_logup(8, 5));
assert!(!should_use_logup(8, 1));
assert!(should_use_logup(8, 256));
}
}