1extern crate alloc;
7use alloc::vec;
8use alloc::vec::Vec;
9
10use lib_q_stark_air::{
11 Air,
12 AirBuilder,
13 BaseAir,
14 WindowAccess,
15};
16use lib_q_stark_field::Field;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct Wire {
21 pub index: usize,
23}
24
25impl Wire {
26 pub fn new(index: usize) -> Self {
28 Self { index }
29 }
30}
31
32#[derive(Debug, Clone)]
34pub enum Constraint<F: Field> {
35 AssertZero(Wire),
37 AssertEqual(Wire, Wire),
39 AssertConstant(Wire, F),
41 AssertAdd(Wire, Wire, Wire),
43 AssertMul(Wire, Wire, Wire),
45}
46
47#[derive(Debug, Clone)]
49pub struct ArithmeticCircuit<F: Field> {
50 pub constraints: Vec<Constraint<F>>,
52 pub witness_size: usize,
54 pub public_input_size: usize,
56}
57
58impl<F: Field> ArithmeticCircuit<F> {
59 pub fn new(witness_size: usize, public_input_size: usize) -> Self {
61 Self {
62 constraints: Vec::new(),
63 witness_size,
64 public_input_size,
65 }
66 }
67
68 pub fn total_wires(&self) -> usize {
70 self.witness_size + self.public_input_size
71 }
72
73 pub fn add_constraint(&mut self, constraint: Constraint<F>) {
75 self.constraints.push(constraint);
76 }
77}
78
79pub struct CircuitBuilder<F: Field> {
81 circuit: ArithmeticCircuit<F>,
82 next_wire: usize,
83}
84
85impl<F: Field> CircuitBuilder<F> {
86 pub fn new(witness_size: usize, public_input_size: usize) -> Self {
110 Self {
111 circuit: ArithmeticCircuit::new(witness_size, public_input_size),
112 next_wire: witness_size + public_input_size,
113 }
114 }
115
116 pub fn alloc_wire(&mut self) -> Wire {
118 let wire = Wire::new(self.next_wire);
119 self.next_wire += 1;
120 wire
121 }
122
123 pub fn wire(&self, index: usize) -> Wire {
125 Wire::new(index)
126 }
127
128 pub fn assert_zero(&mut self, wire: Wire) {
130 self.circuit.add_constraint(Constraint::AssertZero(wire));
131 }
132
133 pub fn assert_eq(&mut self, left: Wire, right: Wire) {
135 self.circuit
136 .add_constraint(Constraint::AssertEqual(left, right));
137 }
138
139 pub fn assert_constant(&mut self, wire: Wire, constant: F) {
141 self.circuit
142 .add_constraint(Constraint::AssertConstant(wire, constant));
143 }
144
145 pub fn add(&mut self, left: Wire, right: Wire) -> Wire {
147 let result = self.alloc_wire();
148 self.circuit
149 .add_constraint(Constraint::AssertAdd(result, left, right));
150 result
151 }
152
153 pub fn mul(&mut self, left: Wire, right: Wire) -> Wire {
155 let result = self.alloc_wire();
156 self.circuit
157 .add_constraint(Constraint::AssertMul(result, left, right));
158 result
159 }
160
161 pub fn build(self) -> ArithmeticCircuit<F> {
163 self.circuit
164 }
165}
166
167pub struct CircuitAir<F: Field> {
172 circuit: ArithmeticCircuit<F>,
173}
174
175impl<F: Field> CircuitAir<F> {
176 pub fn new(circuit: ArithmeticCircuit<F>) -> Self {
178 Self { circuit }
179 }
180
181 pub fn circuit(&self) -> &ArithmeticCircuit<F> {
183 &self.circuit
184 }
185
186 pub fn generate_trace(
202 &self,
203 witness: &[F],
204 public: &[F],
205 ) -> Result<lib_q_stark_matrix::dense::RowMajorMatrix<F>, lib_q_core::Error> {
206 use lib_q_stark_matrix::dense::RowMajorMatrix;
207
208 if witness.len() != self.circuit.witness_size {
210 return Err(lib_q_core::Error::InvalidState {
211 operation: "CircuitAir::generate_trace".into(),
212 reason: alloc::format!(
213 "Witness size mismatch: expected {}, got {}",
214 self.circuit.witness_size,
215 witness.len()
216 ),
217 });
218 }
219
220 if public.len() != self.circuit.public_input_size {
221 return Err(lib_q_core::Error::InvalidState {
222 operation: "CircuitAir::generate_trace".into(),
223 reason: alloc::format!(
224 "Public input size mismatch: expected {}, got {}",
225 self.circuit.public_input_size,
226 public.len()
227 ),
228 });
229 }
230
231 let width = self.width();
232
233 let mut trace_values = F::zero_vec(width);
235
236 for (i, val) in witness.iter().enumerate() {
238 if i < width {
239 trace_values[i] = *val;
240 }
241 }
242
243 for (i, val) in public.iter().enumerate() {
245 let idx = self.circuit.witness_size + i;
246 if idx < width {
247 trace_values[idx] = *val;
248 }
249 }
250
251 for constraint in &self.circuit.constraints {
253 match constraint {
254 Constraint::AssertAdd(out, l, r)
255 if out.index < width && l.index < width && r.index < width =>
256 {
257 trace_values[out.index] = trace_values[l.index] + trace_values[r.index];
258 }
259 Constraint::AssertMul(out, l, r)
260 if out.index < width && l.index < width && r.index < width =>
261 {
262 trace_values[out.index] = trace_values[l.index] * trace_values[r.index];
263 }
264 _ => {}
266 }
267 }
268
269 const MIN_TRACE_ROWS: usize = 64;
271 if MIN_TRACE_ROWS > 1 {
272 let mut padded = trace_values.clone();
273 for _ in 1..MIN_TRACE_ROWS {
274 padded.extend_from_slice(&trace_values);
275 }
276 Ok(RowMajorMatrix::new(padded, width))
277 } else {
278 Ok(RowMajorMatrix::new(trace_values, width))
279 }
280 }
281}
282
283impl<F: Field> BaseAir<F> for CircuitAir<F> {
284 fn width(&self) -> usize {
285 let max_wire = self
288 .circuit
289 .constraints
290 .iter()
291 .flat_map(|c| match c {
292 Constraint::AssertZero(w) => vec![w.index],
293 Constraint::AssertEqual(l, r) => vec![l.index, r.index],
294 Constraint::AssertConstant(w, _) => vec![w.index],
295 Constraint::AssertAdd(r, l, r2) => vec![r.index, l.index, r2.index],
296 Constraint::AssertMul(r, l, r2) => vec![r.index, l.index, r2.index],
297 })
298 .max()
299 .unwrap_or(0);
300 (max_wire + 1).max(self.circuit.total_wires())
301 }
302}
303
304impl<F: Field, AB: AirBuilder<F = F>> Air<AB> for CircuitAir<F> {
305 fn eval(&self, builder: &mut AB) {
306 let main = builder.main();
307 let row = main.current_slice();
308
309 for constraint in &self.circuit.constraints {
311 match constraint {
312 Constraint::AssertZero(w) => {
313 if w.index < row.len() {
315 builder.assert_zero(row[w.index]);
316 }
317 }
318 Constraint::AssertEqual(l, r) => {
319 if l.index < row.len() && r.index < row.len() {
321 builder.assert_eq(row[l.index], row[r.index]);
322 }
323 }
324 Constraint::AssertConstant(w, c) => {
325 if w.index < row.len() {
327 builder.assert_eq(row[w.index], *c);
328 }
329 }
330 Constraint::AssertAdd(out, l, r) => {
331 if out.index < row.len() && l.index < row.len() && r.index < row.len() {
333 let sum = row[l.index] + row[r.index];
334 builder.assert_eq(row[out.index], sum);
335 }
336 }
337 Constraint::AssertMul(out, l, r) => {
338 if out.index < row.len() && l.index < row.len() && r.index < row.len() {
340 let product = row[l.index] * row[r.index];
341 builder.assert_eq(row[out.index], product);
342 }
343 }
344 }
345 }
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use lib_q_stark_air::BaseAir;
352 use lib_q_stark_field::PrimeCharacteristicRing;
353 use lib_q_stark_field::extension::Complex;
354 use lib_q_stark_mersenne31::Mersenne31;
355
356 use super::*;
357
358 type TestField = Complex<Mersenne31>;
359
360 #[test]
361 fn test_circuit_builder_new() {
362 let builder = CircuitBuilder::<TestField>::new(5, 2);
363 let circuit = builder.build();
364 assert_eq!(circuit.witness_size, 5);
365 assert_eq!(circuit.public_input_size, 2);
366 assert_eq!(circuit.total_wires(), 7);
367 }
368
369 #[test]
370 fn test_circuit_builder_alloc_wire() {
371 let mut builder = CircuitBuilder::<TestField>::new(3, 2);
372 let wire1 = builder.alloc_wire();
373 let wire2 = builder.alloc_wire();
374 assert_eq!(wire1.index, 5); assert_eq!(wire2.index, 6);
376 }
377
378 #[test]
379 fn test_circuit_builder_constraints() {
380 let mut builder = CircuitBuilder::<TestField>::new(2, 1);
381 let w0 = builder.wire(0);
382 let w1 = builder.wire(1);
383 let w2 = builder.wire(2);
384
385 builder.assert_zero(w0);
386 builder.assert_eq(w1, w2);
387 builder.assert_constant(w0, <TestField as PrimeCharacteristicRing>::ONE);
388
389 let circuit = builder.build();
390 assert_eq!(circuit.constraints.len(), 3);
391 }
392
393 #[test]
394 fn test_circuit_builder_add_mul() {
395 let mut builder = CircuitBuilder::<TestField>::new(2, 1);
396 let a = builder.wire(0);
397 let b = builder.wire(1);
398 let sum = builder.add(a, b);
399 let product = builder.mul(a, b);
400
401 assert!(sum.index >= 3);
402 assert!(product.index >= 3);
403 assert!(product.index > sum.index);
404
405 let circuit = builder.build();
406 assert_eq!(circuit.constraints.len(), 2);
407 }
408
409 #[test]
410 fn test_arithmetic_circuit() {
411 let mut circuit = ArithmeticCircuit::<TestField>::new(3, 2);
412 circuit.add_constraint(Constraint::AssertZero(Wire::new(0)));
413 circuit.add_constraint(Constraint::AssertEqual(Wire::new(1), Wire::new(2)));
414
415 assert_eq!(circuit.constraints.len(), 2);
416 assert_eq!(circuit.total_wires(), 5);
417 }
418
419 #[test]
420 fn test_circuit_air_width() {
421 let mut circuit = ArithmeticCircuit::<TestField>::new(2, 1);
422 circuit.add_constraint(Constraint::AssertZero(Wire::new(0)));
423 circuit.add_constraint(Constraint::AssertEqual(Wire::new(1), Wire::new(2)));
424
425 let air = CircuitAir::new(circuit);
426 assert!(BaseAir::<TestField>::width(&air) >= 3); }
428}