provekit-r1cs-compiler 1.0.0

R1CS compiler for ProveKit, translating Noir programs to R1CS constraints
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
use {
    crate::{
        digits::{add_digital_decomposition, DigitalDecompositionWitnessesBuilder},
        noir_to_r1cs::NoirToR1CSCompiler,
    },
    ark_std::{One, Zero},
    provekit_common::{
        witness::{ProductLinearTerm, WitnessBuilder, WitnessCoefficient},
        FieldElement,
    },
    std::{
        collections::{BTreeMap, HashSet},
        ops::Neg,
    },
};

/// Minimum base width to consider during optimization.
const MIN_BASE_WIDTH: u32 = 2;

/// Maximum base width to consider during optimization. Beyond 17 bits
/// the table side alone (2^18+ entries) always exceeds the cost of
/// decomposing into smaller digits.
const MAX_BASE_WIDTH: u32 = 17;

/// A single range check request: a witness that must be in [0, 2^bits).
struct RangeCheckRequest {
    witness_idx: usize,
    bits:        u32,
}

/// Determines whether LogUp is cheaper than naive range checking for a bucket
/// of `num_bits`-wide checks containing `count` witnesses, based on witness
/// count (memory cost).
///
/// LogUp witnesses: table_size (multiplicities) + 2×table_size (inverse +
/// quotient per table entry) + count (inverse per witness) + 1 (challenge)
/// = 3×table_size + count + 1.
/// Naive witnesses: (table_size - 2) intermediate products per witness.
///
/// Returns `true` if LogUp should be used, `false` for naive product checks.
fn should_use_logup(num_bits: u32, count: usize) -> bool {
    let table_size = 1usize << num_bits;
    // LogUp witnesses: table_size (multiplicities) + 2*table_size
    // (inverse + quotient per entry) + count (inverse per witness)
    // + 1 (challenge)
    let logup_cost = table_size
        .saturating_mul(3)
        .saturating_add(count)
        .saturating_add(1);
    // Naive witnesses: (table_size - 2) intermediate products per witness
    let naive_cost = count.saturating_mul(table_size.saturating_sub(2));
    logup_cost < naive_cost
}

/// Returns the witness cost for a single atomic bucket of `num_bits`-wide
/// checks containing `count` witnesses, choosing whichever strategy (LogUp
/// or naive) produces fewer witnesses. Returns `usize::MAX` for
/// impractically large bit widths where the table would overflow.
fn bucket_cost(num_bits: u32, count: usize) -> usize {
    if count == 0 || num_bits == 0 {
        return 0;
    }
    // Guard against overflow: a table of 2^num_bits entries is impractical
    // for large bit widths. usize is at least 32 bits wide, but even
    // 2^30 entries would be enormous; cap at (usize::BITS - 2) to avoid
    // overflow in the arithmetic below.
    if num_bits >= (usize::BITS - 1) {
        return usize::MAX;
    }
    let table_size = 1usize << num_bits;
    let logup_cost = table_size
        .saturating_mul(3)
        .saturating_add(count)
        .saturating_add(1);
    let naive_cost = count.saturating_mul(table_size.saturating_sub(2));
    if should_use_logup(num_bits, count) {
        logup_cost
    } else {
        naive_cost
    }
}

/// Calculates the total witness cost for a given `base_width`.
///
/// For each request with `bits > base_width`, a digital decomposition is
/// performed, creating `num_digits` digit witnesses per value. The
/// resulting digits are bucketed by their bit width. Each bucket's cost
/// is the cheaper of LogUp lookup and naive product checks, measured in
/// witness count.
fn calculate_witness_cost(base_width: u32, collected: &[RangeCheckRequest]) -> usize {
    let mut decomposition_witnesses: usize = 0;
    let mut atomic_buckets: BTreeMap<u32, usize> = BTreeMap::new();

    for check in collected {
        if check.bits <= base_width {
            // No decomposition needed; goes directly to atomic bucket.
            *atomic_buckets.entry(check.bits).or_default() += 1;
        } else {
            let num_full_digits = check.bits / base_width;
            let remainder = check.bits % base_width;
            let num_digits = num_full_digits as usize + if remainder > 0 { 1 } else { 0 };

            // Decomposition creates num_digits witnesses per value.
            decomposition_witnesses += num_digits;

            *atomic_buckets.entry(base_width).or_default() += num_full_digits as usize;
            if remainder > 0 {
                *atomic_buckets.entry(remainder).or_default() += 1;
            }
        }
    }

    let mut total = decomposition_witnesses;
    for (&num_bits, &count) in &atomic_buckets {
        total = total.saturating_add(bucket_cost(num_bits, count));
    }
    total
}

