1use alloc::vec::Vec;
2
3use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues};
4use p3_field::Field;
5use p3_matrix::Matrix;
6use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
7use p3_matrix::stack::VerticalPair;
8use tracing::instrument;
9
10#[instrument(name = "check constraints", skip_all)]
21pub(crate) fn check_constraints<F, A>(air: &A, main: &RowMajorMatrix<F>, public_values: &Vec<F>)
22where
23 F: Field,
24 A: for<'a> Air<DebugConstraintBuilder<'a, F>>,
25{
26 let height = main.height();
27
28 (0..height).for_each(|i| {
29 let i_next = (i + 1) % height;
30
31 let local = main.row_slice(i).unwrap(); let next = main.row_slice(i_next).unwrap(); let main = VerticalPair::new(
34 RowMajorMatrixView::new_row(&*local),
35 RowMajorMatrixView::new_row(&*next),
36 );
37
38 let mut builder = DebugConstraintBuilder {
39 row_index: i,
40 main,
41 public_values,
42 is_first_row: F::from_bool(i == 0),
43 is_last_row: F::from_bool(i == height - 1),
44 is_transition: F::from_bool(i != height - 1),
45 };
46
47 air.eval(&mut builder);
48 });
49}
50
51#[derive(Debug)]
56pub struct DebugConstraintBuilder<'a, F: Field> {
57 row_index: usize,
59 main: VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>,
61 public_values: &'a [F],
63 is_first_row: F,
65 is_last_row: F,
67 is_transition: F,
69}
70
71impl<'a, F> AirBuilder for DebugConstraintBuilder<'a, F>
72where
73 F: Field,
74{
75 type F = F;
76 type Expr = F;
77 type Var = F;
78 type M = VerticalPair<RowMajorMatrixView<'a, F>, RowMajorMatrixView<'a, F>>;
79
80 fn main(&self) -> Self::M {
81 self.main
82 }
83
84 fn is_first_row(&self) -> Self::Expr {
85 self.is_first_row
86 }
87
88 fn is_last_row(&self) -> Self::Expr {
89 self.is_last_row
90 }
91
92 fn is_transition_window(&self, size: usize) -> Self::Expr {
95 if size == 2 {
96 self.is_transition
97 } else {
98 panic!("only supports a window size of 2")
99 }
100 }
101
102 fn assert_zero<I: Into<Self::Expr>>(&mut self, x: I) {
103 assert_eq!(
104 x.into(),
105 F::ZERO,
106 "constraints had nonzero value on row {}",
107 self.row_index
108 );
109 }
110
111 fn assert_eq<I1: Into<Self::Expr>, I2: Into<Self::Expr>>(&mut self, x: I1, y: I2) {
112 let x = x.into();
113 let y = y.into();
114 assert_eq!(
115 x, y,
116 "values didn't match on row {}: {} != {}",
117 self.row_index, x, y
118 );
119 }
120}
121
122impl<F: Field> AirBuilderWithPublicValues for DebugConstraintBuilder<'_, F> {
123 type PublicVar = Self::F;
124
125 fn public_values(&self) -> &[Self::F] {
126 self.public_values
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use alloc::vec;
133
134 use p3_air::{BaseAir, BaseAirWithPublicValues};
135 use p3_baby_bear::BabyBear;
136 use p3_field::PrimeCharacteristicRing;
137
138 use super::*;
139
140 #[derive(Debug)]
147 struct RowLogicAir<const W: usize>;
148
149 impl<F: Field, const W: usize> BaseAir<F> for RowLogicAir<W> {
150 fn width(&self) -> usize {
151 W
152 }
153 }
154
155 impl<F: Field, const W: usize> BaseAirWithPublicValues<F> for RowLogicAir<W> {}
156
157 impl<F: Field, const W: usize> Air<DebugConstraintBuilder<'_, F>> for RowLogicAir<W> {
158 fn eval(&self, builder: &mut DebugConstraintBuilder<'_, F>) {
159 let main = builder.main();
160
161 for col in 0..W {
162 let a = main.top.get(0, col).unwrap();
163 let b = main.bottom.get(0, col).unwrap();
164
165 builder.when_transition().assert_eq(b, a + F::ONE);
167 }
168
169 let public_values = builder.public_values;
171 let mut when_last = builder.when(builder.is_last_row);
172 for (i, &pv) in public_values.iter().enumerate().take(W) {
173 when_last.assert_eq(main.top.get(0, i).unwrap(), pv);
174 }
175 }
176 }
177
178 #[test]
179 fn test_incremental_rows_with_last_row_check() {
180 let air = RowLogicAir::<2>;
183 let values = vec![
184 BabyBear::ONE,
185 BabyBear::ONE, BabyBear::new(2),
187 BabyBear::new(2), BabyBear::new(3),
189 BabyBear::new(3), BabyBear::new(4),
191 BabyBear::new(4), ];
193 let main = RowMajorMatrix::new(values, 2);
194 check_constraints(&air, &main, &vec![BabyBear::new(4); 2]);
195 }
196
197 #[test]
198 #[should_panic]
199 fn test_incorrect_increment_logic() {
200 let air = RowLogicAir::<2>;
202 let values = vec![
203 BabyBear::ONE,
204 BabyBear::ONE, BabyBear::new(2),
206 BabyBear::new(2), BabyBear::new(5),
208 BabyBear::new(5), BabyBear::new(6),
210 BabyBear::new(6), ];
212 let main = RowMajorMatrix::new(values, 2);
213 check_constraints(&air, &main, &vec![BabyBear::new(6); 2]);
214 }
215
216 #[test]
217 #[should_panic]
218 fn test_wrong_last_row_public_value() {
219 let air = RowLogicAir::<2>;
221 let values = vec![
222 BabyBear::ONE,
223 BabyBear::ONE, BabyBear::new(2),
225 BabyBear::new(2), BabyBear::new(3),
227 BabyBear::new(3), BabyBear::new(4),
229 BabyBear::new(4), ];
231 let main = RowMajorMatrix::new(values, 2);
232 check_constraints(&air, &main, &vec![BabyBear::new(4), BabyBear::new(5)]);
234 }
235
236 #[test]
237 fn test_single_row_wraparound_logic() {
238 let air = RowLogicAir::<2>;
242 let values = vec![
243 BabyBear::new(99),
244 BabyBear::new(77), ];
246 let main = RowMajorMatrix::new(values, 2);
247 check_constraints(&air, &main, &vec![BabyBear::new(99), BabyBear::new(77)]);
248 }
249}