capy_graph/
lib.rs

1//! This module provides a framework for constructing and evaluating arithmetic circuits.
2//! It supports operations such as addition, multiplication, and custom operations through hint gates.
3//! The circuits can be constructed dynamically, evaluated in parallel layers, and verified with equality constraints.
4//! This flexible architecture is suitable for applications requiring configurable computational graphs, such as in
5//! cryptographic schemes or complex algorithm simulations.
6//!
7//! # Example Usage:
8//! ```rust
9//! use capy_graph::Circuit;
10//! use std::sync::Arc;
11//!
12//! let mut circuit = Circuit::new();
13//! let x = circuit.constant(10);
14//! let y = circuit.add(x, x);
15//! let custom_operation = Arc::new(|val: u32| val * 2);
16//! let z = circuit.hint(x, custom_operation);
17//! circuit.assert_equal(y, z);
18//!
19//! let input_values = vec![10];
20//! let debug = true;
21//! assert!(circuit.evaluate(&input_values, debug).is_ok());
22//! assert!(circuit.check_constraints().is_ok());
23//! ```
24//!
25//! This example demonstrates creating a circuit with constant inputs, adding two nodes,
26//! applying a custom doubling operation, and asserting equality conditions. It also shows how
27//! to evaluate the circuit with debugging enabled to trace computation values and performance.
28mod tests;
29
30use rand::distributions::{Distribution, Uniform};
31use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
32use std::{
33    collections::{HashSet, VecDeque},
34    fmt,
35    panic::{catch_unwind, AssertUnwindSafe},
36    sync::{
37        atomic::{AtomicUsize, Ordering},
38        Arc, Mutex,
39    },
40    time::{Duration, Instant},
41};
42use thiserror::Error;
43
44/// Enum to represent errors that can occur within the Circuit operations.
45#[derive(Error, Debug)]
46pub enum CircuitError {
47    #[error("Cannot evaluate an empty circuit")]
48    EmptyCircuit,
49    #[error("Error evaluating node: {0}")]
50    NodeEvaluationError(String),
51    #[error("Constraint check failed")]
52    ConstraintCheckFailure,
53    #[error("Failed to acquire a necessary lock: {0}")]
54    LockAcquisitionError(String),
55    #[error("Internal error: expected non-empty layers")]
56    EmptyLayersError,
57}
58
59/// Represents a gate in an arithmetic circuit.
60///
61/// Variants:
62/// - `Add`: Adds the values from two nodes identified by their indices.
63/// - `Multiply`: Multiplies the values from two nodes identified by their indices.
64/// - `Hint`: A custom gate which allows for applying any user-defined operation.
65///   It takes a single `u32` input and produces a `u32` output, defined by a closure.
66#[derive(Clone)]
67pub enum Gate {
68    Add(usize, usize),
69    Multiply(usize, usize),
70    Hint(usize, Arc<dyn Fn(u32) -> u32 + Send + Sync>),
71}
72
73/// Represents a node within an arithmetic circuit.
74///
75/// Variants:
76/// - `Input(u32)`: A constant input value to the circuit. Set once and used during execution.
77/// - `Variable`: Represents a variable whose value is determined during the circuit's execution.
78/// - `Operation`: Applies an operation defined by a `Gate` to inputs, dynamically during execution.
79#[derive(Clone)]
80pub enum Node {
81    Input(u32),
82    Variable,
83    Operation(Gate, Vec<usize>),
84}
85
86/// Represents an arithmetic circuit.
87/// This struct manages nodes, gates, and the evaluation process of the circuit.
88/// It supports adding various types of nodes, including constants, variables, and operations,
89/// and provides methods to evaluate the circuit, check constraints, and apply custom operations.
90pub struct Circuit {
91    nodes: Vec<Node>, // Holds all nodes within the circuit, including inputs, variables, or operations.
92    equalities: Vec<(usize, usize)>, // Tracks equality constraints between pairs of node indices.
93    layers: Option<Vec<Vec<usize>>>, // Organized layers for efficient evaluation and parallel processing.
94    results: Vec<u32>, // Stores the computation results of the circuit nodes after evaluation.
95    total_duration: Duration, // Total time taken to evaluate the circuit.
96    number_of_layers: usize, // Number of layers used during the parallel evaluation of the circuit.
97    number_of_constraints: usize, // Total number of equality constraints defined in the circuit.
98    total_hint_gates: AtomicUsize, // Counter for the number of hint gates processed during evaluation.
99    total_gates_processed: AtomicUsize, // Counter for the total number of gates processed during evaluation.
100    gates_per_second: f64, // Computation throughput: number of gates processed per second.
101}
102
103// Default implementation of `Circuit` that clippy realllly wants to keep adding in...
104impl Default for Circuit {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110impl Circuit {
111    /// Create a new circuit and initialize all fields to empty.
112    pub fn new() -> Self {
113        Circuit {
114            nodes: Vec::new(),
115            equalities: Vec::new(),
116            layers: None,
117            results: Vec::new(),
118            total_gates_processed: AtomicUsize::new(0),
119            total_hint_gates: AtomicUsize::new(0),
120            total_duration: Duration::new(0, 0),
121            number_of_layers: 0,
122            number_of_constraints: 0,
123            gates_per_second: 0.0,
124        }
125    }
126
127    // Insert a gate into the circuit. Returns the index
128    // of the newly inserted node.
129    fn insert_gate(&mut self, gate: Gate) -> usize {
130        let dependencies = match &gate {
131            Gate::Add(left, right) => vec![*left, *right],
132            Gate::Multiply(left, right) => vec![*left, *right],
133            Gate::Hint(idx, _) => vec![*idx],
134        };
135
136        self.nodes.push(Node::Operation(gate, dependencies));
137        self.nodes.len() - 1
138    }
139
140    /// Inserts an input node into the circuit. Any number of input
141    /// nodes can be inserted. Their values are then set in sequence
142    /// by passing a list of `u32`s to `circuit.evaluate(&[])`.
143    /// Returns the index of the newly inserted node.
144    ///
145    /// ### Usage:
146    /// ```
147    /// let mut circuit = capy_graph::Circuit::new();
148    /// let x = circuit.init();
149    /// let y = circuit.init();
150    /// let z = circuit.init();
151    /// let debug = true;
152    /// assert!(circuit.evaluate(&[1, 2, 3], debug).is_ok());
153    /// ```
154    pub fn init(&mut self) -> usize {
155        self.nodes.push(Node::Variable);
156        self.nodes.len() - 1
157    }
158
159    /// Inserts a constant-valued node into the circuit.
160    /// Returns the index of the newly inserted node.
161    /// ### Usage:
162    /// ```
163    /// let mut circuit = capy_graph::Circuit::new();
164    /// let x = circuit.constant(42);
165    /// let y = circuit.mul(x, x);
166    /// ```
167    pub fn constant(&mut self, value: u32) -> usize {
168        self.nodes.push(Node::Input(value));
169        self.nodes.len() - 1
170    }
171
172    /// Inserts an `addition` node into the circuit. It has max fan-in
173    /// of 2 and accepts the indices of the nodes to add. Addition is
174    /// saturated; overflow yields the max `u32` value, underflow yields
175    /// the minimum.
176
177    ///
178    /// ### Usage:
179    /// ```
180    /// let mut circuit = capy_graph::Circuit::new();
181    /// let x = circuit.constant(42);
182    /// let y = circuit.add(x, x);
183    /// ```
184    pub fn add(&mut self, idx: usize, idx2: usize) -> usize {
185        self.insert_gate(Gate::Add(idx, idx2))
186    }
187
188    /// Inserts a `multiplication` node into the circuit. It has max fan-in
189    /// of 2 and accepts the indices of the nodes to multiply. Multiplication
190    /// is saturated; overflow yields the max `u32` value, underflow yields
191    /// the minimum.
192    ///
193    /// ### Usage:
194    /// ```
195    /// let mut circuit = capy_graph::Circuit::new();
196    /// let x = circuit.constant(42);
197    /// let y = circuit.mul(x, x);
198    /// ```
199    pub fn mul(&mut self, idx1: usize, idx2: usize) -> usize {
200        self.insert_gate(Gate::Multiply(idx1, idx2))
201    }
202
203    /// Inserts a custom function into the circuit. This function
204    /// is passed as a closure with trait bounds restricted to
205    /// `Send` + `Sync` in order to support layerization
206    /// and parallel circuit evaluation. The circuit will
207    /// catch any panics (i.e. dividing by zero) as a `CircuitError`.
208    ///
209    /// ### Usage:
210    /// ```
211    /// use capy_graph::Circuit;
212    /// use std::sync::Arc;
213    /// let mut circuit = Circuit::new();
214    /// let two = circuit.init();
215    /// let b = circuit.constant(16);
216    /// // the circuit doesn't support division, so we hint it
217    /// let c = circuit.hint(
218    ///     b,
219    ///     Arc::new(|x: u32| x / 8) as Arc<dyn Fn(u32) -> u32 + Send + Sync>
220    /// );
221    /// // then we establish a constraint to ensure the hint is executed correctly
222    /// let constraint = circuit.mul(c, two);
223    /// circuit.assert_equal(two, c);
224    /// let debug = true;
225    /// assert!(circuit.evaluate(&[2], debug).is_ok());
226    /// assert!(circuit.check_constraints().is_ok());
227    /// ```
228    pub fn hint(&mut self, idx: usize, func: Arc<dyn Fn(u32) -> u32 + Send + Sync>) -> usize {
229        self.insert_gate(Gate::Hint(idx, func))
230    }
231
232    /// Inserts a constraint-check between two nodes into the circuit.
233    /// This is useful for asserting that custom functions were executed
234    /// correctly.
235    ///
236    /// ### Usage:
237    /// ```
238    /// let mut circuit = capy_graph::Circuit::new();
239    /// let x = circuit.init();
240    /// let y = circuit.constant(42);
241    /// circuit.assert_equal(x, y);
242    /// circuit.evaluate(&[42], false);
243    /// assert!(circuit.check_constraints().is_ok());
244    /// ```
245    pub fn assert_equal(&mut self, idx1: usize, idx2: usize) {
246        self.equalities.push((idx1, idx2));
247    }
248
249    /// Checks if all constraints in the circuit are satisfied.
250    /// Returns `Ok(())` if all constraints are satisfied, or
251    /// `Err(CircuitError::ConstraintCheckFailure)` if any constraint fails.
252    /// ### Usage:
253    /// ```
254    /// let mut circuit = capy_graph::Circuit::new();
255    /// let x = circuit.init();
256    /// let y = circuit.constant(42);
257    /// circuit.assert_equal(x, y);
258    /// circuit.evaluate(&[42], false);
259    /// assert!(circuit.check_constraints().is_ok());
260    /// ```
261    pub fn check_constraints(&self) -> Result<(), CircuitError> {
262        if self
263            .equalities
264            .iter()
265            .all(|&(idx1, idx2)| self.results[idx1] == self.results[idx2])
266        {
267            Ok(())
268        } else {
269            Err(CircuitError::ConstraintCheckFailure)
270        }
271    }
272
273    /// Evaluates the circuit, initializing all `init` nodes to a list of input `u32` values.
274    /// Inputs are initialized sequentially as they appear in the input list.
275    /// Gracefully errors on any panic introduced from a custom hint or when attempting
276    /// to evaluate an empty circuit.
277    /// Optionally print debug information from the evaluation. These details include:
278    /// - evaluation circuit evaluation time
279    /// - number of layers
280    /// - number of gates
281    /// - number of hints
282    /// - number of constraints
283    /// - number of gates processed per second
284    ///
285    /// # Example Usage:
286    /// ```rust
287    /// use capy_graph::Circuit;
288    /// use std::sync::Arc;
289    ///
290    /// let mut circuit = Circuit::new();
291    /// let x = circuit.constant(10);
292    /// let y = circuit.add(x, x);
293    /// let custom_operation = Arc::new(|val: u32| val * 2);
294    /// let z = circuit.hint(x, custom_operation);
295    /// circuit.assert_equal(y, z);
296    ///
297    /// let input_values = vec![10];
298    /// let debug = true;
299    /// assert!(circuit.evaluate(&input_values, debug).is_ok());
300    /// assert!(circuit.check_constraints().is_ok());
301    /// ```
302    pub fn evaluate(&mut self, input_vals: &[u32], debug: bool) -> Result<(), CircuitError> {
303        if self.nodes.is_empty() {
304            return Err(CircuitError::EmptyCircuit);
305        }
306
307        let mut results = vec![0; self.nodes.len()];
308        let start_time = Instant::now();
309        let total_gates_processed = AtomicUsize::new(0);
310        let total_hint_gates = AtomicUsize::new(0);
311
312        // Use parallel Kahn's to split the graph into its requisite layers
313        self.layerize()?;
314        self.number_of_layers = self.layers.as_ref().map_or(0, Vec::len);
315        self.number_of_constraints = self.equalities.len();
316
317        if let Some(layers) = &self.layers {
318            for (i, layer) in layers.iter().enumerate() {
319                let layer_start = Instant::now();
320
321                let layer_results: Result<Vec<_>, CircuitError> = layer
322                    .par_iter() // Use Rayon's parallel iterator here
323                    .map(|&node_idx| {
324                        let node = &self.nodes[node_idx];
325                        match node {
326                            Node::Input(value) => Ok(*value),
327                            Node::Variable => Ok(input_vals[node_idx]),
328                            Node::Operation(gate, _) => {
329                                if matches!(gate, Gate::Hint(_, _)) {
330                                    total_hint_gates.fetch_add(1, Ordering::Relaxed);
331                                }
332                                total_gates_processed.fetch_add(1, Ordering::Relaxed);
333                                self.evaluate_gate(gate, &results)
334                            }
335                        }
336                    })
337                    .collect();
338
339                let layer_results = layer_results?;
340                let layer_duration = layer_start.elapsed();
341
342                // Update results after processing each layer
343                for (&node_idx, &result) in layer.iter().zip(layer_results.iter()) {
344                    results[node_idx] = result;
345                }
346
347                if debug {
348                    println!("Layer {}: Processed in {:?}", i + 1, layer_duration);
349                }
350            }
351        }
352        // Collect debug information into self
353        self.total_hint_gates = total_hint_gates;
354        self.results = results;
355        self.total_duration = start_time.elapsed();
356        if self.total_duration > Duration::ZERO {
357            self.gates_per_second = total_gates_processed.load(Ordering::Relaxed) as f64
358                / self.total_duration.as_secs_f64();
359        }
360        self.total_gates_processed = total_gates_processed;
361
362        if debug {
363            println!("{}", self)
364        }
365
366        Ok(())
367    }
368
369    // Defines the operations of the gates in the circuit and specifies the format of the
370    // `hint` gate. Returns `CircuitError` if hint function panics for any reason.
371    fn evaluate_gate(&self, gate: &Gate, results: &[u32]) -> Result<u32, CircuitError> {
372        match gate {
373            Gate::Add(left, right) => Ok(results[*left].saturating_add(results[*right])),
374            Gate::Multiply(left, right) => Ok(results[*left].saturating_mul(results[*right])),
375            Gate::Hint(idx, func) => {
376                let result = catch_unwind(AssertUnwindSafe(|| func(results[*idx])));
377                result.map_err(|_| CircuitError::NodeEvaluationError("Function panic".to_string()))
378            }
379        }
380    }
381
382    // Topological sort using parallel Kahn's algorithm: https://dl.acm.org/doi/10.1145/368996.369025
383    // Finds and seperates all nodes and gates into layers based on their dependencies. This facilitates
384    // parallel evaluation of the circuit, hypothetically increasing evaluation performance and throughput.
385    fn layerize(&mut self) -> Result<(), CircuitError> {
386        let nodes = &self.nodes;
387        let num_nodes = nodes.len();
388        let in_degree = Arc::new(Mutex::new(vec![0; num_nodes]));
389        let graph = Arc::new(Mutex::new(vec![vec![]; num_nodes]));
390
391        // Initialize graph and in_degree using rayon
392        (0..num_nodes)
393            .into_par_iter()
394            .try_for_each(|node_idx| -> Result<(), CircuitError> {
395                if let Node::Operation(_, deps) = &nodes[node_idx] {
396                    let mut graph_lock = graph.lock().map_err(|e| {
397                        CircuitError::LockAcquisitionError(format!("Failed to lock graph: {}", e))
398                    })?;
399                    let mut in_degree_lock = in_degree.lock().map_err(|e| {
400                        CircuitError::LockAcquisitionError(format!(
401                            "Failed to lock in_degree: {}",
402                            e
403                        ))
404                    })?;
405                    for &dep in deps {
406                        graph_lock[dep].push(node_idx);
407                        in_degree_lock[node_idx] += 1;
408                    }
409                }
410                Ok(())
411            })?;
412
413        // Determine the initial layer of nodes with zero in-degree
414        let mut queue = VecDeque::new();
415        let mut layers = Vec::new();
416        {
417            let in_deg = in_degree.lock().map_err(|e| {
418                CircuitError::LockAcquisitionError(format!(
419                    "Failed to lock in_degree for reading: {}",
420                    e
421                ))
422            })?;
423            for (i, &degree) in in_deg.iter().enumerate() {
424                if degree == 0 {
425                    queue.push_back(i);
426                }
427            }
428        }
429
430        // Process the layers
431        while !queue.is_empty() {
432            let current_layer = queue.drain(..).collect::<Vec<_>>();
433            layers.push(current_layer);
434
435            let mut next_layer = HashSet::new();
436            {
437                let graph_lock = graph.lock().map_err(|e| {
438                    CircuitError::LockAcquisitionError(format!(
439                        "Failed to lock graph for processing: {}",
440                        e
441                    ))
442                })?;
443                let mut in_deg_lock = in_degree.lock().map_err(|e| {
444                    CircuitError::LockAcquisitionError(format!(
445                        "Failed to lock in_degree for updating: {}",
446                        e
447                    ))
448                })?;
449
450                // Safely access the last element of layers, handling the error if layers is somehow empty.
451                let last_layer = layers.last().ok_or(CircuitError::EmptyLayersError)?;
452                for &node_idx in last_layer {
453                    for &dependent in &graph_lock[node_idx] {
454                        in_deg_lock[dependent] -= 1;
455                        if in_deg_lock[dependent] == 0 {
456                            next_layer.insert(dependent);
457                        }
458                    }
459                }
460            }
461
462            for node in next_layer {
463                queue.push_back(node);
464            }
465        }
466
467        self.layers = Some(layers);
468        Ok(())
469    }
470    /// Generate an arbitrary-size, random combination of nodes and gates for testing purposes.
471    /// Includes a mix of variants from the `Gate` enum as well as a custom hint. Automatically
472    /// constrains the hint and creates random dependencies between nodes.
473    /// ### Usage:
474    /// ```
475    /// use capy_graph::Circuit;
476    /// let mut circuit = Circuit::new();
477    /// // Generate a large random circuit
478    /// let num_gates = 100000;
479    /// circuit.generate_random(num_gates);
480    /// // Mock input
481    /// let inputs = vec![42; 10];
482    /// // Evaluate the circuit
483    /// assert!(circuit.evaluate(&inputs, true).is_ok());
484    /// // check all random constraints
485    /// assert!(circuit.check_constraints().is_ok());
486    /// ```
487    pub fn generate_random(&mut self, num_gates: usize) {
488        let num_inputs = 10; // Fixed number of input variables
489
490        // Initialize input nodes with random values
491        for _ in 0..num_inputs {
492            self.constant(rand::random::<u32>() % 100);
493        }
494
495        let custom_funcs: Vec<Arc<dyn Fn(u32) -> u32 + Send + Sync>> =
496            vec![Arc::new(|x| (x as f32).sqrt().round() as u32)];
497
498        let mut rng = rand::thread_rng();
499        let gate_dist = Uniform::from(0..3); // For Add, Multiply, Custom
500        let index_dist = Uniform::from(0..self.nodes.len());
501        let func_dist = Uniform::from(0..custom_funcs.len());
502
503        for _ in 0..num_gates {
504            // Sample some random gates
505            let gate_type = gate_dist.sample(&mut rng);
506            let idx1 = index_dist.sample(&mut rng);
507
508            match gate_type {
509                0 => {
510                    self.add(idx1, index_dist.sample(&mut rng));
511                }
512                1 => {
513                    self.mul(idx1, index_dist.sample(&mut rng));
514                }
515                2 => {
516                    // Hint function
517                    let func_idx = func_dist.sample(&mut rng);
518                    // Clone on an Arc just increments reference counter and is cheap, so this is ok for now
519                    let func_node = self.hint(idx1, custom_funcs[func_idx].clone());
520
521                    // If we add a hint into the circuit, let's also add an accompanying
522                    // equality check to enforce the constraint automatically
523                    let verification_node =
524                        self.apply_equality_constraint(func_node, func_idx, idx1);
525
526                    // Assert that original and verified are equal
527                    self.assert_equal(idx1, verification_node);
528                }
529                _ => unreachable!(),
530            }
531        }
532    }
533
534    // Helper function applies an equality constraint to a randomly generated hint
535    fn apply_equality_constraint(
536        &mut self,
537        func_node: usize,
538        func_idx: usize,
539        original_idx: usize,
540    ) -> usize {
541        match func_idx {
542            3 => self.hint(func_node, Arc::new(|x| x * x)), // Check sqrt by squaring
543            _ => original_idx,
544        }
545    }
546}
547
548impl fmt::Display for Circuit {
549    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
550        writeln!(f, "Circuit Evaluation Summary:")?;
551        writeln!(f, "Total evaluation time: {:?}", self.total_duration)?;
552        writeln!(f, "Number of layers: {}", self.number_of_layers)?;
553        writeln!(f, "Number of constraints: {}", self.number_of_constraints)?;
554        writeln!(
555            f,
556            "Number of hint gates processed: {}",
557            self.total_hint_gates.load(Ordering::Relaxed)
558        )?;
559        writeln!(
560            f,
561            "Total gates processed: {}",
562            self.total_gates_processed.load(Ordering::Relaxed)
563        )?;
564        writeln!(
565            f,
566            "Gates processed per second: {:.2}",
567            self.gates_per_second
568        )?;
569
570        Ok(())
571    }
572}