use comp_cat_rs::collapse::free_category::{Edge, FreeCategoryError, Graph, Vertex};
use crate::field::Field;
use crate::wire::WireCount;
#[derive(Debug, Clone)]
pub enum PrimitiveGate<F: Field> {
Add,
Mul,
Const(F),
Bool,
Dup,
}
impl<F: Field> PrimitiveGate<F> {
#[must_use]
pub fn input_count(&self) -> WireCount {
match self {
Self::Add | Self::Mul => WireCount::new(2),
Self::Const(_) => WireCount::new(0),
Self::Bool | Self::Dup => WireCount::new(1),
}
}
#[must_use]
pub fn output_count(&self) -> WireCount {
match self {
Self::Add | Self::Mul | Self::Const(_) | Self::Bool => WireCount::new(1),
Self::Dup => WireCount::new(2),
}
}
}
#[derive(Debug, Clone)]
pub struct GateSpec<F: Field> {
gate: PrimitiveGate<F>,
source: Vertex,
target: Vertex,
}
impl<F: Field> GateSpec<F> {
#[must_use]
pub fn new(gate: PrimitiveGate<F>, source: Vertex, target: Vertex) -> Self {
Self {
gate,
source,
target,
}
}
#[must_use]
pub fn gate(&self) -> &PrimitiveGate<F> {
&self.gate
}
#[must_use]
pub fn source(&self) -> Vertex {
self.source
}
#[must_use]
pub fn target(&self) -> Vertex {
self.target
}
}
#[derive(Debug)]
pub struct PlonkishGraph<F: Field> {
vertices: Vec<WireCount>,
edges: Vec<GateSpec<F>>,
}
impl<F: Field> PlonkishGraph<F> {
#[must_use]
pub fn standard() -> Self {
Self {
vertices: vec![WireCount::new(0), WireCount::new(1), WireCount::new(2)],
edges: vec![
GateSpec::new(PrimitiveGate::Add, Vertex::new(2), Vertex::new(1)),
GateSpec::new(PrimitiveGate::Mul, Vertex::new(2), Vertex::new(1)),
GateSpec::new(PrimitiveGate::Bool, Vertex::new(1), Vertex::new(1)),
GateSpec::new(PrimitiveGate::Dup, Vertex::new(1), Vertex::new(2)),
],
}
}
#[must_use]
pub fn with_const(self, c: F) -> (Self, Edge) {
let edge_index = Edge::new(self.edges.len());
let new_spec = GateSpec::new(PrimitiveGate::Const(c), Vertex::new(0), Vertex::new(1));
let graph = Self {
vertices: self.vertices,
edges: self
.edges
.into_iter()
.chain(core::iter::once(new_spec))
.collect(),
};
(graph, edge_index)
}
#[must_use]
pub fn vertices(&self) -> &[WireCount] {
&self.vertices
}
#[must_use]
pub fn gate_specs(&self) -> &[GateSpec<F>] {
&self.edges
}
pub fn gate_spec_at(&self, edge: Edge) -> Result<&GateSpec<F>, crate::error::Error> {
if edge.index() < self.edges.len() {
Ok(&self.edges[edge.index()])
} else {
Err(FreeCategoryError::EdgeOutOfBounds {
edge,
count: self.edges.len(),
}
.into())
}
}
pub fn wire_count_at(&self, vertex: Vertex) -> Result<WireCount, crate::error::Error> {
if vertex.index() < self.vertices.len() {
Ok(self.vertices[vertex.index()])
} else {
Err(FreeCategoryError::VertexOutOfBounds {
vertex,
count: self.vertices.len(),
}
.into())
}
}
}
impl<F: Field> Graph for PlonkishGraph<F> {
fn vertex_count(&self) -> usize {
self.vertices.len()
}
fn edge_count(&self) -> usize {
self.edges.len()
}
fn source(&self, edge: Edge) -> Result<Vertex, FreeCategoryError> {
if edge.index() < self.edges.len() {
Ok(self.edges[edge.index()].source())
} else {
Err(FreeCategoryError::EdgeOutOfBounds {
edge,
count: self.edges.len(),
})
}
}
fn target(&self, edge: Edge) -> Result<Vertex, FreeCategoryError> {
if edge.index() < self.edges.len() {
Ok(self.edges[edge.index()].target())
} else {
Err(FreeCategoryError::EdgeOutOfBounds {
edge,
count: self.edges.len(),
})
}
}
}