Skip to main content

lib_q_zkp/
circuit.rs

1//! Circuit builder for arithmetic constraints
2//!
3//! This module provides a circuit abstraction for building arithmetic constraints
4//! that can be compiled into AIR (Algebraic Intermediate Representation) for STARK proofs.
5
6extern crate alloc;
7use alloc::vec;
8use alloc::vec::Vec;
9
10use lib_q_stark_air::{
11    Air,
12    AirBuilder,
13    BaseAir,
14    WindowAccess,
15};
16use lib_q_stark_field::Field;
17
18/// A wire in the circuit, representing a field element
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct Wire {
21    /// The index of the wire in the witness vector
22    pub index: usize,
23}
24
25impl Wire {
26    /// Create a new wire with the given index
27    pub fn new(index: usize) -> Self {
28        Self { index }
29    }
30}
31
32/// A constraint in the circuit
33#[derive(Debug, Clone)]
34pub enum Constraint<F: Field> {
35    /// Assert that a wire equals zero: `wire == 0`
36    AssertZero(Wire),
37    /// Assert that two wires are equal: `left == right`
38    AssertEqual(Wire, Wire),
39    /// Assert that a wire equals a constant: `wire == constant`
40    AssertConstant(Wire, F),
41    /// Assert that a wire equals the sum of two wires: `wire == left + right`
42    AssertAdd(Wire, Wire, Wire),
43    /// Assert that a wire equals the product of two wires: `wire == left * right`
44    AssertMul(Wire, Wire, Wire),
45}
46
47/// An arithmetic circuit containing constraints and metadata
48#[derive(Debug, Clone)]
49pub struct ArithmeticCircuit<F: Field> {
50    /// The constraints in the circuit
51    pub constraints: Vec<Constraint<F>>,
52    /// The number of witness wires (excluding public inputs)
53    pub witness_size: usize,
54    /// The number of public input wires
55    pub public_input_size: usize,
56}
57
58impl<F: Field> ArithmeticCircuit<F> {
59    /// Create a new empty circuit
60    pub fn new(witness_size: usize, public_input_size: usize) -> Self {
61        Self {
62            constraints: Vec::new(),
63            witness_size,
64            public_input_size,
65        }
66    }
67
68    /// Get the total number of wires (witness + public inputs)
69    pub fn total_wires(&self) -> usize {
70        self.witness_size + self.public_input_size
71    }
72
73    /// Add a constraint to the circuit
74    pub fn add_constraint(&mut self, constraint: Constraint<F>) {
75        self.constraints.push(constraint);
76    }
77}
78
79/// Builder for constructing arithmetic circuits
80pub struct CircuitBuilder<F: Field> {
81    circuit: ArithmeticCircuit<F>,
82    next_wire: usize,
83}
84
85impl<F: Field> CircuitBuilder<F> {
86    /// Create a new circuit builder
87    ///
88    /// # Arguments
89    ///
90    /// * `witness_size` - Number of witness wires (private inputs)
91    /// * `public_input_size` - Number of public input wires
92    ///
93    /// # Example
94    ///
95    /// ```rust,ignore
96    /// use lib_q_zkp::circuit::CircuitBuilder;
97    /// use lib_q_stark_field::extension::Complex;
98    /// use lib_q_stark_mersenne31::Mersenne31;
99    ///
100    /// type Val = Complex<Mersenne31>;
101    ///
102    /// let mut builder = CircuitBuilder::<Val>::new(2, 1);
103    /// let a = builder.wire(0);  // witness wire 0
104    /// let b = builder.wire(1);  // witness wire 1
105    /// let sum = builder.add(a, b);
106    /// builder.assert_zero(sum);
107    /// let circuit = builder.build();
108    /// ```
109    pub fn new(witness_size: usize, public_input_size: usize) -> Self {
110        Self {
111            circuit: ArithmeticCircuit::new(witness_size, public_input_size),
112            next_wire: witness_size + public_input_size,
113        }
114    }
115
116    /// Allocate a new intermediate wire
117    pub fn alloc_wire(&mut self) -> Wire {
118        let wire = Wire::new(self.next_wire);
119        self.next_wire += 1;
120        wire
121    }
122
123    /// Get a wire by index (for public inputs and witness)
124    pub fn wire(&self, index: usize) -> Wire {
125        Wire::new(index)
126    }
127
128    /// Assert that a wire equals zero
129    pub fn assert_zero(&mut self, wire: Wire) {
130        self.circuit.add_constraint(Constraint::AssertZero(wire));
131    }
132
133    /// Assert that two wires are equal
134    pub fn assert_eq(&mut self, left: Wire, right: Wire) {
135        self.circuit
136            .add_constraint(Constraint::AssertEqual(left, right));
137    }
138
139    /// Assert that a wire equals a constant
140    pub fn assert_constant(&mut self, wire: Wire, constant: F) {
141        self.circuit
142            .add_constraint(Constraint::AssertConstant(wire, constant));
143    }
144
145    /// Add two wires and return the result wire
146    pub fn add(&mut self, left: Wire, right: Wire) -> Wire {
147        let result = self.alloc_wire();
148        self.circuit
149            .add_constraint(Constraint::AssertAdd(result, left, right));
150        result
151    }
152
153    /// Multiply two wires and return the result wire
154    pub fn mul(&mut self, left: Wire, right: Wire) -> Wire {
155        let result = self.alloc_wire();
156        self.circuit
157            .add_constraint(Constraint::AssertMul(result, left, right));
158        result
159    }
160
161    /// Build the circuit
162    pub fn build(self) -> ArithmeticCircuit<F> {
163        self.circuit
164    }
165}
166
167/// AIR implementation for an arithmetic circuit
168///
169/// This converts a circuit into an AIR that can be used with STARK proving.
170/// The trace represents all wire values, with one row containing all wire values.
171pub struct CircuitAir<F: Field> {
172    circuit: ArithmeticCircuit<F>,
173}
174
175impl<F: Field> CircuitAir<F> {
176    /// Create a new CircuitAir from an ArithmeticCircuit
177    pub fn new(circuit: ArithmeticCircuit<F>) -> Self {
178        Self { circuit }
179    }
180
181    /// Get a reference to the underlying circuit
182    pub fn circuit(&self) -> &ArithmeticCircuit<F> {
183        &self.circuit
184    }
185
186    /// Generate an execution trace from witness values
187    ///
188    /// The witness values should include all wire values in the circuit.
189    /// Wire indices 0..witness_size are witness wires,
190    /// indices witness_size..witness_size+public_input_size are public inputs,
191    /// and remaining indices are intermediate wires.
192    ///
193    /// # Arguments
194    ///
195    /// * `witness` - Private witness values (witness wires)
196    /// * `public` - Public input values
197    ///
198    /// # Returns
199    ///
200    /// A RowMajorMatrix containing the trace, or an error if validation fails
201    pub fn generate_trace(
202        &self,
203        witness: &[F],
204        public: &[F],
205    ) -> Result<lib_q_stark_matrix::dense::RowMajorMatrix<F>, lib_q_core::Error> {
206        use lib_q_stark_matrix::dense::RowMajorMatrix;
207
208        // Validate input sizes
209        if witness.len() != self.circuit.witness_size {
210            return Err(lib_q_core::Error::InvalidState {
211                operation: "CircuitAir::generate_trace".into(),
212                reason: alloc::format!(
213                    "Witness size mismatch: expected {}, got {}",
214                    self.circuit.witness_size,
215                    witness.len()
216                ),
217            });
218        }
219
220        if public.len() != self.circuit.public_input_size {
221            return Err(lib_q_core::Error::InvalidState {
222                operation: "CircuitAir::generate_trace".into(),
223                reason: alloc::format!(
224                    "Public input size mismatch: expected {}, got {}",
225                    self.circuit.public_input_size,
226                    public.len()
227                ),
228            });
229        }
230
231        let width = self.width();
232
233        // Allocate trace for a single row (power of 2)
234        let mut trace_values = F::zero_vec(width);
235
236        // Fill witness wires
237        for (i, val) in witness.iter().enumerate() {
238            if i < width {
239                trace_values[i] = *val;
240            }
241        }
242
243        // Fill public input wires
244        for (i, val) in public.iter().enumerate() {
245            let idx = self.circuit.witness_size + i;
246            if idx < width {
247                trace_values[idx] = *val;
248            }
249        }
250
251        // Compute intermediate wire values by evaluating constraints
252        for constraint in &self.circuit.constraints {
253            match constraint {
254                Constraint::AssertAdd(out, l, r)
255                    if out.index < width && l.index < width && r.index < width =>
256                {
257                    trace_values[out.index] = trace_values[l.index] + trace_values[r.index];
258                }
259                Constraint::AssertMul(out, l, r)
260                    if out.index < width && l.index < width && r.index < width =>
261                {
262                    trace_values[out.index] = trace_values[l.index] * trace_values[r.index];
263                }
264                // Other constraints don't compute new values
265                _ => {}
266            }
267        }
268
269        // Pad to at least MIN_TRACE_ROWS so FRI has sufficient two-adic height (degree >= 1)
270        const MIN_TRACE_ROWS: usize = 64;
271        if MIN_TRACE_ROWS > 1 {
272            let mut padded = trace_values.clone();
273            for _ in 1..MIN_TRACE_ROWS {
274                padded.extend_from_slice(&trace_values);
275            }
276            Ok(RowMajorMatrix::new(padded, width))
277        } else {
278            Ok(RowMajorMatrix::new(trace_values, width))
279        }
280    }
281}
282
283impl<F: Field> BaseAir<F> for CircuitAir<F> {
284    fn width(&self) -> usize {
285        // The width is the total number of wires (witness + public inputs + intermediate wires)
286        // We need to compute this from the constraints
287        let max_wire = self
288            .circuit
289            .constraints
290            .iter()
291            .flat_map(|c| match c {
292                Constraint::AssertZero(w) => vec![w.index],
293                Constraint::AssertEqual(l, r) => vec![l.index, r.index],
294                Constraint::AssertConstant(w, _) => vec![w.index],
295                Constraint::AssertAdd(r, l, r2) => vec![r.index, l.index, r2.index],
296                Constraint::AssertMul(r, l, r2) => vec![r.index, l.index, r2.index],
297            })
298            .max()
299            .unwrap_or(0);
300        (max_wire + 1).max(self.circuit.total_wires())
301    }
302}
303
304impl<F: Field, AB: AirBuilder<F = F>> Air<AB> for CircuitAir<F> {
305    fn eval(&self, builder: &mut AB) {
306        let main = builder.main();
307        let row = main.current_slice();
308
309        // Evaluate each constraint in the circuit
310        for constraint in &self.circuit.constraints {
311            match constraint {
312                Constraint::AssertZero(w) => {
313                    // Constraint: wire[w.index] == 0
314                    if w.index < row.len() {
315                        builder.assert_zero(row[w.index]);
316                    }
317                }
318                Constraint::AssertEqual(l, r) => {
319                    // Constraint: wire[l.index] == wire[r.index]
320                    if l.index < row.len() && r.index < row.len() {
321                        builder.assert_eq(row[l.index], row[r.index]);
322                    }
323                }
324                Constraint::AssertConstant(w, c) => {
325                    // Constraint: wire[w.index] == constant
326                    if w.index < row.len() {
327                        builder.assert_eq(row[w.index], *c);
328                    }
329                }
330                Constraint::AssertAdd(out, l, r) => {
331                    // Constraint: wire[out.index] == wire[l.index] + wire[r.index]
332                    if out.index < row.len() && l.index < row.len() && r.index < row.len() {
333                        let sum = row[l.index] + row[r.index];
334                        builder.assert_eq(row[out.index], sum);
335                    }
336                }
337                Constraint::AssertMul(out, l, r) => {
338                    // Constraint: wire[out.index] == wire[l.index] * wire[r.index]
339                    if out.index < row.len() && l.index < row.len() && r.index < row.len() {
340                        let product = row[l.index] * row[r.index];
341                        builder.assert_eq(row[out.index], product);
342                    }
343                }
344            }
345        }
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use lib_q_stark_air::BaseAir;
352    use lib_q_stark_field::PrimeCharacteristicRing;
353    use lib_q_stark_field::extension::Complex;
354    use lib_q_stark_mersenne31::Mersenne31;
355
356    use super::*;
357
358    type TestField = Complex<Mersenne31>;
359
360    #[test]
361    fn test_circuit_builder_new() {
362        let builder = CircuitBuilder::<TestField>::new(5, 2);
363        let circuit = builder.build();
364        assert_eq!(circuit.witness_size, 5);
365        assert_eq!(circuit.public_input_size, 2);
366        assert_eq!(circuit.total_wires(), 7);
367    }
368
369    #[test]
370    fn test_circuit_builder_alloc_wire() {
371        let mut builder = CircuitBuilder::<TestField>::new(3, 2);
372        let wire1 = builder.alloc_wire();
373        let wire2 = builder.alloc_wire();
374        assert_eq!(wire1.index, 5); // 3 witness + 2 public = 5
375        assert_eq!(wire2.index, 6);
376    }
377
378    #[test]
379    fn test_circuit_builder_constraints() {
380        let mut builder = CircuitBuilder::<TestField>::new(2, 1);
381        let w0 = builder.wire(0);
382        let w1 = builder.wire(1);
383        let w2 = builder.wire(2);
384
385        builder.assert_zero(w0);
386        builder.assert_eq(w1, w2);
387        builder.assert_constant(w0, <TestField as PrimeCharacteristicRing>::ONE);
388
389        let circuit = builder.build();
390        assert_eq!(circuit.constraints.len(), 3);
391    }
392
393    #[test]
394    fn test_circuit_builder_add_mul() {
395        let mut builder = CircuitBuilder::<TestField>::new(2, 1);
396        let a = builder.wire(0);
397        let b = builder.wire(1);
398        let sum = builder.add(a, b);
399        let product = builder.mul(a, b);
400
401        assert!(sum.index >= 3);
402        assert!(product.index >= 3);
403        assert!(product.index > sum.index);
404
405        let circuit = builder.build();
406        assert_eq!(circuit.constraints.len(), 2);
407    }
408
409    #[test]
410    fn test_arithmetic_circuit() {
411        let mut circuit = ArithmeticCircuit::<TestField>::new(3, 2);
412        circuit.add_constraint(Constraint::AssertZero(Wire::new(0)));
413        circuit.add_constraint(Constraint::AssertEqual(Wire::new(1), Wire::new(2)));
414
415        assert_eq!(circuit.constraints.len(), 2);
416        assert_eq!(circuit.total_wires(), 5);
417    }
418
419    #[test]
420    fn test_circuit_air_width() {
421        let mut circuit = ArithmeticCircuit::<TestField>::new(2, 1);
422        circuit.add_constraint(Constraint::AssertZero(Wire::new(0)));
423        circuit.add_constraint(Constraint::AssertEqual(Wire::new(1), Wire::new(2)));
424
425        let air = CircuitAir::new(circuit);
426        assert!(BaseAir::<TestField>::width(&air) >= 3); // At least total_wires
427    }
428}