p3_lookup/logup.rs
1//! Core LogUp Implementation
2//!
3//! ## Mathematical Foundation
4//!
5//! LogUp transforms the standard lookup equation:
6//! ```text
7//! ∏(α - a_i)^(m_i) = ∏(α - b_j)^(m'_j)
8//! ```
9//!
10//! Into an equivalent sum-based form using logarithmic derivatives:
11//! ```text
12//! ∑(m_i/(α - a_i)) = ∑(m'_j/(α - b_j))
13//! ```
14//!
15//! Where:
16//! - `α` is a random challenge
17//! - `m_i, m'_j` are multiplicities (how many times each element appears)
18//! - The transformation eliminates expensive exponentiation operations
19
20use alloc::vec;
21use alloc::vec::Vec;
22
23use p3_air::{ExtensionBuilder, PermutationAirBuilder, WindowAccess};
24use p3_field::{Field, PrimeCharacteristicRing};
25use p3_matrix::Matrix;
26use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
27use p3_matrix::stack::VerticalPair;
28use p3_maybe_rayon::prelude::*;
29use p3_uni_stark::{StarkGenericConfig, Val};
30use tracing::instrument;
31
32use crate::lookup_traits::{
33 Kind, Lookup, LookupData, LookupGadget, LookupTraceBuilder, symbolic_to_expr,
34};
35use crate::types::{LookupError, LookupEvaluator};
36
37/// Core LogUp gadget implementing lookup arguments via logarithmic derivatives.
38///
39/// The LogUp gadget transforms the multiplicative lookup constraint:
40/// ```text
41/// ∏(α - a_i)^(m_i) = ∏(α - b_j)^(m'_j)
42/// ```
43///
44/// Into an equivalent additive constraint using logarithmic differentiation:
45/// ```text
46/// ∑(m_i/(α - a_i)) = ∑(m'_j/(α - b_j))
47/// ```
48///
49/// This is implemented using a running sum auxiliary column `s` that accumulates:
50/// ```text
51/// s[i+1] = s[i] + ∑(m_a/(α - a)) - ∑(m_b/(α - b))
52/// ```
53///
54/// Note that we do not differentiate between `a` and `b` in the implementation:
55/// we simply have a list of `elements` with possibly negative `multiplicities`.
56///
57/// Constraints are defined as:
58/// - **Initial Constraint**: `s[0] = 0`
59/// - **Transition Constraint**: `s[i+1] = s[i] + contribution[i]`
60/// - **Final Constraint**: `s[n-1] + contribution[n-1] = 0`
61#[derive(Debug, Clone, Default)]
62pub struct LogUpGadget;
63
64impl LogUpGadget {
65 /// Creates a new LogUp gadget instance.
66 pub const fn new() -> Self {
67 Self {}
68 }
69
70 /// Computes the combined elements for each tuple using the challenge `beta`:
71 /// `combined_elements[i] = ∑elements[i][n-j] * β^j`
72 fn combine_elements<AB, E>(
73 &self,
74 elements: &[Vec<E>],
75 alpha: &AB::ExprEF,
76 beta: &AB::ExprEF,
77 ) -> Vec<AB::ExprEF>
78 where
79 AB: PermutationAirBuilder,
80 E: Into<AB::ExprEF> + Clone,
81 {
82 elements
83 .iter()
84 .map(|elts| {
85 // Combine the elements in the tuple using beta.
86 let combined_elt = elts.iter().fold(AB::ExprEF::ZERO, |acc, elt| {
87 elt.clone().into() + acc * beta.clone()
88 });
89
90 // Compute (α - combined_elt)
91 alpha.clone() - combined_elt
92 })
93 .collect()
94 }
95
96 /// Computes the numerator and denominator of the fraction:
97 /// `∑(m_i / (α - combined_elements[i]))`, where
98 /// `combined_elements[i] = ∑elements[i][n-j] * β^j
99 pub(crate) fn compute_combined_sum_terms<AB, E, M>(
100 &self,
101 elements: &[Vec<E>],
102 multiplicities: &[M],
103 alpha: &AB::ExprEF,
104 beta: &AB::ExprEF,
105 ) -> (AB::ExprEF, AB::ExprEF)
106 where
107 AB: PermutationAirBuilder,
108 E: Into<AB::ExprEF> + Clone,
109 M: Into<AB::ExprEF> + Clone,
110 {
111 if elements.is_empty() {
112 return (AB::ExprEF::ZERO, AB::ExprEF::ONE);
113 }
114
115 let n = elements.len();
116
117 // Precompute all (α - ∑e_{i, j} β^j) terms
118 let terms = self.combine_elements::<AB, E>(elements, alpha, beta);
119
120 // Build prefix products: pref[i] = ∏_{j=0}^{i-1}(α - e_j)
121 let mut pref = Vec::with_capacity(n + 1);
122 pref.push(AB::ExprEF::ONE);
123 for t in &terms {
124 pref.push(pref.last().unwrap().clone() * t.clone());
125 }
126
127 // Build suffix products: suff[i] = ∏_{j=i}^{n-1}(α - e_j)
128 let mut suff = vec![AB::ExprEF::ONE; n + 1];
129 for i in (0..n).rev() {
130 suff[i] = suff[i + 1].clone() * terms[i].clone();
131 }
132
133 // Common denominator is the product of all terms
134 let common_denominator = pref[n].clone();
135
136 // Compute numerator: ∑(m_i * ∏_{j≠i}(α - e_j))
137 //
138 // The product without i is: pref[i] * suff[i+1]
139 let numerator = (0..n).fold(AB::ExprEF::ZERO, |acc, i| {
140 acc + multiplicities[i].clone().into() * pref[i].clone() * suff[i + 1].clone()
141 });
142
143 (numerator, common_denominator)
144 }
145
146 /// Evaluates the transition and boundary constraints for a lookup argument.
147 ///
148 /// # Arguments:
149 /// * builder - The AIR builder to construct expressions.
150 /// * context - The lookup context containing:
151 /// * the kind of lookup (local or global),
152 /// * elements,
153 /// * multiplicities,
154 /// * and auxiliary column indices.
155 /// * opt_expected_cumulated - Optional expected cumulative value for global lookups. For local lookups, this should be `None`.
156 fn eval_update<AB>(
157 &self,
158 builder: &mut AB,
159 context: &Lookup<AB::F>,
160 opt_expected_cumulated: Option<AB::ExprEF>,
161 ) where
162 AB: PermutationAirBuilder,
163 {
164 let Lookup {
165 kind,
166 element_exprs,
167 multiplicities_exprs,
168 columns,
169 } = context;
170
171 assert!(
172 element_exprs.len() == multiplicities_exprs.len(),
173 "Mismatched lengths: elements and multiplicities must have same length"
174 );
175 assert_eq!(
176 columns.len(),
177 self.num_aux_cols(),
178 "There is exactly one auxiliary column for LogUp"
179 );
180 let column = columns[0];
181
182 // First, turn the symbolic expressions into builder expressions, for elements and multiplicities.
183 let elements = element_exprs
184 .iter()
185 .map(|exprs| {
186 exprs
187 .iter()
188 .map(|expr| symbolic_to_expr(builder, expr).into())
189 .collect::<Vec<_>>()
190 })
191 .collect::<Vec<_>>();
192
193 let multiplicities = multiplicities_exprs
194 .iter()
195 .map(|expr| symbolic_to_expr(builder, expr).into())
196 .collect::<Vec<_>>();
197
198 // Access the permutation (aux) table. It carries the running sum column `s`.
199 let permutation = builder.permutation();
200
201 let permutation_challenges = builder.permutation_randomness();
202
203 assert!(
204 permutation_challenges.len() >= self.num_challenges() * (column + 1),
205 "Insufficient permutation challenges"
206 );
207
208 // Challenge for the running sum.
209 let alpha = permutation_challenges[self.num_challenges() * column];
210 // Challenge for combining the lookup tuples.
211 let beta = permutation_challenges[self.num_challenges() * column + 1];
212
213 assert!(
214 permutation.current_slice().len() > column,
215 "Permutation trace has insufficient width"
216 );
217
218 // Read s[i] from the local row at the specified column.
219 let s_local = permutation.current(column).unwrap().into();
220 // Read s[i+1] from the next row (or a zero-padded view on the last row).
221 let s_next = permutation.next(column).unwrap().into();
222
223 // Anchor s[0] = 0 at the start.
224 //
225 // Avoids a high-degree boundary constraint.
226 // Telescoping is enforced by the last-row check (s[n−1] + contribution[n-1] = 0).
227 // This keeps aux and main traces aligned in length.
228 builder.when_first_row().assert_zero_ext(s_local.clone());
229
230 // Build the fraction: ∑ m_i/(α - combined_elements[i]) = numerator / denominator .
231 let (numerator, common_denominator) = self
232 .compute_combined_sum_terms::<AB, AB::ExprEF, AB::ExprEF>(
233 &elements,
234 &multiplicities,
235 &alpha.into(),
236 &beta.into(),
237 );
238
239 if let Some(expected_cumulated) = opt_expected_cumulated {
240 // If there is an `expected_cumulated`, we are in a global lookup update.
241 assert!(
242 matches!(kind, Kind::Global(_)),
243 "Expected cumulated value provided for a non-global lookup"
244 );
245
246 // Transition constraint:
247 builder.when_transition().assert_zero_ext(
248 (s_next - s_local.clone()) * common_denominator.clone() - numerator.clone(),
249 );
250
251 // Final constraint:
252 let final_val = (expected_cumulated - s_local) * common_denominator - numerator;
253 builder.when_last_row().assert_zero_ext(final_val);
254 } else {
255 // If we don't have an `expected_cumulated`, we are in a local lookup update.
256 assert!(
257 matches!(kind, Kind::Local),
258 "No expected cumulated value provided for a global lookup"
259 );
260
261 // If we are in a local lookup, the previous transition constraint doesn't have to be limited to transition rows:
262 // - we are already ensuring that the first row is 0,
263 // - at point `g^{n - 1}` (where `n` is the domain size), the next point is `g^0`, so that the constraint still holds
264 // on the last row.
265 builder.assert_zero_ext((s_next - s_local) * common_denominator - numerator);
266 }
267 }
268}
269
270impl LookupEvaluator for LogUpGadget {
271 fn num_aux_cols(&self) -> usize {
272 1
273 }
274
275 fn num_challenges(&self) -> usize {
276 2
277 }
278
279 /// # Mathematical Details
280 /// The constraint enforces:
281 /// ```text
282 /// ∑_i(multiplicities[i] / (α - combined_elements[i])) = 0
283 /// ```
284 ///
285 /// where `multiplicities` can be negative, and
286 /// `combined_elements[i] = ∑elements[i][n-j] * β^j`.
287 ///
288 /// This is implemented using a running sum column that should sum to zero.
289 fn eval_local_lookup<AB>(&self, builder: &mut AB, context: &Lookup<AB::F>)
290 where
291 AB: PermutationAirBuilder,
292 {
293 if let Kind::Global(_) = context.kind {
294 panic!("Global lookups are not supported in local evaluation")
295 }
296
297 self.eval_update(builder, context, None);
298 }
299
300 /// # Mathematical Details
301 /// The constraint enforces:
302 /// ```text
303 /// ∑_i(multiplicities[i] / (α - combined_elements[i])) = `expected_cumulated`
304 /// ```
305 ///
306 /// where `multiplicities` can be negative, and
307 /// `combined_elements[i] = ∑elements[i][n-j] * β^j`.
308 ///
309 /// `expected_cumulated` is provided by the prover, and the sum of all `expected_cumulated` for this global interaction
310 /// should be 0. The latter is checked as the final step, after all AIRS have been verified.
311 ///
312 /// This is implemented using a running sum column that should sum to `expected_cumulated`.
313 fn eval_global_update<AB>(
314 &self,
315 builder: &mut AB,
316 context: &Lookup<AB::F>,
317 expected_cumulated: AB::ExprEF,
318 ) where
319 AB: PermutationAirBuilder,
320 {
321 self.eval_update(builder, context, Some(expected_cumulated));
322 }
323}
324
325impl LookupGadget for LogUpGadget {
326 fn verify_global_final_value<EF: Field>(
327 &self,
328 all_expected_cumulative: &[EF],
329 ) -> Result<(), LookupError> {
330 let total = all_expected_cumulative.iter().cloned().sum::<EF>();
331
332 if !total.is_zero() {
333 // We set the name associated to the lookup to None because we don't have access to the actual name here.
334 // The actual name will be set in the verifier directly.
335 return Err(LookupError::GlobalCumulativeMismatch(None));
336 }
337
338 Ok(())
339 }
340
341 /// We need to compute the degree of the transition constraint,
342 /// as it is the constraint with highest degree:
343 /// `(s[n + 1] - s[n]) * common_denominator - numerator = 0`
344 ///
345 /// But in `common_denominator`, each combined element e_i = ∑e_{i, j} β^j
346 /// contributes (α - e_i). So we need to sum the degree of all
347 /// combined elements to find the degree of the common denominator.
348 ///
349 /// `numerator = ∑(m_i * ∏_{j≠i}(α - e_j))`, where the e_j are the combined elements.
350 /// So we have to compute the max of all m_i * ∏_{j≠i}(α - e_j).
351 ///
352 /// The constraint degree is then:
353 /// `1 + max(deg(numerator), deg(common_denominator))`
354 fn constraint_degree<F: Field>(&self, context: &Lookup<F>) -> usize {
355 assert!(context.multiplicities_exprs.len() == context.element_exprs.len());
356
357 let n = context.multiplicities_exprs.len();
358
359 // Compute degrees in a single pass.
360 let mut degs = Vec::with_capacity(n);
361 let mut deg_sum = 0;
362 for elems in &context.element_exprs {
363 let deg = elems
364 .iter()
365 .map(|elt| elt.degree_multiple())
366 .max()
367 .unwrap_or(0);
368 degs.push(deg);
369 deg_sum += deg;
370 }
371
372 // Compute 1 + degree(denominator).
373 let deg_denom_constr = 1 + deg_sum;
374
375 // Compute degree(numerator).
376 let multiplicities = &context.multiplicities_exprs;
377 let deg_num = (0..n)
378 .map(|i| multiplicities[i].degree_multiple() + deg_sum - degs[i])
379 .max()
380 .unwrap_or(0);
381
382 deg_denom_constr.max(deg_num)
383 }
384
385 #[instrument(name = "generate lookup permutation", skip_all, level = "debug")]
386 fn generate_permutation<SC: StarkGenericConfig>(
387 &self,
388 main: &RowMajorMatrix<Val<SC>>,
389 preprocessed: &Option<RowMajorMatrix<Val<SC>>>,
390 public_values: &[Val<SC>],
391 lookups: &[Lookup<Val<SC>>],
392 lookup_data: &mut [LookupData<SC::Challenge>],
393 permutation_challenges: &[SC::Challenge],
394 ) -> RowMajorMatrix<SC::Challenge> {
395 let height = main.height();
396 let width = self.num_aux_cols() * lookups.len();
397
398 // Validate challenge count matches number of lookups.
399 debug_assert_eq!(
400 permutation_challenges.len(),
401 lookups.len() * self.num_challenges(),
402 "perm challenge count must be per-lookup"
403 );
404
405 // Enforce uniqueness of auxiliary column indices across lookups.
406 #[cfg(debug_assertions)]
407 {
408 use alloc::collections::btree_set::BTreeSet;
409
410 let mut seen = BTreeSet::new();
411 for ctx in lookups {
412 let a = ctx.columns[0];
413 if !seen.insert(a) {
414 panic!("duplicate aux column index {a} across lookups");
415 }
416 }
417 }
418
419 // 1. PRE-COMPUTE DENOMINATORS
420 // We flatten all denominators from all rows/lookups into one giant vector.
421 // Order: Row -> Lookup -> Element Tuple
422 let denoms_per_row: usize = lookups.iter().map(|l| l.element_exprs.len()).sum();
423 let mut lookup_denom_offsets = Vec::with_capacity(lookups.len() + 1);
424 lookup_denom_offsets.push(0);
425 for l in lookups.iter() {
426 lookup_denom_offsets
427 .push(lookup_denom_offsets.last().copied().unwrap() + l.element_exprs.len());
428 }
429 let num_lookups = lookups.len();
430
431 let mut all_denominators = vec![SC::Challenge::ZERO; height * denoms_per_row];
432 let mut all_multiplicities = vec![Val::<SC>::ZERO; height * denoms_per_row];
433
434 all_denominators
435 .par_chunks_mut(denoms_per_row)
436 .zip(all_multiplicities.par_chunks_mut(denoms_per_row))
437 .enumerate()
438 .for_each(|(i, (denom_row, mult_row))| {
439 let local_main_row = main.row_slice(i).unwrap();
440 let next_main_row = main.row_slice((i + 1) % height).unwrap();
441 let main_rows = VerticalPair::new(
442 RowMajorMatrixView::new_row(&local_main_row),
443 RowMajorMatrixView::new_row(&next_main_row),
444 );
445 let preprocessed_rows_data = preprocessed.as_ref().map(|prep| {
446 (
447 prep.row_slice(i).unwrap(),
448 prep.row_slice((i + 1) % height).unwrap(),
449 )
450 });
451 let preprocessed_rows = match preprocessed_rows_data.as_ref() {
452 Some((local_preprocessed_row, next_preprocessed_row)) => VerticalPair::new(
453 RowMajorMatrixView::new_row(local_preprocessed_row),
454 RowMajorMatrixView::new_row(next_preprocessed_row),
455 ),
456 None => VerticalPair::new(
457 RowMajorMatrixView::new(&[], 0),
458 RowMajorMatrixView::new(&[], 0),
459 ),
460 };
461
462 let row_builder: LookupTraceBuilder<'_, SC> = LookupTraceBuilder::new(
463 main_rows,
464 preprocessed_rows,
465 public_values,
466 permutation_challenges,
467 height,
468 i,
469 );
470
471 let mut offset = 0;
472 for context in lookups.iter() {
473 let alpha = permutation_challenges[self.num_challenges() * context.columns[0]];
474 let beta =
475 permutation_challenges[self.num_challenges() * context.columns[0] + 1];
476
477 // Evaluate each tuple's elements and combine them via Horner's method
478 // in a single pass. This avoids allocating a temporary vector of
479 // evaluated elements per tuple, then another vector of combined results.
480 //
481 // For a tuple (e_0, e_1, …, e_{k-1}), computes:
482 //
483 // combined = e_0 + e_1·β + e_2·β^2 + … + e_{k-1}·β^{k-1}
484 //
485 // Then stores (α − combined) as the denominator.
486 for (j, elts) in context.element_exprs.iter().enumerate() {
487 let combined_elt = elts.iter().fold(SC::Challenge::ZERO, |acc, e| {
488 acc * beta + symbolic_to_expr(&row_builder, e)
489 });
490 denom_row[offset] = alpha - combined_elt;
491 mult_row[offset] =
492 symbolic_to_expr(&row_builder, &context.multiplicities_exprs[j]);
493 offset += 1;
494 }
495 }
496 });
497
498 debug_assert_eq!(all_denominators.len(), height * denoms_per_row);
499
500 // 2. BATCH INVERSION
501 // This turns O(N) inversions into O(1) inversion + O(N) multiplications.
502 // Recomputing multiplicities during trace building is cheaper than recomputing inversions,
503 // or storing them beforehand (as they could possibly constitute quite a large amount of data).
504 let all_inverses = p3_field::batch_multiplicative_inverse(&all_denominators);
505
506 #[cfg(debug_assertions)]
507 let mut inv_cursor = 0;
508 #[cfg(debug_assertions)]
509 let _debug_check: Vec<_> = (0..height)
510 .map(|_| {
511 lookups.iter().for_each(|context| {
512 inv_cursor += context.multiplicities_exprs.len();
513 });
514 })
515 .collect();
516
517 // 3. BUILD TRACE
518 let mut row_sums = SC::Challenge::zero_vec(height * num_lookups);
519 row_sums
520 .par_chunks_mut(num_lookups)
521 .enumerate()
522 .for_each(|(i, row_sums_i)| {
523 let inv_base = i * denoms_per_row;
524 for (lookup_idx, _context) in lookups.iter().enumerate() {
525 let start = lookup_denom_offsets[lookup_idx];
526 let end = lookup_denom_offsets[lookup_idx + 1];
527 let sum = (start..end)
528 .map(|k| all_inverses[inv_base + k] * all_multiplicities[inv_base + k])
529 .sum();
530 row_sums_i[lookup_idx] = sum;
531 }
532 });
533
534 let mut aux_trace = SC::Challenge::zero_vec(height * width);
535 let mut permutation_counter = 0;
536
537 // Each lookup column gets its own running sum.
538 // Since these columns are independent, we build them one at a time.
539 //
540 // The running sum is an *exclusive* prefix sum of the per-row contributions:
541 //
542 // s[0] = 0
543 // s[i] = row_sum[0] + row_sum[1] + … + row_sum[i-1]
544 //
545 // A naive serial loop would be O(height). Instead we use a three-phase
546 // parallel prefix sum, splitting the work across threads:
547 //
548 // Phase A — Each thread computes a local prefix sum on its chunk.
549 // Phase B — A tiny sequential pass (one entry per thread) combines
550 // the chunk totals into global offsets.
551 // Phase C — Each thread adds its global offset back into its chunk.
552 //
553 // After the three phases, we have an *inclusive* prefix sum.
554 // Shifting by one position turns it into the exclusive sum we need.
555 let num_threads = current_num_threads();
556 let chunk_size = height.div_ceil(num_threads);
557
558 // Reuse a single buffer across all lookup columns to avoid re-allocating on every iteration.
559 let mut prefix = SC::Challenge::zero_vec(height);
560
561 for (lookup_idx, context) in lookups.iter().enumerate() {
562 let aux_idx = context.columns[0];
563
564 // Fill the buffer with this column's per-row contributions.
565 for (i, val) in prefix.iter_mut().enumerate() {
566 *val = row_sums[i * num_lookups + lookup_idx];
567 }
568
569 // Phase A — Local inclusive prefix sums, one chunk per thread.
570 prefix.par_chunks_mut(chunk_size).for_each(|chunk| {
571 for i in 1..chunk.len() {
572 chunk[i] += chunk[i - 1];
573 }
574 });
575
576 // Phase B — Combine chunk totals into cumulative offsets.
577 // Only as many entries as there are chunks (one per thread), so this is tiny.
578 let mut offsets = SC::Challenge::zero_vec(height.div_ceil(chunk_size));
579 for i in 1..offsets.len() {
580 offsets[i] = offsets[i - 1] + prefix[i * chunk_size - 1];
581 }
582
583 // Phase C — Fold global offsets back into each chunk.
584 prefix
585 .par_chunks_mut(chunk_size)
586 .enumerate()
587 .for_each(|(chunk_idx, chunk)| {
588 let offset = offsets[chunk_idx];
589 if !offset.is_zero() {
590 for val in chunk.iter_mut() {
591 *val += offset;
592 }
593 }
594 });
595
596 // At this point we hold an *inclusive* prefix sum.
597 //
598 // The auxiliary trace needs the *exclusive* version (shifted right by one, starting at zero).
599 //
600 // - Row 0 is already zero from initialization;
601 // - Each subsequent row gets the inclusive sum of all *previous* rows.
602 aux_trace
603 .par_chunks_mut(width)
604 .skip(1)
605 .enumerate()
606 .for_each(|(i, row)| {
607 row[aux_idx] = prefix[i];
608 });
609
610 // For global lookups, record the total sum across all rows.
611 if matches!(context.kind, Kind::Global(_)) {
612 lookup_data[permutation_counter].expected_cumulated = prefix[height - 1];
613 permutation_counter += 1;
614 }
615 }
616
617 // Check that we have updated all `lookup_data` entries.
618 debug_assert_eq!(permutation_counter, lookup_data.len());
619 #[cfg(debug_assertions)] // Compiler complains about inv_cursor despite being under a `debug_assert`
620 debug_assert_eq!(inv_cursor, all_inverses.len());
621 RowMajorMatrix::new(aux_trace, width)
622 }
623}