1use crate::ProgramCell;
21use alloc::vec;
22use alloc::vec::Vec;
23use hashbrown::HashMap;
24use hekate_core::errors::Error;
25use hekate_math::{Flat, HardwareField, TowerField};
26
27pub mod builder;
28
29#[derive(Clone, Debug)]
33pub struct ConstraintTerm<F> {
34 pub coeff: F,
35 pub poly_ind: Vec<ProgramCell>, }
37
38impl<F: TowerField> ConstraintTerm<F> {
39 pub fn new(coeff: F, cells: Vec<ProgramCell>) -> Self {
40 Self {
41 coeff,
42 poly_ind: cells,
43 }
44 }
45}
46
47#[derive(Clone, Debug)]
50pub struct Constraint<F> {
51 pub terms: Vec<ConstraintTerm<F>>,
52}
53
54impl<F: TowerField> Constraint<F> {
55 pub fn new(terms: Vec<ConstraintTerm<F>>) -> Self {
56 Self { terms }
57 }
58}
59
60#[derive(Clone, Debug, PartialEq, Eq)]
63pub enum BoundaryTarget<F> {
64 PublicInput(usize),
68
69 Constant(F),
72}
73
74#[derive(Clone, Debug)]
77pub struct BoundaryConstraint<F> {
78 pub col_idx: usize,
79 pub row_idx: usize,
80 pub target: BoundaryTarget<F>,
81}
82
83impl<F> BoundaryConstraint<F> {
84 pub fn with_public_input(col_idx: usize, row_idx: usize, public_input_idx: usize) -> Self {
85 Self {
86 col_idx,
87 row_idx,
88 target: BoundaryTarget::PublicInput(public_input_idx),
89 }
90 }
91
92 pub fn with_constant(col_idx: usize, row_idx: usize, val: F) -> Self {
93 Self {
94 col_idx,
95 row_idx,
96 target: BoundaryTarget::Constant(val),
97 }
98 }
99}
100
101impl<F: TowerField> BoundaryConstraint<F> {
102 pub fn resolve_target(
107 &self,
108 instance: &crate::ProgramInstance<F>,
109 ) -> hekate_core::errors::Result<F> {
110 match &self.target {
111 BoundaryTarget::Constant(v) => Ok(*v),
112 BoundaryTarget::PublicInput(idx) => {
113 instance.public_input(*idx).ok_or(Error::Protocol {
114 protocol: "boundary",
115 message: "public_input_idx out of bounds",
116 })
117 }
118 }
119 }
120
121 pub fn absorb_into<H: hekate_crypto::Hasher>(
126 &self,
127 transcript: &mut hekate_crypto::transcript::Transcript<H>,
128 ) {
129 transcript.append_u64(b"chiplet_bnd_col", self.col_idx as u64);
130 transcript.append_u64(b"chiplet_bnd_row", self.row_idx as u64);
131
132 match &self.target {
133 BoundaryTarget::PublicInput(idx) => {
134 transcript.append_u64(b"chiplet_bnd_kind", 0);
135 transcript.append_u64(b"chiplet_bnd_pub", *idx as u64);
136 }
137 BoundaryTarget::Constant(v) => {
138 transcript.append_u64(b"chiplet_bnd_kind", 1);
139 transcript.append_field(b"chiplet_bnd_const", *v);
140 }
141 }
142 }
143}
144
145#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
151pub struct ExprId(pub u32);
152
153#[derive(Clone, Debug)]
156pub enum ConstraintExpr<F> {
157 Cell(ProgramCell),
159
160 Const(F),
162
163 Add(ExprId, ExprId),
166
167 Mul(ExprId, ExprId),
169
170 Scale(F, ExprId),
176
177 Sum(Vec<ExprId>),
182}
183
184pub struct ConstraintArena<F> {
195 nodes: Vec<ConstraintExpr<F>>,
196
197 cell_cache: HashMap<ProgramCell, ExprId>,
200}
201
202impl<F: TowerField> Default for ConstraintArena<F> {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208impl<F: TowerField> ConstraintArena<F> {
209 pub fn new() -> Self {
210 Self {
211 nodes: Vec::new(),
212 cell_cache: HashMap::new(),
213 }
214 }
215
216 pub fn alloc(&mut self, expr: ConstraintExpr<F>) -> ExprId {
219 let id = ExprId(self.nodes.len() as u32);
220 self.nodes.push(expr);
221
222 id
223 }
224
225 pub fn get(&self, id: ExprId) -> &ConstraintExpr<F> {
227 &self.nodes[id.0 as usize]
228 }
229
230 pub fn len(&self) -> usize {
232 self.nodes.len()
233 }
234
235 pub fn is_empty(&self) -> bool {
236 self.nodes.is_empty()
237 }
238
239 pub fn shift_cells(&mut self, offset: usize) {
244 for node in &mut self.nodes {
245 if let ConstraintExpr::Cell(cell) = node {
246 cell.col_idx += offset;
247 }
248 }
249
250 let old_cache = core::mem::take(&mut self.cell_cache);
251 for (mut cell, id) in old_cache {
252 cell.col_idx += offset;
253 self.cell_cache.insert(cell, id);
254 }
255 }
256
257 pub fn cell(&mut self, cell: ProgramCell) -> ExprId {
261 if let Some(&id) = self.cell_cache.get(&cell) {
262 return id;
263 }
264
265 let id = self.alloc(ConstraintExpr::Cell(cell));
266 self.cell_cache.insert(cell, id);
267
268 id
269 }
270
271 pub fn constant(&mut self, val: F) -> ExprId {
273 self.alloc(ConstraintExpr::Const(val))
274 }
275
276 pub fn add(&mut self, a: ExprId, b: ExprId) -> ExprId {
278 self.alloc(ConstraintExpr::Add(a, b))
279 }
280
281 pub fn mul(&mut self, a: ExprId, b: ExprId) -> ExprId {
283 self.alloc(ConstraintExpr::Mul(a, b))
284 }
285
286 pub fn scale(&mut self, coeff: F, a: ExprId) -> ExprId {
288 self.alloc(ConstraintExpr::Scale(coeff, a))
289 }
290
291 pub fn sum(&mut self, children: Vec<ExprId>) -> ExprId {
293 self.alloc(ConstraintExpr::Sum(children))
294 }
295}
296
297pub struct ConstraintAst<F> {
301 pub arena: ConstraintArena<F>,
302 pub roots: Vec<ExprId>,
303 pub labels: Vec<Option<&'static str>>,
304}
305
306impl<F: TowerField> Clone for ConstraintArena<F> {
307 fn clone(&self) -> Self {
308 Self {
309 nodes: self.nodes.clone(),
310 cell_cache: self.cell_cache.clone(),
311 }
312 }
313}
314
315impl<F: TowerField> Clone for ConstraintAst<F> {
316 fn clone(&self) -> Self {
317 Self {
318 arena: self.arena.clone(),
319 roots: self.roots.clone(),
320 labels: self.labels.clone(),
321 }
322 }
323}
324
325impl<F: TowerField> ConstraintAst<F> {
326 pub fn max_degree(&self) -> usize {
329 if self.arena.is_empty() {
330 return 0;
331 }
332
333 let n = self.arena.len();
334 let mut deg: Vec<usize> = Vec::with_capacity(n);
335
336 for i in 0..n {
337 let d = match self.arena.get(ExprId(i as u32)) {
338 ConstraintExpr::Cell(_) => 1,
339 ConstraintExpr::Const(_) => 0,
340 ConstraintExpr::Add(a, b) => deg[a.0 as usize].max(deg[b.0 as usize]),
341 ConstraintExpr::Mul(a, b) => deg[a.0 as usize] + deg[b.0 as usize],
342 ConstraintExpr::Scale(_, a) => deg[a.0 as usize],
343 ConstraintExpr::Sum(children) => children
344 .iter()
345 .map(|c| deg[c.0 as usize])
346 .max()
347 .unwrap_or(0),
348 };
349
350 deg.push(d);
351 }
352
353 self.roots
354 .iter()
355 .map(|r| deg[r.0 as usize])
356 .max()
357 .unwrap_or(0)
358 }
359
360 pub fn precompute_hardware_consts(&self) -> Vec<Flat<F>>
363 where
364 F: HardwareField,
365 {
366 let n = self.arena.len();
367
368 let mut consts: Vec<Flat<F>> = Vec::with_capacity(n);
369 for i in 0..n {
370 let c = match self.arena.get(ExprId(i as u32)) {
371 ConstraintExpr::Const(k) => k.to_hardware(),
372 ConstraintExpr::Scale(k, _) => k.to_hardware(),
373 _ => Flat::from_raw(F::ZERO),
374 };
375
376 consts.push(c);
377 }
378
379 consts
380 }
381
382 pub fn evaluate(&self, current_row: &[Flat<F>], next_row: &[Flat<F>]) -> Vec<Flat<F>>
385 where
386 F: HardwareField,
387 {
388 let consts = self.precompute_hardware_consts();
389 let mut buf: Vec<Flat<F>> = Vec::with_capacity(self.arena.len());
390
391 self.evaluate_into(&consts, current_row, next_row, &mut buf);
392
393 self.roots.iter().map(|r| buf[r.0 as usize]).collect()
394 }
395
396 pub fn evaluate_into(
399 &self,
400 consts: &[Flat<F>],
401 current_row: &[Flat<F>],
402 next_row: &[Flat<F>],
403 buf: &mut Vec<Flat<F>>,
404 ) where
405 F: HardwareField,
406 {
407 buf.clear();
408
409 for (i, &coeff) in consts.iter().enumerate() {
410 let v = match self.arena.get(ExprId(i as u32)) {
411 ConstraintExpr::Cell(cell) => {
412 if cell.next_row {
413 next_row[cell.col_idx]
414 } else {
415 current_row[cell.col_idx]
416 }
417 }
418 ConstraintExpr::Const(_) => coeff,
419 ConstraintExpr::Add(a, b) => buf[a.0 as usize] + buf[b.0 as usize],
420 ConstraintExpr::Mul(a, b) => buf[a.0 as usize] * buf[b.0 as usize],
421 ConstraintExpr::Scale(_, a) => coeff * buf[a.0 as usize],
422 ConstraintExpr::Sum(children) => {
423 let mut s = Flat::from_raw(F::ZERO);
424 for c in children {
425 s += buf[c.0 as usize];
426 }
427
428 s
429 }
430 };
431
432 buf.push(v);
433 }
434 }
435
436 pub fn merge(&mut self, other: ConstraintAst<F>) {
438 let mut id_map: Vec<ExprId> = Vec::with_capacity(other.arena.len());
439 for node in other.arena.nodes {
440 let new_id = match node {
441 ConstraintExpr::Cell(cell) => self.arena.cell(cell),
442 ConstraintExpr::Const(val) => self.arena.constant(val),
443 ConstraintExpr::Add(a, b) => {
444 self.arena.add(id_map[a.0 as usize], id_map[b.0 as usize])
445 }
446 ConstraintExpr::Mul(a, b) => {
447 self.arena.mul(id_map[a.0 as usize], id_map[b.0 as usize])
448 }
449 ConstraintExpr::Scale(coeff, inner) => {
450 self.arena.scale(coeff, id_map[inner.0 as usize])
451 }
452 ConstraintExpr::Sum(children) => {
453 let remapped: Vec<ExprId> =
454 children.into_iter().map(|c| id_map[c.0 as usize]).collect();
455 self.arena.sum(remapped)
456 }
457 };
458
459 id_map.push(new_id);
460 }
461
462 for (root, label) in other.roots.into_iter().zip(other.labels) {
463 self.roots.push(id_map[root.0 as usize]);
464 self.labels.push(label);
465 }
466 }
467
468 pub fn to_constraints(&self) -> Vec<Constraint<F>> {
471 type FlatTerm<F> = (F, Vec<ProgramCell>);
474
475 fn expand<F: TowerField>(
476 arena: &ConstraintArena<F>,
477 id: ExprId,
478 cache: &mut Vec<Option<Vec<FlatTerm<F>>>>,
479 ) -> Vec<FlatTerm<F>> {
480 if let Some(ref cached) = cache[id.0 as usize] {
481 return cached.clone();
482 }
483
484 let result = match arena.get(id) {
485 ConstraintExpr::Cell(cell) => {
486 vec![(F::ONE, vec![*cell])]
487 }
488 ConstraintExpr::Const(k) => {
489 vec![(*k, vec![])]
490 }
491 ConstraintExpr::Add(a, b) => {
492 let mut terms = expand(arena, *a, cache);
493 terms.extend(expand(arena, *b, cache));
494
495 terms
496 }
497 ConstraintExpr::Mul(a, b) => {
498 let left = expand(arena, *a, cache);
499 let right = expand(arena, *b, cache);
500
501 let mut terms = Vec::with_capacity(left.len() * right.len());
502 for (lc, lp) in &left {
503 for (rc, rp) in &right {
504 let coeff = *lc * *rc;
505 let mut cells = lp.clone();
506
507 cells.extend_from_slice(rp);
508 terms.push((coeff, cells));
509 }
510 }
511
512 terms
513 }
514 ConstraintExpr::Scale(k, a) => {
515 let inner = expand(arena, *a, cache);
516 inner
517 .into_iter()
518 .map(|(c, cells)| (*k * c, cells))
519 .collect()
520 }
521 ConstraintExpr::Sum(children) => {
522 let mut terms = Vec::new();
523 for child in children {
524 terms.extend(expand(arena, *child, cache));
525 }
526
527 terms
528 }
529 };
530
531 cache[id.0 as usize] = Some(result.clone());
532
533 result
534 }
535
536 let n = self.arena.len();
537 let mut cache: Vec<Option<Vec<FlatTerm<F>>>> = vec![None; n];
538
539 self.roots
540 .iter()
541 .map(|root| {
542 let flat_terms = expand(&self.arena, *root, &mut cache);
543 Constraint::new(
544 flat_terms
545 .into_iter()
546 .map(|(coeff, cells)| ConstraintTerm::new(coeff, cells))
547 .collect(),
548 )
549 })
550 .collect()
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557 use crate::constraint::ConstraintExpr;
558 use crate::constraint::builder::ConstraintSystem;
559 use crate::{Air, Program};
560 use hekate_core::trace::ColumnType;
561 use hekate_math::{Block128, Flat};
562
563 type F = Block128;
564
565 #[derive(Clone)]
566 struct TestFibProgram;
567
568 impl Air<F> for TestFibProgram {
569 fn num_columns(&self) -> usize {
570 3
571 }
572
573 fn column_layout(&self) -> &[ColumnType] {
574 &[ColumnType::B32, ColumnType::B32, ColumnType::Bit]
575 }
576
577 fn constraint_ast(&self) -> ConstraintAst<F> {
578 let cs = ConstraintSystem::<F>::new();
579
580 let [a, b, q] = [cs.col(0), cs.col(1), cs.col(2)];
581 let [na, nb] = [cs.next(0), cs.next(1)];
582
583 cs.constrain(q * (na + b));
584 cs.constrain(q * (nb + a + b));
585
586 cs.build()
587 }
588 }
589
590 impl Program<F> for TestFibProgram {}
591
592 #[test]
593 fn default_constraint_ast_produces_correct_roots() {
594 let program = TestFibProgram;
595 let ast = program.constraint_ast();
596
597 assert_eq!(ast.roots.len(), 2);
599
600 assert!(!ast.arena.is_empty());
605
606 for &root in &ast.roots {
608 let node = ast.arena.get(root);
609 match node {
610 ConstraintExpr::Sum(children) => {
611 assert!(!children.is_empty());
612 }
613 ConstraintExpr::Mul(_, _) => {
614 }
616 _ => panic!("Root should be Sum or Mul, got {:?}", node),
617 }
618 }
619 }
620
621 #[test]
622 fn cell_dedup_works() {
623 let mut arena = ConstraintArena::<F>::new();
624
625 let cell_a = ProgramCell::current(0);
626 let cell_b = ProgramCell::current(0);
627 let cell_c = ProgramCell::next(0);
628
629 let id_a = arena.cell(cell_a);
630 let id_b = arena.cell(cell_b);
631 let id_c = arena.cell(cell_c);
632
633 assert_eq!(id_a, id_b);
635 assert_ne!(id_a, id_c);
637 assert_eq!(arena.len(), 2);
639 }
640
641 #[test]
642 fn dag_sharing_reduces_node_count() {
643 let mut arena = ConstraintArena::<F>::new();
644
645 let a = arena.cell(ProgramCell::current(0));
647 let b = arena.cell(ProgramCell::current(1));
648 let c = arena.cell(ProgramCell::current(2));
649 let theta = arena.sum(vec![a, b, c]);
650
651 let d = arena.cell(ProgramCell::current(3));
653 let expr1 = arena.mul(theta, d);
654 let expr2 = arena.mul(theta, a); let dag_node_count = arena.len();
657 assert_eq!(dag_node_count, 7);
659
660 assert!(dag_node_count < 10);
666
667 match arena.get(expr1) {
669 ConstraintExpr::Mul(lhs, _) => assert_eq!(*lhs, theta),
670 _ => panic!("Expected Mul"),
671 }
672
673 match arena.get(expr2) {
674 ConstraintExpr::Mul(lhs, rhs) => {
675 assert_eq!(*lhs, theta);
676 assert_eq!(*rhs, a);
677 }
678 _ => panic!("Expected Mul"),
679 }
680 }
681
682 #[test]
683 fn default_constraint_ast_node_count_matches_flat() {
684 let program = TestFibProgram;
685 let flat = program.constraints();
686 let ast = program.constraint_ast();
687
688 let mut flat_cell_count = 0;
690 for c in &flat {
691 for t in &c.terms {
692 flat_cell_count += t.poly_ind.len();
693 }
694 }
695
696 let ast_cell_count = ast
698 .arena
699 .nodes
700 .iter()
701 .filter(|n| matches!(n, ConstraintExpr::Cell(_)))
702 .count();
703
704 assert!(ast_cell_count <= flat_cell_count);
705
706 assert_eq!(ast_cell_count, 5);
711 }
712
713 #[test]
714 fn empty_constraint_produces_empty_ast() {
715 #[derive(Clone)]
716 struct EmptyProgram;
717
718 impl Air<F> for EmptyProgram {
719 fn num_columns(&self) -> usize {
720 0
721 }
722
723 fn column_layout(&self) -> &[ColumnType] {
724 &[]
725 }
726
727 fn constraint_ast(&self) -> ConstraintAst<F> {
728 ConstraintSystem::<F>::new().build()
729 }
730 }
731
732 impl Program<F> for EmptyProgram {}
733
734 let ast = EmptyProgram.constraint_ast();
735 assert!(ast.roots.is_empty());
736 assert!(ast.arena.is_empty());
737 }
738
739 #[test]
740 fn single_term_constraint_no_sum_wrapper() {
741 #[derive(Clone)]
742 struct SingleTermProgram;
743
744 impl Air<F> for SingleTermProgram {
745 fn num_columns(&self) -> usize {
746 2
747 }
748
749 fn column_layout(&self) -> &[ColumnType] {
750 &[ColumnType::B128, ColumnType::B128]
751 }
752
753 fn constraint_ast(&self) -> ConstraintAst<F> {
754 let cs = ConstraintSystem::<F>::new();
755 cs.constrain(cs.col(0) * cs.col(1));
756
757 cs.build()
758 }
759 }
760
761 impl Program<F> for SingleTermProgram {}
762
763 let ast = SingleTermProgram.constraint_ast();
764 assert_eq!(ast.roots.len(), 1);
765
766 match ast.arena.get(ast.roots[0]) {
768 ConstraintExpr::Mul(_, _) => {} other => panic!("Expected Mul for single-term, got {:?}", other),
770 }
771 }
772
773 #[test]
778 fn max_degree_fibonacci() {
779 let program = TestFibProgram;
780 let ast = program.constraint_ast();
781
782 let flat = program.constraints();
787 let flat_max = flat
788 .iter()
789 .flat_map(|c| c.terms.iter())
790 .map(|t| t.poly_ind.len())
791 .max()
792 .unwrap_or(0);
793
794 assert_eq!(ast.max_degree(), flat_max);
795 }
796
797 #[test]
798 fn max_degree_empty() {
799 let ast = ConstraintAst::<F> {
800 arena: ConstraintArena::new(),
801 roots: Vec::new(),
802 labels: Vec::new(),
803 };
804 assert_eq!(ast.max_degree(), 0);
805 }
806
807 #[test]
808 fn max_degree_builder_mul_chain() {
809 use crate::constraint::builder::ConstraintSystem;
810
811 let cs = ConstraintSystem::<F>::new();
812 let a = cs.col(0);
813 let b = cs.col(1);
814 let c = cs.col(2);
815
816 cs.constrain(a * b * c);
818 cs.constrain(a + b);
820
821 let ast = cs.build();
822 assert_eq!(ast.max_degree(), 3);
823 }
824
825 #[test]
826 fn to_constraints_roundtrip_fibonacci() {
827 let program = TestFibProgram;
828 let ast = program.constraint_ast();
829 let flat_from_ast = ast.to_constraints();
830 let flat_direct = program.constraints();
831
832 assert_eq!(flat_from_ast.len(), flat_direct.len());
834
835 for (a, d) in flat_from_ast.iter().zip(flat_direct.iter()) {
837 assert_eq!(a.terms.len(), d.terms.len());
838 }
839 }
840
841 #[test]
842 fn to_constraints_from_builder() {
843 use crate::constraint::builder::ConstraintSystem;
844
845 let cs = ConstraintSystem::<F>::new();
846 let a = cs.col(0);
847 let b = cs.col(1);
848
849 cs.constrain(a + b);
851
852 let ast = cs.build();
853 let flat = ast.to_constraints();
854
855 assert_eq!(flat.len(), 1);
856 assert_eq!(flat[0].terms.len(), 2);
857
858 for term in &flat[0].terms {
860 assert_eq!(term.coeff, F::ONE);
861 assert_eq!(term.poly_ind.len(), 1);
862 }
863 }
864
865 #[test]
866 fn evaluate_simple_constraint() {
867 use crate::constraint::builder::ConstraintSystem;
868 use hekate_math::Flat;
869
870 let cs = ConstraintSystem::<F>::new();
871 let a = cs.col(0);
872 let b = cs.col(1);
873
874 cs.constrain(a + b);
876
877 let ast = cs.build();
878
879 let current = vec![
881 Flat::from_raw(F::from(3u128)),
882 Flat::from_raw(F::from(3u128)),
883 ];
884 let next = vec![Flat::from_raw(F::ZERO); 2];
885
886 let evals = ast.evaluate(¤t, &next);
887 assert_eq!(evals.len(), 1);
888 assert_eq!(evals[0], Flat::from_raw(F::ZERO)); let current2 = vec![
892 Flat::from_raw(F::from(3u128)),
893 Flat::from_raw(F::from(5u128)),
894 ];
895 let evals2 = ast.evaluate(¤t2, &next);
896 assert_ne!(evals2[0], Flat::from_raw(F::ZERO));
897 }
898
899 #[test]
900 fn evaluate_into_matches_evaluate() {
901 let cs = ConstraintSystem::<F>::new();
902
903 let a = cs.col(0);
904 let b = cs.col(1);
905 let na = cs.next(0);
906
907 cs.constrain(a + b);
908 cs.constrain(a * b);
909 cs.constrain(na + a);
910
911 let ast = cs.build();
912 let zero = Flat::from_raw(F::ZERO);
913
914 let current = vec![
915 Flat::from_raw(F::from(3u128)),
916 Flat::from_raw(F::from(5u128)),
917 ];
918 let next = vec![Flat::from_raw(F::from(7u128)), zero];
919
920 let expected = ast.evaluate(¤t, &next);
921
922 let consts = ast.precompute_hardware_consts();
923
924 let mut buf = Vec::new();
925 ast.evaluate_into(&consts, ¤t, &next, &mut buf);
926
927 for (i, root) in ast.roots.iter().enumerate() {
928 assert_eq!(buf[root.0 as usize], expected[i]);
929 }
930 }
931}