1use crate::ProgramCell;
21use alloc::vec;
22use alloc::vec::Vec;
23use hashbrown::HashMap;
24use hekate_core::errors::Error;
25use hekate_math::{Flat, HardwareField, TowerField};
26
27pub mod builder;
28
29#[derive(Clone, Debug)]
33pub struct ConstraintTerm<F> {
34 pub coeff: F,
35 pub poly_ind: Vec<ProgramCell>, }
37
38impl<F: TowerField> ConstraintTerm<F> {
39 pub fn new(coeff: F, cells: Vec<ProgramCell>) -> Self {
40 Self {
41 coeff,
42 poly_ind: cells,
43 }
44 }
45}
46
47#[derive(Clone, Debug)]
50pub struct Constraint<F> {
51 pub terms: Vec<ConstraintTerm<F>>,
52}
53
54impl<F: TowerField> Constraint<F> {
55 pub fn new(terms: Vec<ConstraintTerm<F>>) -> Self {
56 Self { terms }
57 }
58}
59
60#[derive(Clone, Debug, PartialEq, Eq)]
63pub enum BoundaryTarget<F> {
64 PublicInput(usize),
68
69 Constant(F),
72}
73
74#[derive(Clone, Debug)]
77pub struct BoundaryConstraint<F> {
78 pub col_idx: usize,
79 pub row_idx: usize,
80 pub target: BoundaryTarget<F>,
81}
82
83impl<F> BoundaryConstraint<F> {
84 pub fn with_public_input(col_idx: usize, row_idx: usize, public_input_idx: usize) -> Self {
85 Self {
86 col_idx,
87 row_idx,
88 target: BoundaryTarget::PublicInput(public_input_idx),
89 }
90 }
91
92 pub fn with_constant(col_idx: usize, row_idx: usize, val: F) -> Self {
93 Self {
94 col_idx,
95 row_idx,
96 target: BoundaryTarget::Constant(val),
97 }
98 }
99}
100
101impl<F: TowerField> BoundaryConstraint<F> {
102 pub fn resolve_target(
107 &self,
108 instance: &crate::ProgramInstance<F>,
109 ) -> hekate_core::errors::Result<F> {
110 match &self.target {
111 BoundaryTarget::Constant(v) => Ok(*v),
112 BoundaryTarget::PublicInput(idx) => {
113 instance.public_input(*idx).ok_or(Error::Protocol {
114 protocol: "boundary",
115 message: "public_input_idx out of bounds",
116 })
117 }
118 }
119 }
120
121 pub fn absorb_into<H: hekate_crypto::Hasher>(
126 &self,
127 transcript: &mut hekate_crypto::transcript::Transcript<H>,
128 ) {
129 transcript.append_u64(b"chiplet_bnd_col", self.col_idx as u64);
130 transcript.append_u64(b"chiplet_bnd_row", self.row_idx as u64);
131
132 match &self.target {
133 BoundaryTarget::PublicInput(idx) => {
134 transcript.append_u64(b"chiplet_bnd_kind", 0);
135 transcript.append_u64(b"chiplet_bnd_pub", *idx as u64);
136 }
137 BoundaryTarget::Constant(v) => {
138 transcript.append_u64(b"chiplet_bnd_kind", 1);
139 transcript.append_field(b"chiplet_bnd_const", *v);
140 }
141 }
142 }
143}
144
145#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
151pub struct ExprId(pub u32);
152
153#[derive(Clone, Debug)]
156pub enum ConstraintExpr<F> {
157 Cell(ProgramCell),
159
160 Const(F),
162
163 Add(ExprId, ExprId),
166
167 Mul(ExprId, ExprId),
169
170 Scale(F, ExprId),
176
177 Sum(Vec<ExprId>),
182}
183
184pub struct ConstraintArena<F> {
195 nodes: Vec<ConstraintExpr<F>>,
196
197 cell_cache: HashMap<ProgramCell, ExprId>,
200}
201
202impl<F: TowerField> Default for ConstraintArena<F> {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208impl<F: TowerField> ConstraintArena<F> {
209 pub fn new() -> Self {
210 Self {
211 nodes: Vec::new(),
212 cell_cache: HashMap::new(),
213 }
214 }
215
216 pub fn alloc(&mut self, expr: ConstraintExpr<F>) -> ExprId {
219 let id = ExprId(self.nodes.len() as u32);
220 self.nodes.push(expr);
221
222 id
223 }
224
225 pub fn get(&self, id: ExprId) -> &ConstraintExpr<F> {
227 &self.nodes[id.0 as usize]
228 }
229
230 pub fn len(&self) -> usize {
232 self.nodes.len()
233 }
234
235 pub fn is_empty(&self) -> bool {
236 self.nodes.is_empty()
237 }
238
239 pub fn shift_cells(&mut self, offset: usize) {
244 for node in &mut self.nodes {
245 if let ConstraintExpr::Cell(cell) = node {
246 cell.col_idx += offset;
247 }
248 }
249
250 let old_cache = core::mem::take(&mut self.cell_cache);
251 for (mut cell, id) in old_cache {
252 cell.col_idx += offset;
253 self.cell_cache.insert(cell, id);
254 }
255 }
256
257 pub fn cell(&mut self, cell: ProgramCell) -> ExprId {
261 if let Some(&id) = self.cell_cache.get(&cell) {
262 return id;
263 }
264
265 let id = self.alloc(ConstraintExpr::Cell(cell));
266 self.cell_cache.insert(cell, id);
267
268 id
269 }
270
271 pub fn constant(&mut self, val: F) -> ExprId {
273 self.alloc(ConstraintExpr::Const(val))
274 }
275
276 pub fn add(&mut self, a: ExprId, b: ExprId) -> ExprId {
278 self.alloc(ConstraintExpr::Add(a, b))
279 }
280
281 pub fn mul(&mut self, a: ExprId, b: ExprId) -> ExprId {
283 self.alloc(ConstraintExpr::Mul(a, b))
284 }
285
286 pub fn scale(&mut self, coeff: F, a: ExprId) -> ExprId {
288 self.alloc(ConstraintExpr::Scale(coeff, a))
289 }
290
291 pub fn sum(&mut self, children: Vec<ExprId>) -> ExprId {
293 self.alloc(ConstraintExpr::Sum(children))
294 }
295}
296
297pub struct ConstraintAst<F> {
301 pub arena: ConstraintArena<F>,
302 pub roots: Vec<ExprId>,
303 pub labels: Vec<Option<&'static str>>,
304}
305
306impl<F: TowerField> Clone for ConstraintArena<F> {
307 fn clone(&self) -> Self {
308 Self {
309 nodes: self.nodes.clone(),
310 cell_cache: self.cell_cache.clone(),
311 }
312 }
313}
314
315impl<F: TowerField> Clone for ConstraintAst<F> {
316 fn clone(&self) -> Self {
317 Self {
318 arena: self.arena.clone(),
319 roots: self.roots.clone(),
320 labels: self.labels.clone(),
321 }
322 }
323}
324
325impl<F: TowerField> ConstraintAst<F> {
326 pub fn max_degree(&self) -> usize {
329 if self.arena.is_empty() {
330 return 0;
331 }
332
333 let n = self.arena.len();
334 let mut deg: Vec<usize> = Vec::with_capacity(n);
335
336 for i in 0..n {
337 let d = match self.arena.get(ExprId(i as u32)) {
338 ConstraintExpr::Cell(_) => 1,
339 ConstraintExpr::Const(_) => 0,
340 ConstraintExpr::Add(a, b) => deg[a.0 as usize].max(deg[b.0 as usize]),
341 ConstraintExpr::Mul(a, b) => deg[a.0 as usize] + deg[b.0 as usize],
342 ConstraintExpr::Scale(_, a) => deg[a.0 as usize],
343 ConstraintExpr::Sum(children) => children
344 .iter()
345 .map(|c| deg[c.0 as usize])
346 .max()
347 .unwrap_or(0),
348 };
349
350 deg.push(d);
351 }
352
353 self.roots
354 .iter()
355 .map(|r| deg[r.0 as usize])
356 .max()
357 .unwrap_or(0)
358 }
359
360 pub fn evaluate(&self, current_row: &[Flat<F>], next_row: &[Flat<F>]) -> Vec<Flat<F>>
363 where
364 F: HardwareField,
365 {
366 let n = self.arena.len();
367 let mut val: Vec<Flat<F>> = Vec::with_capacity(n);
368
369 for i in 0..n {
370 let v = match self.arena.get(ExprId(i as u32)) {
371 ConstraintExpr::Cell(cell) => {
372 if cell.next_row {
373 next_row[cell.col_idx]
374 } else {
375 current_row[cell.col_idx]
376 }
377 }
378 ConstraintExpr::Const(k) => k.to_hardware(),
379 ConstraintExpr::Add(a, b) => val[a.0 as usize] + val[b.0 as usize],
380 ConstraintExpr::Mul(a, b) => val[a.0 as usize] * val[b.0 as usize],
381 ConstraintExpr::Scale(k, a) => k.to_hardware() * val[a.0 as usize],
382 ConstraintExpr::Sum(children) => {
383 let mut s = Flat::from_raw(F::ZERO);
384 for c in children {
385 s += val[c.0 as usize];
386 }
387
388 s
389 }
390 };
391
392 val.push(v);
393 }
394
395 self.roots.iter().map(|r| val[r.0 as usize]).collect()
396 }
397
398 pub fn evaluate_into(
402 &self,
403 current_row: &[Flat<F>],
404 next_row: &[Flat<F>],
405 buf: &mut Vec<Flat<F>>,
406 ) where
407 F: HardwareField,
408 {
409 buf.clear();
410
411 let n = self.arena.len();
412 for i in 0..n {
413 let v = match self.arena.get(ExprId(i as u32)) {
414 ConstraintExpr::Cell(cell) => {
415 if cell.next_row {
416 next_row[cell.col_idx]
417 } else {
418 current_row[cell.col_idx]
419 }
420 }
421 ConstraintExpr::Const(k) => k.to_hardware(),
422 ConstraintExpr::Add(a, b) => buf[a.0 as usize] + buf[b.0 as usize],
423 ConstraintExpr::Mul(a, b) => buf[a.0 as usize] * buf[b.0 as usize],
424 ConstraintExpr::Scale(k, a) => k.to_hardware() * buf[a.0 as usize],
425 ConstraintExpr::Sum(children) => {
426 let mut s = Flat::from_raw(F::ZERO);
427 for c in children {
428 s += buf[c.0 as usize];
429 }
430
431 s
432 }
433 };
434
435 buf.push(v);
436 }
437 }
438
439 pub fn merge(&mut self, other: ConstraintAst<F>) {
441 let mut id_map: Vec<ExprId> = Vec::with_capacity(other.arena.len());
442 for node in other.arena.nodes {
443 let new_id = match node {
444 ConstraintExpr::Cell(cell) => self.arena.cell(cell),
445 ConstraintExpr::Const(val) => self.arena.constant(val),
446 ConstraintExpr::Add(a, b) => {
447 self.arena.add(id_map[a.0 as usize], id_map[b.0 as usize])
448 }
449 ConstraintExpr::Mul(a, b) => {
450 self.arena.mul(id_map[a.0 as usize], id_map[b.0 as usize])
451 }
452 ConstraintExpr::Scale(coeff, inner) => {
453 self.arena.scale(coeff, id_map[inner.0 as usize])
454 }
455 ConstraintExpr::Sum(children) => {
456 let remapped: Vec<ExprId> =
457 children.into_iter().map(|c| id_map[c.0 as usize]).collect();
458 self.arena.sum(remapped)
459 }
460 };
461
462 id_map.push(new_id);
463 }
464
465 for (root, label) in other.roots.into_iter().zip(other.labels) {
466 self.roots.push(id_map[root.0 as usize]);
467 self.labels.push(label);
468 }
469 }
470
471 pub fn to_constraints(&self) -> Vec<Constraint<F>> {
474 type FlatTerm<F> = (F, Vec<ProgramCell>);
477
478 fn expand<F: TowerField>(
479 arena: &ConstraintArena<F>,
480 id: ExprId,
481 cache: &mut Vec<Option<Vec<FlatTerm<F>>>>,
482 ) -> Vec<FlatTerm<F>> {
483 if let Some(ref cached) = cache[id.0 as usize] {
484 return cached.clone();
485 }
486
487 let result = match arena.get(id) {
488 ConstraintExpr::Cell(cell) => {
489 vec![(F::ONE, vec![*cell])]
490 }
491 ConstraintExpr::Const(k) => {
492 vec![(*k, vec![])]
493 }
494 ConstraintExpr::Add(a, b) => {
495 let mut terms = expand(arena, *a, cache);
496 terms.extend(expand(arena, *b, cache));
497
498 terms
499 }
500 ConstraintExpr::Mul(a, b) => {
501 let left = expand(arena, *a, cache);
502 let right = expand(arena, *b, cache);
503
504 let mut terms = Vec::with_capacity(left.len() * right.len());
505 for (lc, lp) in &left {
506 for (rc, rp) in &right {
507 let coeff = *lc * *rc;
508 let mut cells = lp.clone();
509
510 cells.extend_from_slice(rp);
511 terms.push((coeff, cells));
512 }
513 }
514
515 terms
516 }
517 ConstraintExpr::Scale(k, a) => {
518 let inner = expand(arena, *a, cache);
519 inner
520 .into_iter()
521 .map(|(c, cells)| (*k * c, cells))
522 .collect()
523 }
524 ConstraintExpr::Sum(children) => {
525 let mut terms = Vec::new();
526 for child in children {
527 terms.extend(expand(arena, *child, cache));
528 }
529
530 terms
531 }
532 };
533
534 cache[id.0 as usize] = Some(result.clone());
535
536 result
537 }
538
539 let n = self.arena.len();
540 let mut cache: Vec<Option<Vec<FlatTerm<F>>>> = vec![None; n];
541
542 self.roots
543 .iter()
544 .map(|root| {
545 let flat_terms = expand(&self.arena, *root, &mut cache);
546 Constraint::new(
547 flat_terms
548 .into_iter()
549 .map(|(coeff, cells)| ConstraintTerm::new(coeff, cells))
550 .collect(),
551 )
552 })
553 .collect()
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use crate::constraint::ConstraintExpr;
561 use crate::constraint::builder::ConstraintSystem;
562 use crate::{Air, Program};
563 use hekate_core::trace::ColumnType;
564 use hekate_math::{Block128, Flat};
565
566 type F = Block128;
567
568 #[derive(Clone)]
569 struct TestFibProgram;
570
571 impl Air<F> for TestFibProgram {
572 fn num_columns(&self) -> usize {
573 3
574 }
575
576 fn column_layout(&self) -> &[ColumnType] {
577 &[ColumnType::B32, ColumnType::B32, ColumnType::Bit]
578 }
579
580 fn constraint_ast(&self) -> ConstraintAst<F> {
581 let cs = ConstraintSystem::<F>::new();
582
583 let [a, b, q] = [cs.col(0), cs.col(1), cs.col(2)];
584 let [na, nb] = [cs.next(0), cs.next(1)];
585
586 cs.constrain(q * (na + b));
587 cs.constrain(q * (nb + a + b));
588
589 cs.build()
590 }
591 }
592
593 impl Program<F> for TestFibProgram {}
594
595 #[test]
596 fn default_constraint_ast_produces_correct_roots() {
597 let program = TestFibProgram;
598 let ast = program.constraint_ast();
599
600 assert_eq!(ast.roots.len(), 2);
602
603 assert!(!ast.arena.is_empty());
608
609 for &root in &ast.roots {
611 let node = ast.arena.get(root);
612 match node {
613 ConstraintExpr::Sum(children) => {
614 assert!(!children.is_empty());
615 }
616 ConstraintExpr::Mul(_, _) => {
617 }
619 _ => panic!("Root should be Sum or Mul, got {:?}", node),
620 }
621 }
622 }
623
624 #[test]
625 fn cell_dedup_works() {
626 let mut arena = ConstraintArena::<F>::new();
627
628 let cell_a = ProgramCell::current(0);
629 let cell_b = ProgramCell::current(0);
630 let cell_c = ProgramCell::next(0);
631
632 let id_a = arena.cell(cell_a);
633 let id_b = arena.cell(cell_b);
634 let id_c = arena.cell(cell_c);
635
636 assert_eq!(id_a, id_b);
638 assert_ne!(id_a, id_c);
640 assert_eq!(arena.len(), 2);
642 }
643
644 #[test]
645 fn dag_sharing_reduces_node_count() {
646 let mut arena = ConstraintArena::<F>::new();
647
648 let a = arena.cell(ProgramCell::current(0));
650 let b = arena.cell(ProgramCell::current(1));
651 let c = arena.cell(ProgramCell::current(2));
652 let theta = arena.sum(vec![a, b, c]);
653
654 let d = arena.cell(ProgramCell::current(3));
656 let expr1 = arena.mul(theta, d);
657 let expr2 = arena.mul(theta, a); let dag_node_count = arena.len();
660 assert_eq!(dag_node_count, 7);
662
663 assert!(dag_node_count < 10);
669
670 match arena.get(expr1) {
672 ConstraintExpr::Mul(lhs, _) => assert_eq!(*lhs, theta),
673 _ => panic!("Expected Mul"),
674 }
675
676 match arena.get(expr2) {
677 ConstraintExpr::Mul(lhs, rhs) => {
678 assert_eq!(*lhs, theta);
679 assert_eq!(*rhs, a);
680 }
681 _ => panic!("Expected Mul"),
682 }
683 }
684
685 #[test]
686 fn default_constraint_ast_node_count_matches_flat() {
687 let program = TestFibProgram;
688 let flat = program.constraints();
689 let ast = program.constraint_ast();
690
691 let mut flat_cell_count = 0;
693 for c in &flat {
694 for t in &c.terms {
695 flat_cell_count += t.poly_ind.len();
696 }
697 }
698
699 let ast_cell_count = ast
701 .arena
702 .nodes
703 .iter()
704 .filter(|n| matches!(n, ConstraintExpr::Cell(_)))
705 .count();
706
707 assert!(ast_cell_count <= flat_cell_count);
708
709 assert_eq!(ast_cell_count, 5);
714 }
715
716 #[test]
717 fn empty_constraint_produces_empty_ast() {
718 #[derive(Clone)]
719 struct EmptyProgram;
720
721 impl Air<F> for EmptyProgram {
722 fn num_columns(&self) -> usize {
723 0
724 }
725
726 fn column_layout(&self) -> &[ColumnType] {
727 &[]
728 }
729
730 fn constraint_ast(&self) -> ConstraintAst<F> {
731 ConstraintSystem::<F>::new().build()
732 }
733 }
734
735 impl Program<F> for EmptyProgram {}
736
737 let ast = EmptyProgram.constraint_ast();
738 assert!(ast.roots.is_empty());
739 assert!(ast.arena.is_empty());
740 }
741
742 #[test]
743 fn single_term_constraint_no_sum_wrapper() {
744 #[derive(Clone)]
745 struct SingleTermProgram;
746
747 impl Air<F> for SingleTermProgram {
748 fn num_columns(&self) -> usize {
749 2
750 }
751
752 fn column_layout(&self) -> &[ColumnType] {
753 &[ColumnType::B128, ColumnType::B128]
754 }
755
756 fn constraint_ast(&self) -> ConstraintAst<F> {
757 let cs = ConstraintSystem::<F>::new();
758 cs.constrain(cs.col(0) * cs.col(1));
759
760 cs.build()
761 }
762 }
763
764 impl Program<F> for SingleTermProgram {}
765
766 let ast = SingleTermProgram.constraint_ast();
767 assert_eq!(ast.roots.len(), 1);
768
769 match ast.arena.get(ast.roots[0]) {
771 ConstraintExpr::Mul(_, _) => {} other => panic!("Expected Mul for single-term, got {:?}", other),
773 }
774 }
775
776 #[test]
781 fn max_degree_fibonacci() {
782 let program = TestFibProgram;
783 let ast = program.constraint_ast();
784
785 let flat = program.constraints();
790 let flat_max = flat
791 .iter()
792 .flat_map(|c| c.terms.iter())
793 .map(|t| t.poly_ind.len())
794 .max()
795 .unwrap_or(0);
796
797 assert_eq!(ast.max_degree(), flat_max);
798 }
799
800 #[test]
801 fn max_degree_empty() {
802 let ast = ConstraintAst::<F> {
803 arena: ConstraintArena::new(),
804 roots: Vec::new(),
805 labels: Vec::new(),
806 };
807 assert_eq!(ast.max_degree(), 0);
808 }
809
810 #[test]
811 fn max_degree_builder_mul_chain() {
812 use crate::constraint::builder::ConstraintSystem;
813
814 let cs = ConstraintSystem::<F>::new();
815 let a = cs.col(0);
816 let b = cs.col(1);
817 let c = cs.col(2);
818
819 cs.constrain(a * b * c);
821 cs.constrain(a + b);
823
824 let ast = cs.build();
825 assert_eq!(ast.max_degree(), 3);
826 }
827
828 #[test]
829 fn to_constraints_roundtrip_fibonacci() {
830 let program = TestFibProgram;
831 let ast = program.constraint_ast();
832 let flat_from_ast = ast.to_constraints();
833 let flat_direct = program.constraints();
834
835 assert_eq!(flat_from_ast.len(), flat_direct.len());
837
838 for (a, d) in flat_from_ast.iter().zip(flat_direct.iter()) {
840 assert_eq!(a.terms.len(), d.terms.len());
841 }
842 }
843
844 #[test]
845 fn to_constraints_from_builder() {
846 use crate::constraint::builder::ConstraintSystem;
847
848 let cs = ConstraintSystem::<F>::new();
849 let a = cs.col(0);
850 let b = cs.col(1);
851
852 cs.constrain(a + b);
854
855 let ast = cs.build();
856 let flat = ast.to_constraints();
857
858 assert_eq!(flat.len(), 1);
859 assert_eq!(flat[0].terms.len(), 2);
860
861 for term in &flat[0].terms {
863 assert_eq!(term.coeff, F::ONE);
864 assert_eq!(term.poly_ind.len(), 1);
865 }
866 }
867
868 #[test]
869 fn evaluate_simple_constraint() {
870 use crate::constraint::builder::ConstraintSystem;
871 use hekate_math::Flat;
872
873 let cs = ConstraintSystem::<F>::new();
874 let a = cs.col(0);
875 let b = cs.col(1);
876
877 cs.constrain(a + b);
879
880 let ast = cs.build();
881
882 let current = vec![
884 Flat::from_raw(F::from(3u128)),
885 Flat::from_raw(F::from(3u128)),
886 ];
887 let next = vec![Flat::from_raw(F::ZERO); 2];
888
889 let evals = ast.evaluate(¤t, &next);
890 assert_eq!(evals.len(), 1);
891 assert_eq!(evals[0], Flat::from_raw(F::ZERO)); let current2 = vec![
895 Flat::from_raw(F::from(3u128)),
896 Flat::from_raw(F::from(5u128)),
897 ];
898 let evals2 = ast.evaluate(¤t2, &next);
899 assert_ne!(evals2[0], Flat::from_raw(F::ZERO));
900 }
901
902 #[test]
903 fn evaluate_into_matches_evaluate() {
904 let cs = ConstraintSystem::<F>::new();
905
906 let a = cs.col(0);
907 let b = cs.col(1);
908 let na = cs.next(0);
909
910 cs.constrain(a + b);
911 cs.constrain(a * b);
912 cs.constrain(na + a);
913
914 let ast = cs.build();
915 let zero = Flat::from_raw(F::ZERO);
916
917 let current = vec![
918 Flat::from_raw(F::from(3u128)),
919 Flat::from_raw(F::from(5u128)),
920 ];
921 let next = vec![Flat::from_raw(F::from(7u128)), zero];
922
923 let expected = ast.evaluate(¤t, &next);
924
925 let mut buf = Vec::new();
926 ast.evaluate_into(¤t, &next, &mut buf);
927
928 for (i, root) in ast.roots.iter().enumerate() {
929 assert_eq!(buf[root.0 as usize], expected[i]);
930 }
931 }
932}