use alloc::vec::Vec;
use core::{fmt, marker::PhantomData};
use miden_core::field::{PrimeCharacteristicRing, QuadFelt};
use miden_crypto::{
rand::random_felt,
stark::air::{
AirBuilder, PeriodicAirBuilder, PermutationAirBuilder,
symbolic::{
BaseEntry, BaseLeaf, ExtEntry, ExtLeaf, SymbolicAirBuilder, SymbolicExpr,
SymbolicExpression, SymbolicExpressionExt, SymbolicVariable, SymbolicVariableExt,
},
},
};
use super::super::{
Challenges, Deg, LookupAir, LookupBatch, LookupBuilder, LookupColumn, LookupGroup,
LookupMessage,
};
use crate::Felt;
type Inner = SymbolicAirBuilder<Felt, QuadFelt>;
type Expr = SymbolicExpression<Felt>;
type ExprEF = SymbolicExpressionExt<Felt, QuadFelt>;
#[derive(Clone, Debug)]
pub enum ValidationError {
NumColumnsMismatch { declared: usize, observed: usize },
ColumnDegreeMismatch {
column_idx: usize,
declared: Deg,
observed: Deg,
},
GroupDegreeMismatch {
column_idx: usize,
group_idx: usize,
name: &'static str,
declared: Deg,
observed: Deg,
},
EncodingMismatch {
column_idx: usize,
group_idx: usize,
name: &'static str,
diff: QuadFelt,
},
ScopeViolation {
column_idx: usize,
group_idx: usize,
name: &'static str,
},
}
impl fmt::Display for ValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NumColumnsMismatch { declared, observed } => {
write!(f, "num_columns mismatch: declared {declared}, observed {observed}")
},
Self::ColumnDegreeMismatch { column_idx, declared, observed } => write!(
f,
"column[{column_idx}] degree mismatch: declared (v={}, u={}), observed (v={}, u={})",
declared.v, declared.u, observed.v, observed.u,
),
Self::GroupDegreeMismatch {
column_idx,
group_idx,
name,
declared,
observed,
} => write!(
f,
"column[{column_idx}] group[{group_idx}] {name:?} degree mismatch: declared (v={}, u={}), observed (v={}, u={})",
declared.v, declared.u, observed.v, observed.u,
),
Self::EncodingMismatch { column_idx, group_idx, name, diff } => write!(
f,
"column[{column_idx}] group[{group_idx}] {name:?} cached-encoding mismatch: V_c·U_e − V_e·U_c = {diff:?}",
),
Self::ScopeViolation { column_idx, group_idx, name } => write!(
f,
"column[{column_idx}] group[{group_idx}] {name:?} simple group called insert_encoded",
),
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct ValidateLayout {
pub trace_width: usize,
pub num_public_values: usize,
pub num_periodic_columns: usize,
pub permutation_width: usize,
pub num_permutation_challenges: usize,
pub num_permutation_values: usize,
}
impl ValidateLayout {
fn to_symbolic(self) -> miden_crypto::stark::air::symbolic::AirLayout {
miden_crypto::stark::air::symbolic::AirLayout {
preprocessed_width: 0,
main_width: self.trace_width,
num_public_values: self.num_public_values,
permutation_width: self.permutation_width,
num_permutation_challenges: self.num_permutation_challenges,
num_permutation_values: self.num_permutation_values,
num_periodic_columns: self.num_periodic_columns,
}
}
}
pub fn validate<A>(air: &A, layout: ValidateLayout) -> Result<(), ValidationError>
where
for<'ab, 'r> A: LookupAir<ValidationBuilder<'ab, 'r>>,
{
let current: Vec<Felt> = (0..layout.trace_width).map(|_| random_felt()).collect();
let next: Vec<Felt> = (0..layout.trace_width).map(|_| random_felt()).collect();
let periodic: Vec<Felt> = (0..layout.num_periodic_columns).map(|_| random_felt()).collect();
let alpha = QuadFelt::new([random_felt(), random_felt()]);
let beta = QuadFelt::new([random_felt(), random_felt()]);
let mut sym = SymbolicAirBuilder::<Felt, QuadFelt>::new(layout.to_symbolic());
let row_valuation = RowValuation {
current: ¤t,
next: &next,
periodic: &periodic,
alpha,
beta,
};
let mut builder = ValidationBuilder::new(&mut sym, air, row_valuation);
air.eval(&mut builder);
match builder.take_error() {
Some(err) => Err(err),
None => Ok(()),
}
}
pub trait ValidateLookupAir {
fn validate(&self, layout: ValidateLayout) -> Result<(), ValidationError>;
}
impl<A> ValidateLookupAir for A
where
for<'ab, 'r> A: LookupAir<ValidationBuilder<'ab, 'r>>,
{
fn validate(&self, layout: ValidateLayout) -> Result<(), ValidationError> {
validate(self, layout)
}
}
#[derive(Clone, Copy)]
struct RowValuation<'r> {
current: &'r [Felt],
next: &'r [Felt],
periodic: &'r [Felt],
alpha: QuadFelt,
beta: QuadFelt,
}
impl<'r> RowValuation<'r> {
fn eval_base(&self, expr: &Expr) -> Felt {
match expr {
SymbolicExpr::Leaf(leaf) => self.eval_base_leaf(leaf),
SymbolicExpr::Add { x, y, .. } => self.eval_base(x) + self.eval_base(y),
SymbolicExpr::Sub { x, y, .. } => self.eval_base(x) - self.eval_base(y),
SymbolicExpr::Neg { x, .. } => -self.eval_base(x),
SymbolicExpr::Mul { x, y, .. } => self.eval_base(x) * self.eval_base(y),
}
}
fn eval_base_leaf(&self, leaf: &BaseLeaf<Felt>) -> Felt {
match leaf {
BaseLeaf::Constant(c) => *c,
BaseLeaf::Variable(SymbolicVariable { entry, index, .. }) => match entry {
BaseEntry::Main { offset: 0 } => self.current[*index],
BaseEntry::Main { offset: 1 } => self.next[*index],
BaseEntry::Periodic => self.periodic[*index],
BaseEntry::Main { offset } => {
panic!("unexpected main offset {offset} in LookupAir::eval")
},
BaseEntry::Preprocessed { .. } | BaseEntry::Public => {
panic!("unexpected {entry:?} leaf in LookupAir::eval")
},
},
BaseLeaf::IsFirstRow | BaseLeaf::IsLastRow | BaseLeaf::IsTransition => {
panic!("selector leaf {leaf:?} unexpected in LookupAir::eval")
},
}
}
fn eval_ext(&self, expr: &ExprEF) -> QuadFelt {
match expr {
SymbolicExpr::Leaf(leaf) => self.eval_ext_leaf(leaf),
SymbolicExpr::Add { x, y, .. } => self.eval_ext(x) + self.eval_ext(y),
SymbolicExpr::Sub { x, y, .. } => self.eval_ext(x) - self.eval_ext(y),
SymbolicExpr::Neg { x, .. } => -self.eval_ext(x),
SymbolicExpr::Mul { x, y, .. } => self.eval_ext(x) * self.eval_ext(y),
}
}
fn eval_ext_leaf(&self, leaf: &ExtLeaf<Felt, QuadFelt>) -> QuadFelt {
match leaf {
ExtLeaf::Base(inner) => self.eval_base(inner).into(),
ExtLeaf::ExtConstant(c) => *c,
ExtLeaf::ExtVariable(SymbolicVariableExt { entry, index, .. }) => match entry {
ExtEntry::Challenge => match *index {
0 => self.alpha,
1 => self.beta,
i => panic!("unexpected challenge index {i} in LookupAir::eval"),
},
ExtEntry::Permutation { .. } | ExtEntry::PermutationValue => {
panic!("unexpected {entry:?} leaf in LookupAir::eval")
},
},
}
}
}
pub struct ValidationBuilder<'ab, 'r> {
ab: &'ab mut Inner,
sym_challenges: Challenges<ExprEF>,
row_valuation: RowValuation<'r>,
column_idx: usize,
declared_columns: usize,
error: Option<ValidationError>,
}
impl<'ab, 'r> ValidationBuilder<'ab, 'r> {
fn new<A>(ab: &'ab mut Inner, air: &A, row_valuation: RowValuation<'r>) -> Self
where
A: LookupAir<Self>,
{
let (alpha, beta): (ExprEF, ExprEF) = {
let r = ab.permutation_randomness();
(r[0].into(), r[1].into())
};
let sym_challenges =
Challenges::<ExprEF>::new(alpha, beta, air.max_message_width(), air.num_bus_ids());
Self {
ab,
sym_challenges,
row_valuation,
column_idx: 0,
declared_columns: air.num_columns(),
error: None,
}
}
fn take_error(mut self) -> Option<ValidationError> {
if self.error.is_none() && self.column_idx != self.declared_columns {
self.error = Some(ValidationError::NumColumnsMismatch {
declared: self.declared_columns,
observed: self.column_idx,
});
}
self.error
}
}
impl<'ab, 'r> LookupBuilder for ValidationBuilder<'ab, 'r> {
type F = Felt;
type Expr = Expr;
type Var = SymbolicVariable<Felt>;
type EF = QuadFelt;
type ExprEF = ExprEF;
type VarEF = SymbolicVariableExt<Felt, QuadFelt>;
type PeriodicVar = SymbolicVariable<Felt>;
type MainWindow = <Inner as AirBuilder>::MainWindow;
type Column<'c>
= ValidationColumn<'c, 'r>
where
Self: 'c;
fn main(&self) -> Self::MainWindow {
self.ab.main()
}
fn periodic_values(&self) -> &[Self::PeriodicVar] {
self.ab.periodic_values()
}
fn next_column<'c, R>(&'c mut self, f: impl FnOnce(&mut Self::Column<'c>) -> R, deg: Deg) -> R {
let column_idx = self.column_idx;
self.column_idx += 1;
let already_errored = self.error.is_some();
let mut col = ValidationColumn {
challenges: &self.sym_challenges,
row_valuation: self.row_valuation,
u: ExprEF::ONE,
v: ExprEF::ZERO,
column_idx,
next_group_idx: 0,
error: None,
_phantom: PhantomData,
};
let result = f(&mut col);
if !already_errored {
if let Some(err) = col.error.take() {
self.error = Some(err);
} else {
let observed = Deg {
v: col.v.degree_multiple(),
u: col.u.degree_multiple(),
};
if observed != deg {
self.error = Some(ValidationError::ColumnDegreeMismatch {
column_idx,
declared: deg,
observed,
});
}
}
}
result
}
}
pub struct ValidationColumn<'c, 'r> {
challenges: &'c Challenges<ExprEF>,
row_valuation: RowValuation<'r>,
u: ExprEF,
v: ExprEF,
column_idx: usize,
next_group_idx: usize,
error: Option<ValidationError>,
_phantom: PhantomData<&'c ()>,
}
impl<'c, 'r> ValidationColumn<'c, 'r> {
fn fold_group(&mut self, u_g: ExprEF, v_g: ExprEF) {
self.v = self.v.clone() * u_g.clone() + v_g * self.u.clone();
self.u = self.u.clone() * u_g;
}
fn check_group_degree(
&mut self,
name: &'static str,
group_idx: usize,
declared: Deg,
u: &ExprEF,
v: &ExprEF,
) {
if self.error.is_some() {
return;
}
let observed = Deg {
v: v.degree_multiple(),
u: u.degree_multiple(),
};
if observed != declared {
self.error = Some(ValidationError::GroupDegreeMismatch {
column_idx: self.column_idx,
group_idx,
name,
declared,
observed,
});
}
}
}
fn fresh_group<'g>(
challenges: &'g Challenges<ExprEF>,
inside_encoded_closure: bool,
) -> ValidationGroup<'g> {
ValidationGroup {
challenges,
u: ExprEF::ONE,
v: ExprEF::ZERO,
inside_encoded_closure,
used_insert_encoded: false,
}
}
impl<'c, 'r> LookupColumn for ValidationColumn<'c, 'r> {
type Expr = Expr;
type ExprEF = ExprEF;
type Group<'g>
= ValidationGroup<'g>
where
Self: 'g;
fn group<'g>(&'g mut self, name: &'static str, f: impl FnOnce(&mut Self::Group<'g>), deg: Deg) {
let group_idx = self.next_group_idx;
self.next_group_idx += 1;
let mut group = fresh_group(self.challenges, false);
f(&mut group);
let ValidationGroup { u, v, used_insert_encoded, .. } = group;
if self.error.is_none() && used_insert_encoded {
self.error = Some(ValidationError::ScopeViolation {
column_idx: self.column_idx,
group_idx,
name,
});
}
self.check_group_degree(name, group_idx, deg, &u, &v);
self.fold_group(u, v);
}
fn group_with_cached_encoding<'g>(
&'g mut self,
name: &'static str,
canonical: impl FnOnce(&mut Self::Group<'g>),
encoded: impl FnOnce(&mut Self::Group<'g>),
deg: Deg,
) {
let group_idx = self.next_group_idx;
self.next_group_idx += 1;
let mut canon = fresh_group(self.challenges, false);
canonical(&mut canon);
let mut enc = fresh_group(self.challenges, true);
encoded(&mut enc);
if self.error.is_none() {
let diff_expr = canon.v.clone() * enc.u.clone() - enc.v.clone() * canon.u;
let diff = self.row_valuation.eval_ext(&diff_expr);
if diff != QuadFelt::ZERO {
self.error = Some(ValidationError::EncodingMismatch {
column_idx: self.column_idx,
group_idx,
name,
diff,
});
}
}
let ValidationGroup { u, v, .. } = enc;
self.check_group_degree(name, group_idx, deg, &u, &v);
self.fold_group(u, v);
}
}
pub struct ValidationGroup<'g> {
challenges: &'g Challenges<ExprEF>,
u: ExprEF,
v: ExprEF,
inside_encoded_closure: bool,
used_insert_encoded: bool,
}
impl<'g> LookupGroup for ValidationGroup<'g> {
type Expr = Expr;
type ExprEF = ExprEF;
type Batch<'b>
= ValidationBatch<'b>
where
Self: 'b;
fn add<M>(&mut self, _name: &'static str, flag: Expr, msg: impl FnOnce() -> M, _deg: Deg)
where
M: LookupMessage<Expr, ExprEF>,
{
let v_msg = msg().encode(self.challenges);
self.u += (v_msg - ExprEF::ONE) * flag.clone();
self.v += flag;
}
fn remove<M>(&mut self, _name: &'static str, flag: Expr, msg: impl FnOnce() -> M, _deg: Deg)
where
M: LookupMessage<Expr, ExprEF>,
{
let v_msg = msg().encode(self.challenges);
self.u += (v_msg - ExprEF::ONE) * flag.clone();
self.v -= flag;
}
fn insert<M>(
&mut self,
_name: &'static str,
flag: Expr,
multiplicity: Expr,
msg: impl FnOnce() -> M,
_deg: Deg,
) where
M: LookupMessage<Expr, ExprEF>,
{
let v_msg = msg().encode(self.challenges);
self.u += (v_msg - ExprEF::ONE) * flag.clone();
self.v += flag * multiplicity;
}
fn batch<'b>(
&'b mut self,
_name: &'static str,
flag: Expr,
build: impl FnOnce(&mut Self::Batch<'b>),
_deg: Deg,
) {
let mut batch = ValidationBatch {
challenges: self.challenges,
n: ExprEF::ZERO,
d: ExprEF::ONE,
};
build(&mut batch);
let ValidationBatch { n, d, .. } = batch;
self.u += (d - ExprEF::ONE) * flag.clone();
self.v += n * flag;
}
fn beta_powers(&self) -> &[ExprEF] {
&self.challenges.beta_powers[..]
}
fn bus_prefix(&self, bus_id: usize) -> ExprEF {
self.challenges.bus_prefix[bus_id].clone()
}
fn insert_encoded(
&mut self,
_name: &'static str,
flag: Expr,
multiplicity: Expr,
encoded: impl FnOnce() -> ExprEF,
_deg: Deg,
) {
if !self.inside_encoded_closure {
self.used_insert_encoded = true;
}
let v_msg = encoded();
self.u += (v_msg - ExprEF::ONE) * flag.clone();
self.v += flag * multiplicity;
}
}
pub struct ValidationBatch<'b> {
challenges: &'b Challenges<ExprEF>,
n: ExprEF,
d: ExprEF,
}
impl<'b> LookupBatch for ValidationBatch<'b> {
type Expr = Expr;
type ExprEF = ExprEF;
fn insert<M>(&mut self, _name: &'static str, multiplicity: Expr, msg: M, _deg: Deg)
where
M: LookupMessage<Expr, ExprEF>,
{
let v_msg = msg.encode(self.challenges);
let d_prev = self.d.clone();
self.n = self.n.clone() * v_msg.clone() + d_prev * multiplicity;
self.d = self.d.clone() * v_msg;
}
fn insert_encoded(
&mut self,
_name: &'static str,
multiplicity: Expr,
encoded: impl FnOnce() -> ExprEF,
_deg: Deg,
) {
let v_msg = encoded();
let d_prev = self.d.clone();
self.n = self.n.clone() * v_msg.clone() + d_prev * multiplicity;
self.d = self.d.clone() * v_msg;
}
}