use {
crate::{
digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder},
noir_to_r1cs::NoirToR1CSCompiler,
uints::U8,
},
ark_ff::PrimeField,
ark_std::One,
provekit_common::{
witness::{ConstantOrR1CSWitness, SumTerm, WitnessBuilder},
FieldElement,
},
std::{
collections::{BTreeMap, BTreeSet, HashMap},
ops::Neg,
},
};
#[derive(Clone, Debug, Copy)]
pub enum BinOp {
And,
Xor,
}
struct LookupChallenges {
sz: usize,
rs: usize,
rs_sqrd: usize,
rs_cubed: usize,
}
type PairMapEntry = (
Option<usize>,
Option<usize>,
ConstantOrR1CSWitness,
ConstantOrR1CSWitness,
);
fn calculate_binop_witness_cost(w: u32, n: usize) -> usize {
assert!(
matches!(w, 2 | 4 | 8),
"width must be in {{2, 4, 8}} to evenly divide 8, got {w}"
);
let d = 8u32.div_ceil(w) as usize;
let table = 3 * (1usize << (2 * w));
let queries = 4 * n * d;
let decomp = if w < 8 { 4 * n * d } else { 0 };
let complementary = n;
let overhead = 6;
table + queries + decomp + complementary + overhead
}
fn get_optimal_binop_width(n: usize) -> u32 {
[2u32, 4, 8]
.into_iter()
.min_by_key(|&w| calculate_binop_witness_cost(w, n))
.unwrap()
}
fn cow_to_digit(
cow: ConstantOrR1CSWitness,
digit_i: usize,
atomic_bits: u32,
dd: &provekit_common::witness::DigitalDecompositionWitnesses,
witness_to_offset: &HashMap<usize, usize>,
) -> ConstantOrR1CSWitness {
match cow {
ConstantOrR1CSWitness::Constant(c) => {
let val = c.into_bigint().0[0];
let digit =
(val >> (digit_i as u64 * atomic_bits as u64)) & ((1u64 << atomic_bits) - 1);
ConstantOrR1CSWitness::Constant(FieldElement::from(digit))
}
ConstantOrR1CSWitness::Witness(w) => {
let offset = witness_to_offset[&w];
ConstantOrR1CSWitness::Witness(dd.get_digit_witness_index(digit_i, offset))
}
}
}
pub(crate) fn add_byte_binop(
r1cs_compiler: &mut NoirToR1CSCompiler,
op: BinOp,
ops: &mut Vec<(ConstantOrR1CSWitness, ConstantOrR1CSWitness, usize)>,
a: U8,
b: U8,
) -> U8 {
debug_assert!(
a.range_checked && b.range_checked,
"Byte binop requires inputs to be range-checked U8s"
);
let result = match op {
BinOp::And => r1cs_compiler.add_witness_builder(WitnessBuilder::And(
r1cs_compiler.num_witnesses(),
ConstantOrR1CSWitness::Witness(a.idx),
ConstantOrR1CSWitness::Witness(b.idx),
)),
BinOp::Xor => r1cs_compiler.add_witness_builder(WitnessBuilder::Xor(
r1cs_compiler.num_witnesses(),
ConstantOrR1CSWitness::Witness(a.idx),
ConstantOrR1CSWitness::Witness(b.idx),
)),
};
ops.push((
ConstantOrR1CSWitness::Witness(a.idx),
ConstantOrR1CSWitness::Witness(b.idx),
result,
));
U8::new(result, true)
}
pub(crate) fn add_combined_binop_constraints(
r1cs_compiler: &mut NoirToR1CSCompiler,
and_ops: Vec<(ConstantOrR1CSWitness, ConstantOrR1CSWitness, usize)>,
xor_ops: Vec<(ConstantOrR1CSWitness, ConstantOrR1CSWitness, usize)>,
) -> Option<u32> {
if and_ops.is_empty() && xor_ops.is_empty() {
return None;
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum OperandKey {
Witness(usize),
Constant([u64; 4]),
}
fn operand_key(op: &ConstantOrR1CSWitness) -> OperandKey {
match op {
ConstantOrR1CSWitness::Witness(idx) => OperandKey::Witness(*idx),
ConstantOrR1CSWitness::Constant(fe) => OperandKey::Constant(fe.into_bigint().0),
}
}
let mut pair_map: BTreeMap<(OperandKey, OperandKey), PairMapEntry> = BTreeMap::new();
for (lhs, rhs, and_out) in &and_ops {
let key = (operand_key(lhs), operand_key(rhs));
pair_map
.entry(key)
.and_modify(|e| {
if let Some(existing) = e.0 {
r1cs_compiler.r1cs.add_constraint(
&[(FieldElement::one(), existing)],
&[(FieldElement::one(), r1cs_compiler.witness_one())],
&[(FieldElement::one(), *and_out)],
);
}
e.0 = Some(*and_out);
})
.or_insert((Some(*and_out), None, *lhs, *rhs));
}
for (lhs, rhs, xor_out) in &xor_ops {
let key = (operand_key(lhs), operand_key(rhs));
pair_map
.entry(key)
.and_modify(|e| {
if let Some(existing) = e.1 {
r1cs_compiler.r1cs.add_constraint(
&[(FieldElement::one(), existing)],
&[(FieldElement::one(), r1cs_compiler.witness_one())],
&[(FieldElement::one(), *xor_out)],
);
}
e.1 = Some(*xor_out);
})
.or_insert((None, Some(*xor_out), *lhs, *rhs));
}
let mut combined_ops_atomic = Vec::with_capacity(pair_map.len());
for (_key, (and_opt, xor_opt, lhs, rhs)) in pair_map {
let and_out = and_opt.unwrap_or_else(|| {
r1cs_compiler.add_witness_builder(WitnessBuilder::And(
r1cs_compiler.num_witnesses(),
lhs,
rhs,
))
});
let xor_out = xor_opt.unwrap_or_else(|| {
r1cs_compiler.add_witness_builder(WitnessBuilder::Xor(
r1cs_compiler.num_witnesses(),
lhs,
rhs,
))
});
combined_ops_atomic.push((lhs, rhs, and_out, xor_out));
}
let atomic_bits = get_optimal_binop_width(combined_ops_atomic.len());
let lookup_ops: Vec<(
ConstantOrR1CSWitness,
ConstantOrR1CSWitness,
ConstantOrR1CSWitness,
ConstantOrR1CSWitness,
)> = if atomic_bits == 8 {
combined_ops_atomic
.iter()
.map(|(lhs, rhs, and_out, xor_out)| {
(
*lhs,
*rhs,
ConstantOrR1CSWitness::Witness(*and_out),
ConstantOrR1CSWitness::Witness(*xor_out),
)
})
.collect()
} else {
let d = 8u32.div_ceil(atomic_bits) as usize;
let log_bases = vec![atomic_bits as usize; d];
let mut witness_set: BTreeSet<usize> = BTreeSet::new();
for (lhs, rhs, and_out, xor_out) in &combined_ops_atomic {
if let ConstantOrR1CSWitness::Witness(w) = lhs {
witness_set.insert(*w);
}
if let ConstantOrR1CSWitness::Witness(w) = rhs {
witness_set.insert(*w);
}
witness_set.insert(*and_out);
witness_set.insert(*xor_out);
}
let witness_bytes: Vec<usize> = witness_set.into_iter().collect();
let witness_to_offset: HashMap<usize, usize> = witness_bytes
.iter()
.enumerate()
.map(|(i, &w)| (w, i))
.collect();
let dd = add_digital_decomposition(r1cs_compiler, log_bases, witness_bytes);
let mut digit_ops = Vec::with_capacity(combined_ops_atomic.len() * d);
for (lhs, rhs, and_out, xor_out) in &combined_ops_atomic {
for digit_i in 0..d {
let lhs_digit = cow_to_digit(*lhs, digit_i, atomic_bits, &dd, &witness_to_offset);
let rhs_digit = cow_to_digit(*rhs, digit_i, atomic_bits, &dd, &witness_to_offset);
let and_digit = ConstantOrR1CSWitness::Witness(
dd.get_digit_witness_index(digit_i, witness_to_offset[and_out]),
);
let xor_digit = ConstantOrR1CSWitness::Witness(
dd.get_digit_witness_index(digit_i, witness_to_offset[xor_out]),
);
digit_ops.push((lhs_digit, rhs_digit, and_digit, xor_digit));
}
}
digit_ops
};
let multiplicities_wb = WitnessBuilder::MultiplicitiesForBinOp(
r1cs_compiler.num_witnesses(),
atomic_bits,
lookup_ops.iter().map(|(lh, rh, ..)| (*lh, *rh)).collect(),
);
let multiplicities_first_witness = r1cs_compiler.add_witness_builder(multiplicities_wb);
let sz =
r1cs_compiler.add_witness_builder(WitnessBuilder::Challenge(r1cs_compiler.num_witnesses()));
let rs =
r1cs_compiler.add_witness_builder(WitnessBuilder::Challenge(r1cs_compiler.num_witnesses()));
let rs_sqrd = r1cs_compiler.add_product(rs, rs);
let rs_cubed = r1cs_compiler.add_product(rs_sqrd, rs);
let challenges = LookupChallenges {
sz,
rs,
rs_sqrd,
rs_cubed,
};
let summands_for_ops = lookup_ops
.into_iter()
.map(|(lhs, rhs, and_out, xor_out)| {
add_combined_lookup_summand(r1cs_compiler, &challenges, lhs, rhs, and_out, xor_out)
})
.map(|coeff| SumTerm(None, coeff))
.collect();
let sum_for_ops = r1cs_compiler.add_sum(summands_for_ops);
let summands_for_table = (0..1u32 << atomic_bits)
.flat_map(|lhs| (0..1u32 << atomic_bits).map(move |rhs| (lhs, rhs, lhs & rhs, lhs ^ rhs)))
.map(|(lhs, rhs, and_out, xor_out)| {
let multiplicity_idx =
multiplicities_first_witness + ((lhs << atomic_bits) as usize) + rhs as usize;
add_table_entry_quotient(
r1cs_compiler,
&challenges,
lhs,
rhs,
and_out,
xor_out,
multiplicity_idx,
)
})
.map(|quotient| SumTerm(None, quotient))
.collect();
let sum_for_table = r1cs_compiler.add_sum(summands_for_table);
r1cs_compiler.r1cs.add_constraint(
&[(FieldElement::one(), r1cs_compiler.witness_one())],
&[(FieldElement::one(), sum_for_ops)],
&[(FieldElement::one(), sum_for_table)],
);
Some(atomic_bits)
}
fn add_table_entry_quotient(
r1cs_compiler: &mut NoirToR1CSCompiler,
c: &LookupChallenges,
lhs: u32,
rhs: u32,
and_out: u32,
xor_out: u32,
multiplicity_witness: usize,
) -> usize {
use provekit_common::witness::CombinedTableEntryInverseData;
let inverse = r1cs_compiler.add_witness_builder(WitnessBuilder::CombinedTableEntryInverse(
CombinedTableEntryInverseData {
idx: r1cs_compiler.num_witnesses(),
sz_challenge: c.sz,
rs_challenge: c.rs,
rs_sqrd: c.rs_sqrd,
rs_cubed: c.rs_cubed,
lhs: FieldElement::from(lhs),
rhs: FieldElement::from(rhs),
and_out: FieldElement::from(and_out),
xor_out: FieldElement::from(xor_out),
},
));
let quotient = r1cs_compiler.add_witness_builder(WitnessBuilder::Product(
r1cs_compiler.num_witnesses(),
multiplicity_witness,
inverse,
));
r1cs_compiler.r1cs.add_constraint(
&[
(FieldElement::one(), c.sz),
(FieldElement::from(lhs).neg(), r1cs_compiler.witness_one()),
(FieldElement::from(rhs).neg(), c.rs),
(FieldElement::from(and_out).neg(), c.rs_sqrd),
(FieldElement::from(xor_out).neg(), c.rs_cubed),
],
&[(FieldElement::one(), quotient)],
&[(FieldElement::one(), multiplicity_witness)],
);
quotient
}
fn add_combined_lookup_summand(
r1cs_compiler: &mut NoirToR1CSCompiler,
c: &LookupChallenges,
lhs: ConstantOrR1CSWitness,
rhs: ConstantOrR1CSWitness,
and_out: ConstantOrR1CSWitness,
xor_out: ConstantOrR1CSWitness,
) -> usize {
let wb = WitnessBuilder::CombinedBinOpLookupDenominator(
r1cs_compiler.num_witnesses(),
c.sz,
c.rs,
c.rs_sqrd,
c.rs_cubed,
lhs,
rhs,
and_out,
xor_out,
);
let denominator = r1cs_compiler.add_witness_builder(wb);
let rs_sqrd_and_term = match and_out {
ConstantOrR1CSWitness::Constant(value) => (FieldElement::from(value), c.rs_sqrd),
ConstantOrR1CSWitness::Witness(witness) => (
FieldElement::one(),
r1cs_compiler.add_product(c.rs_sqrd, witness),
),
};
let rs_cubed_xor_term = match xor_out {
ConstantOrR1CSWitness::Constant(value) => (FieldElement::from(value), c.rs_cubed),
ConstantOrR1CSWitness::Witness(witness) => (
FieldElement::one(),
r1cs_compiler.add_product(c.rs_cubed, witness),
),
};
r1cs_compiler
.r1cs
.add_constraint(&[(FieldElement::one().neg(), c.rs)], &[rhs.to_tuple()], &[
(FieldElement::one(), denominator),
(FieldElement::one().neg(), c.sz),
lhs.to_tuple(),
rs_sqrd_and_term,
rs_cubed_xor_term,
]);
let inverse = r1cs_compiler.add_witness_builder(WitnessBuilder::Inverse(
r1cs_compiler.num_witnesses(),
denominator,
));
r1cs_compiler.r1cs.add_constraint(
&[(FieldElement::one(), denominator)],
&[(FieldElement::one(), inverse)],
&[(FieldElement::one(), r1cs_compiler.witness_one())],
);
inverse
}
#[cfg(test)]
mod tests {
use {super::*, crate::digits::add_digital_decomposition};
fn constraints_satisfied(r1cs: &provekit_common::R1CS, witness: &[FieldElement]) -> bool {
let a = r1cs.a() * witness;
let b = r1cs.b() * witness;
let c = r1cs.c() * witness;
a.iter()
.zip(b.iter())
.zip(c.iter())
.all(|((av, bv), cv)| *av * *bv == *cv)
}
fn decompose(value: u64, w: u32, d: usize) -> Vec<u64> {
let mask = (1u64 << w) - 1;
(0..d)
.map(|i| (value >> (i as u64 * w as u64)) & mask)
.collect()
}
#[test]
fn optimal_binop_width_always_divides_8() {
for n in 1..=1024 {
let w = get_optimal_binop_width(n);
assert!(
8 % w == 0,
"get_optimal_binop_width({n}) returned {w}, which does not divide 8"
);
}
}
#[test]
fn non_canonical_byte_rejected_by_recomposition() {
for atomic_bits in [2u32, 4] {
let d = (8 / atomic_bits) as usize;
let log_bases = vec![atomic_bits as usize; d];
let mut compiler = NoirToR1CSCompiler::new();
let byte_idx = compiler.num_witnesses();
compiler.r1cs.add_witnesses(1);
compiler.witness_builders.push(WitnessBuilder::Constant(
provekit_common::witness::ConstantTerm(byte_idx, FieldElement::from(200u64)),
));
let dd = add_digital_decomposition(&mut compiler, log_bases, vec![byte_idx]);
let num_w = compiler.num_witnesses();
let mut witness = vec![FieldElement::from(0u64); num_w];
witness[0] = FieldElement::from(1u64); witness[byte_idx] = FieldElement::from(200u64);
for (i, digit) in decompose(200, atomic_bits, d).into_iter().enumerate() {
witness[dd.get_digit_witness_index(i, 0)] = FieldElement::from(digit);
}
assert!(
constraints_satisfied(&compiler.r1cs, &witness),
"Canonical byte=200 with w={atomic_bits} must satisfy recomposition"
);
witness[byte_idx] = FieldElement::from(256u64);
assert!(
!constraints_satisfied(&compiler.r1cs, &witness),
"byte=256 with digits for 200 (w={atomic_bits}) must NOT satisfy recomposition"
);
let max_representable: u64 = (0..d)
.map(|i| ((1u64 << atomic_bits) - 1) << (i as u64 * atomic_bits as u64))
.sum();
assert_eq!(
max_representable, 255,
"Max representable value with w={atomic_bits}, d={d} digits must be exactly 255"
);
}
}
}