use crate::ProgramCell;
use alloc::vec;
use alloc::vec::Vec;
use hashbrown::HashMap;
use hekate_core::errors::Error;
use hekate_math::{Flat, HardwareField, TowerField};
pub mod builder;
#[derive(Clone, Debug)]
pub struct ConstraintTerm<F> {
pub coeff: F,
pub poly_ind: Vec<ProgramCell>, }
impl<F: TowerField> ConstraintTerm<F> {
pub fn new(coeff: F, cells: Vec<ProgramCell>) -> Self {
Self {
coeff,
poly_ind: cells,
}
}
}
#[derive(Clone, Debug)]
pub struct Constraint<F> {
pub terms: Vec<ConstraintTerm<F>>,
}
impl<F: TowerField> Constraint<F> {
pub fn new(terms: Vec<ConstraintTerm<F>>) -> Self {
Self { terms }
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum BoundaryTarget<F> {
PublicInput(usize),
Constant(F),
}
#[derive(Clone, Debug)]
pub struct BoundaryConstraint<F> {
pub col_idx: usize,
pub row_idx: usize,
pub target: BoundaryTarget<F>,
}
impl<F> BoundaryConstraint<F> {
pub fn with_public_input(col_idx: usize, row_idx: usize, public_input_idx: usize) -> Self {
Self {
col_idx,
row_idx,
target: BoundaryTarget::PublicInput(public_input_idx),
}
}
pub fn with_constant(col_idx: usize, row_idx: usize, val: F) -> Self {
Self {
col_idx,
row_idx,
target: BoundaryTarget::Constant(val),
}
}
}
impl<F: TowerField> BoundaryConstraint<F> {
pub fn resolve_target(
&self,
instance: &crate::ProgramInstance<F>,
) -> hekate_core::errors::Result<F> {
match &self.target {
BoundaryTarget::Constant(v) => Ok(*v),
BoundaryTarget::PublicInput(idx) => {
instance.public_input(*idx).ok_or(Error::Protocol {
protocol: "boundary",
message: "public_input_idx out of bounds",
})
}
}
}
pub fn absorb_into<H: hekate_crypto::Hasher>(
&self,
transcript: &mut hekate_crypto::transcript::Transcript<H>,
) {
transcript.append_u64(b"chiplet_bnd_col", self.col_idx as u64);
transcript.append_u64(b"chiplet_bnd_row", self.row_idx as u64);
match &self.target {
BoundaryTarget::PublicInput(idx) => {
transcript.append_u64(b"chiplet_bnd_kind", 0);
transcript.append_u64(b"chiplet_bnd_pub", *idx as u64);
}
BoundaryTarget::Constant(v) => {
transcript.append_u64(b"chiplet_bnd_kind", 1);
transcript.append_field(b"chiplet_bnd_const", *v);
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ExprId(pub u32);
#[derive(Clone, Debug)]
pub enum ConstraintExpr<F> {
Cell(ProgramCell),
Const(F),
Add(ExprId, ExprId),
Mul(ExprId, ExprId),
Scale(F, ExprId),
Sum(Vec<ExprId>),
}
pub struct ConstraintArena<F> {
nodes: Vec<ConstraintExpr<F>>,
cell_cache: HashMap<ProgramCell, ExprId>,
}
impl<F: TowerField> Default for ConstraintArena<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: TowerField> ConstraintArena<F> {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
cell_cache: HashMap::new(),
}
}
pub fn alloc(&mut self, expr: ConstraintExpr<F>) -> ExprId {
let id = ExprId(self.nodes.len() as u32);
self.nodes.push(expr);
id
}
pub fn get(&self, id: ExprId) -> &ConstraintExpr<F> {
&self.nodes[id.0 as usize]
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn shift_cells(&mut self, offset: usize) {
for node in &mut self.nodes {
if let ConstraintExpr::Cell(cell) = node {
cell.col_idx += offset;
}
}
let old_cache = core::mem::take(&mut self.cell_cache);
for (mut cell, id) in old_cache {
cell.col_idx += offset;
self.cell_cache.insert(cell, id);
}
}
pub fn cell(&mut self, cell: ProgramCell) -> ExprId {
if let Some(&id) = self.cell_cache.get(&cell) {
return id;
}
let id = self.alloc(ConstraintExpr::Cell(cell));
self.cell_cache.insert(cell, id);
id
}
pub fn constant(&mut self, val: F) -> ExprId {
self.alloc(ConstraintExpr::Const(val))
}
pub fn add(&mut self, a: ExprId, b: ExprId) -> ExprId {
self.alloc(ConstraintExpr::Add(a, b))
}
pub fn mul(&mut self, a: ExprId, b: ExprId) -> ExprId {
self.alloc(ConstraintExpr::Mul(a, b))
}
pub fn scale(&mut self, coeff: F, a: ExprId) -> ExprId {
self.alloc(ConstraintExpr::Scale(coeff, a))
}
pub fn sum(&mut self, children: Vec<ExprId>) -> ExprId {
self.alloc(ConstraintExpr::Sum(children))
}
}
pub struct ConstraintAst<F> {
pub arena: ConstraintArena<F>,
pub roots: Vec<ExprId>,
pub labels: Vec<Option<&'static str>>,
}
impl<F: TowerField> Clone for ConstraintArena<F> {
fn clone(&self) -> Self {
Self {
nodes: self.nodes.clone(),
cell_cache: self.cell_cache.clone(),
}
}
}
impl<F: TowerField> Clone for ConstraintAst<F> {
fn clone(&self) -> Self {
Self {
arena: self.arena.clone(),
roots: self.roots.clone(),
labels: self.labels.clone(),
}
}
}
impl<F: TowerField> ConstraintAst<F> {
pub fn max_degree(&self) -> usize {
if self.arena.is_empty() {
return 0;
}
let n = self.arena.len();
let mut deg: Vec<usize> = Vec::with_capacity(n);
for i in 0..n {
let d = match self.arena.get(ExprId(i as u32)) {
ConstraintExpr::Cell(_) => 1,
ConstraintExpr::Const(_) => 0,
ConstraintExpr::Add(a, b) => deg[a.0 as usize].max(deg[b.0 as usize]),
ConstraintExpr::Mul(a, b) => deg[a.0 as usize] + deg[b.0 as usize],
ConstraintExpr::Scale(_, a) => deg[a.0 as usize],
ConstraintExpr::Sum(children) => children
.iter()
.map(|c| deg[c.0 as usize])
.max()
.unwrap_or(0),
};
deg.push(d);
}
self.roots
.iter()
.map(|r| deg[r.0 as usize])
.max()
.unwrap_or(0)
}
pub fn evaluate(&self, current_row: &[Flat<F>], next_row: &[Flat<F>]) -> Vec<Flat<F>>
where
F: HardwareField,
{
let n = self.arena.len();
let mut val: Vec<Flat<F>> = Vec::with_capacity(n);
for i in 0..n {
let v = match self.arena.get(ExprId(i as u32)) {
ConstraintExpr::Cell(cell) => {
if cell.next_row {
next_row[cell.col_idx]
} else {
current_row[cell.col_idx]
}
}
ConstraintExpr::Const(k) => k.to_hardware(),
ConstraintExpr::Add(a, b) => val[a.0 as usize] + val[b.0 as usize],
ConstraintExpr::Mul(a, b) => val[a.0 as usize] * val[b.0 as usize],
ConstraintExpr::Scale(k, a) => k.to_hardware() * val[a.0 as usize],
ConstraintExpr::Sum(children) => {
let mut s = Flat::from_raw(F::ZERO);
for c in children {
s += val[c.0 as usize];
}
s
}
};
val.push(v);
}
self.roots.iter().map(|r| val[r.0 as usize]).collect()
}
pub fn evaluate_into(
&self,
current_row: &[Flat<F>],
next_row: &[Flat<F>],
buf: &mut Vec<Flat<F>>,
) where
F: HardwareField,
{
buf.clear();
let n = self.arena.len();
for i in 0..n {
let v = match self.arena.get(ExprId(i as u32)) {
ConstraintExpr::Cell(cell) => {
if cell.next_row {
next_row[cell.col_idx]
} else {
current_row[cell.col_idx]
}
}
ConstraintExpr::Const(k) => k.to_hardware(),
ConstraintExpr::Add(a, b) => buf[a.0 as usize] + buf[b.0 as usize],
ConstraintExpr::Mul(a, b) => buf[a.0 as usize] * buf[b.0 as usize],
ConstraintExpr::Scale(k, a) => k.to_hardware() * buf[a.0 as usize],
ConstraintExpr::Sum(children) => {
let mut s = Flat::from_raw(F::ZERO);
for c in children {
s += buf[c.0 as usize];
}
s
}
};
buf.push(v);
}
}
pub fn merge(&mut self, other: ConstraintAst<F>) {
let mut id_map: Vec<ExprId> = Vec::with_capacity(other.arena.len());
for node in other.arena.nodes {
let new_id = match node {
ConstraintExpr::Cell(cell) => self.arena.cell(cell),
ConstraintExpr::Const(val) => self.arena.constant(val),
ConstraintExpr::Add(a, b) => {
self.arena.add(id_map[a.0 as usize], id_map[b.0 as usize])
}
ConstraintExpr::Mul(a, b) => {
self.arena.mul(id_map[a.0 as usize], id_map[b.0 as usize])
}
ConstraintExpr::Scale(coeff, inner) => {
self.arena.scale(coeff, id_map[inner.0 as usize])
}
ConstraintExpr::Sum(children) => {
let remapped: Vec<ExprId> =
children.into_iter().map(|c| id_map[c.0 as usize]).collect();
self.arena.sum(remapped)
}
};
id_map.push(new_id);
}
for (root, label) in other.roots.into_iter().zip(other.labels) {
self.roots.push(id_map[root.0 as usize]);
self.labels.push(label);
}
}
pub fn to_constraints(&self) -> Vec<Constraint<F>> {
type FlatTerm<F> = (F, Vec<ProgramCell>);
fn expand<F: TowerField>(
arena: &ConstraintArena<F>,
id: ExprId,
cache: &mut Vec<Option<Vec<FlatTerm<F>>>>,
) -> Vec<FlatTerm<F>> {
if let Some(ref cached) = cache[id.0 as usize] {
return cached.clone();
}
let result = match arena.get(id) {
ConstraintExpr::Cell(cell) => {
vec![(F::ONE, vec![*cell])]
}
ConstraintExpr::Const(k) => {
vec![(*k, vec![])]
}
ConstraintExpr::Add(a, b) => {
let mut terms = expand(arena, *a, cache);
terms.extend(expand(arena, *b, cache));
terms
}
ConstraintExpr::Mul(a, b) => {
let left = expand(arena, *a, cache);
let right = expand(arena, *b, cache);
let mut terms = Vec::with_capacity(left.len() * right.len());
for (lc, lp) in &left {
for (rc, rp) in &right {
let coeff = *lc * *rc;
let mut cells = lp.clone();
cells.extend_from_slice(rp);
terms.push((coeff, cells));
}
}
terms
}
ConstraintExpr::Scale(k, a) => {
let inner = expand(arena, *a, cache);
inner
.into_iter()
.map(|(c, cells)| (*k * c, cells))
.collect()
}
ConstraintExpr::Sum(children) => {
let mut terms = Vec::new();
for child in children {
terms.extend(expand(arena, *child, cache));
}
terms
}
};
cache[id.0 as usize] = Some(result.clone());
result
}
let n = self.arena.len();
let mut cache: Vec<Option<Vec<FlatTerm<F>>>> = vec![None; n];
self.roots
.iter()
.map(|root| {
let flat_terms = expand(&self.arena, *root, &mut cache);
Constraint::new(
flat_terms
.into_iter()
.map(|(coeff, cells)| ConstraintTerm::new(coeff, cells))
.collect(),
)
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::constraint::ConstraintExpr;
use crate::constraint::builder::ConstraintSystem;
use crate::{Air, Program};
use hekate_core::trace::ColumnType;
use hekate_math::{Block128, Flat};
type F = Block128;
#[derive(Clone)]
struct TestFibProgram;
impl Air<F> for TestFibProgram {
fn num_columns(&self) -> usize {
3
}
fn column_layout(&self) -> &[ColumnType] {
&[ColumnType::B32, ColumnType::B32, ColumnType::Bit]
}
fn constraint_ast(&self) -> ConstraintAst<F> {
let cs = ConstraintSystem::<F>::new();
let [a, b, q] = [cs.col(0), cs.col(1), cs.col(2)];
let [na, nb] = [cs.next(0), cs.next(1)];
cs.constrain(q * (na + b));
cs.constrain(q * (nb + a + b));
cs.build()
}
}
impl Program<F> for TestFibProgram {}
#[test]
fn default_constraint_ast_produces_correct_roots() {
let program = TestFibProgram;
let ast = program.constraint_ast();
assert_eq!(ast.roots.len(), 2);
assert!(!ast.arena.is_empty());
for &root in &ast.roots {
let node = ast.arena.get(root);
match node {
ConstraintExpr::Sum(children) => {
assert!(!children.is_empty());
}
ConstraintExpr::Mul(_, _) => {
}
_ => panic!("Root should be Sum or Mul, got {:?}", node),
}
}
}
#[test]
fn cell_dedup_works() {
let mut arena = ConstraintArena::<F>::new();
let cell_a = ProgramCell::current(0);
let cell_b = ProgramCell::current(0);
let cell_c = ProgramCell::next(0);
let id_a = arena.cell(cell_a);
let id_b = arena.cell(cell_b);
let id_c = arena.cell(cell_c);
assert_eq!(id_a, id_b);
assert_ne!(id_a, id_c);
assert_eq!(arena.len(), 2);
}
#[test]
fn dag_sharing_reduces_node_count() {
let mut arena = ConstraintArena::<F>::new();
let a = arena.cell(ProgramCell::current(0));
let b = arena.cell(ProgramCell::current(1));
let c = arena.cell(ProgramCell::current(2));
let theta = arena.sum(vec![a, b, c]);
let d = arena.cell(ProgramCell::current(3));
let expr1 = arena.mul(theta, d);
let expr2 = arena.mul(theta, a);
let dag_node_count = arena.len();
assert_eq!(dag_node_count, 7);
assert!(dag_node_count < 10);
match arena.get(expr1) {
ConstraintExpr::Mul(lhs, _) => assert_eq!(*lhs, theta),
_ => panic!("Expected Mul"),
}
match arena.get(expr2) {
ConstraintExpr::Mul(lhs, rhs) => {
assert_eq!(*lhs, theta);
assert_eq!(*rhs, a);
}
_ => panic!("Expected Mul"),
}
}
#[test]
fn default_constraint_ast_node_count_matches_flat() {
let program = TestFibProgram;
let flat = program.constraints();
let ast = program.constraint_ast();
let mut flat_cell_count = 0;
for c in &flat {
for t in &c.terms {
flat_cell_count += t.poly_ind.len();
}
}
let ast_cell_count = ast
.arena
.nodes
.iter()
.filter(|n| matches!(n, ConstraintExpr::Cell(_)))
.count();
assert!(ast_cell_count <= flat_cell_count);
assert_eq!(ast_cell_count, 5);
}
#[test]
fn empty_constraint_produces_empty_ast() {
#[derive(Clone)]
struct EmptyProgram;
impl Air<F> for EmptyProgram {
fn num_columns(&self) -> usize {
0
}
fn column_layout(&self) -> &[ColumnType] {
&[]
}
fn constraint_ast(&self) -> ConstraintAst<F> {
ConstraintSystem::<F>::new().build()
}
}
impl Program<F> for EmptyProgram {}
let ast = EmptyProgram.constraint_ast();
assert!(ast.roots.is_empty());
assert!(ast.arena.is_empty());
}
#[test]
fn single_term_constraint_no_sum_wrapper() {
#[derive(Clone)]
struct SingleTermProgram;
impl Air<F> for SingleTermProgram {
fn num_columns(&self) -> usize {
2
}
fn column_layout(&self) -> &[ColumnType] {
&[ColumnType::B128, ColumnType::B128]
}
fn constraint_ast(&self) -> ConstraintAst<F> {
let cs = ConstraintSystem::<F>::new();
cs.constrain(cs.col(0) * cs.col(1));
cs.build()
}
}
impl Program<F> for SingleTermProgram {}
let ast = SingleTermProgram.constraint_ast();
assert_eq!(ast.roots.len(), 1);
match ast.arena.get(ast.roots[0]) {
ConstraintExpr::Mul(_, _) => {} other => panic!("Expected Mul for single-term, got {:?}", other),
}
}
#[test]
fn max_degree_fibonacci() {
let program = TestFibProgram;
let ast = program.constraint_ast();
let flat = program.constraints();
let flat_max = flat
.iter()
.flat_map(|c| c.terms.iter())
.map(|t| t.poly_ind.len())
.max()
.unwrap_or(0);
assert_eq!(ast.max_degree(), flat_max);
}
#[test]
fn max_degree_empty() {
let ast = ConstraintAst::<F> {
arena: ConstraintArena::new(),
roots: Vec::new(),
labels: Vec::new(),
};
assert_eq!(ast.max_degree(), 0);
}
#[test]
fn max_degree_builder_mul_chain() {
use crate::constraint::builder::ConstraintSystem;
let cs = ConstraintSystem::<F>::new();
let a = cs.col(0);
let b = cs.col(1);
let c = cs.col(2);
cs.constrain(a * b * c);
cs.constrain(a + b);
let ast = cs.build();
assert_eq!(ast.max_degree(), 3);
}
#[test]
fn to_constraints_roundtrip_fibonacci() {
let program = TestFibProgram;
let ast = program.constraint_ast();
let flat_from_ast = ast.to_constraints();
let flat_direct = program.constraints();
assert_eq!(flat_from_ast.len(), flat_direct.len());
for (a, d) in flat_from_ast.iter().zip(flat_direct.iter()) {
assert_eq!(a.terms.len(), d.terms.len());
}
}
#[test]
fn to_constraints_from_builder() {
use crate::constraint::builder::ConstraintSystem;
let cs = ConstraintSystem::<F>::new();
let a = cs.col(0);
let b = cs.col(1);
cs.constrain(a + b);
let ast = cs.build();
let flat = ast.to_constraints();
assert_eq!(flat.len(), 1);
assert_eq!(flat[0].terms.len(), 2);
for term in &flat[0].terms {
assert_eq!(term.coeff, F::ONE);
assert_eq!(term.poly_ind.len(), 1);
}
}
#[test]
fn evaluate_simple_constraint() {
use crate::constraint::builder::ConstraintSystem;
use hekate_math::Flat;
let cs = ConstraintSystem::<F>::new();
let a = cs.col(0);
let b = cs.col(1);
cs.constrain(a + b);
let ast = cs.build();
let current = vec![
Flat::from_raw(F::from(3u128)),
Flat::from_raw(F::from(3u128)),
];
let next = vec![Flat::from_raw(F::ZERO); 2];
let evals = ast.evaluate(¤t, &next);
assert_eq!(evals.len(), 1);
assert_eq!(evals[0], Flat::from_raw(F::ZERO));
let current2 = vec![
Flat::from_raw(F::from(3u128)),
Flat::from_raw(F::from(5u128)),
];
let evals2 = ast.evaluate(¤t2, &next);
assert_ne!(evals2[0], Flat::from_raw(F::ZERO));
}
#[test]
fn evaluate_into_matches_evaluate() {
let cs = ConstraintSystem::<F>::new();
let a = cs.col(0);
let b = cs.col(1);
let na = cs.next(0);
cs.constrain(a + b);
cs.constrain(a * b);
cs.constrain(na + a);
let ast = cs.build();
let zero = Flat::from_raw(F::ZERO);
let current = vec![
Flat::from_raw(F::from(3u128)),
Flat::from_raw(F::from(5u128)),
];
let next = vec![Flat::from_raw(F::from(7u128)), zero];
let expected = ast.evaluate(¤t, &next);
let mut buf = Vec::new();
ast.evaluate_into(¤t, &next, &mut buf);
for (i, root) in ast.roots.iter().enumerate() {
assert_eq!(buf[root.0 as usize], expected[i]);
}
}
}