use alloc::vec::Vec;
use p3_field::{ExtensionField, Field};
use p3_matrix::Matrix;
use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
use p3_matrix::stack::ViewPair;
use crate::{
Air, AirBuilder, AirBuilderWithContext, ExtensionBuilder, PermutationAirBuilder, RowWindow,
};
#[derive(Debug, Clone)]
pub struct ConstraintFailure {
pub row: usize,
pub constraint: usize,
}
#[derive(Debug)]
pub struct DebugConstraintBuilder<'a, F: Field, EF: ExtensionField<F> = F> {
row_index: usize,
constraint_index: usize,
failures: Vec<ConstraintFailure>,
main: ViewPair<'a, F>,
preprocessed: RowWindow<'a, F>,
public_values: &'a [F],
is_first_row: F,
is_last_row: F,
is_transition: F,
permutation: Option<ViewPair<'a, EF>>,
permutation_challenges: &'a [EF],
permutation_values: &'a [EF],
}
impl<'a, F: Field> DebugConstraintBuilder<'a, F> {
pub fn new(
row_index: usize,
main: ViewPair<'a, F>,
preprocessed: ViewPair<'a, F>,
public_values: &'a [F],
is_first_row: F,
is_last_row: F,
is_transition: F,
) -> Self {
Self {
row_index,
constraint_index: 0,
failures: Vec::new(),
main,
preprocessed: RowWindow::from_two_rows(
preprocessed.top.values,
preprocessed.bottom.values,
),
public_values,
is_first_row,
is_last_row,
is_transition,
permutation: None,
permutation_challenges: &[],
permutation_values: &[],
}
}
}
impl<'a, F: Field, EF: ExtensionField<F>> DebugConstraintBuilder<'a, F, EF> {
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments)]
pub fn new_with_permutation(
row_index: usize,
main: ViewPair<'a, F>,
preprocessed: ViewPair<'a, F>,
public_values: &'a [F],
is_first_row: F,
is_last_row: F,
is_transition: F,
permutation: ViewPair<'a, EF>,
permutation_challenges: &'a [EF],
permutation_values: &'a [EF],
) -> Self {
Self {
row_index,
constraint_index: 0,
failures: Vec::new(),
main,
preprocessed: RowWindow::from_two_rows(
preprocessed.top.values,
preprocessed.bottom.values,
),
public_values,
is_first_row,
is_last_row,
is_transition,
permutation: Some(permutation),
permutation_challenges,
permutation_values,
}
}
pub const fn has_failures(&self) -> bool {
!self.failures.is_empty()
}
pub fn failures(&self) -> &[ConstraintFailure] {
&self.failures
}
pub fn into_failures(self) -> Vec<ConstraintFailure> {
self.failures
}
}
impl<'a, F, EF> AirBuilder for DebugConstraintBuilder<'a, F, EF>
where
F: Field,
EF: ExtensionField<F>,
{
type F = F;
type Expr = F;
type Var = F;
type PreprocessedWindow = RowWindow<'a, F>;
type MainWindow = RowWindow<'a, F>;
type PublicVar = F;
fn main(&self) -> Self::MainWindow {
RowWindow::from_two_rows(self.main.top.values, self.main.bottom.values)
}
fn preprocessed(&self) -> &Self::PreprocessedWindow {
&self.preprocessed
}
fn is_first_row(&self) -> Self::Expr {
self.is_first_row
}
fn is_last_row(&self) -> Self::Expr {
self.is_last_row
}
fn is_transition_window(&self, size: usize) -> Self::Expr {
assert!(size <= 2, "only two-row windows are supported, got {size}");
self.is_transition
}
fn assert_zero<I: Into<Self::Expr>>(&mut self, x: I) {
if x.into() != F::ZERO {
self.failures.push(ConstraintFailure {
row: self.row_index,
constraint: self.constraint_index,
});
}
self.constraint_index += 1;
}
fn public_values(&self) -> &[Self::PublicVar] {
self.public_values
}
}
impl<F: Field, EF: ExtensionField<F>> AirBuilderWithContext for DebugConstraintBuilder<'_, F, EF> {
type EvalContext = ();
fn eval_context(&self) -> &Self::EvalContext {
&()
}
}
impl<F: Field, EF: ExtensionField<F>> ExtensionBuilder for DebugConstraintBuilder<'_, F, EF> {
type EF = EF;
type ExprEF = EF;
type VarEF = EF;
fn assert_zero_ext<I>(&mut self, x: I)
where
I: Into<Self::ExprEF>,
{
if x.into() != EF::ZERO {
self.failures.push(ConstraintFailure {
row: self.row_index,
constraint: self.constraint_index,
});
}
self.constraint_index += 1;
}
}
impl<'a, F: Field, EF: ExtensionField<F>> PermutationAirBuilder
for DebugConstraintBuilder<'a, F, EF>
{
type MP = RowWindow<'a, EF>;
type RandomVar = EF;
type PermutationVar = EF;
fn permutation(&self) -> Self::MP {
let p = self.permutation
.expect("permutation() called on a builder created without permutation data; use new_with_permutation()");
RowWindow::from_two_rows(p.top.values, p.bottom.values)
}
fn permutation_randomness(&self) -> &[Self::RandomVar] {
self.permutation_challenges
}
fn permutation_values(&self) -> &[Self::PermutationVar] {
self.permutation_values
}
}
#[allow(unused)] pub fn check_constraints<F, A>(air: &A, main: &RowMajorMatrix<F>, public_values: &[F])
where
F: Field,
A: for<'a> Air<DebugConstraintBuilder<'a, F>>,
{
let height = main.height();
let preprocessed = air.preprocessed_trace();
for row_index in 0..height {
let row_index_next = (row_index + 1) % height;
let local = unsafe { main.row_slice_unchecked(row_index) };
let next = unsafe { main.row_slice_unchecked(row_index_next) };
let main_pair = ViewPair::new(
RowMajorMatrixView::new_row(&*local),
RowMajorMatrixView::new_row(&*next),
);
let (prep_local, prep_next) = preprocessed.as_ref().map_or((None, None), |prep| unsafe {
(
Some(prep.row_slice_unchecked(row_index)),
Some(prep.row_slice_unchecked(row_index_next)),
)
});
let preprocessed_pair = match (prep_local.as_ref(), prep_next.as_ref()) {
(Some(l), Some(n)) => ViewPair::new(
RowMajorMatrixView::new_row(&**l),
RowMajorMatrixView::new_row(&**n),
),
_ => ViewPair::new(
RowMajorMatrixView::new(&[], 0),
RowMajorMatrixView::new(&[], 0),
),
};
let mut builder = DebugConstraintBuilder::new(
row_index,
main_pair,
preprocessed_pair,
public_values,
F::from_bool(row_index == 0),
F::from_bool(row_index == height - 1),
F::from_bool(row_index != height - 1),
);
air.eval(&mut builder);
if builder.has_failures() {
let indices: Vec<usize> = builder.failures().iter().map(|f| f.constraint).collect();
panic!(
"constraints not satisfied on row {row_index}: \
failed constraint indices = {indices:?}"
);
}
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use p3_baby_bear::BabyBear;
use p3_field::PrimeCharacteristicRing;
use super::*;
use crate::{BaseAir, WindowAccess};
#[derive(Debug)]
struct RowLogicAir<const W: usize>;
impl<F: Field, const W: usize> BaseAir<F> for RowLogicAir<W> {
fn width(&self) -> usize {
W
}
}
impl<F: Field, const W: usize> Air<DebugConstraintBuilder<'_, F>> for RowLogicAir<W> {
fn eval(&self, builder: &mut DebugConstraintBuilder<'_, F>) {
let main = builder.main();
for col in 0..W {
let current = main.current(col).unwrap();
let next = main.next(col).unwrap();
builder.when_transition().assert_eq(next, current + F::ONE);
}
let public_values = builder.public_values;
let mut when_last = builder.when(builder.is_last_row);
for (i, &pv) in public_values.iter().enumerate().take(W) {
when_last.assert_eq(main.current(i).unwrap(), pv);
}
}
}
#[test]
fn test_incremental_rows_with_last_row_check() {
let air = RowLogicAir::<2>;
let values = vec![
BabyBear::ONE,
BabyBear::ONE, BabyBear::new(2),
BabyBear::new(2), BabyBear::new(3),
BabyBear::new(3), BabyBear::new(4),
BabyBear::new(4), ];
let main = RowMajorMatrix::new(values, 2);
check_constraints(&air, &main, &[BabyBear::new(4); 2]);
}
#[test]
#[should_panic]
fn test_incorrect_increment_logic() {
let air = RowLogicAir::<2>;
let values = vec![
BabyBear::ONE,
BabyBear::ONE, BabyBear::new(2),
BabyBear::new(2), BabyBear::new(5),
BabyBear::new(5), BabyBear::new(6),
BabyBear::new(6), ];
let main = RowMajorMatrix::new(values, 2);
check_constraints(&air, &main, &[BabyBear::new(6); 2]);
}
#[test]
#[should_panic]
fn test_wrong_last_row_public_value() {
let air = RowLogicAir::<2>;
let values = vec![
BabyBear::ONE,
BabyBear::ONE, BabyBear::new(2),
BabyBear::new(2), BabyBear::new(3),
BabyBear::new(3), BabyBear::new(4),
BabyBear::new(4), ];
let main = RowMajorMatrix::new(values, 2);
check_constraints(&air, &main, &[BabyBear::new(4), BabyBear::new(5)]);
}
#[test]
fn test_single_row_wraparound_logic() {
let air = RowLogicAir::<2>;
let values = vec![
BabyBear::new(99),
BabyBear::new(77), ];
let main = RowMajorMatrix::new(values, 2);
check_constraints(&air, &main, &[BabyBear::new(99), BabyBear::new(77)]);
}
fn eval_single_row<const W: usize, A>(
air: &A,
row: [BabyBear; W],
) -> DebugConstraintBuilder<'static, BabyBear>
where
A: for<'a> Air<DebugConstraintBuilder<'a, BabyBear>>,
{
let row: &'static [BabyBear] = Vec::from(row).leak();
let view = RowMajorMatrixView::new_row(row);
let main = ViewPair::new(view, view);
let empty_view = RowMajorMatrixView::new(&[], 0);
let mut builder = DebugConstraintBuilder::new(
0,
main,
ViewPair::new(empty_view, empty_view),
&[],
BabyBear::ONE, BabyBear::ONE, BabyBear::ZERO, );
air.eval(&mut builder);
builder
}
#[derive(Debug)]
struct AllZeroAir<const W: usize>;
impl<F: Field, const W: usize> BaseAir<F> for AllZeroAir<W> {
fn width(&self) -> usize {
W
}
}
impl<F: Field, const W: usize> Air<DebugConstraintBuilder<'_, F>> for AllZeroAir<W> {
fn eval(&self, builder: &mut DebugConstraintBuilder<'_, F>) {
let main = builder.main();
for col in 0..W {
builder.assert_zero(main.current(col).unwrap());
}
}
}
#[test]
fn test_no_failures_when_all_constraints_pass() {
let builder = eval_single_row(&AllZeroAir::<3>, [BabyBear::ZERO; 3]);
assert!(!builder.has_failures());
assert!(builder.failures().is_empty());
}
#[test]
fn test_multiple_failures_collected() {
let builder = eval_single_row(
&AllZeroAir::<3>,
[BabyBear::ONE, BabyBear::ZERO, BabyBear::new(42)],
);
assert!(builder.has_failures());
let failures = builder.failures();
assert_eq!(failures.len(), 2);
assert_eq!(failures[0].row, 0);
assert_eq!(failures[0].constraint, 0);
assert_eq!(failures[1].row, 0);
assert_eq!(failures[1].constraint, 2);
}
#[test]
fn test_into_failures() {
let builder = eval_single_row(&AllZeroAir::<2>, [BabyBear::ONE, BabyBear::ONE]);
let failures = builder.into_failures();
assert_eq!(failures.len(), 2);
assert_eq!(failures[0].constraint, 0);
assert_eq!(failures[1].constraint, 1);
}
#[test]
#[should_panic(expected = "failed constraint indices = [0, 2]")]
fn test_panic_message_lists_all_failed_indices() {
let air = AllZeroAir::<3>;
let values = vec![BabyBear::ONE, BabyBear::ZERO, BabyBear::new(7)];
let main = RowMajorMatrix::new(values, 3);
check_constraints(&air, &main, &[]);
}
}