p3_uni_stark/
check_constraints.rs

1use alloc::vec::Vec;
2
3use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues};
4use p3_field::Field;
5use p3_matrix::Matrix;
6use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
7use p3_matrix::stack::VerticalPair;
8use tracing::instrument;
9
10/// Runs constraint checks using a given AIR definition and trace matrix.
11///
12/// Iterates over every row in `main`, providing both the current and next row
13/// (with wraparound) to the AIR logic. Also injects public values into the builder
14/// for first/last row assertions.
15///
16/// # Arguments
17/// - `air`: The AIR logic to run
18/// - `main`: The trace matrix (rows of witness values)
19/// - `public_values`: Public values provided to the builder
20#[instrument(name = "check constraints", skip_all)]
21pub(crate) fn check_constraints<F, A>(air: &A, main: &RowMajorMatrix<F>, public_values: &Vec<F>)
22where
23    F: Field,
24    A: for<'a> Air<DebugConstraintBuilder<'a, F>>,
25{
26    let height = main.height();
27
28    (0..height).for_each(|i| {
29        let i_next = (i + 1) % height;
30
31        let local = main.row_slice(i).unwrap(); // i < height so unwrap should never fail.
32        let next = main.row_slice(i_next).unwrap(); // i_next < height so unwrap should never fail.
33        let main = VerticalPair::new(
34            RowMajorMatrixView::new_row(&*local),
35            RowMajorMatrixView::new_row(&*next),
36        );
37
38        let mut builder = DebugConstraintBuilder {
39            row_index: i,
40            main,
41            public_values,
42            is_first_row: F::from_bool(i == 0),
43            is_last_row: F::from_bool(i == height - 1),
44            is_transition: F::from_bool(i != height - 1),
45        };
46
47        air.eval(&mut builder);
48    });
49}
50
51/// A builder that runs constraint assertions during testing.
52///
53/// Used in conjunction with [`check_constraints`] to simulate
54/// an execution trace and verify that the AIR logic enforces all constraints.
55#[derive(Debug)]
56pub struct DebugConstraintBuilder<'a, F: Field> {
57    /// The index of the row currently being evaluated.
58    row_index: usize,
59    /// A view of the current and next row as a vertical pair.
60    main: VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>,
61    /// The public values provided for constraint validation (e.g. inputs or outputs).
62    public_values: &'a [F],
63    /// A flag indicating whether this is the first row.
64    is_first_row: F,
65    /// A flag indicating whether this is the last row.
66    is_last_row: F,
67    /// A flag indicating whether this is a transition row (not the last row).
68    is_transition: F,
69}
70
71impl<'a, F> AirBuilder for DebugConstraintBuilder<'a, F>
72where
73    F: Field,
74{
75    type F = F;
76    type Expr = F;
77    type Var = F;
78    type M = VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>;
79
80    fn main(&self) -> Self::M {
81        self.main
82    }
83
84    fn is_first_row(&self) -> Self::Expr {
85        self.is_first_row
86    }
87
88    fn is_last_row(&self) -> Self::Expr {
89        self.is_last_row
90    }
91
92    /// # Panics
93    /// This function panics if `size` is not `2`.
94    fn is_transition_window(&self, size: usize) -> Self::Expr {
95        if size == 2 {
96            self.is_transition
97        } else {
98            panic!("only supports a window size of 2")
99        }
100    }
101
102    fn assert_zero<I: Into<Self::Expr>>(&mut self, x: I) {
103        assert_eq!(
104            x.into(),
105            F::ZERO,
106            "constraints had nonzero value on row {}",
107            self.row_index
108        );
109    }
110
111    fn assert_eq<I1: Into<Self::Expr>, I2: Into<Self::Expr>>(&mut self, x: I1, y: I2) {
112        let x = x.into();
113        let y = y.into();
114        assert_eq!(
115            x, y,
116            "values didn't match on row {}: {} != {}",
117            self.row_index, x, y
118        );
119    }
120}
121
122impl<F: Field> AirBuilderWithPublicValues for DebugConstraintBuilder<'_, F> {
123    type PublicVar = Self::F;
124
125    fn public_values(&self) -> &[Self::F] {
126        self.public_values
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use alloc::vec;
133
134    use p3_air::{BaseAir, BaseAirWithPublicValues};
135    use p3_baby_bear::BabyBear;
136    use p3_field::PrimeCharacteristicRing;
137
138    use super::*;
139
140    /// A test AIR that enforces a simple linear transition logic:
141    /// - Each cell in the next row must equal the current cell plus 1 (i.e., `next = current + 1`)
142    /// - On the last row, the current row must match the provided public values.
143    ///
144    /// This is useful for validating constraint evaluation, transition logic,
145    /// and row condition flags (first/last/transition).
146    #[derive(Debug)]
147    struct RowLogicAir<const W: usize>;
148
149    impl<F: Field, const W: usize> BaseAir<F> for RowLogicAir<W> {
150        fn width(&self) -> usize {
151            W
152        }
153    }
154
155    impl<F: Field, const W: usize> BaseAirWithPublicValues<F> for RowLogicAir<W> {}
156
157    impl<F: Field, const W: usize> Air<DebugConstraintBuilder<'_, F>> for RowLogicAir<W> {
158        fn eval(&self, builder: &mut DebugConstraintBuilder<'_, F>) {
159            let main = builder.main();
160
161            for col in 0..W {
162                let a = main.top.get(0, col).unwrap();
163                let b = main.bottom.get(0, col).unwrap();
164
165                // New logic: enforce row[i+1] = row[i] + 1, only on transitions
166                builder.when_transition().assert_eq(b, a + F::ONE);
167            }
168
169            // Add public value equality on last row for extra coverage
170            let public_values = builder.public_values;
171            let mut when_last = builder.when(builder.is_last_row);
172            for (i, &pv) in public_values.iter().enumerate().take(W) {
173                when_last.assert_eq(main.top.get(0, i).unwrap(), pv);
174            }
175        }
176    }
177
178    #[test]
179    fn test_incremental_rows_with_last_row_check() {
180        // Each row = previous + 1, with 4 rows total, 2 columns.
181        // Last row must match public values [4, 4]
182        let air = RowLogicAir::<2>;
183        let values = vec![
184            BabyBear::ONE,
185            BabyBear::ONE, // Row 0
186            BabyBear::new(2),
187            BabyBear::new(2), // Row 1
188            BabyBear::new(3),
189            BabyBear::new(3), // Row 2
190            BabyBear::new(4),
191            BabyBear::new(4), // Row 3 (last)
192        ];
193        let main = RowMajorMatrix::new(values, 2);
194        check_constraints(&air, &main, &vec![BabyBear::new(4); 2]);
195    }
196
197    #[test]
198    #[should_panic]
199    fn test_incorrect_increment_logic() {
200        // Row 2 does not equal row 1 + 1 → should fail on transition from row 1 to 2.
201        let air = RowLogicAir::<2>;
202        let values = vec![
203            BabyBear::ONE,
204            BabyBear::ONE, // Row 0
205            BabyBear::new(2),
206            BabyBear::new(2), // Row 1
207            BabyBear::new(5),
208            BabyBear::new(5), // Row 2 (wrong)
209            BabyBear::new(6),
210            BabyBear::new(6), // Row 3
211        ];
212        let main = RowMajorMatrix::new(values, 2);
213        check_constraints(&air, &main, &vec![BabyBear::new(6); 2]);
214    }
215
216    #[test]
217    #[should_panic]
218    fn test_wrong_last_row_public_value() {
219        // The transition logic is fine, but public value check fails at the last row.
220        let air = RowLogicAir::<2>;
221        let values = vec![
222            BabyBear::ONE,
223            BabyBear::ONE, // Row 0
224            BabyBear::new(2),
225            BabyBear::new(2), // Row 1
226            BabyBear::new(3),
227            BabyBear::new(3), // Row 2
228            BabyBear::new(4),
229            BabyBear::new(4), // Row 3
230        ];
231        let main = RowMajorMatrix::new(values, 2);
232        // Wrong public value on column 1
233        check_constraints(&air, &main, &vec![BabyBear::new(4), BabyBear::new(5)]);
234    }
235
236    #[test]
237    fn test_single_row_wraparound_logic() {
238        // A single-row matrix still performs a wraparound check with itself.
239        // row[0] == row[0] + 1 ⇒ fails unless handled properly by transition logic.
240        // Here: is_transition == false ⇒ so no assertions are enforced.
241        let air = RowLogicAir::<2>;
242        let values = vec![
243            BabyBear::new(99),
244            BabyBear::new(77), // Row 0
245        ];
246        let main = RowMajorMatrix::new(values, 2);
247        check_constraints(&air, &main, &vec![BabyBear::new(99), BabyBear::new(77)]);
248    }
249}