use alloc::{vec, vec::Vec};
use core::borrow::Borrow;
use miden_core::{
field::{ExtensionField, Field},
utils::{Matrix, RowMajorMatrix},
};
use miden_crypto::stark::air::RowWindow;
use super::{
Challenges, Deg, LookupAir, LookupBatch, LookupBuilder, LookupColumn, LookupFractions,
LookupGroup, LookupMessage,
};
pub struct ProverLookupBuilder<'a, F, EF>
where
F: Field,
EF: ExtensionField<F>,
{
main: RowWindow<'a, F>,
periodic_values: &'a [F],
challenges: &'a Challenges<EF>,
fractions: &'a mut LookupFractions<F, EF>,
column_idx: usize,
}
impl<'a, F, EF> ProverLookupBuilder<'a, F, EF>
where
F: Field,
EF: ExtensionField<F>,
{
pub fn new<A>(
main: RowWindow<'a, F>,
periodic_values: &'a [F],
challenges: &'a Challenges<EF>,
air: &A,
fractions: &'a mut LookupFractions<F, EF>,
) -> Self
where
A: LookupAir<Self>,
{
debug_assert_eq!(
fractions.num_columns(),
air.num_columns(),
"fractions buffer must be pre-sized to air.num_columns()",
);
Self {
main,
periodic_values,
challenges,
fractions,
column_idx: 0,
}
}
}
pub fn build_lookup_fractions<A, F, EF>(
air: &A,
main_trace: &RowMajorMatrix<F>,
periodic_columns: &[Vec<F>],
challenges: &Challenges<EF>,
) -> LookupFractions<F, EF>
where
F: Field,
EF: ExtensionField<F>,
for<'a> A: LookupAir<ProverLookupBuilder<'a, F, EF>>,
{
let num_rows = main_trace.height();
let width = main_trace.width();
let flat: &[F] = main_trace.values.borrow();
let shape = air.column_shape().to_vec();
let mut fractions = LookupFractions::from_shape(shape, num_rows);
let mut periodic_row: Vec<F> = vec![F::ZERO; periodic_columns.len()];
for r in 0..num_rows {
let curr = &flat[r * width..(r + 1) * width];
let nxt_idx = (r + 1) % num_rows;
let next = &flat[nxt_idx * width..(nxt_idx + 1) * width];
let window = RowWindow::from_two_rows(curr, next);
for (i, col) in periodic_columns.iter().enumerate() {
periodic_row[i] = col[r % col.len()];
}
let mut lb =
ProverLookupBuilder::new(window, &periodic_row, challenges, air, &mut fractions);
air.eval(&mut lb);
}
debug_assert_eq!(
fractions.counts().len(),
num_rows * fractions.num_columns(),
"counts buffer should have exactly num_rows * num_cols entries after collection",
);
fractions
}
impl<'a, F, EF> LookupBuilder for ProverLookupBuilder<'a, F, EF>
where
F: Field,
EF: ExtensionField<F>,
{
type F = F;
type Expr = F;
type Var = F;
type EF = EF;
type ExprEF = EF;
type VarEF = EF;
type PeriodicVar = F;
type MainWindow = RowWindow<'a, F>;
type Column<'c>
= ProverColumn<'c, F, EF>
where
Self: 'c;
fn main(&self) -> Self::MainWindow {
self.main
}
fn periodic_values(&self) -> &[Self::PeriodicVar] {
self.periodic_values
}
fn next_column<'c, R>(
&'c mut self,
f: impl FnOnce(&mut Self::Column<'c>) -> R,
_deg: Deg,
) -> R {
let idx = self.column_idx;
let vec = &mut self.fractions.fractions;
let counts = &mut self.fractions.counts;
let shape_col = self.fractions.shape[idx];
let start_len = vec.len();
let (result, pushed) = {
let mut col = ProverColumn {
challenges: self.challenges,
fractions: vec,
};
let result = f(&mut col);
(result, col.fractions.len() - start_len)
};
debug_assert!(
pushed <= shape_col,
"column {idx} exceeded its shape bound: pushed {pushed}, shape says {shape_col}",
);
counts.push(pushed);
self.column_idx += 1;
result
}
}
pub struct ProverColumn<'c, F, EF>
where
F: Field,
EF: ExtensionField<F>,
{
challenges: &'c Challenges<EF>,
fractions: &'c mut Vec<(F, EF)>,
}
impl<'c, F, EF> LookupColumn for ProverColumn<'c, F, EF>
where
F: Field,
EF: ExtensionField<F>,
{
type Expr = F;
type ExprEF = EF;
type Group<'g>
= ProverGroup<'g, F, EF>
where
Self: 'g;
fn group<'g>(
&'g mut self,
_name: &'static str,
f: impl FnOnce(&mut Self::Group<'g>),
_deg: Deg,
) {
let mut group = ProverGroup {
challenges: self.challenges,
fractions: &mut *self.fractions,
};
f(&mut group)
}
fn group_with_cached_encoding<'g>(
&'g mut self,
name: &'static str,
canonical: impl FnOnce(&mut Self::Group<'g>),
_encoded: impl FnOnce(&mut Self::Group<'g>),
deg: Deg,
) {
self.group(name, canonical, deg);
}
}
pub struct ProverGroup<'g, F, EF>
where
F: Field,
EF: ExtensionField<F>,
{
challenges: &'g Challenges<EF>,
fractions: &'g mut Vec<(F, EF)>,
}
impl<'g, F, EF> LookupGroup for ProverGroup<'g, F, EF>
where
F: Field,
EF: ExtensionField<F>,
{
type Expr = F;
type ExprEF = EF;
type Batch<'b>
= ProverBatch<'b, F, EF>
where
Self: 'b;
fn insert<M>(
&mut self,
_name: &'static str,
flag: F,
multiplicity: F,
msg: impl FnOnce() -> M,
_deg: Deg,
) where
M: LookupMessage<F, EF>,
{
if flag == F::ZERO {
return;
}
let v = msg().encode(self.challenges);
self.fractions.push((multiplicity, v));
}
fn batch<'b>(
&'b mut self,
_name: &'static str,
flag: F,
build: impl FnOnce(&mut Self::Batch<'b>),
_deg: Deg,
) {
let active = flag != F::ZERO;
let mut batch = ProverBatch {
challenges: self.challenges,
fractions: &mut *self.fractions,
active,
};
build(&mut batch)
}
}
pub struct ProverBatch<'b, F, EF>
where
F: Field,
EF: ExtensionField<F>,
{
challenges: &'b Challenges<EF>,
fractions: &'b mut Vec<(F, EF)>,
active: bool,
}
impl<'b, F, EF> LookupBatch for ProverBatch<'b, F, EF>
where
F: Field,
EF: ExtensionField<F>,
{
type Expr = F;
type ExprEF = EF;
fn insert<M>(&mut self, _name: &'static str, multiplicity: F, msg: M, _deg: Deg)
where
M: LookupMessage<F, EF>,
{
if !self.active {
return;
}
let v = msg.encode(self.challenges);
self.fractions.push((multiplicity, v));
}
fn insert_encoded(
&mut self,
_name: &'static str,
multiplicity: F,
encoded: impl FnOnce() -> EF,
_deg: Deg,
) {
if !self.active {
return;
}
let v = encoded();
self.fractions.push((multiplicity, v));
}
}
#[cfg(test)]
mod tests {
extern crate std;
use std::{vec, vec::Vec};
use miden_core::field::{PrimeCharacteristicRing, QuadFelt};
use miden_crypto::stark::air::RowWindow;
use super::*;
use crate::{
Felt,
lookup::{Deg, LookupAir, accumulate_slow, message::LookupMessage},
};
#[derive(Clone, Copy, Debug)]
struct SmokeMsg {
value: Felt,
}
impl LookupMessage<Felt, QuadFelt> for SmokeMsg {
fn encode(&self, challenges: &Challenges<QuadFelt>) -> QuadFelt {
challenges.bus_prefix[0] + challenges.beta_powers[0] * self.value
}
}
struct SmokeAir;
const SMOKE_SHAPE: [usize; 2] = [2, 1];
impl<LB> LookupAir<LB> for SmokeAir
where
LB: LookupBuilder<F = Felt, EF = QuadFelt, Expr = Felt, ExprEF = QuadFelt>,
{
fn num_columns(&self) -> usize {
2
}
fn column_shape(&self) -> &[usize] {
&SMOKE_SHAPE
}
fn max_message_width(&self) -> usize {
1
}
fn num_bus_ids(&self) -> usize {
1
}
fn eval(&self, builder: &mut LB) {
builder.next_column(
|col| {
col.group(
"smoke_grp_0",
|g| {
g.add(
"smoke_add",
Felt::ONE,
|| SmokeMsg { value: Felt::ONE },
Deg { v: 0, u: 0 },
);
g.remove(
"smoke_remove",
Felt::ONE,
|| SmokeMsg { value: Felt::new_unchecked(2) },
Deg { v: 0, u: 0 },
);
},
Deg { v: 0, u: 0 },
);
},
Deg { v: 0, u: 0 },
);
builder.next_column(
|col| {
col.group(
"smoke_grp_1",
|g| {
g.batch(
"smoke_batch",
Felt::ONE,
|b| {
b.insert(
"smoke_batch_insert",
Felt::ONE,
SmokeMsg { value: Felt::new_unchecked(3) },
Deg { v: 0, u: 0 },
);
},
Deg { v: 0, u: 0 },
);
},
Deg { v: 0, u: 0 },
);
},
Deg { v: 0, u: 0 },
);
}
}
#[test]
fn prover_lookup_builder_collects_into_fractions() {
const NUM_ROWS: usize = 8;
let air = SmokeAir;
let alpha = QuadFelt::new([Felt::new_unchecked(7), Felt::new_unchecked(11)]);
let beta = QuadFelt::new([Felt::new_unchecked(13), Felt::new_unchecked(17)]);
let challenges = Challenges::<QuadFelt>::new(alpha, beta, 1, 1);
let empty_row: Vec<Felt> = vec![];
let periodic_values: Vec<Felt> = vec![];
let shape =
<SmokeAir as LookupAir<ProverLookupBuilder<'_, Felt, QuadFelt>>>::column_shape(&air)
.to_vec();
let mut fractions = LookupFractions::<Felt, QuadFelt>::from_shape(shape, NUM_ROWS);
for _row in 0..NUM_ROWS {
let window = RowWindow::from_two_rows(&empty_row, &empty_row);
let mut lb = ProverLookupBuilder::new(
window,
&periodic_values,
&challenges,
&air,
&mut fractions,
);
air.eval(&mut lb);
}
assert_eq!(fractions.num_columns(), 2);
assert_eq!(fractions.shape(), &SMOKE_SHAPE);
assert_eq!(fractions.counts().len(), NUM_ROWS * 2);
for row_counts in fractions.counts().chunks(2) {
assert_eq!(row_counts, &[2, 1]);
}
assert_eq!(fractions.fractions().len(), 3 * NUM_ROWS);
for (i, (m, d)) in fractions.fractions().iter().enumerate() {
assert_ne!(*d, QuadFelt::ZERO);
let expected_m = match i % 3 {
0 => Felt::ONE,
1 => Felt::NEG_ONE,
_ => Felt::ONE,
};
assert_eq!(*m, expected_m);
}
let aux = accumulate_slow(&fractions);
assert_eq!(aux.len(), 2);
for col_aux in &aux {
assert_eq!(col_aux.len(), NUM_ROWS + 1);
}
assert_eq!(aux[0][0], QuadFelt::ZERO, "accumulator initial must be zero");
let d1 = SmokeMsg { value: Felt::ONE }.encode(&challenges);
let d2 = SmokeMsg { value: Felt::new_unchecked(2) }.encode(&challenges);
let d3 = SmokeMsg { value: Felt::new_unchecked(3) }.encode(&challenges);
let delta0 = d1.try_inverse().unwrap() - d2.try_inverse().unwrap();
let delta1 = d3.try_inverse().unwrap();
for r in 0..NUM_ROWS {
assert_eq!(aux[0][r + 1] - aux[0][r], delta0 + delta1);
}
for &entry in aux[1].iter().take(NUM_ROWS) {
assert_eq!(entry, delta1);
}
}
}