use alloc::{vec, vec::Vec};
use miden_core::{
field::{ExtensionField, Field},
utils::{Matrix, RowMajorMatrix},
};
use miden_crypto::stark::air::LiftedAir;
use super::{Challenges, LookupAir, ProverLookupBuilder, prover::build_lookup_fractions};
const ACCUMULATE_ROWS_PER_CHUNK: usize = 512;
pub fn build_logup_aux_trace<A, F, EF>(
air: &A,
main: &RowMajorMatrix<F>,
challenges: &[EF],
) -> (RowMajorMatrix<EF>, Vec<EF>)
where
F: Field,
EF: ExtensionField<F>,
A: LiftedAir<F, EF>,
for<'a> A: LookupAir<ProverLookupBuilder<'a, F, EF>>,
{
let _span = tracing::info_span!("build_aux_trace_logup").entered();
let alpha = challenges[0];
let beta = challenges[1];
let lookup_challenges =
Challenges::<EF>::new(alpha, beta, air.max_message_width(), air.num_bus_ids());
let periodic = air.periodic_columns();
let fractions = build_lookup_fractions(air, main, &periodic, &lookup_challenges);
let full = accumulate(&fractions);
let num_cols = full.width;
let num_rows = main.height();
debug_assert_eq!(
full.values.len(),
(num_rows + 1) * num_cols,
"accumulate output buffer is sized for num_rows + 1 rows",
);
let mut data = full.values;
let committed_final = data[num_rows * num_cols];
data.truncate(num_rows * num_cols);
let aux_trace = RowMajorMatrix::new(data, num_cols);
(aux_trace, vec![committed_final])
}
pub struct LookupFractions<F, EF>
where
F: Field,
EF: ExtensionField<F>,
{
pub(super) fractions: Vec<(F, EF)>,
pub(super) counts: Vec<usize>,
pub(super) shape: Vec<usize>,
num_rows: usize,
num_cols: usize,
}
impl<F, EF> LookupFractions<F, EF>
where
F: Field,
EF: ExtensionField<F>,
{
pub fn from_shape(shape: Vec<usize>, num_rows: usize) -> Self {
let num_cols = shape.len();
let total_fraction_capacity: usize = num_rows * shape.iter().sum::<usize>();
let fractions = Vec::with_capacity(total_fraction_capacity);
let counts = Vec::with_capacity(num_rows * num_cols);
Self {
fractions,
counts,
shape,
num_rows,
num_cols,
}
}
pub fn num_columns(&self) -> usize {
self.num_cols
}
pub fn num_rows(&self) -> usize {
self.num_rows
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn fractions(&self) -> &[(F, EF)] {
&self.fractions
}
pub fn counts(&self) -> &[usize] {
&self.counts
}
}
pub fn accumulate_slow<F, EF>(fractions: &LookupFractions<F, EF>) -> Vec<Vec<EF>>
where
F: Field,
EF: ExtensionField<F>,
{
let num_cols = fractions.num_columns();
let num_rows = fractions.num_rows();
let mut aux: Vec<Vec<EF>> = (0..num_cols).map(|_| vec![EF::ZERO; num_rows + 1]).collect();
let flat_fractions = fractions.fractions();
let flat_counts = fractions.counts();
debug_assert_eq!(
flat_counts.len(),
num_rows * num_cols,
"counts length {} != num_rows * num_cols {}",
flat_counts.len(),
num_rows * num_cols,
);
let mut per_row_value = vec![EF::ZERO; num_cols];
let mut running_sum = EF::ZERO;
let mut cursor = 0usize;
for (row, row_counts) in flat_counts.chunks(num_cols).enumerate() {
for (col, &count) in row_counts.iter().enumerate() {
let mut sum = EF::ZERO;
for &(m, d) in &flat_fractions[cursor..cursor + count] {
let d_inv = d
.try_inverse()
.expect("LogUp denominator must be non-zero (bus_prefix is never zero)");
sum += d_inv * m;
}
per_row_value[col] = sum;
cursor += count;
}
for col in 1..num_cols {
aux[col][row] = per_row_value[col];
}
let row_total: EF = per_row_value.iter().copied().sum();
running_sum += row_total;
aux[0][row + 1] = running_sum;
}
debug_assert_eq!(
cursor,
flat_fractions.len(),
"cursor {cursor} != total fractions {}",
flat_fractions.len(),
);
aux
}
pub fn accumulate<F, EF>(fractions: &LookupFractions<F, EF>) -> RowMajorMatrix<EF>
where
F: Field,
EF: ExtensionField<F>,
{
let num_cols = fractions.num_columns();
let num_rows = fractions.num_rows();
let out_rows = num_rows + 1;
let mut output_data = vec![EF::ZERO; out_rows * num_cols];
let flat_fractions = fractions.fractions();
let flat_counts = fractions.counts();
debug_assert_eq!(
flat_counts.len(),
num_rows * num_cols,
"counts length {} != num_rows * num_cols {}",
flat_counts.len(),
num_rows * num_cols,
);
if num_rows == 0 || flat_fractions.is_empty() {
return RowMajorMatrix::new(output_data, num_cols);
}
let row_frac_offsets = compute_row_frac_offsets(flat_counts, num_rows, num_cols);
debug_assert_eq!(row_frac_offsets.len(), num_rows + 1);
debug_assert_eq!(row_frac_offsets[num_rows], flat_fractions.len());
let frac_region = &mut output_data[..num_rows * num_cols];
let mut row_totals: Vec<EF> = vec![EF::ZERO; num_rows];
let rows_per_chunk = ACCUMULATE_ROWS_PER_CHUNK;
let phase1 = |(chunk_idx, (chunk_out, totals_slice)): (usize, (&mut [EF], &mut [EF]))| {
let row_lo = chunk_idx * rows_per_chunk;
let row_hi = (row_lo + rows_per_chunk).min(num_rows);
let chunk_rows = row_hi - row_lo;
let frac_lo = row_frac_offsets[row_lo];
let frac_hi = row_frac_offsets[row_hi];
let chunk_fracs = &flat_fractions[frac_lo..frac_hi];
let chunk_counts = &flat_counts[row_lo * num_cols..row_hi * num_cols];
debug_assert_eq!(chunk_out.len(), chunk_rows * num_cols);
debug_assert_eq!(totals_slice.len(), chunk_rows);
if chunk_fracs.is_empty() {
return;
}
let mut scratch: Vec<EF> = vec![EF::ZERO; chunk_fracs.len()];
invert_and_scale(chunk_fracs, &mut scratch);
let mut per_row_value: Vec<EF> = vec![EF::ZERO; num_cols];
let mut cursor = 0usize;
for row_in_chunk in 0..chunk_rows {
let row_counts = &chunk_counts[row_in_chunk * num_cols..(row_in_chunk + 1) * num_cols];
let out_row_base = row_in_chunk * num_cols;
for (col, &count) in row_counts.iter().enumerate() {
let end = cursor + count;
let sum = scratch[cursor..end].iter().copied().sum();
per_row_value[col] = sum;
cursor = end;
}
let out_row = &mut chunk_out[out_row_base..out_row_base + num_cols];
out_row[1..].copy_from_slice(&per_row_value[1..]);
totals_slice[row_in_chunk] = per_row_value.iter().copied().sum();
}
debug_assert_eq!(cursor, chunk_fracs.len());
};
#[cfg(not(feature = "concurrent"))]
{
frac_region
.chunks_mut(rows_per_chunk * num_cols)
.zip(row_totals.chunks_mut(rows_per_chunk))
.enumerate()
.for_each(phase1);
}
#[cfg(feature = "concurrent")]
{
use miden_crypto::parallel::*;
frac_region
.par_chunks_mut(rows_per_chunk * num_cols)
.zip(row_totals.par_chunks_mut(rows_per_chunk))
.enumerate()
.for_each(phase1);
}
let mut acc = EF::ZERO;
for r in 0..num_rows {
acc += row_totals[r];
output_data[(r + 1) * num_cols] = acc;
}
RowMajorMatrix::new(output_data, num_cols)
}
fn compute_row_frac_offsets(flat_counts: &[usize], num_rows: usize, num_cols: usize) -> Vec<usize> {
debug_assert_eq!(flat_counts.len(), num_rows * num_cols);
let mut offsets = Vec::with_capacity(num_rows + 1);
let mut acc = 0usize;
offsets.push(0);
for row_counts in flat_counts.chunks(num_cols) {
for &count in row_counts {
acc += count;
}
offsets.push(acc);
}
offsets
}
fn invert_and_scale<F, EF>(chunk_fracs: &[(F, EF)], scratch: &mut [EF])
where
F: Field,
EF: ExtensionField<F>,
{
debug_assert_eq!(scratch.len(), chunk_fracs.len());
debug_assert!(!chunk_fracs.is_empty());
let mut acc = chunk_fracs[0].1;
scratch[0] = acc;
for i in 1..chunk_fracs.len() {
acc *= chunk_fracs[i].1;
scratch[i] = acc;
}
let mut running_inv = scratch[scratch.len() - 1]
.try_inverse()
.expect("LogUp denominator product must be non-zero (bus_prefix is never zero)");
for i in (1..chunk_fracs.len()).rev() {
let (m_i, d_i) = chunk_fracs[i];
scratch[i] = scratch[i - 1] * running_inv * m_i;
running_inv *= d_i;
}
scratch[0] = running_inv * chunk_fracs[0].0;
}
#[cfg(test)]
mod tests {
use miden_core::{
field::{PrimeCharacteristicRing, QuadFelt},
utils::Matrix,
};
use super::*;
use crate::{
Felt,
lookup::{LookupAir, LookupBuilder},
};
struct Lcg(u64);
impl Lcg {
fn next(&mut self) -> u64 {
self.0 = self.0.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
self.0
}
fn felt(&mut self) -> Felt {
Felt::new_unchecked(self.next() >> 32)
}
fn quad(&mut self) -> QuadFelt {
QuadFelt::new([self.felt(), self.felt()])
}
}
fn random_fixture(
shape: &[usize],
num_rows: usize,
seed: u64,
) -> LookupFractions<Felt, QuadFelt> {
let mut rng = Lcg(seed);
let mut fx: LookupFractions<Felt, QuadFelt> = LookupFractions {
fractions: Vec::with_capacity(num_rows * shape.iter().sum::<usize>()),
counts: Vec::with_capacity(num_rows * shape.len()),
shape: shape.to_vec(),
num_rows,
num_cols: shape.len(),
};
for _row in 0..num_rows {
for &max_count in shape {
let count = (rng.next() as usize) % (max_count + 1);
for _ in 0..count {
let m = rng.felt();
let d = loop {
let candidate = rng.quad();
if candidate != QuadFelt::ZERO {
break candidate;
}
};
fx.fractions.push((m, d));
}
fx.counts.push(count);
}
}
fx
}
fn assert_matrix_matches_slow(
slow: &[Vec<QuadFelt>],
fast: &RowMajorMatrix<QuadFelt>,
num_cols: usize,
num_rows: usize,
) {
assert_eq!(fast.width(), num_cols, "fast.width() mismatch");
assert_eq!(fast.height(), num_rows + 1, "fast.height() mismatch");
assert_eq!(slow.len(), num_cols, "slow column count mismatch");
for (col, slow_col) in slow.iter().enumerate() {
assert_eq!(slow_col.len(), num_rows + 1, "slow col {col} row count mismatch");
for (row, &s) in slow_col.iter().enumerate() {
let f = fast.values[row * num_cols + col];
assert_eq!(s, f, "row {row} col {col} differs: slow={s:?} fast={f:?}",);
}
}
}
struct FakeAir {
shape: [usize; 2],
}
impl<LB: LookupBuilder> LookupAir<LB> for FakeAir {
fn num_columns(&self) -> usize {
self.shape.len()
}
fn column_shape(&self) -> &[usize] {
&self.shape
}
fn max_message_width(&self) -> usize {
0
}
fn num_bus_ids(&self) -> usize {
0
}
fn eval(&self, _builder: &mut LB) {}
}
fn fixture(shape: [usize; 2], num_rows: usize) -> LookupFractions<Felt, QuadFelt> {
LookupFractions::from_shape(shape.to_vec(), num_rows)
}
#[test]
fn accumulate_slow_hand_crafted() {
let one = Felt::new_unchecked(1);
let two = Felt::new_unchecked(2);
let d1 = QuadFelt::new([Felt::new_unchecked(3), Felt::new_unchecked(0)]);
let d2 = QuadFelt::new([Felt::new_unchecked(5), Felt::new_unchecked(0)]);
let mut fx = fixture([2, 1], 2);
fx.fractions.push((one, d1));
fx.fractions.push((two, d2));
fx.counts.push(2); fx.counts.push(0); fx.fractions.push((one, d1));
fx.counts.push(1); fx.fractions.push((two, d2));
fx.counts.push(1);
let aux = accumulate_slow(&fx);
assert_eq!(aux.len(), 2);
assert_eq!(aux[0].len(), 3);
assert_eq!(aux[1].len(), 3);
let d1_inv = d1.try_inverse().unwrap();
let d2_inv = d2.try_inverse().unwrap();
let row0_col0 = d1_inv + d2_inv.double();
let row1_col0 = d1_inv;
let row1_col1 = d2_inv.double();
assert_eq!(aux[0][0], QuadFelt::ZERO);
assert_eq!(aux[0][1], row0_col0);
assert_eq!(aux[0][2], row0_col0 + row1_col0 + row1_col1);
assert_eq!(aux[1][0], QuadFelt::ZERO);
assert_eq!(aux[1][1], row1_col1);
}
#[test]
fn new_reserves_capacity() {
let air = FakeAir { shape: [3, 5] };
let fx: LookupFractions<Felt, QuadFelt> =
LookupFractions::from_shape(air.shape.to_vec(), 10);
assert_eq!(fx.num_columns(), 2);
assert_eq!(fx.num_rows(), 10);
assert_eq!(fx.shape(), &[3, 5]);
assert!(fx.fractions.capacity() >= 10 * (3 + 5));
assert!(fx.counts.capacity() >= 10 * 2);
assert!(fx.fractions.is_empty());
assert!(fx.counts.is_empty());
}
#[test]
fn accumulate_matches_accumulate_slow_random() {
const SHAPE: [usize; 3] = [2, 1, 3];
const NUM_ROWS: usize = 32;
const _: () = assert!(
NUM_ROWS < ACCUMULATE_ROWS_PER_CHUNK,
"must stay in one chunk to test phase 1",
);
let fx = random_fixture(&SHAPE, NUM_ROWS, 0x00c0_ffee_beef_c0de);
let slow = accumulate_slow(&fx);
let fast = accumulate(&fx);
assert_matrix_matches_slow(&slow, &fast, SHAPE.len(), NUM_ROWS);
}
#[test]
fn accumulate_multi_chunk_matches_accumulate_slow() {
const SHAPE: [usize; 4] = [1, 2, 3, 1];
const NUM_ROWS: usize = ACCUMULATE_ROWS_PER_CHUNK * 3 + 7;
let fx = random_fixture(&SHAPE, NUM_ROWS, 0xdead_beef_cafe_babe);
let slow = accumulate_slow(&fx);
let fast = accumulate(&fx);
assert_matrix_matches_slow(&slow, &fast, SHAPE.len(), NUM_ROWS);
}
#[test]
fn accumulate_empty_trace() {
let shape = vec![2usize, 3, 1];
let num_cols = shape.len();
let fx: LookupFractions<Felt, QuadFelt> = LookupFractions {
fractions: Vec::new(),
counts: Vec::new(),
shape,
num_rows: 0,
num_cols,
};
let aux = accumulate(&fx);
assert_eq!(aux.width(), num_cols);
assert_eq!(aux.height(), 1);
assert!(aux.values.iter().all(|v| *v == QuadFelt::ZERO));
}
}