use core::marker::PhantomData;
use miden_core::field::PrimeCharacteristicRing;
use miden_crypto::stark::air::{ExtensionBuilder, LiftedAirBuilder, WindowAccess};
use super::{
Challenges, Deg, LookupAir, LookupBatch, LookupBuilder, LookupColumn, LookupGroup,
LookupMessage,
};
pub struct ConstraintLookupBuilder<'ab, AB>
where
AB: LiftedAirBuilder + 'ab,
{
ab: &'ab mut AB,
challenges: Challenges<AB::ExprEF>,
column_idx: usize,
}
impl<'ab, AB> ConstraintLookupBuilder<'ab, AB>
where
AB: LiftedAirBuilder,
{
pub fn new<A>(ab: &'ab mut AB, air: &A) -> Self
where
A: LookupAir<Self>,
{
let (alpha, beta): (AB::ExprEF, AB::ExprEF) = {
let r = ab.permutation_randomness();
(r[0].into(), r[1].into())
};
let challenges =
Challenges::<AB::ExprEF>::new(alpha, beta, air.max_message_width(), air.num_bus_ids());
Self { ab, challenges, column_idx: 0 }
}
}
impl<'ab, AB> LookupBuilder for ConstraintLookupBuilder<'ab, AB>
where
AB: LiftedAirBuilder,
{
type F = AB::F;
type Expr = AB::Expr;
type Var = AB::Var;
type EF = AB::EF;
type ExprEF = AB::ExprEF;
type VarEF = AB::VarEF;
type PeriodicVar = AB::PeriodicVar;
type MainWindow = AB::MainWindow;
type Column<'a>
= ConstraintColumn<'a, AB>
where
Self: 'a,
AB: 'a;
fn main(&self) -> Self::MainWindow {
self.ab.main()
}
fn periodic_values(&self) -> &[Self::PeriodicVar] {
self.ab.periodic_values()
}
fn next_column<'a, R>(
&'a mut self,
f: impl FnOnce(&mut Self::Column<'a>) -> R,
_deg: Deg,
) -> R {
let mut col = ConstraintColumn {
challenges: &self.challenges,
u: AB::ExprEF::ONE,
v: AB::ExprEF::ZERO,
_phantom: PhantomData,
};
let result = f(&mut col);
let ConstraintColumn { u, v, .. } = col;
let col_idx = self.column_idx;
self.column_idx += 1;
if col_idx == 0 {
let (acc, acc_next, committed_final) = {
let mp = self.ab.permutation();
let acc: AB::ExprEF = mp.current_slice()[0].into();
let acc_next: AB::ExprEF = mp.next_slice()[0].into();
let committed_final: AB::ExprEF = self.ab.permutation_values()[0].clone().into();
(acc, acc_next, committed_final)
};
let all_curr_sum = {
let mp = self.ab.permutation();
let current = mp.current_slice();
let mut sum: AB::ExprEF = current[0].into();
for &aux_i in ¤t[1..] {
sum += aux_i.into();
}
sum
};
self.ab.when_first_row().assert_zero_ext(acc.clone());
self.ab.when_transition().assert_zero_ext(u * (acc_next - all_curr_sum) - v);
self.ab.when_last_row().assert_eq_ext(acc, committed_final);
} else {
let acc_curr: AB::ExprEF = {
let mp = self.ab.permutation();
mp.current_slice()[col_idx].into()
};
self.ab.when_transition().assert_zero_ext(u * acc_curr.clone() - v);
self.ab.when_last_row().assert_zero_ext(acc_curr);
}
result
}
}
pub struct ConstraintColumn<'a, AB>
where
AB: LiftedAirBuilder + 'a,
{
challenges: &'a Challenges<AB::ExprEF>,
u: AB::ExprEF,
v: AB::ExprEF,
_phantom: PhantomData<AB>,
}
impl<'a, AB> ConstraintColumn<'a, AB>
where
AB: LiftedAirBuilder,
{
fn fold_group(&mut self, u_g: AB::ExprEF, v_g: AB::ExprEF) {
self.v = self.v.clone() * u_g.clone() + v_g * self.u.clone();
self.u = self.u.clone() * u_g;
}
}
impl<'a, AB> LookupColumn for ConstraintColumn<'a, AB>
where
AB: LiftedAirBuilder,
{
type Expr = AB::Expr;
type ExprEF = AB::ExprEF;
type Group<'g>
= ConstraintGroup<'g, AB>
where
Self: 'g,
AB: 'g;
fn group<'g>(
&'g mut self,
_name: &'static str,
f: impl FnOnce(&mut Self::Group<'g>),
_deg: Deg,
) {
let mut group = ConstraintGroup {
challenges: self.challenges,
u: AB::ExprEF::ONE,
v: AB::ExprEF::ZERO,
_phantom: PhantomData,
};
f(&mut group);
let ConstraintGroup { u, v, .. } = group;
self.fold_group(u, v);
}
fn group_with_cached_encoding<'g>(
&'g mut self,
_name: &'static str,
_canonical: impl FnOnce(&mut Self::Group<'g>),
encoded: impl FnOnce(&mut Self::Group<'g>),
_deg: Deg,
) {
let mut group = ConstraintGroup {
challenges: self.challenges,
u: AB::ExprEF::ONE,
v: AB::ExprEF::ZERO,
_phantom: PhantomData,
};
encoded(&mut group);
let ConstraintGroup { u, v, .. } = group;
self.fold_group(u, v);
}
}
pub struct ConstraintGroup<'a, AB>
where
AB: LiftedAirBuilder + 'a,
{
challenges: &'a Challenges<AB::ExprEF>,
u: AB::ExprEF,
v: AB::ExprEF,
_phantom: PhantomData<AB>,
}
impl<'a, AB> LookupGroup for ConstraintGroup<'a, AB>
where
AB: LiftedAirBuilder,
{
type Expr = AB::Expr;
type ExprEF = AB::ExprEF;
type Batch<'b>
= ConstraintBatch<'b, AB>
where
Self: 'b,
AB: 'b;
fn add<M>(&mut self, _name: &'static str, flag: Self::Expr, msg: impl FnOnce() -> M, _deg: Deg)
where
M: LookupMessage<Self::Expr, Self::ExprEF>,
{
let v = msg().encode(self.challenges);
self.u += (v - AB::ExprEF::ONE) * flag.clone();
self.v += flag;
}
fn remove<M>(
&mut self,
_name: &'static str,
flag: Self::Expr,
msg: impl FnOnce() -> M,
_deg: Deg,
) where
M: LookupMessage<Self::Expr, Self::ExprEF>,
{
let v = msg().encode(self.challenges);
self.u += (v - AB::ExprEF::ONE) * flag.clone();
self.v -= flag;
}
fn insert<M>(
&mut self,
_name: &'static str,
flag: Self::Expr,
multiplicity: Self::Expr,
msg: impl FnOnce() -> M,
_deg: Deg,
) where
M: LookupMessage<Self::Expr, Self::ExprEF>,
{
let v = msg().encode(self.challenges);
self.u += (v - AB::ExprEF::ONE) * flag.clone();
self.v += flag * multiplicity;
}
fn batch<'b>(
&'b mut self,
_name: &'static str,
flag: Self::Expr,
build: impl FnOnce(&mut Self::Batch<'b>),
_deg: Deg,
) {
let mut batch = ConstraintBatch {
challenges: self.challenges,
n: AB::ExprEF::ZERO,
d: AB::ExprEF::ONE,
_phantom: PhantomData,
};
build(&mut batch);
let ConstraintBatch { n, d, .. } = batch;
self.u += (d - AB::ExprEF::ONE) * flag.clone();
self.v += n * flag;
}
fn beta_powers(&self) -> &[Self::ExprEF] {
&self.challenges.beta_powers[..]
}
fn bus_prefix(&self, bus_id: usize) -> Self::ExprEF {
self.challenges.bus_prefix[bus_id].clone()
}
fn insert_encoded(
&mut self,
_name: &'static str,
flag: Self::Expr,
multiplicity: Self::Expr,
encoded: impl FnOnce() -> Self::ExprEF,
_deg: Deg,
) {
let v = encoded();
self.u += (v - AB::ExprEF::ONE) * flag.clone();
self.v += flag * multiplicity;
}
}
pub struct ConstraintBatch<'a, AB>
where
AB: LiftedAirBuilder + 'a,
{
challenges: &'a Challenges<AB::ExprEF>,
n: AB::ExprEF,
d: AB::ExprEF,
_phantom: PhantomData<AB>,
}
impl<'a, AB> LookupBatch for ConstraintBatch<'a, AB>
where
AB: LiftedAirBuilder,
{
type Expr = AB::Expr;
type ExprEF = AB::ExprEF;
fn add<M>(&mut self, _name: &'static str, msg: M, _deg: Deg)
where
M: LookupMessage<Self::Expr, Self::ExprEF>,
{
let v = msg.encode(self.challenges);
let d_prev = self.d.clone();
self.n = self.n.clone() * v.clone() + d_prev;
self.d = self.d.clone() * v;
}
fn remove<M>(&mut self, _name: &'static str, msg: M, _deg: Deg)
where
M: LookupMessage<Self::Expr, Self::ExprEF>,
{
let v = msg.encode(self.challenges);
let d_prev = self.d.clone();
self.n = self.n.clone() * v.clone() - d_prev;
self.d = self.d.clone() * v;
}
fn insert<M>(&mut self, _name: &'static str, multiplicity: Self::Expr, msg: M, _deg: Deg)
where
M: LookupMessage<Self::Expr, Self::ExprEF>,
{
let v = msg.encode(self.challenges);
let d_prev = self.d.clone();
self.n = self.n.clone() * v.clone() + d_prev * multiplicity;
self.d = self.d.clone() * v;
}
fn insert_encoded(
&mut self,
_name: &'static str,
multiplicity: Self::Expr,
encoded: impl FnOnce() -> Self::ExprEF,
_deg: Deg,
) {
let v = encoded();
let d_prev = self.d.clone();
self.n = self.n.clone() * v.clone() + d_prev * multiplicity;
self.d = self.d.clone() * v;
}
}