/// Finds the base width that minimizes the total witness count for the
/// given set of range check requests.
///
/// Searches widths from [MIN_BASE_WIDTH, MAX_BASE_WIDTH]. Base widths
/// above 17 are never beneficial because the table side alone would
/// require 2^18+ witnesses, which always exceeds the cost of
/// decomposing into smaller digits.
fn get_optimal_base_width(collected: &[RangeCheckRequest]) -> u32 {
    let mut min_cost = usize::MAX;
    let mut optimal_width = 8u32;

    for base_width in MIN_BASE_WIDTH..=MAX_BASE_WIDTH {
        let cost = calculate_witness_cost(base_width, collected);
        if cost < min_cost {
            min_cost = cost;
            optimal_width = base_width;
        }
    }

    optimal_width
}

/// Add witnesses and constraints that ensure that the values of the witness
/// belong to a range 0..2^k (for some k).
///
/// Uses dynamic base width optimization: all range check requests are
/// collected, and the optimal decomposition base width is determined by
/// minimizing the total witness count (memory cost). The search evaluates
/// every base width from [MIN_BASE_WIDTH] to [MAX_BASE_WIDTH]. For each
/// candidate, the cost model picks the cheaper of LogUp and naive for
/// every atomic bucket.
///
/// Values with bit widths larger than the chosen base are digitally
/// decomposed; the resulting digits (and values already ≤ the base) are then
/// range checked via LogUp lookup or naive product checks, whichever is
/// cheaper per bucket.
///
/// `range_checks` is a map from the number of bits k to the vector of
/// witness indices that are to be constrained within the range [0..2^k].
pub(crate) fn add_range_checks(
    r1cs: &mut NoirToR1CSCompiler,
    range_checks: BTreeMap<u32, Vec<usize>>,
) -> Option<u32> {
    if range_checks.is_empty() {
        return None;
    }

    // Phase 1: Flatten all range checks into individual requests and
    // deduplicate per bit-width group.
    let collected: Vec<RangeCheckRequest> = range_checks
        .into_iter()
        .flat_map(|(num_bits, values)| {
            let mut seen = HashSet::new();
            values
                .into_iter()
                .filter(move |v| seen.insert(*v))
                .map(move |witness_idx| RangeCheckRequest {
                    witness_idx,
                    bits: num_bits,
                })
        })
        .collect();

    if collected.is_empty() {
        return None;
    }

    // Phase 2: Find the optimal base width that minimizes total witness
    // count.
    let base_width = get_optimal_base_width(&collected);

    // Phase 3: Decompose values larger than base_width and collect atomic
    // range check buckets.
    let max_bucket = base_width as usize + 1;
    let mut atomic_range_checks: Vec<Vec<Vec<usize>>> = vec![vec![vec![]]; max_bucket];

    // Group collected requests by bit width for batch decomposition.
    let mut by_bits: BTreeMap<u32, Vec<usize>> = BTreeMap::new();
    for req in &collected {
        by_bits.entry(req.bits).or_default().push(req.witness_idx);
    }

    for (num_bits, values_to_lookup) in by_bits {
        if num_bits > base_width {
            let num_full_digits = num_bits / base_width;
            let remainder = num_bits % base_width;
            let mut log_bases = vec![base_width as usize; num_full_digits as usize];
            if remainder > 0 {
                log_bases.push(remainder as usize);
            }
            let dd_struct = add_digital_decomposition(r1cs, log_bases, values_to_lookup);

            dd_struct
                .log_bases
                .iter()
                .enumerate()
                .map(|(digit_place, log_base)| {
                    (
                        *log_base as u32,
                        (0..dd_struct.num_witnesses_to_decompose)
                            .map(|i| dd_struct.get_digit_witness_index(digit_place, i))
                            .collect::<Vec<_>>(),
                    )
                })
                .for_each(|(log_base, digit_witnesses)| {
                    atomic_range_checks[log_base as usize].push(digit_witnesses);
                });
        } else {
            atomic_range_checks[num_bits as usize].push(values_to_lookup);
        }
    }

    // Phase 4: For each atomic bucket, add range check constraints.
    // Choose LogUp or naive based on whichever produces fewer witnesses.
    atomic_range_checks
        .iter()
        .enumerate()
        .for_each(|(num_bits, all_values_to_lookup)| {
            // Deduplicate across digit groups.
            let values_to_lookup: Vec<usize> = {
                let mut seen = HashSet::new();
                all_values_to_lookup
                    .iter()
                    .flat_map(|v| v.iter())
                    .copied()
                    .filter(|v| seen.insert(*v))
                    .collect()
            };
            if values_to_lookup.is_empty() {
                return;
            }
            let num_bits = num_bits as u32;
            if should_use_logup(num_bits, values_to_lookup.len()) {
                add_range_check_via_lookup(r1cs, num_bits, &values_to_lookup);
            } else {
                values_to_lookup.iter().for_each(|value| {
                    add_naive_range_check(r1cs, num_bits, *value);
                })
            }
        });

    Some(base_width)
}

