use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use core::fmt;
use p3_field::{ExtensionField, Field};
use p3_matrix::Matrix;
use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
use p3_matrix::stack::ViewPair;
use crate::{
Air, AirBuilder, AirBuilderWithContext, ExtensionBuilder, Name, NamedAirBuilder,
NamedExtensionBuilder, PermutationAirBuilder, RowWindow,
};
#[derive(Debug, Clone)]
pub struct ConstraintFailure {
pub row: usize,
pub constraint: usize,
pub label: Option<String>,
}
impl fmt::Display for ConstraintFailure {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "#{}", self.constraint)?;
if let Some(label) = &self.label {
write!(f, " {label:?}")?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ConstraintReport {
pub failures: Vec<ConstraintFailure>,
pub total_rows: usize,
pub total_constraints_per_row: usize,
}
impl ConstraintReport {
pub const fn is_ok(&self) -> bool {
self.failures.is_empty()
}
}
#[derive(Debug)]
pub struct DebugConstraintBuilder<'a, F: Field, EF: ExtensionField<F> = F> {
row_index: usize,
constraint_index: usize,
failures: Vec<ConstraintFailure>,
main: ViewPair<'a, F>,
preprocessed: RowWindow<'a, F>,
public_values: &'a [F],
is_first_row: F,
is_last_row: F,
is_transition: F,
permutation: Option<ViewPair<'a, EF>>,
permutation_challenges: &'a [EF],
permutation_values: &'a [EF],
periodic_row: &'a [F],
}
impl<'a, F: Field> DebugConstraintBuilder<'a, F> {
#[allow(clippy::too_many_arguments)]
pub fn new(
row_index: usize,
main: ViewPair<'a, F>,
preprocessed: ViewPair<'a, F>,
public_values: &'a [F],
is_first_row: F,
is_last_row: F,
is_transition: F,
periodic_row: &'a [F],
) -> Self {
Self {
row_index,
constraint_index: 0,
failures: Vec::new(),
main,
preprocessed: RowWindow::from_two_rows(
preprocessed.top.values,
preprocessed.bottom.values,
),
public_values,
is_first_row,
is_last_row,
is_transition,
permutation: None,
permutation_challenges: &[],
permutation_values: &[],
periodic_row,
}
}
}
impl<'a, F: Field, EF: ExtensionField<F>> DebugConstraintBuilder<'a, F, EF> {
#[allow(clippy::too_many_arguments)]
pub fn new_with_permutation(
row_index: usize,
main: ViewPair<'a, F>,
preprocessed: ViewPair<'a, F>,
public_values: &'a [F],
is_first_row: F,
is_last_row: F,
is_transition: F,
permutation: ViewPair<'a, EF>,
permutation_challenges: &'a [EF],
permutation_values: &'a [EF],
periodic_row: &'a [F],
) -> Self {
Self {
row_index,
constraint_index: 0,
failures: Vec::new(),
main,
preprocessed: RowWindow::from_two_rows(
preprocessed.top.values,
preprocessed.bottom.values,
),
public_values,
is_first_row,
is_last_row,
is_transition,
permutation: Some(permutation),
permutation_challenges,
permutation_values,
periodic_row,
}
}
pub const fn has_failures(&self) -> bool {
!self.failures.is_empty()
}
pub fn failures(&self) -> &[ConstraintFailure] {
&self.failures
}
pub fn formatted_failures(&self) -> String {
let entries: Vec<String> = self.failures.iter().map(ToString::to_string).collect();
format!("[{}]", entries.join(", "))
}
pub fn into_failures(self) -> Vec<ConstraintFailure> {
self.failures
}
}
impl<'a, F, EF> AirBuilder for DebugConstraintBuilder<'a, F, EF>
where
F: Field,
EF: ExtensionField<F>,
{
type F = F;
type Expr = F;
type Var = F;
type PreprocessedWindow = RowWindow<'a, F>;
type MainWindow = RowWindow<'a, F>;
type PublicVar = F;
type PeriodicVar = F;
fn main(&self) -> Self::MainWindow {
RowWindow::from_two_rows(self.main.top.values, self.main.bottom.values)
}
fn preprocessed(&self) -> &Self::PreprocessedWindow {
&self.preprocessed
}
fn is_first_row(&self) -> Self::Expr {
self.is_first_row
}
fn is_last_row(&self) -> Self::Expr {
self.is_last_row
}
fn is_transition(&self) -> Self::Expr {
self.is_transition
}
fn assert_zero<I: Into<Self::Expr>>(&mut self, x: I) {
self.assert_zero_named(x, "");
}
fn public_values(&self) -> &[Self::PublicVar] {
self.public_values
}
fn periodic_values(&self) -> &[Self::PeriodicVar] {
self.periodic_row
}
}
impl<F: Field, EF: ExtensionField<F>> AirBuilderWithContext for DebugConstraintBuilder<'_, F, EF> {
type EvalContext = ();
fn eval_context(&self) -> &Self::EvalContext {
&()
}
}
impl<F: Field, EF: ExtensionField<F>> ExtensionBuilder for DebugConstraintBuilder<'_, F, EF> {
type EF = EF;
type ExprEF = EF;
type VarEF = EF;
fn assert_zero_ext<I>(&mut self, x: I)
where
I: Into<Self::ExprEF>,
{
self.assert_zero_ext_named(x, "");
}
}
impl<'a, F: Field, EF: ExtensionField<F>> PermutationAirBuilder
for DebugConstraintBuilder<'a, F, EF>
{
type MP = RowWindow<'a, EF>;
type RandomVar = EF;
type PermutationVar = EF;
fn permutation(&self) -> Self::MP {
let p = self.permutation
.expect("permutation() called on a builder created without permutation data; use new_with_permutation()");
RowWindow::from_two_rows(p.top.values, p.bottom.values)
}
fn permutation_randomness(&self) -> &[Self::RandomVar] {
self.permutation_challenges
}
fn permutation_values(&self) -> &[Self::PermutationVar] {
self.permutation_values
}
}
impl<F: Field, EF: ExtensionField<F>> NamedAirBuilder for DebugConstraintBuilder<'_, F, EF> {
fn assert_zero_named<I, N>(&mut self, x: I, name: N)
where
I: Into<Self::Expr>,
N: Name,
{
if x.into() != F::ZERO {
let label = name.evaluate().to_string();
self.failures.push(ConstraintFailure {
row: self.row_index,
constraint: self.constraint_index,
label: if label.is_empty() { None } else { Some(label) },
});
}
self.constraint_index += 1;
}
}
impl<F: Field, EF: ExtensionField<F>> NamedExtensionBuilder for DebugConstraintBuilder<'_, F, EF> {
fn assert_zero_ext_named<I, N>(&mut self, x: I, name: N)
where
I: Into<Self::ExprEF>,
N: Name,
{
if x.into() != EF::ZERO {
let label = name.evaluate().to_string();
self.failures.push(ConstraintFailure {
row: self.row_index,
constraint: self.constraint_index,
label: if label.is_empty() { None } else { Some(label) },
});
}
self.constraint_index += 1;
}
}
#[allow(unused)] pub fn check_constraints<F, A>(air: &A, main: &RowMajorMatrix<F>, public_values: &[F])
where
F: Field,
A: for<'a> Air<DebugConstraintBuilder<'a, F>>,
{
let height = main.height();
let preprocessed = air.preprocessed_trace();
if let Some(prep) = preprocessed.as_ref() {
assert_eq!(
prep.height(),
height,
"debug constraint check requires preprocessed trace height ({}) to match main trace height ({})",
prep.height(),
height
);
}
for row_index in 0..height {
let row_index_next = (row_index + 1) % height;
let local = unsafe { main.row_slice_unchecked(row_index) };
let next = unsafe { main.row_slice_unchecked(row_index_next) };
let main_pair = ViewPair::new(
RowMajorMatrixView::new_row(&*local),
RowMajorMatrixView::new_row(&*next),
);
let (prep_local, prep_next) = preprocessed.as_ref().map_or((None, None), |prep| unsafe {
(
Some(prep.row_slice_unchecked(row_index)),
Some(prep.row_slice_unchecked(row_index_next)),
)
});
let preprocessed_pair = match (prep_local.as_ref(), prep_next.as_ref()) {
(Some(l), Some(n)) => ViewPair::new(
RowMajorMatrixView::new_row(&**l),
RowMajorMatrixView::new_row(&**n),
),
_ => ViewPair::new(
RowMajorMatrixView::new(&[], 0),
RowMajorMatrixView::new(&[], 0),
),
};
let periodic_row = air.periodic_values(row_index);
let mut builder = DebugConstraintBuilder::new(
row_index,
main_pair,
preprocessed_pair,
public_values,
F::from_bool(row_index == 0),
F::from_bool(row_index == height - 1),
F::from_bool(row_index != height - 1),
&periodic_row,
);
air.eval(&mut builder);
if builder.has_failures() {
let rendered = builder.formatted_failures();
panic!(
"constraints not satisfied on row {row_index}: \
failed constraints = {rendered}"
);
}
}
}
#[allow(unused)] pub fn check_all_constraints<F, A>(
air: &A,
main: &RowMajorMatrix<F>,
public_values: &[F],
max_failures: Option<usize>,
) -> ConstraintReport
where
F: Field,
A: for<'a> Air<DebugConstraintBuilder<'a, F>>,
{
let height = main.height();
let preprocessed = air.preprocessed_trace();
if let Some(prep) = preprocessed.as_ref() {
assert_eq!(
prep.height(),
height,
"debug constraint check requires preprocessed trace height ({}) to match main trace height ({})",
prep.height(),
height
);
}
let mut all_failures = Vec::new();
let mut total_constraints_per_row = 0;
for row_index in 0..height {
if let Some(cap) = max_failures
&& all_failures.len() >= cap
{
break;
}
let row_index_next = (row_index + 1) % height;
let local = unsafe { main.row_slice_unchecked(row_index) };
let next = unsafe { main.row_slice_unchecked(row_index_next) };
let main_pair = ViewPair::new(
RowMajorMatrixView::new_row(&*local),
RowMajorMatrixView::new_row(&*next),
);
let (prep_local, prep_next) = preprocessed.as_ref().map_or((None, None), |prep| unsafe {
(
Some(prep.row_slice_unchecked(row_index)),
Some(prep.row_slice_unchecked(row_index_next)),
)
});
let preprocessed_pair = match (prep_local.as_ref(), prep_next.as_ref()) {
(Some(l), Some(n)) => ViewPair::new(
RowMajorMatrixView::new_row(&**l),
RowMajorMatrixView::new_row(&**n),
),
_ => ViewPair::new(
RowMajorMatrixView::new(&[], 0),
RowMajorMatrixView::new(&[], 0),
),
};
let periodic_row = air.periodic_values(row_index);
let mut builder = DebugConstraintBuilder::new(
row_index,
main_pair,
preprocessed_pair,
public_values,
F::from_bool(row_index == 0),
F::from_bool(row_index == height - 1),
F::from_bool(row_index != height - 1),
&periodic_row,
);
air.eval(&mut builder);
if row_index == 0 {
total_constraints_per_row = builder.constraint_index;
}
all_failures.extend(builder.into_failures());
}
ConstraintReport {
failures: all_failures,
total_rows: height,
total_constraints_per_row,
}
}
#[cfg(test)]
mod tests {
use alloc::{format, vec};
use p3_baby_bear::BabyBear;
use p3_field::PrimeCharacteristicRing;
use super::*;
use crate::{BaseAir, WindowAccess};
#[derive(Debug)]
struct RowLogicAir<const W: usize>;
impl<F: Field, const W: usize> BaseAir<F> for RowLogicAir<W> {
fn width(&self) -> usize {
W
}
}
impl<F: Field, const W: usize> Air<DebugConstraintBuilder<'_, F>> for RowLogicAir<W> {
fn eval(&self, builder: &mut DebugConstraintBuilder<'_, F>) {
let main = builder.main();
for col in 0..W {
let current = main.current(col).unwrap();
let next = main.next(col).unwrap();
builder.when_transition().assert_eq(next, current + F::ONE);
}
let public_values = builder.public_values;
let mut when_last = builder.when(builder.is_last_row);
for (i, &pv) in public_values.iter().enumerate().take(W) {
when_last.assert_eq(main.current(i).unwrap(), pv);
}
}
}
#[test]
fn test_incremental_rows_with_last_row_check() {
let air = RowLogicAir::<2>;
let values = vec![
BabyBear::ONE,
BabyBear::ONE, BabyBear::new(2),
BabyBear::new(2), BabyBear::new(3),
BabyBear::new(3), BabyBear::new(4),
BabyBear::new(4), ];
let main = RowMajorMatrix::new(values, 2);
check_constraints(&air, &main, &[BabyBear::new(4); 2]);
}
#[test]
#[should_panic]
fn test_incorrect_increment_logic() {
let air = RowLogicAir::<2>;
let values = vec![
BabyBear::ONE,
BabyBear::ONE, BabyBear::new(2),
BabyBear::new(2), BabyBear::new(5),
BabyBear::new(5), BabyBear::new(6),
BabyBear::new(6), ];
let main = RowMajorMatrix::new(values, 2);
check_constraints(&air, &main, &[BabyBear::new(6); 2]);
}
#[test]
#[should_panic]
fn test_wrong_last_row_public_value() {
let air = RowLogicAir::<2>;
let values = vec![
BabyBear::ONE,
BabyBear::ONE, BabyBear::new(2),
BabyBear::new(2), BabyBear::new(3),
BabyBear::new(3), BabyBear::new(4),
BabyBear::new(4), ];
let main = RowMajorMatrix::new(values, 2);
check_constraints(&air, &main, &[BabyBear::new(4), BabyBear::new(5)]);
}
#[test]
fn test_single_row_wraparound_logic() {
let air = RowLogicAir::<2>;
let values = vec![
BabyBear::new(99),
BabyBear::new(77), ];
let main = RowMajorMatrix::new(values, 2);
check_constraints(&air, &main, &[BabyBear::new(99), BabyBear::new(77)]);
}
fn eval_single_row<const W: usize, A>(
air: &A,
row: [BabyBear; W],
) -> DebugConstraintBuilder<'static, BabyBear>
where
A: for<'a> Air<DebugConstraintBuilder<'a, BabyBear>>,
{
let row: &'static [BabyBear] = Vec::from(row).leak();
let view = RowMajorMatrixView::new_row(row);
let main = ViewPair::new(view, view);
let empty_view = RowMajorMatrixView::new(&[], 0);
let mut builder = DebugConstraintBuilder::new(
0,
main,
ViewPair::new(empty_view, empty_view),
&[],
BabyBear::ONE, BabyBear::ONE, BabyBear::ZERO, &[],
);
air.eval(&mut builder);
builder
}
#[derive(Debug)]
struct AllZeroAir<const W: usize>;
impl<F: Field, const W: usize> BaseAir<F> for AllZeroAir<W> {
fn width(&self) -> usize {
W
}
}
impl<F: Field, const W: usize> Air<DebugConstraintBuilder<'_, F>> for AllZeroAir<W> {
fn eval(&self, builder: &mut DebugConstraintBuilder<'_, F>) {
let main = builder.main();
for col in 0..W {
builder.assert_zero(main.current(col).unwrap());
}
}
}
#[test]
fn test_no_failures_when_all_constraints_pass() {
let builder = eval_single_row(&AllZeroAir::<3>, [BabyBear::ZERO; 3]);
assert!(!builder.has_failures());
assert!(builder.failures().is_empty());
}
#[test]
fn test_multiple_failures_collected() {
let builder = eval_single_row(
&AllZeroAir::<3>,
[BabyBear::ONE, BabyBear::ZERO, BabyBear::new(42)],
);
assert!(builder.has_failures());
let failures = builder.failures();
assert_eq!(failures.len(), 2);
assert_eq!(failures[0].row, 0);
assert_eq!(failures[0].constraint, 0);
assert_eq!(failures[1].row, 0);
assert_eq!(failures[1].constraint, 2);
}
#[test]
fn test_into_failures() {
let builder = eval_single_row(&AllZeroAir::<2>, [BabyBear::ONE, BabyBear::ONE]);
let failures = builder.into_failures();
assert_eq!(failures.len(), 2);
assert_eq!(failures[0].constraint, 0);
assert_eq!(failures[1].constraint, 1);
}
#[test]
#[should_panic(expected = "failed constraints = [#0, #2]")]
fn test_panic_message_lists_all_failed_indices() {
let air = AllZeroAir::<3>;
let values = vec![BabyBear::ONE, BabyBear::ZERO, BabyBear::new(7)];
let main = RowMajorMatrix::new(values, 3);
check_constraints(&air, &main, &[]);
}
#[test]
#[should_panic(expected = "failed constraints = [#0, #1 \"col_1_must_be_zero\"]")]
fn test_panic_message_includes_label_when_available() {
let air = NamedConstraintAir;
let values = vec![BabyBear::ONE, BabyBear::ONE];
let main = RowMajorMatrix::new(values, 2);
check_constraints(&air, &main, &[]);
}
#[test]
fn test_check_all_constraints_no_failures() {
let air = AllZeroAir::<2>;
let values = BabyBear::zero_vec(4); let main = RowMajorMatrix::new(values, 2);
let report = check_all_constraints(&air, &main, &[], None);
assert!(report.is_ok());
assert_eq!(report.total_rows, 2);
assert_eq!(report.total_constraints_per_row, 2);
}
#[test]
fn test_check_all_constraints_multiple_rows_fail() {
let air = AllZeroAir::<2>;
let values = vec![BabyBear::ONE, BabyBear::ZERO, BabyBear::ZERO, BabyBear::ONE];
let main = RowMajorMatrix::new(values, 2);
let report = check_all_constraints(&air, &main, &[], None);
assert!(!report.is_ok());
assert_eq!(report.failures.len(), 2);
assert_eq!(report.failures[0].row, 0);
assert_eq!(report.failures[0].constraint, 0);
assert_eq!(report.failures[1].row, 1);
assert_eq!(report.failures[1].constraint, 1);
}
#[test]
fn test_check_all_constraints_max_failures_cap() {
let air = AllZeroAir::<2>;
let values = vec![BabyBear::ONE; 8];
let main = RowMajorMatrix::new(values, 2);
let report = check_all_constraints(&air, &main, &[], Some(3));
assert!(report.failures.len() >= 3);
assert!(report.failures.len() <= 4);
}
#[derive(Debug)]
struct NamedConstraintAir;
impl<F: Field> BaseAir<F> for NamedConstraintAir {
fn width(&self) -> usize {
2
}
}
impl<F: Field> Air<DebugConstraintBuilder<'_, F>> for NamedConstraintAir {
fn eval(&self, builder: &mut DebugConstraintBuilder<'_, F>) {
let main = builder.main();
builder.assert_zero(main.current(0).unwrap());
builder.assert_zero_named(main.current(1).unwrap(), "col_1_must_be_zero");
}
}
#[test]
fn test_named_constraint_label_captured() {
let builder = eval_single_row(
&NamedConstraintAir,
[BabyBear::ONE, BabyBear::ONE], );
let failures = builder.failures();
assert_eq!(failures.len(), 2);
assert_eq!(failures[0].constraint, 0);
assert!(failures[0].label.is_none());
assert_eq!(failures[1].constraint, 1);
assert_eq!(failures[1].label.as_deref(), Some("col_1_must_be_zero"));
}
#[test]
fn test_named_constraint_no_label_when_passing() {
let builder = eval_single_row(
&NamedConstraintAir,
[BabyBear::ZERO, BabyBear::ZERO], );
assert!(!builder.has_failures());
}
#[test]
fn test_named_constraint_in_full_report() {
let air = NamedConstraintAir;
let values = vec![BabyBear::ONE; 4];
let main = RowMajorMatrix::new(values, 2);
let report = check_all_constraints(&air, &main, &[], None);
assert_eq!(report.failures.len(), 4);
let labeled: Vec<_> = report
.failures
.iter()
.filter(|f| f.label.is_some())
.collect();
assert_eq!(labeled.len(), 2); assert!(
labeled
.iter()
.all(|f| f.label.as_deref() == Some("col_1_must_be_zero"))
);
}
#[derive(Debug)]
struct NamespacedAir;
impl<F: Field> BaseAir<F> for NamespacedAir {
fn width(&self) -> usize {
3
}
}
impl<F: Field> Air<DebugConstraintBuilder<'_, F>> for NamespacedAir {
fn eval(&self, builder: &mut DebugConstraintBuilder<'_, F>) {
use crate::NamespaceExt;
let main = builder.main();
let ns = "range_check";
builder.assert_zero_named(main.current(0).unwrap(), ns.join("limb_0"));
let i = 1;
builder.assert_zero_named(main.current(1).unwrap(), ns.name(|| format!("limb_{i}")));
builder.assert_zero_named(main.current(2).unwrap(), || format!("col_{}", 2));
}
}
#[test]
fn test_namespace_join_labels() {
let builder = eval_single_row(
&NamespacedAir,
[BabyBear::ONE, BabyBear::ONE, BabyBear::ONE],
);
let failures = builder.into_failures();
assert_eq!(failures.len(), 3);
assert_eq!(failures[0].label.as_deref(), Some("range_check::limb_0"));
assert_eq!(failures[1].label.as_deref(), Some("range_check::limb_1"));
assert_eq!(failures[2].label.as_deref(), Some("col_2"));
}
#[derive(Debug)]
struct ShapeProbeAir {
prep_height: usize,
prep_width: usize,
}
impl<F: Field> BaseAir<F> for ShapeProbeAir {
fn width(&self) -> usize {
1
}
fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
if self.prep_height == 0 {
return None;
}
let total = self.prep_height * self.prep_width;
Some(RowMajorMatrix::new(F::zero_vec(total), self.prep_width))
}
}
impl<F: Field> Air<DebugConstraintBuilder<'_, F>> for ShapeProbeAir {
fn eval(&self, _builder: &mut DebugConstraintBuilder<'_, F>) {
}
}
#[test]
fn test_preprocessed_height_matches_main_passes() {
let air = ShapeProbeAir {
prep_height: 4,
prep_width: 1,
};
let main = RowMajorMatrix::new(BabyBear::zero_vec(4), 1);
check_constraints(&air, &main, &[]);
}
#[test]
#[should_panic(expected = "preprocessed trace height")]
fn test_preprocessed_height_mismatch_panics_in_check_constraints() {
let air = ShapeProbeAir {
prep_height: 8,
prep_width: 1,
};
let main = RowMajorMatrix::new(BabyBear::zero_vec(4), 1);
check_constraints(&air, &main, &[]);
}
#[test]
fn test_preprocessed_height_matches_main_passes_in_check_all_constraints() {
let air = ShapeProbeAir {
prep_height: 4,
prep_width: 1,
};
let main = RowMajorMatrix::new(BabyBear::zero_vec(4), 1);
let report = check_all_constraints(&air, &main, &[], None);
assert!(report.is_ok());
assert_eq!(report.total_rows, 4);
}
#[test]
#[should_panic(expected = "preprocessed trace height")]
fn test_preprocessed_height_mismatch_panics_in_check_all_constraints() {
let air = ShapeProbeAir {
prep_height: 8,
prep_width: 1,
};
let main = RowMajorMatrix::new(BabyBear::zero_vec(4), 1);
let _ = check_all_constraints(&air, &main, &[], None);
}
}