Skip to main content

p3_lookup/
logup.rs

1//! Core LogUp Implementation
2//!
3//! ## Mathematical Foundation
4//!
5//! LogUp transforms the standard lookup equation:
6//! ```text
7//! ∏(α - a_i)^(m_i) = ∏(α - b_j)^(m'_j)
8//! ```
9//!
10//! Into an equivalent sum-based form using logarithmic derivatives:
11//! ```text
12//! ∑(m_i/(α - a_i)) = ∑(m'_j/(α - b_j))
13//! ```
14//!
15//! Where:
16//! - `α` is a random challenge
17//! - `m_i, m'_j` are multiplicities (how many times each element appears)
18//! - The transformation eliminates expensive exponentiation operations
19
20use alloc::vec;
21use alloc::vec::Vec;
22
23use p3_air::{ExtensionBuilder, PermutationAirBuilder, WindowAccess};
24use p3_field::{Field, PrimeCharacteristicRing};
25use p3_matrix::Matrix;
26use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
27use p3_matrix::stack::VerticalPair;
28use p3_maybe_rayon::prelude::*;
29use p3_uni_stark::{StarkGenericConfig, Val};
30use tracing::instrument;
31
32use crate::lookup_traits::{
33    Kind, Lookup, LookupData, LookupGadget, LookupTraceBuilder, symbolic_to_expr,
34};
35use crate::types::{LookupError, LookupEvaluator};
36
37/// Core LogUp gadget implementing lookup arguments via logarithmic derivatives.
38///
39/// The LogUp gadget transforms the multiplicative lookup constraint:
40/// ```text
41/// ∏(α - a_i)^(m_i) = ∏(α - b_j)^(m'_j)
42/// ```
43///
44/// Into an equivalent additive constraint using logarithmic differentiation:
45/// ```text
46/// ∑(m_i/(α - a_i)) = ∑(m'_j/(α - b_j))
47/// ```
48///
49/// This is implemented using a running sum auxiliary column `s` that accumulates:
50/// ```text
51/// s[i+1] = s[i] + ∑(m_a/(α - a)) - ∑(m_b/(α - b))
52/// ```
53///
54/// Note that we do not differentiate between `a` and `b` in the implementation:
55/// we simply have a list of `elements` with possibly negative `multiplicities`.
56///
57/// Constraints are defined as:
58/// - **Initial Constraint**: `s[0] = 0`
59/// - **Transition Constraint**: `s[i+1] = s[i] + contribution[i]`
60/// - **Final Constraint**: `s[n-1] + contribution[n-1] = 0`
61#[derive(Debug, Clone, Default)]
62pub struct LogUpGadget;
63
64impl LogUpGadget {
65    /// Creates a new LogUp gadget instance.
66    pub const fn new() -> Self {
67        Self {}
68    }
69
70    /// Computes the combined elements for each tuple using the challenge `beta`:
71    /// `combined_elements[i] = ∑elements[i][n-j] * β^j`
72    fn combine_elements<AB, E>(
73        &self,
74        elements: &[Vec<E>],
75        alpha: &AB::ExprEF,
76        beta: &AB::ExprEF,
77    ) -> Vec<AB::ExprEF>
78    where
79        AB: PermutationAirBuilder,
80        E: Into<AB::ExprEF> + Clone,
81    {
82        elements
83            .iter()
84            .map(|elts| {
85                // Combine the elements in the tuple using beta.
86                let combined_elt = elts.iter().fold(AB::ExprEF::ZERO, |acc, elt| {
87                    elt.clone().into() + acc * beta.clone()
88                });
89
90                // Compute (α - combined_elt)
91                alpha.clone() - combined_elt
92            })
93            .collect()
94    }
95
96    /// Computes the numerator and denominator of the fraction:
97    /// `∑(m_i / (α - combined_elements[i]))`, where
98    /// `combined_elements[i] = ∑elements[i][n-j] * β^j
99    pub(crate) fn compute_combined_sum_terms<AB, E, M>(
100        &self,
101        elements: &[Vec<E>],
102        multiplicities: &[M],
103        alpha: &AB::ExprEF,
104        beta: &AB::ExprEF,
105    ) -> (AB::ExprEF, AB::ExprEF)
106    where
107        AB: PermutationAirBuilder,
108        E: Into<AB::ExprEF> + Clone,
109        M: Into<AB::ExprEF> + Clone,
110    {
111        if elements.is_empty() {
112            return (AB::ExprEF::ZERO, AB::ExprEF::ONE);
113        }
114
115        let n = elements.len();
116
117        // Precompute all (α - ∑e_{i, j} β^j) terms
118        let terms = self.combine_elements::<AB, E>(elements, alpha, beta);
119
120        // Build prefix products: pref[i] = ∏_{j=0}^{i-1}(α - e_j)
121        let mut pref = Vec::with_capacity(n + 1);
122        pref.push(AB::ExprEF::ONE);
123        for t in &terms {
124            pref.push(pref.last().unwrap().clone() * t.clone());
125        }
126
127        // Build suffix products: suff[i] = ∏_{j=i}^{n-1}(α - e_j)
128        let mut suff = vec![AB::ExprEF::ONE; n + 1];
129        for i in (0..n).rev() {
130            suff[i] = suff[i + 1].clone() * terms[i].clone();
131        }
132
133        // Common denominator is the product of all terms
134        let common_denominator = pref[n].clone();
135
136        // Compute numerator: ∑(m_i * ∏_{j≠i}(α - e_j))
137        //
138        // The product without i is: pref[i] * suff[i+1]
139        let numerator = (0..n).fold(AB::ExprEF::ZERO, |acc, i| {
140            acc + multiplicities[i].clone().into() * pref[i].clone() * suff[i + 1].clone()
141        });
142
143        (numerator, common_denominator)
144    }
145
146    /// Evaluates the transition and boundary constraints for a lookup argument.
147    ///
148    /// # Arguments:
149    /// * builder - The AIR builder to construct expressions.
150    /// * context - The lookup context containing:
151    ///     * the kind of lookup (local or global),
152    ///     * elements,
153    ///     * multiplicities,
154    ///     * and auxiliary column indices.
155    /// * opt_expected_cumulated - Optional expected cumulative value for global lookups. For local lookups, this should be `None`.
156    fn eval_update<AB>(
157        &self,
158        builder: &mut AB,
159        context: &Lookup<AB::F>,
160        opt_expected_cumulated: Option<AB::ExprEF>,
161    ) where
162        AB: PermutationAirBuilder,
163    {
164        let Lookup {
165            kind,
166            element_exprs,
167            multiplicities_exprs,
168            columns,
169        } = context;
170
171        assert!(
172            element_exprs.len() == multiplicities_exprs.len(),
173            "Mismatched lengths: elements and multiplicities must have same length"
174        );
175        assert_eq!(
176            columns.len(),
177            self.num_aux_cols(),
178            "There is exactly one auxiliary column for LogUp"
179        );
180        let column = columns[0];
181
182        // First, turn the symbolic expressions into builder expressions, for elements and multiplicities.
183        let elements = element_exprs
184            .iter()
185            .map(|exprs| {
186                exprs
187                    .iter()
188                    .map(|expr| symbolic_to_expr(builder, expr).into())
189                    .collect::<Vec<_>>()
190            })
191            .collect::<Vec<_>>();
192
193        let multiplicities = multiplicities_exprs
194            .iter()
195            .map(|expr| symbolic_to_expr(builder, expr).into())
196            .collect::<Vec<_>>();
197
198        // Access the permutation (aux) table. It carries the running sum column `s`.
199        let permutation = builder.permutation();
200
201        let permutation_challenges = builder.permutation_randomness();
202
203        assert!(
204            permutation_challenges.len() >= self.num_challenges() * (column + 1),
205            "Insufficient permutation challenges"
206        );
207
208        // Challenge for the running sum.
209        let alpha = permutation_challenges[self.num_challenges() * column];
210        // Challenge for combining the lookup tuples.
211        let beta = permutation_challenges[self.num_challenges() * column + 1];
212
213        assert!(
214            permutation.current_slice().len() > column,
215            "Permutation trace has insufficient width"
216        );
217
218        // Read s[i] from the local row at the specified column.
219        let s_local = permutation.current(column).unwrap().into();
220        // Read s[i+1] from the next row (or a zero-padded view on the last row).
221        let s_next = permutation.next(column).unwrap().into();
222
223        // Anchor s[0] = 0 at the start.
224        //
225        // Avoids a high-degree boundary constraint.
226        // Telescoping is enforced by the last-row check (s[n−1] + contribution[n-1] = 0).
227        // This keeps aux and main traces aligned in length.
228        builder.when_first_row().assert_zero_ext(s_local.clone());
229
230        // Build the fraction:  ∑ m_i/(α - combined_elements[i])  =  numerator / denominator .
231        let (numerator, common_denominator) = self
232            .compute_combined_sum_terms::<AB, AB::ExprEF, AB::ExprEF>(
233                &elements,
234                &multiplicities,
235                &alpha.into(),
236                &beta.into(),
237            );
238
239        if let Some(expected_cumulated) = opt_expected_cumulated {
240            // If there is an `expected_cumulated`, we are in a global lookup update.
241            assert!(
242                matches!(kind, Kind::Global(_)),
243                "Expected cumulated value provided for a non-global lookup"
244            );
245
246            // Transition constraint:
247            builder.when_transition().assert_zero_ext(
248                (s_next - s_local.clone()) * common_denominator.clone() - numerator.clone(),
249            );
250
251            // Final constraint:
252            let final_val = (expected_cumulated - s_local) * common_denominator - numerator;
253            builder.when_last_row().assert_zero_ext(final_val);
254        } else {
255            // If we don't have an `expected_cumulated`, we are in a local lookup update.
256            assert!(
257                matches!(kind, Kind::Local),
258                "No expected cumulated value provided for a global lookup"
259            );
260
261            // If we are in a local lookup, the previous transition constraint doesn't have to be limited to transition rows:
262            // - we are already ensuring that the first row is 0,
263            // - at point `g^{n - 1}` (where `n` is the domain size), the next point is `g^0`, so that the constraint still holds
264            // on the last row.
265            builder.assert_zero_ext((s_next - s_local) * common_denominator - numerator);
266        }
267    }
268}
269
270impl LookupEvaluator for LogUpGadget {
271    fn num_aux_cols(&self) -> usize {
272        1
273    }
274
275    fn num_challenges(&self) -> usize {
276        2
277    }
278
279    /// # Mathematical Details
280    /// The constraint enforces:
281    /// ```text
282    /// ∑_i(multiplicities[i] / (α - combined_elements[i])) = 0
283    /// ```
284    ///
285    /// where `multiplicities` can be negative, and
286    /// `combined_elements[i] = ∑elements[i][n-j] * β^j`.
287    ///
288    /// This is implemented using a running sum column that should sum to zero.
289    fn eval_local_lookup<AB>(&self, builder: &mut AB, context: &Lookup<AB::F>)
290    where
291        AB: PermutationAirBuilder,
292    {
293        if let Kind::Global(_) = context.kind {
294            panic!("Global lookups are not supported in local evaluation")
295        }
296
297        self.eval_update(builder, context, None);
298    }
299
300    /// # Mathematical Details
301    /// The constraint enforces:
302    /// ```text
303    /// ∑_i(multiplicities[i] / (α - combined_elements[i])) = `expected_cumulated`
304    /// ```
305    ///
306    /// where `multiplicities` can be negative, and
307    /// `combined_elements[i] = ∑elements[i][n-j] * β^j`.
308    ///
309    /// `expected_cumulated` is provided by the prover, and the sum of all `expected_cumulated` for this global interaction
310    /// should be 0. The latter is checked as the final step, after all AIRS have been verified.
311    ///
312    /// This is implemented using a running sum column that should sum to `expected_cumulated`.
313    fn eval_global_update<AB>(
314        &self,
315        builder: &mut AB,
316        context: &Lookup<AB::F>,
317        expected_cumulated: AB::ExprEF,
318    ) where
319        AB: PermutationAirBuilder,
320    {
321        self.eval_update(builder, context, Some(expected_cumulated));
322    }
323}
324
325impl LookupGadget for LogUpGadget {
326    fn verify_global_final_value<EF: Field>(
327        &self,
328        all_expected_cumulative: &[EF],
329    ) -> Result<(), LookupError> {
330        let total = all_expected_cumulative.iter().cloned().sum::<EF>();
331
332        if !total.is_zero() {
333            // We set the name associated to the lookup to None because we don't have access to the actual name here.
334            // The actual name will be set in the verifier directly.
335            return Err(LookupError::GlobalCumulativeMismatch(None));
336        }
337
338        Ok(())
339    }
340
341    /// We need to compute the degree of the transition constraint,
342    /// as it is the constraint with highest degree:
343    /// `(s[n + 1] - s[n]) * common_denominator - numerator = 0`
344    ///
345    /// But in `common_denominator`, each combined element e_i = ∑e_{i, j} β^j
346    /// contributes (α - e_i). So we need to sum the degree of all
347    /// combined elements to find the degree of the common denominator.
348    ///
349    /// `numerator = ∑(m_i * ∏_{j≠i}(α - e_j))`, where the e_j are the combined elements.
350    /// So we have to compute the max of all m_i * ∏_{j≠i}(α - e_j).
351    ///
352    /// The constraint degree is then:
353    /// `1 + max(deg(numerator), deg(common_denominator))`
354    fn constraint_degree<F: Field>(&self, context: &Lookup<F>) -> usize {
355        assert!(context.multiplicities_exprs.len() == context.element_exprs.len());
356
357        let n = context.multiplicities_exprs.len();
358
359        // Compute degrees in a single pass.
360        let mut degs = Vec::with_capacity(n);
361        let mut deg_sum = 0;
362        for elems in &context.element_exprs {
363            let deg = elems
364                .iter()
365                .map(|elt| elt.degree_multiple())
366                .max()
367                .unwrap_or(0);
368            degs.push(deg);
369            deg_sum += deg;
370        }
371
372        // Compute 1 + degree(denominator).
373        let deg_denom_constr = 1 + deg_sum;
374
375        // Compute degree(numerator).
376        let multiplicities = &context.multiplicities_exprs;
377        let deg_num = (0..n)
378            .map(|i| multiplicities[i].degree_multiple() + deg_sum - degs[i])
379            .max()
380            .unwrap_or(0);
381
382        deg_denom_constr.max(deg_num)
383    }
384
385    #[instrument(name = "generate lookup permutation", skip_all, level = "debug")]
386    fn generate_permutation<SC: StarkGenericConfig>(
387        &self,
388        main: &RowMajorMatrix<Val<SC>>,
389        preprocessed: &Option<RowMajorMatrix<Val<SC>>>,
390        public_values: &[Val<SC>],
391        lookups: &[Lookup<Val<SC>>],
392        lookup_data: &mut [LookupData<SC::Challenge>],
393        permutation_challenges: &[SC::Challenge],
394    ) -> RowMajorMatrix<SC::Challenge> {
395        let height = main.height();
396        let width = self.num_aux_cols() * lookups.len();
397
398        // Validate challenge count matches number of lookups.
399        debug_assert_eq!(
400            permutation_challenges.len(),
401            lookups.len() * self.num_challenges(),
402            "perm challenge count must be per-lookup"
403        );
404
405        // Enforce uniqueness of auxiliary column indices across lookups.
406        #[cfg(debug_assertions)]
407        {
408            use alloc::collections::btree_set::BTreeSet;
409
410            let mut seen = BTreeSet::new();
411            for ctx in lookups {
412                let a = ctx.columns[0];
413                if !seen.insert(a) {
414                    panic!("duplicate aux column index {a} across lookups");
415                }
416            }
417        }
418
419        // 1. PRE-COMPUTE DENOMINATORS
420        // We flatten all denominators from all rows/lookups into one giant vector.
421        // Order: Row -> Lookup -> Element Tuple
422        let denoms_per_row: usize = lookups.iter().map(|l| l.element_exprs.len()).sum();
423        let mut lookup_denom_offsets = Vec::with_capacity(lookups.len() + 1);
424        lookup_denom_offsets.push(0);
425        for l in lookups.iter() {
426            lookup_denom_offsets
427                .push(lookup_denom_offsets.last().copied().unwrap() + l.element_exprs.len());
428        }
429        let num_lookups = lookups.len();
430
431        let mut all_denominators = vec![SC::Challenge::ZERO; height * denoms_per_row];
432        let mut all_multiplicities = vec![Val::<SC>::ZERO; height * denoms_per_row];
433
434        all_denominators
435            .par_chunks_mut(denoms_per_row)
436            .zip(all_multiplicities.par_chunks_mut(denoms_per_row))
437            .enumerate()
438            .for_each(|(i, (denom_row, mult_row))| {
439                let local_main_row = main.row_slice(i).unwrap();
440                let next_main_row = main.row_slice((i + 1) % height).unwrap();
441                let main_rows = VerticalPair::new(
442                    RowMajorMatrixView::new_row(&local_main_row),
443                    RowMajorMatrixView::new_row(&next_main_row),
444                );
445                let preprocessed_rows_data = preprocessed.as_ref().map(|prep| {
446                    (
447                        prep.row_slice(i).unwrap(),
448                        prep.row_slice((i + 1) % height).unwrap(),
449                    )
450                });
451                let preprocessed_rows = match preprocessed_rows_data.as_ref() {
452                    Some((local_preprocessed_row, next_preprocessed_row)) => VerticalPair::new(
453                        RowMajorMatrixView::new_row(local_preprocessed_row),
454                        RowMajorMatrixView::new_row(next_preprocessed_row),
455                    ),
456                    None => VerticalPair::new(
457                        RowMajorMatrixView::new(&[], 0),
458                        RowMajorMatrixView::new(&[], 0),
459                    ),
460                };
461
462                let row_builder: LookupTraceBuilder<'_, SC> = LookupTraceBuilder::new(
463                    main_rows,
464                    preprocessed_rows,
465                    public_values,
466                    permutation_challenges,
467                    height,
468                    i,
469                );
470
471                let mut offset = 0;
472                for context in lookups.iter() {
473                    let alpha = permutation_challenges[self.num_challenges() * context.columns[0]];
474                    let beta =
475                        permutation_challenges[self.num_challenges() * context.columns[0] + 1];
476
477                    // Evaluate each tuple's elements and combine them via Horner's method
478                    // in a single pass. This avoids allocating a temporary vector of
479                    // evaluated elements per tuple, then another vector of combined results.
480                    //
481                    // For a tuple (e_0, e_1, …, e_{k-1}), computes:
482                    //
483                    //   combined = e_0 + e_1·β + e_2·β^2 + … + e_{k-1}·β^{k-1}
484                    //
485                    // Then stores (α − combined) as the denominator.
486                    for (j, elts) in context.element_exprs.iter().enumerate() {
487                        let combined_elt = elts.iter().fold(SC::Challenge::ZERO, |acc, e| {
488                            acc * beta + symbolic_to_expr(&row_builder, e)
489                        });
490                        denom_row[offset] = alpha - combined_elt;
491                        mult_row[offset] =
492                            symbolic_to_expr(&row_builder, &context.multiplicities_exprs[j]);
493                        offset += 1;
494                    }
495                }
496            });
497
498        debug_assert_eq!(all_denominators.len(), height * denoms_per_row);
499
500        // 2. BATCH INVERSION
501        // This turns O(N) inversions into O(1) inversion + O(N) multiplications.
502        // Recomputing multiplicities during trace building is cheaper than recomputing inversions,
503        // or storing them beforehand (as they could possibly constitute quite a large amount of data).
504        let all_inverses = p3_field::batch_multiplicative_inverse(&all_denominators);
505
506        #[cfg(debug_assertions)]
507        let mut inv_cursor = 0;
508        #[cfg(debug_assertions)]
509        let _debug_check: Vec<_> = (0..height)
510            .map(|_| {
511                lookups.iter().for_each(|context| {
512                    inv_cursor += context.multiplicities_exprs.len();
513                });
514            })
515            .collect();
516
517        // 3. BUILD TRACE
518        let mut row_sums = SC::Challenge::zero_vec(height * num_lookups);
519        row_sums
520            .par_chunks_mut(num_lookups)
521            .enumerate()
522            .for_each(|(i, row_sums_i)| {
523                let inv_base = i * denoms_per_row;
524                for (lookup_idx, _context) in lookups.iter().enumerate() {
525                    let start = lookup_denom_offsets[lookup_idx];
526                    let end = lookup_denom_offsets[lookup_idx + 1];
527                    let sum = (start..end)
528                        .map(|k| all_inverses[inv_base + k] * all_multiplicities[inv_base + k])
529                        .sum();
530                    row_sums_i[lookup_idx] = sum;
531                }
532            });
533
534        let mut aux_trace = SC::Challenge::zero_vec(height * width);
535        let mut permutation_counter = 0;
536
537        // Each lookup column gets its own running sum.
538        // Since these columns are independent, we build them one at a time.
539        //
540        // The running sum is an *exclusive* prefix sum of the per-row contributions:
541        //
542        //   s[0] = 0
543        //   s[i] = row_sum[0] + row_sum[1] + … + row_sum[i-1]
544        //
545        // A naive serial loop would be O(height). Instead we use a three-phase
546        // parallel prefix sum, splitting the work across threads:
547        //
548        //   Phase A — Each thread computes a local prefix sum on its chunk.
549        //   Phase B — A tiny sequential pass (one entry per thread) combines
550        //             the chunk totals into global offsets.
551        //   Phase C — Each thread adds its global offset back into its chunk.
552        //
553        // After the three phases, we have an *inclusive* prefix sum.
554        // Shifting by one position turns it into the exclusive sum we need.
555        let num_threads = current_num_threads();
556        let chunk_size = height.div_ceil(num_threads);
557
558        // Reuse a single buffer across all lookup columns to avoid re-allocating on every iteration.
559        let mut prefix = SC::Challenge::zero_vec(height);
560
561        for (lookup_idx, context) in lookups.iter().enumerate() {
562            let aux_idx = context.columns[0];
563
564            // Fill the buffer with this column's per-row contributions.
565            for (i, val) in prefix.iter_mut().enumerate() {
566                *val = row_sums[i * num_lookups + lookup_idx];
567            }
568
569            // Phase A — Local inclusive prefix sums, one chunk per thread.
570            prefix.par_chunks_mut(chunk_size).for_each(|chunk| {
571                for i in 1..chunk.len() {
572                    chunk[i] += chunk[i - 1];
573                }
574            });
575
576            // Phase B — Combine chunk totals into cumulative offsets.
577            // Only as many entries as there are chunks (one per thread), so this is tiny.
578            let mut offsets = SC::Challenge::zero_vec(height.div_ceil(chunk_size));
579            for i in 1..offsets.len() {
580                offsets[i] = offsets[i - 1] + prefix[i * chunk_size - 1];
581            }
582
583            // Phase C — Fold global offsets back into each chunk.
584            prefix
585                .par_chunks_mut(chunk_size)
586                .enumerate()
587                .for_each(|(chunk_idx, chunk)| {
588                    let offset = offsets[chunk_idx];
589                    if !offset.is_zero() {
590                        for val in chunk.iter_mut() {
591                            *val += offset;
592                        }
593                    }
594                });
595
596            // At this point we hold an *inclusive* prefix sum.
597            //
598            // The auxiliary trace needs the *exclusive* version (shifted right by one, starting at zero).
599            //
600            // - Row 0 is already zero from initialization;
601            // - Each subsequent row gets the inclusive sum of all *previous* rows.
602            aux_trace
603                .par_chunks_mut(width)
604                .skip(1)
605                .enumerate()
606                .for_each(|(i, row)| {
607                    row[aux_idx] = prefix[i];
608                });
609
610            // For global lookups, record the total sum across all rows.
611            if matches!(context.kind, Kind::Global(_)) {
612                lookup_data[permutation_counter].expected_cumulated = prefix[height - 1];
613                permutation_counter += 1;
614            }
615        }
616
617        // Check that we have updated all `lookup_data` entries.
618        debug_assert_eq!(permutation_counter, lookup_data.len());
619        #[cfg(debug_assertions)] // Compiler complains about inv_cursor despite being under a `debug_assert`
620        debug_assert_eq!(inv_cursor, all_inverses.len());
621        RowMajorMatrix::new(aux_trace, width)
622    }
623}