/// Helper function which computes all the terms of the summation for
/// each side (LHS and RHS) of the log-derivative multiset check.
/// Uses a fused constraint to check equality of both sums directly.
fn add_range_check_via_lookup(
    r1cs_compiler: &mut NoirToR1CSCompiler,
    num_bits: u32,
    values_to_lookup: &[usize],
) {
    // Add witnesses for the multiplicities
    let wb = WitnessBuilder::MultiplicitiesForRange(
        r1cs_compiler.num_witnesses(),
        1 << num_bits,
        values_to_lookup.into(),
    );
    let multiplicities_first_witness = r1cs_compiler.add_witness_builder(wb);
    // Sample the Schwartz-Zippel challenge for the log derivative
    // multiset check.
    let sz_challenge =
        r1cs_compiler.add_witness_builder(WitnessBuilder::Challenge(r1cs_compiler.num_witnesses()));

    // Collect table side terms: multiplicity / (X - table_value)
    // Uses fused single constraint: (X - table_value) × quotient = multiplicity
    // instead of two constraints (inverse + product).
    let mut logup_summands: Vec<(FieldElement, usize)> = (0..(1 << num_bits))
        .map(|table_value| {
            let multiplicity_witness = multiplicities_first_witness + table_value;
            (
                FieldElement::one(),
                add_range_table_entry_quotient(
                    r1cs_compiler,
                    sz_challenge,
                    table_value as u64,
                    multiplicity_witness,
                ),
            )
        })
        .collect();

    // Collect witness side terms with negated coefficients: -1/(X - witness_value)
    for value in values_to_lookup {
        let witness_idx =
            add_lookup_factor(r1cs_compiler, sz_challenge, FieldElement::one(), *value);
        logup_summands.push((FieldElement::one().neg(), witness_idx));
    }

    // Constraint: (Σ table_terms - Σ witness_terms) * 1 = 0
    r1cs_compiler.r1cs.add_constraint(
        &logup_summands,
        &[(FieldElement::one(), r1cs_compiler.witness_one())],
        &[(FieldElement::zero(), r1cs_compiler.witness_one())],
    );
}

/// Helper function that computes the inverse of the LogUp denominator
/// for table values: 1/(X - t_j), or for witness values: 1/(X - w_i).
/// Uses a single fused constraint to verify the inverse.
pub(crate) fn add_lookup_factor(
    r1cs_compiler: &mut NoirToR1CSCompiler,
    sz_challenge: usize,
    value_coeff: FieldElement,
    value_witness: usize,
) -> usize {
    // Directly compute inverse of (X - c·v) using LogUpInverse
    let inverse = r1cs_compiler.add_witness_builder(WitnessBuilder::LogUpInverse(
        r1cs_compiler.num_witnesses(),
        sz_challenge,
        WitnessCoefficient(value_coeff, value_witness),
    ));
    // Single fused constraint: (X - c·v) * inverse = 1
    r1cs_compiler.r1cs.add_constraint(
        &[
            (FieldElement::one(), sz_challenge),
            (value_coeff.neg(), value_witness),
        ],
        &[(FieldElement::one(), inverse)],
        &[(FieldElement::one(), r1cs_compiler.witness_one())],
    );

    inverse
}

