use alloc::vec::Vec;
use super::*;
use crate::{Felt, field::QuadFelt};
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Circuit {
pub num_inputs: usize,
pub constants: Vec<Felt>,
pub instructions: Vec<Instruction>,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub enum NodeID {
Input(usize),
Const(usize),
Eval(usize),
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Instruction {
pub node_l: NodeID,
pub node_r: NodeID,
pub op: Op,
}
#[derive(Debug)]
pub enum CircuitError {
LayoutInvalid,
InstructionMalformed,
InputsWrongNumber,
}
pub struct CircuitLayout {
pub num_inputs: usize,
pub num_constants: usize,
pub num_instructions: usize,
}
impl Circuit {
pub fn new(
num_inputs: usize,
constants: Vec<Felt>,
instructions: Vec<Instruction>,
) -> Result<Self, CircuitError> {
let layout = CircuitLayout {
num_inputs,
num_constants: constants.len(),
num_instructions: instructions.len(),
};
if instructions.is_empty() {
return Err(CircuitError::LayoutInvalid);
}
for (instruction_idx, instruction) in instructions.iter().enumerate() {
let eval_node = NodeID::Eval(instruction_idx);
let valid_node = |node: NodeID| layout.contains_node(&node) && node < eval_node;
if !(valid_node(instruction.node_l) && valid_node(instruction.node_r)) {
return Err(CircuitError::InstructionMalformed);
}
}
Ok(Self { num_inputs, constants, instructions })
}
pub fn evaluate(&self, inputs: &[QuadFelt]) -> Result<QuadFelt, CircuitError> {
let layout = self.layout();
if inputs.len() != layout.num_inputs {
return Err(CircuitError::InputsWrongNumber);
}
let mut nodes = Vec::with_capacity(layout.num_nodes());
nodes.extend(inputs.iter().copied());
nodes.extend(self.constants.iter().map(|c| QuadFelt::from(*c)));
for instruction in &self.instructions {
let id_l = layout
.node_index(&instruction.node_l)
.expect("left operand node not found in circuit layout");
let v_l =
*nodes.get(id_l).expect("left operand index out of bounds in evaluated nodes");
let id_r = layout
.node_index(&instruction.node_r)
.expect("right operand node not found in circuit layout");
let v_r =
*nodes.get(id_r).expect("right operand index out of bounds in evaluated nodes");
let v_out = match instruction.op {
Op::Sub => v_l - v_r,
Op::Mul => v_l * v_r,
Op::Add => v_l + v_r,
};
nodes.push(v_out);
}
Ok(*nodes.last().unwrap())
}
pub fn layout(&self) -> CircuitLayout {
CircuitLayout {
num_inputs: self.num_inputs,
num_constants: self.constants.len(),
num_instructions: self.instructions.len(),
}
}
}
impl CircuitLayout {
pub fn num_vars(&self) -> usize {
self.num_inputs + self.num_constants
}
pub fn num_nodes(&self) -> usize {
self.num_vars() + self.num_instructions
}
pub fn contains_node(&self, node: &NodeID) -> bool {
match *node {
NodeID::Input(id) => id < self.num_inputs,
NodeID::Const(id) => id < self.num_constants,
NodeID::Eval(id) => id < self.num_instructions,
}
}
pub fn node_index(&self, node: &NodeID) -> Option<usize> {
if !self.contains_node(node) {
return None;
}
let id = match *node {
NodeID::Input(id) => id,
NodeID::Const(id) => id + self.num_inputs,
NodeID::Eval(id) => id + self.num_vars(),
};
Some(id)
}
}