pub use crate::matrix::Matrix;
use crate::{
circuit::Var,
constraint_system::{
cs_prototype::GateRegistry, ConstraintSystem, Constraints, Gate, Val, WitnessAccess,
},
gates::Equality,
};
use std::{
cmp::Ordering,
collections::BTreeMap,
fmt::Display,
ops::{Add, Mul, Sub},
};
#[derive(Clone, Debug)]
pub struct MultiSet<T>(BTreeMap<T, usize>);
impl<T> Default for MultiSet<T> {
fn default() -> Self {
Self(Default::default())
}
}
#[derive(PartialEq, Eq, Clone)]
pub enum MatrixIndex {
Io(usize),
Selector(usize),
}
impl MatrixIndex {
fn order(&self, b: &Self) -> Ordering {
use MatrixIndex::{Io, Selector};
match (self, b) {
(Io(_), Selector(_)) => Ordering::Less,
(Selector(_), Io(_)) => Ordering::Greater,
(Io(a), Io(b)) | (Selector(a), Selector(b)) => a.cmp(b),
}
}
}
impl PartialOrd for MatrixIndex {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MatrixIndex {
fn cmp(&self, other: &Self) -> Ordering {
self.order(other)
}
}
#[derive(Clone, Debug)]
pub struct CcsStructure<const IO: usize, const S: usize> {
pub io_matrices: [Matrix; IO],
pub gate_selectors: Vec<usize>,
pub input_len: usize,
pub gates: Vec<Constraints<Exp<usize>>>,
pub trace_len: usize,
}
impl<const IO: usize, const S: usize> CcsStructure<IO, S> {
pub fn vars(&self) -> usize {
let len_padded = self.trace_len.next_power_of_two();
len_padded.ilog2() as usize
}
}
#[derive(Debug)]
struct Constraint<T, const IO: usize> {
io: [T; IO],
len: usize,
selector: usize,
}
#[derive(Debug, Default)]
pub struct StructureBuilder<const IO: usize> {
next: usize,
vars: Vec<usize>,
registry: GateRegistry,
constraints: Vec<Constraint<WitnessIndex, IO>>,
}
#[derive(Clone, Copy, Debug)]
pub struct WitnessIndex(usize);
impl Add for WitnessIndex {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
panic!(
"tried to add {:?} and {:?}, this type of var should never be added",
self, rhs
);
}
}
impl Sub for WitnessIndex {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
panic!(
"tried to subtract {:?} and {:?}, this type of var should never be subtracted",
self, rhs
);
}
}
impl Mul for WitnessIndex {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
panic!(
"tried to multiply {:?} and {:?}, this type of var should never be multiplied",
self, rhs
);
}
}
impl Val for WitnessIndex {}
impl<const MAX_IO: usize> StructureBuilder<MAX_IO> {
pub(crate) fn vars(&self) -> &[usize] {
&self.vars
}
pub fn gate_counts(&self) -> Vec<(&'static str, usize)> {
let registry = &self.registry;
let constraints = &self.constraints;
let mut counts = vec![0; registry.gate_registry.len()];
for constraint in constraints {
counts[constraint.selector] += 1;
}
let mut named_counts = counts.into_iter().map(|c| ("", c)).collect::<Vec<_>>();
for gate in registry.gate_registry.values() {
named_counts[gate.0].0 = gate.2;
}
named_counts
}
fn var(&mut self) -> WitnessIndex {
self.next += 1;
let v = self.next;
self.vars.push(v);
WitnessIndex(v)
}
pub fn with_inputs<const I: usize>() -> (Self, [WitnessIndex; I]) {
let mut new = Self::default();
let inputs = [(); I].map(|_| new.var());
(new, inputs)
}
pub fn reserve_outputs<const O: usize>(&mut self) {
for _ in 0..O {
let _ = self.var();
}
}
pub fn link_outputs<const I: usize, const O: usize>(&mut self, outputs: [WitnessIndex; O]) {
for (i, b) in outputs.into_iter().enumerate() {
let a = WitnessIndex(i + I + 1);
<Self as ConstraintSystem<(), WitnessIndex>>::execute::<Equality, 2, 2, 0>(
self,
[a, b].map(Var),
);
}
}
pub fn build<const S: usize>(self, public_io_len: usize) -> CcsStructure<MAX_IO, S> {
let Self {
registry,
constraints,
..
} = self;
let mut io_matrices = [(); MAX_IO].map(|_| Matrix::with_capacity(constraints.len()));
let mut gate_selectors = vec![];
for constraint in constraints.into_iter() {
let constraint: Constraint<WitnessIndex, MAX_IO> = constraint;
let Constraint { io, len, selector } = constraint;
for i in 0..len {
io_matrices[i].push_row_single_value(io[i].0);
}
(len..MAX_IO).for_each(|i| {
io_matrices[i].push_row_empty();
});
gate_selectors.push(selector);
if selector >= S {
panic!("not enough selectors for all gates, increase S");
}
}
let gates = registry.expressions_sorted();
let trace_len = self.vars.len();
assert_eq!(trace_len, self.next);
CcsStructure {
input_len: public_io_len,
io_matrices,
gate_selectors,
gates,
trace_len,
}
}
}
impl<F, const MAX_IO: usize> ConstraintSystem<F, WitnessIndex> for StructureBuilder<MAX_IO> {
fn execute<G, const IO: usize, const I: usize, const O: usize>(
&mut self,
inputs: [Var<WitnessIndex>; I],
) -> [Var<WitnessIndex>; O]
where
G: Gate<IO, I, O> + 'static,
{
let inputs = inputs.map(Var::unwrap);
let mut io = [WitnessIndex(0); MAX_IO];
io[..I].copy_from_slice(&inputs[..I]);
let output = [(); O].map(|_| self.var());
io[I..(I + O)].copy_from_slice(&output[..O]);
let selector = self.registry.selector::<G, IO, I, O>();
let constraint = Constraint {
io,
len: IO,
selector,
};
self.constraints.push(constraint);
output.map(Var)
}
fn read(&self, _var: &Var<WitnessIndex>, _token: &WitnessAccess) -> F {
panic!("can't access witness during constraint building");
}
fn free_variable<W>(&mut self, _value: W) -> Var<WitnessIndex>
where
W: for<'a> FnOnce(&mut Self, &'a WitnessAccess) -> F,
{
Var(self.var())
}
}
#[derive(Debug, Clone)]
pub enum Exp<T> {
Atom(T),
Add(Box<Self>, Box<Self>),
Mul(Box<Self>, Box<Self>),
Sub(Box<Self>, Box<Self>),
}
impl<T> Add<Self> for Exp<T> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self::Add(Box::new(self), Box::new(rhs))
}
}
impl<T> Mul<Self> for Exp<T> {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
Self::Mul(Box::new(self), Box::new(rhs))
}
}
impl<T> Sub<Self> for Exp<T> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self::Sub(Box::new(self), Box::new(rhs))
}
}
impl<T: Clone> Val for Exp<T> {}
impl<T: Display> Display for MultiSet<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for (i, n) in self.0.iter() {
for _ in 0..*n {
write!(f, "v{i}")?;
}
}
writeln!(f)
}
}
impl<T: Ord + Clone> Exp<T> {
pub fn map<V, F>(self, f: &F) -> Exp<V>
where
F: Fn(T) -> V,
{
use Exp::*;
match self {
Atom(v) => Atom(f(v)),
Add(e1, e2) => Add(Box::new(e1.map(f)), Box::new(e2.map(f))),
Mul(e1, e2) => Mul(Box::new(e1.map(f)), Box::new(e2.map(f))),
Sub(e1, e2) => Sub(Box::new(e1.map(f)), Box::new(e2.map(f))),
}
}
}
impl GateRegistry {
fn expressions_sorted(self) -> Vec<Constraints<Exp<usize>>> {
let mut gates: Vec<(usize, Constraints<Exp<usize>>)> = self
.gate_registry
.into_values()
.map(|(id, c, _)| (id, c))
.collect();
gates.sort_by_key(|x| x.0);
for (i1, (i2, _)) in gates.iter().enumerate() {
assert_eq!(i1, *i2, "unexpected index");
}
gates.into_iter().map(|(_, exp)| exp).collect()
}
}