use {
crate::noir_to_r1cs::NoirToR1CSCompiler,
ark_ff::{Field, Zero},
provekit_common::{
witness::{compute_spread, ConstantOrR1CSWitness, SumTerm, WitnessBuilder},
FieldElement,
},
std::{
collections::{BTreeMap, HashMap},
ops::Neg,
},
};
pub(crate) const SIGMA0_CHUNKS: [u32; 4] = [2, 11, 9, 10];
pub(crate) const SIGMA1_CHUNKS: [u32; 4] = [6, 5, 14, 7];
pub(crate) const SMALL_SIGMA0_CHUNKS: [u32; 4] = [3, 4, 11, 14];
pub(crate) const SMALL_SIGMA1_CHUNKS: [u32; 4] = [10, 7, 2, 13];
pub(crate) const BYTE_CHUNKS: [u32; 4] = [8, 8, 8, 8];
fn subchunks(bits: u32, w: u32) -> Vec<u32> {
if bits <= w {
vec![bits]
} else {
let mut v = vec![w; (bits / w) as usize];
if bits % w > 0 {
v.push(bits % w);
}
v
}
}
#[derive(Clone, Debug)]
pub(crate) struct SpreadChunk {
pub total_bits: u32,
pub sub_values: Vec<usize>,
pub sub_spreads: Vec<usize>,
pub sub_bits: Vec<u32>,
}
#[derive(Clone, Debug)]
pub(crate) struct SpreadWord {
pub packed: usize,
pub chunks: Vec<SpreadChunk>,
}
impl SpreadWord {
pub fn spread_terms_for_rotation(&self, start_chunk: usize) -> Vec<SumTerm> {
let n = self.chunks.len();
let mut terms = Vec::new();
let mut bit_offset: u32 = 0;
for i in 0..n {
let chunk_idx = (start_chunk + i) % n;
let chunk = &self.chunks[chunk_idx];
let mut sub_offset = bit_offset;
for (j, &sub_bits) in chunk.sub_bits.iter().enumerate() {
let coeff = FieldElement::from(1u64 << (2 * sub_offset));
terms.push(SumTerm(Some(coeff), chunk.sub_spreads[j]));
sub_offset += sub_bits;
}
bit_offset += chunk.total_bits;
}
terms
}
pub fn spread_terms_for_shift(&self, dropped_chunks: usize) -> Vec<SumTerm> {
let mut terms = Vec::new();
let mut bit_offset: u32 = 0;
for chunk in self.chunks.iter().skip(dropped_chunks) {
let mut sub_offset = bit_offset;
for (j, &sub_bits) in chunk.sub_bits.iter().enumerate() {
let coeff = FieldElement::from(1u64 << (2 * sub_offset));
terms.push(SumTerm(Some(coeff), chunk.sub_spreads[j]));
sub_offset += sub_bits;
}
bit_offset += chunk.total_bits;
}
terms
}
pub fn spread_identity(&self) -> Vec<SumTerm> {
self.spread_terms_for_rotation(0)
}
}
pub(crate) struct SpreadAccumulator {
pub table_bits: u32,
pub lookups: Vec<(ConstantOrR1CSWitness, ConstantOrR1CSWitness)>,
pub spread_cache: HashMap<usize, usize>,
pub range_checks: BTreeMap<u32, Vec<usize>>,
}
impl SpreadAccumulator {
pub fn new(table_bits: u32) -> Self {
Self {
table_bits,
lookups: Vec::new(),
spread_cache: HashMap::new(),
range_checks: BTreeMap::new(),
}
}
}
pub(crate) fn decompose_to_spread_word(
compiler: &mut NoirToR1CSCompiler,
accum: &mut SpreadAccumulator,
packed: usize,
chunk_spec: &[u32],
) -> SpreadWord {
let num_chunks = chunk_spec.len();
let w = accum.table_bits;
let mut flat_bits: Vec<u32> = Vec::new();
let mut chunk_sub_counts: Vec<usize> = Vec::with_capacity(num_chunks);
for &bits in chunk_spec {
let subs = subchunks(bits, w);
chunk_sub_counts.push(subs.len());
flat_bits.extend(subs);
}
let sub_start = compiler.num_witnesses();
compiler.add_witness_builder(WitnessBuilder::ChunkDecompose {
output_start: sub_start,
packed,
chunk_bits: flat_bits.clone(),
});
let mut recomp_terms: Vec<(FieldElement, usize)> = Vec::with_capacity(flat_bits.len());
let mut bit_offset: u32 = 0;
for (i, &bits) in flat_bits.iter().enumerate() {
recomp_terms.push((FieldElement::from(1u64 << bit_offset), sub_start + i));
bit_offset += bits;
}
compiler.r1cs.add_constraint(
&recomp_terms,
&[(FieldElement::ONE, compiler.witness_one())],
&[(FieldElement::ONE, packed)],
);
let mut chunks = Vec::with_capacity(num_chunks);
let mut flat_idx = 0usize;
for ci in 0..num_chunks {
let n_subs = chunk_sub_counts[ci];
let sub_bits_slice = &flat_bits[flat_idx..flat_idx + n_subs];
let mut sub_values = Vec::with_capacity(n_subs);
let mut sub_spreads = Vec::with_capacity(n_subs);
for j in 0..n_subs {
let val_idx = sub_start + flat_idx + j;
let spread_idx = add_spread_witness(compiler, accum, val_idx);
if sub_bits_slice[j] < w {
accum
.range_checks
.entry(sub_bits_slice[j])
.or_default()
.push(val_idx);
}
sub_values.push(val_idx);
sub_spreads.push(spread_idx);
}
chunks.push(SpreadChunk {
total_bits: chunk_spec[ci],
sub_values,
sub_spreads,
sub_bits: sub_bits_slice.to_vec(),
});
flat_idx += n_subs;
}
SpreadWord { packed, chunks }
}
fn add_spread_witness(
compiler: &mut NoirToR1CSCompiler,
accum: &mut SpreadAccumulator,
value_idx: usize,
) -> usize {
if let Some(&cached) = accum.spread_cache.get(&value_idx) {
return cached;
}
let spread_idx = compiler.num_witnesses();
compiler.add_witness_builder(WitnessBuilder::SpreadWitness(spread_idx, value_idx));
accum.lookups.push((
ConstantOrR1CSWitness::Witness(value_idx),
ConstantOrR1CSWitness::Witness(spread_idx),
));
accum.spread_cache.insert(value_idx, spread_idx);
spread_idx
}
pub(crate) struct SpreadDecompResult {
pub even_values: Vec<usize>,
pub odd_values: Vec<usize>,
pub chunk_bits: Vec<u32>,
}
pub(crate) fn spread_decompose(
compiler: &mut NoirToR1CSCompiler,
accum: &mut SpreadAccumulator,
sum_terms: Vec<SumTerm>,
) -> SpreadDecompResult {
let extract_chunks = subchunks(32, accum.table_bits);
let n_chunks = extract_chunks.len();
let mut combined: HashMap<usize, FieldElement> = HashMap::new();
for SumTerm(coeff, idx) in &sum_terms {
*combined.entry(*idx).or_insert(FieldElement::zero()) += coeff.unwrap_or(FieldElement::ONE);
}
let az: Vec<(FieldElement, usize)> = combined
.into_iter()
.map(|(idx, coeff)| (coeff, idx))
.collect();
let even_start = compiler.num_witnesses();
compiler.add_witness_builder(WitnessBuilder::SpreadBitExtract {
output_start: even_start,
chunk_bits: extract_chunks.clone(),
sum_terms: sum_terms.clone(),
extract_even: true,
});
let odd_start = compiler.num_witnesses();
compiler.add_witness_builder(WitnessBuilder::SpreadBitExtract {
output_start: odd_start,
chunk_bits: extract_chunks.clone(),
sum_terms,
extract_even: false,
});
let mut even_spreads = Vec::with_capacity(n_chunks);
let mut odd_spreads = Vec::with_capacity(n_chunks);
for i in 0..n_chunks {
let even_val = even_start + i;
let odd_val = odd_start + i;
even_spreads.push(add_spread_witness(compiler, accum, even_val));
odd_spreads.push(add_spread_witness(compiler, accum, odd_val));
if extract_chunks[i] < accum.table_bits {
accum
.range_checks
.entry(extract_chunks[i])
.or_default()
.push(even_val);
accum
.range_checks
.entry(extract_chunks[i])
.or_default()
.push(odd_val);
}
}
let even_values: Vec<usize> = (even_start..even_start + n_chunks).collect();
let odd_values: Vec<usize> = (odd_start..odd_start + n_chunks).collect();
let two = FieldElement::from(2u64);
let mut cz: Vec<(FieldElement, usize)> = Vec::with_capacity(2 * n_chunks);
let mut bit_offset = 0u32;
for (i, &bits) in extract_chunks.iter().enumerate() {
let base = FieldElement::from(1u64 << (2 * bit_offset));
cz.push((two * base, odd_spreads[i]));
cz.push((base, even_spreads[i]));
bit_offset += bits;
}
compiler
.r1cs
.add_constraint(&az, &[(FieldElement::ONE, compiler.witness_one())], &cz);
SpreadDecompResult {
even_values,
odd_values,
chunk_bits: extract_chunks,
}
}
pub(crate) fn pack_chunks(
compiler: &mut NoirToR1CSCompiler,
chunk_bits: &[u32],
values: &[usize],
) -> usize {
assert_eq!(chunk_bits.len(), values.len());
let packed_idx = compiler.num_witnesses();
let mut terms = Vec::with_capacity(values.len());
let mut constraint_terms = Vec::with_capacity(values.len());
let mut multiplier = FieldElement::ONE;
for (i, &val_idx) in values.iter().enumerate() {
terms.push(SumTerm(Some(multiplier), val_idx));
constraint_terms.push((multiplier, val_idx));
multiplier *= FieldElement::from(1u64 << chunk_bits[i]);
}
compiler.add_witness_builder(WitnessBuilder::Sum(packed_idx, terms));
compiler.r1cs.add_constraint(
&constraint_terms,
&[(FieldElement::ONE, compiler.witness_one())],
&[(FieldElement::ONE, packed_idx)],
);
packed_idx
}
pub(crate) fn add_u32_addition_spread(
compiler: &mut NoirToR1CSCompiler,
accum: &mut SpreadAccumulator,
packed_inputs: &[usize],
constants: &[u32],
output_chunks: &[u32],
) -> SpreadWord {
let const_sum: u64 = constants.iter().map(|&c| c as u64).sum();
let const_field = FieldElement::from(const_sum);
let result_witness = compiler.num_witnesses();
let carry_witness = result_witness + 1;
let mut wb_inputs: Vec<ConstantOrR1CSWitness> = packed_inputs
.iter()
.map(|&w| ConstantOrR1CSWitness::Witness(w))
.collect();
for &c in constants {
wb_inputs.push(ConstantOrR1CSWitness::Constant(FieldElement::from(
c as u64,
)));
}
compiler.add_witness_builder(WitnessBuilder::U32AdditionMulti(
result_witness,
carry_witness,
wb_inputs,
));
let mut sum_lhs: Vec<(FieldElement, usize)> = packed_inputs
.iter()
.map(|&w| (FieldElement::ONE, w))
.collect();
if const_sum > 0 {
sum_lhs.push((const_field, compiler.witness_one()));
}
let two_pow_32 = FieldElement::from(1u64 << 32);
compiler
.r1cs
.add_constraint(&sum_lhs, &[(FieldElement::ONE, compiler.witness_one())], &[
(FieldElement::ONE, result_witness),
(two_pow_32, carry_witness),
]);
let carry_spread = compiler.num_witnesses();
compiler.add_witness_builder(WitnessBuilder::SpreadWitness(carry_spread, carry_witness));
accum.lookups.push((
ConstantOrR1CSWitness::Witness(carry_witness),
ConstantOrR1CSWitness::Witness(carry_spread),
));
decompose_to_spread_word(compiler, accum, result_witness, output_chunks)
}
pub(crate) fn decompose_constant_to_spread_word(
compiler: &mut NoirToR1CSCompiler,
packed_witness: usize,
constant_value: u32,
chunk_spec: &[u32],
table_bits: u32,
) -> SpreadWord {
let w = table_bits;
let w_one = compiler.witness_one();
let num_chunks = chunk_spec.len();
let mut flat_bits: Vec<u32> = Vec::new();
let mut chunk_sub_counts: Vec<usize> = Vec::with_capacity(num_chunks);
for &bits in chunk_spec {
let subs = subchunks(bits, w);
chunk_sub_counts.push(subs.len());
flat_bits.extend(subs);
}
let mut flat_values: Vec<u64> = Vec::with_capacity(flat_bits.len());
let mut remaining = constant_value as u64;
for &bits in &flat_bits {
let mask = (1u64 << bits) - 1;
flat_values.push(remaining & mask);
remaining >>= bits;
}
let mut recomposed: u64 = 0;
let mut shift = 0u32;
for (i, &bits) in flat_bits.iter().enumerate() {
recomposed += flat_values[i] << shift;
shift += bits;
}
assert_eq!(
recomposed, constant_value as u64,
"constant spread decomposition mismatch: {constant_value:#x} decomposed to {recomposed:#x}"
);
let mut chunks = Vec::with_capacity(num_chunks);
let mut flat_idx = 0usize;
for ci in 0..num_chunks {
let n_subs = chunk_sub_counts[ci];
let sub_bits_slice = &flat_bits[flat_idx..flat_idx + n_subs];
let mut sub_spreads = Vec::with_capacity(n_subs);
for j in 0..n_subs {
let spread_val = compute_spread(flat_values[flat_idx + j]);
let spread_idx =
compiler.add_sum(vec![SumTerm(Some(FieldElement::from(spread_val)), w_one)]);
sub_spreads.push(spread_idx);
}
chunks.push(SpreadChunk {
total_bits: chunk_spec[ci],
sub_values: sub_spreads.clone(),
sub_spreads,
sub_bits: sub_bits_slice.to_vec(),
});
flat_idx += n_subs;
}
SpreadWord {
packed: packed_witness,
chunks,
}
}
pub(crate) fn add_spread_table_constraints(
compiler: &mut NoirToR1CSCompiler,
accum: SpreadAccumulator,
) -> BTreeMap<u32, Vec<usize>> {
let range_checks = accum.range_checks;
if accum.lookups.is_empty() {
return range_checks;
}
let table_size = 1u32 << accum.table_bits;
let mult_first = compiler.num_witnesses();
let query_inputs: Vec<ConstantOrR1CSWitness> =
accum.lookups.iter().map(|(input, _)| *input).collect();
compiler.add_witness_builder(WitnessBuilder::MultiplicitiesForSpread(
mult_first,
accum.table_bits,
query_inputs,
));
let sz = compiler.add_witness_builder(WitnessBuilder::Challenge(compiler.num_witnesses()));
let rs = compiler.add_witness_builder(WitnessBuilder::Challenge(compiler.num_witnesses()));
let mut logup_summands: Vec<(FieldElement, usize)> = Vec::new();
for (input, spread_output) in &accum.lookups {
let denom = compiler.add_witness_builder(WitnessBuilder::SpreadLookupDenominator(
compiler.num_witnesses(),
sz,
rs,
*input,
*spread_output,
));
let (input_coeff, input_idx) = input.to_tuple();
let mut az: Vec<(FieldElement, usize)> =
vec![(FieldElement::ONE, sz), (input_coeff.neg(), input_idx)];
match spread_output {
ConstantOrR1CSWitness::Constant(val) => {
az.push((val.neg(), rs));
}
ConstantOrR1CSWitness::Witness(w) => {
let prod = compiler.add_product(rs, *w);
az.push((FieldElement::ONE.neg(), prod));
}
}
let inverse =
compiler.add_witness_builder(WitnessBuilder::Inverse(compiler.num_witnesses(), denom));
compiler
.r1cs
.add_constraint(&az, &[(FieldElement::ONE, inverse)], &[(
FieldElement::ONE,
compiler.witness_one(),
)]);
logup_summands.push((FieldElement::ONE.neg(), inverse));
}
for x in 0..table_size {
let spread_x = compute_spread(x as u64);
let multiplicity_idx = mult_first + x as usize;
let quotient = compiler.add_witness_builder(WitnessBuilder::SpreadTableQuotient {
idx: compiler.num_witnesses(),
sz,
rs,
input_val: FieldElement::from(x),
spread_val: FieldElement::from(spread_x),
multiplicity: multiplicity_idx,
});
compiler.r1cs.add_constraint(
&[
(FieldElement::ONE, sz),
(FieldElement::from(x).neg(), compiler.witness_one()),
(FieldElement::from(spread_x).neg(), rs),
],
&[(FieldElement::ONE, quotient)],
&[(FieldElement::ONE, multiplicity_idx)],
);
logup_summands.push((FieldElement::ONE, quotient));
}
compiler.r1cs.add_constraint(
&logup_summands,
&[(FieldElement::ONE, compiler.witness_one())],
&[(FieldElement::zero(), compiler.witness_one())],
);
range_checks
}
pub(crate) fn calculate_spread_witness_cost(w: u32, n_sha: usize, n_const_hash: usize) -> usize {
let sc = |bits: u32| bits.div_ceil(w) as usize;
let m_spec = |spec: &[u32]| -> usize { spec.iter().map(|&b| sc(b)).sum() };
let n_sd = sc(32);
let decomp = |spec: &[u32]| 2 * m_spec(spec); let decomp_const = |spec: &[u32]| m_spec(spec); let sd = 4 * n_sd; let pk = 1usize; let add = |spec: &[u32]| 3 + 2 * m_spec(spec);
let decomp_l = |spec: &[u32]| m_spec(spec);
let sd_l = 2 * n_sd;
let add_l = |spec: &[u32]| 1 + m_spec(spec);
let init_inline = 16 * decomp(&BYTE_CHUNKS)
+ 2 * decomp(&BYTE_CHUNKS)
+ 3 * decomp(&SIGMA0_CHUNKS)
+ 3 * decomp(&SIGMA1_CHUNKS);
let init_lookups = 16 * decomp_l(&BYTE_CHUNKS)
+ 2 * decomp_l(&BYTE_CHUNKS)
+ 3 * decomp_l(&SIGMA0_CHUNKS)
+ 3 * decomp_l(&SIGMA1_CHUNKS);
let init_const_inline = 16 * decomp(&BYTE_CHUNKS)
+ 2 * decomp_const(&BYTE_CHUNKS)
+ 3 * decomp_const(&SIGMA0_CHUNKS)
+ 3 * decomp_const(&SIGMA1_CHUNKS);
let init_const_lookups = 16 * decomp_l(&BYTE_CHUNKS);
let msg_inline = (decomp(&SMALL_SIGMA0_CHUNKS) + sd + pk)
+ (decomp(&SMALL_SIGMA1_CHUNKS) + sd + pk)
+ add(&BYTE_CHUNKS);
let msg_lookups = (decomp_l(&SMALL_SIGMA0_CHUNKS) + sd_l)
+ (decomp_l(&SMALL_SIGMA1_CHUNKS) + sd_l)
+ add_l(&BYTE_CHUNKS);
let comp_inline = (sd + pk) + (2 * sd + 2 * pk + 1) + (sd + pk) + (sd + pk) + add(&SIGMA1_CHUNKS) + add(&SIGMA0_CHUNKS); let comp_lookups =
sd_l + 2 * sd_l + sd_l + sd_l + add_l(&SIGMA1_CHUNKS) + add_l(&SIGMA0_CHUNKS);
let final_inline = 8 * add(&BYTE_CHUNKS);
let final_lookups = 8 * add_l(&BYTE_CHUNKS);
let shared_inline = 48 * msg_inline + 64 * comp_inline + final_inline;
let shared_lookups = 48 * msg_lookups + 64 * comp_lookups + final_lookups;
let n_normal = n_sha - n_const_hash;
let normal_cost =
n_normal * ((init_inline + shared_inline) + 3 * (init_lookups + shared_lookups));
let const_cost = n_const_hash
* ((init_const_inline + shared_inline) + 3 * (init_const_lookups + shared_lookups));
let table = 2 * (1usize << w) + 2;
table + normal_cost + const_cost
}
pub(crate) fn get_optimal_spread_width(n_sha: usize, n_const_hash: usize) -> u32 {
(3u32..=20)
.min_by_key(|&w| calculate_spread_witness_cost(w, n_sha, n_const_hash))
.unwrap()
}
#[cfg(test)]
mod tests {
use {
super::*, crate::noir_to_r1cs::NoirToR1CSCompiler, provekit_common::witness::compute_spread,
};
#[test]
fn test_optimal_spread_width() {
assert_eq!(get_optimal_spread_width(1, 0), 11);
assert_eq!(get_optimal_spread_width(35, 0), 16);
}
#[test]
fn constant_decomp_pins_correct_spread_values() {
let value: u32 = 0xcafe_babe;
let mut compiler = NoirToR1CSCompiler::new();
let w_one = compiler.witness_one();
let packed = compiler.add_sum(vec![SumTerm(Some(FieldElement::from(value as u64)), w_one)]);
let sw = decompose_constant_to_spread_word(&mut compiler, packed, value, &BYTE_CHUNKS, 8);
let bytes = value.to_le_bytes();
for (i, chunk) in sw.chunks.iter().enumerate() {
let spread_idx = chunk.sub_spreads[0];
match &compiler.witness_builders[spread_idx] {
WitnessBuilder::Sum(idx, terms) => {
assert_eq!(*idx, spread_idx);
assert_eq!(terms.len(), 1);
let expected = compute_spread(bytes[i] as u64);
assert_eq!(
terms[0].0.unwrap(),
FieldElement::from(expected),
"chunk {i}: spread mismatch for byte {:#04x}",
bytes[i],
);
}
other => panic!("chunk {i}: expected Sum, got {other:?}"),
}
}
}
#[test]
fn constant_decomp_multi_sub_chunk() {
let value: u32 = 0xa5a5_a5a5;
let mut compiler = NoirToR1CSCompiler::new();
let w_one = compiler.witness_one();
let packed = compiler.add_sum(vec![SumTerm(Some(FieldElement::from(value as u64)), w_one)]);
let sw = decompose_constant_to_spread_word(
&mut compiler,
packed,
value,
&SIGMA0_CHUNKS, 8,
);
assert_eq!(sw.chunks[0].sub_bits, vec![2]);
assert_eq!(sw.chunks[1].sub_bits, vec![8, 3]);
assert_eq!(sw.chunks[2].sub_bits, vec![8, 1]);
assert_eq!(sw.chunks[3].sub_bits, vec![8, 2]);
let mut remaining = value as u64;
for chunk in &sw.chunks {
for (j, &bits) in chunk.sub_bits.iter().enumerate() {
let mask = (1u64 << bits) - 1;
let sub_val = remaining & mask;
remaining >>= bits;
let spread_idx = chunk.sub_spreads[j];
match &compiler.witness_builders[spread_idx] {
WitnessBuilder::Sum(_, terms) => {
let expected = compute_spread(sub_val);
assert_eq!(
terms[0].0.unwrap(),
FieldElement::from(expected),
"sub-chunk value {sub_val:#x} ({bits} bits): spread mismatch",
);
}
other => panic!("expected Sum, got {other:?}"),
}
}
}
assert_eq!(remaining, 0, "all 32 bits should be consumed");
}
}