/// A naive range check helper function, computing the
/// $\prod_{i = 0}^{range}(a - i) = 0$ to check whether a witness found at
/// `index_witness`, which is $a$, is in the $range$, which is `num_bits`.
fn add_naive_range_check(
    r1cs_compiler: &mut NoirToR1CSCompiler,
    num_bits: u32,
    index_witness: usize,
) {
    let mut current_product_witness = index_witness;
    (1..(1 << num_bits) - 1).for_each(|index: u32| {
        let next_product_witness =
            r1cs_compiler.add_witness_builder(WitnessBuilder::ProductLinearOperation(
                r1cs_compiler.num_witnesses(),
                ProductLinearTerm(
                    current_product_witness,
                    FieldElement::one(),
                    FieldElement::zero(),
                ),
                ProductLinearTerm(
                    index_witness,
                    FieldElement::one(),
                    FieldElement::from(index).neg(),
                ),
            ));
        r1cs_compiler.r1cs.add_constraint(
            &[(FieldElement::one(), current_product_witness)],
            &[
                (FieldElement::one(), index_witness),
                (FieldElement::from(index).neg(), r1cs_compiler.witness_one()),
            ],
            &[(FieldElement::one(), next_product_witness)],
        );
        current_product_witness = next_product_witness;
    });

    r1cs_compiler.r1cs.add_constraint(
        &[(FieldElement::one(), current_product_witness)],
        &[
            (FieldElement::one(), index_witness),
            (
                FieldElement::from((1 << num_bits) - 1_u32).neg(),
                r1cs_compiler.witness_one(),
            ),
        ],
        &[(FieldElement::zero(), r1cs_compiler.witness_one())],
    );
}

/// Computes quotient = multiplicity / (X - table_value) using a single R1CS
/// constraint: (X - table_value) × quotient = multiplicity.
///
/// Internally creates an inverse witness (for batch inversion) and a product
/// witness (inverse × multiplicity), but only emits one constraint instead
/// of the usual two (inverse constraint + product constraint).
fn add_range_table_entry_quotient(
    r1cs_compiler: &mut NoirToR1CSCompiler,
    sz_challenge: usize,
    table_value: u64,
    multiplicity_witness: usize,
) -> usize {
    // Step 1: Create inverse witness 1/(X - table_value) for batch inversion
    let inverse = r1cs_compiler.add_witness_builder(WitnessBuilder::LogUpInverse(
        r1cs_compiler.num_witnesses(),
        sz_challenge,
        WitnessCoefficient(FieldElement::from(table_value), r1cs_compiler.witness_one()),
    ));

    // Step 2: Create product witness (multiplicity * inverse = quotient)
    // Note: we do NOT call add_product() because that would add a constraint.
    let quotient = r1cs_compiler.add_witness_builder(WitnessBuilder::Product(
        r1cs_compiler.num_witnesses(),
        multiplicity_witness,
        inverse,
    ));

    // Step 3: Single constraint: (X - table_value) × quotient = multiplicity
    // This replaces two constraints: (X - table_value) × inverse = 1 and
    // inverse × multiplicity = quotient.
    r1cs_compiler.r1cs.add_constraint(
        &[
            (FieldElement::one(), sz_challenge),
            (
                FieldElement::from(table_value).neg(),
                r1cs_compiler.witness_one(),
            ),
        ],
        &[(FieldElement::one(), quotient)],
        &[(FieldElement::one(), multiplicity_witness)],
    );

    quotient
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Verifies that bucket_cost returns 0 for edge cases where no work is
    /// needed.
    #[test]
    fn bucket_cost_zero_cases() {
        assert_eq!(bucket_cost(0, 100), 0);
        assert_eq!(bucket_cost(5, 0), 0);
    }

    /// Verifies the overflow guard returns usize::MAX for impractically large
    /// bit widths (>= usize::BITS - 1, which is 63 on 64-bit systems).
    #[test]
    fn bucket_cost_overflow_guard() {
        assert_eq!(bucket_cost(63, 1), usize::MAX);
    }

    /// Verifies should_use_logup witness-based decision logic.
    #[test]
    fn should_use_logup_decision() {
        // 1-bit, 1 witness: naive=1×(2-2)=0, logup=3×2+1+1=8 → naive wins
        assert!(!should_use_logup(1, 1));
        // 8-bit, 5 witnesses: naive=5×254=1270, logup=3×256+5+1=774 → logup
        assert!(should_use_logup(8, 5));
        // 8-bit, 1 witness: naive=1×254=254, logup=3×256+1+1=770 → naive
        assert!(!should_use_logup(8, 1));
        // 8-bit, 256 witnesses: naive=256×254=65024, logup=3×256+256+1=1025
        assert!(should_use_logup(8, 256));
    }
}