1use crate::ProgramCell;
45use crate::constraint::{ConstraintArena, ConstraintAst, ExprId};
46use alloc::vec::Vec;
47use core::cell::RefCell;
48use core::ops::{Add, Mul, Sub};
49use hekate_math::TowerField;
50
51pub struct ConstraintSystem<F: TowerField> {
57 inner: RefCell<Inner<F>>,
58}
59
60struct Inner<F: TowerField> {
61 arena: ConstraintArena<F>,
62 roots: Vec<ExprId>,
63 labels: Vec<Option<&'static str>>,
64}
65
66impl<F: TowerField> ConstraintSystem<F> {
67 pub fn new() -> Self {
69 Self {
70 inner: RefCell::new(Inner {
71 arena: ConstraintArena::new(),
72 roots: Vec::new(),
73 labels: Vec::new(),
74 }),
75 }
76 }
77
78 pub fn from_ast(ast: ConstraintAst<F>) -> Self {
80 Self {
81 inner: RefCell::new(Inner {
82 arena: ast.arena,
83 roots: ast.roots,
84 labels: ast.labels,
85 }),
86 }
87 }
88
89 pub fn col(&self, idx: usize) -> Expr<'_, F> {
95 let id = self
96 .inner
97 .borrow_mut()
98 .arena
99 .cell(ProgramCell::current(idx));
100 Expr { id, cs: self }
101 }
102
103 pub fn next(&self, idx: usize) -> Expr<'_, F> {
105 let id = self.inner.borrow_mut().arena.cell(ProgramCell::next(idx));
106 Expr { id, cs: self }
107 }
108
109 pub fn constant(&self, val: F) -> Expr<'_, F> {
115 let id = self.inner.borrow_mut().arena.constant(val);
116 Expr { id, cs: self }
117 }
118
119 pub fn one(&self) -> Expr<'_, F> {
121 self.constant(F::ONE)
122 }
123
124 pub fn scale(&self, coeff: F, expr: Expr<'_, F>) -> Expr<'_, F> {
135 let id = self.inner.borrow_mut().arena.scale(coeff, expr.id);
136 Expr { id, cs: self }
137 }
138
139 pub fn sum(&self, children: &[Expr<'_, F>]) -> Expr<'_, F> {
142 let ids: Vec<ExprId> = children.iter().map(|e| e.id).collect();
143 let id = self.inner.borrow_mut().arena.sum(ids);
144
145 Expr { id, cs: self }
146 }
147
148 pub fn constrain(&self, expr: Expr<'_, F>) {
158 let mut inner = self.inner.borrow_mut();
159 inner.roots.push(expr.id);
160 inner.labels.push(None);
161 }
162
163 pub fn constrain_named(&self, label: &'static str, expr: Expr<'_, F>) {
164 let mut inner = self.inner.borrow_mut();
165 inner.roots.push(expr.id);
166 inner.labels.push(Some(label));
167 }
168
169 pub fn assert_boolean(&self, s: Expr<'_, F>) {
179 let sq = s * s;
181 let expr = sq + s;
182
183 self.constrain_named("boolean", expr);
184 }
185
186 pub fn assert_zero_when(&self, sel: Expr<'_, F>, body: Expr<'_, F>) {
192 self.constrain_named("zero_when", sel * body);
193 }
194
195 pub fn assert_one_hot(&self, selectors: &[Expr<'_, F>]) {
203 let s = self.sum(selectors);
204 let one = self.one();
205
206 self.constrain_named("one_hot", s + one);
207 }
208
209 pub fn assert_paired_bus_mutex(&self, s_send: usize, s_recv: usize) {
212 let send = self.col(s_send);
213 let recv = self.col(s_recv);
214
215 self.assert_boolean(send);
216 self.assert_boolean(recv);
217
218 self.constrain_named("paired_bus_mutex", send * recv);
219 }
220
221 pub fn build(self) -> ConstraintAst<F> {
233 let inner = self.inner.into_inner();
234 ConstraintAst {
235 arena: inner.arena,
236 roots: inner.roots,
237 labels: inner.labels,
238 }
239 }
240}
241
242impl<F: TowerField> Default for ConstraintSystem<F> {
243 fn default() -> Self {
244 Self::new()
245 }
246}
247
248#[derive(Clone, Copy)]
257pub struct Expr<'a, F: TowerField> {
258 pub(crate) id: ExprId,
259 pub(crate) cs: &'a ConstraintSystem<F>,
260}
261
262impl<'a, F: TowerField> Add for Expr<'a, F> {
264 type Output = Expr<'a, F>;
265
266 fn add(self, rhs: Self) -> Self::Output {
267 let id = self.cs.inner.borrow_mut().arena.add(self.id, rhs.id);
268 Expr { id, cs: self.cs }
269 }
270}
271
272impl<'a, F: TowerField> Mul for Expr<'a, F> {
274 type Output = Expr<'a, F>;
275
276 fn mul(self, rhs: Self) -> Self::Output {
277 let id = self.cs.inner.borrow_mut().arena.mul(self.id, rhs.id);
278 Expr { id, cs: self.cs }
279 }
280}
281
282impl<'a, F: TowerField> Sub for Expr<'a, F> {
284 type Output = Expr<'a, F>;
285
286 fn sub(self, rhs: Self) -> Self::Output {
287 let id = self.cs.inner.borrow_mut().arena.add(self.id, rhs.id);
291 Expr { id, cs: self.cs }
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use crate::constraint::ConstraintExpr;
299 use hekate_math::Block128;
300
301 type F = Block128;
302
303 #[test]
304 fn basic_fibonacci_builder() {
305 let cs = ConstraintSystem::<F>::new();
306
307 let a = cs.col(0);
308 let b = cs.col(1);
309 let q = cs.col(2);
310 let na = cs.next(0);
311 let nb = cs.next(1);
312
313 cs.constrain(q * (na + b));
315 cs.constrain(q * (nb + a + b));
317
318 let ast = cs.build();
319
320 assert_eq!(ast.roots.len(), 2);
321 assert!(!ast.arena.is_empty());
322
323 for &root in &ast.roots {
325 match ast.arena.get(root) {
326 ConstraintExpr::Mul(_, _) => {}
327 other => panic!("Expected Mul root, got {:?}", other),
328 }
329 }
330 }
331
332 #[test]
333 fn cell_dedup_through_builder() {
334 let cs = ConstraintSystem::<F>::new();
335
336 let a1 = cs.col(0);
337 let a2 = cs.col(0);
338 let b = cs.col(1);
339
340 assert_eq!(a1.id, a2.id);
342 assert_ne!(a1.id, b.id);
344 }
345
346 #[test]
347 fn sub_equals_add_in_char2() {
348 let cs = ConstraintSystem::<F>::new();
349
350 let a = cs.col(0);
351 let b = cs.col(1);
352
353 let sum = a + b;
354 let diff = a - b;
355
356 let ast_sum = cs.inner.borrow();
359 match (ast_sum.arena.get(sum.id), ast_sum.arena.get(diff.id)) {
360 (ConstraintExpr::Add(la, ra), ConstraintExpr::Add(lb, rb)) => {
361 assert_eq!(la, lb);
362 assert_eq!(ra, rb);
363 }
364 _ => panic!("Expected Add nodes for both + and -"),
365 }
366 }
367
368 #[test]
369 fn assert_boolean_structure() {
370 let cs = ConstraintSystem::<F>::new();
371 let s = cs.col(5);
372
373 cs.assert_boolean(s);
374
375 let ast = cs.build();
376 assert_eq!(ast.roots.len(), 1);
377
378 match ast.arena.get(ast.roots[0]) {
380 ConstraintExpr::Add(lhs, rhs) => {
381 match ast.arena.get(*lhs) {
383 ConstraintExpr::Mul(a, b) => {
384 assert_eq!(a, b); }
386 other => panic!("Expected Mul for s², got {:?}", other),
387 }
388
389 match ast.arena.get(*rhs) {
391 ConstraintExpr::Cell(cell) => {
392 assert_eq!(cell.col_idx, 5);
393 assert!(!cell.next_row);
394 }
395 other => panic!("Expected Cell for s, got {:?}", other),
396 }
397 }
398 other => panic!("Expected Add for s²+s, got {:?}", other),
399 }
400 }
401
402 #[test]
403 fn assert_zero_when_structure() {
404 let cs = ConstraintSystem::<F>::new();
405 let sel = cs.col(0);
406 let body = cs.col(1) + cs.col(2);
407
408 cs.assert_zero_when(sel, body);
409
410 let ast = cs.build();
411 assert_eq!(ast.roots.len(), 1);
412
413 match ast.arena.get(ast.roots[0]) {
415 ConstraintExpr::Mul(_, _) => {}
416 other => panic!("Expected Mul, got {:?}", other),
417 }
418 }
419
420 #[test]
421 fn scale_produces_scale_node() {
422 let cs = ConstraintSystem::<F>::new();
423 let a = cs.col(0);
424 let scaled = cs.scale(F::from(8u128), a);
425
426 let a_id = a.id;
428 let scaled_id = scaled.id;
429
430 let ast = cs.build();
431 match ast.arena.get(scaled_id) {
432 ConstraintExpr::Scale(coeff, inner) => {
433 assert_eq!(*coeff, F::from(8u128));
434 assert_eq!(*inner, a_id);
435 }
436 other => panic!("Expected Scale, got {:?}", other),
437 }
438 }
439
440 #[test]
441 fn sum_produces_sum_node() {
442 let cs = ConstraintSystem::<F>::new();
443 let a = cs.col(0);
444 let b = cs.col(1);
445 let c = cs.col(2);
446 let s = cs.sum(&[a, b, c]);
447
448 let (a_id, b_id, c_id) = (a.id, b.id, c.id);
450
451 let s_id = s.id;
452 let ast = cs.build();
453
454 match ast.arena.get(s_id) {
455 ConstraintExpr::Sum(children) => {
456 assert_eq!(children.len(), 3);
457 assert_eq!(children[0], a_id);
458 assert_eq!(children[1], b_id);
459 assert_eq!(children[2], c_id);
460 }
461 other => panic!("Expected Sum, got {:?}", other),
462 }
463 }
464
465 #[test]
466 fn dag_sharing_via_expr_reuse() {
467 let cs = ConstraintSystem::<F>::new();
468
469 let a = cs.col(0);
470 let b = cs.col(1);
471 let c = cs.col(2);
472
473 let theta = cs.sum(&[a, b, c]);
475
476 let d = cs.col(3);
478 cs.constrain(theta * d);
479 cs.constrain(theta * a);
480
481 let ast = cs.build();
482 assert_eq!(ast.roots.len(), 2);
483
484 match (ast.arena.get(ast.roots[0]), ast.arena.get(ast.roots[1])) {
486 (ConstraintExpr::Mul(lhs0, _), ConstraintExpr::Mul(lhs1, _)) => {
487 assert_eq!(lhs0, lhs1); }
489 _ => panic!("Expected Mul roots"),
490 }
491 }
492
493 #[test]
494 fn empty_system_produces_empty_ast() {
495 let cs = ConstraintSystem::<F>::new();
496 let ast = cs.build();
497 assert!(ast.roots.is_empty());
498 assert!(ast.arena.is_empty());
499 }
500
501 #[test]
502 fn builder_matches_manual_structure() {
503 let cs = ConstraintSystem::<F>::new();
505 let _a = cs.col(0);
506 let b = cs.col(1);
507 let q = cs.col(2);
508 let na = cs.next(0);
509
510 cs.constrain(q * (na + b));
511
512 let ast = cs.build();
513
514 assert_eq!(ast.roots.len(), 1);
516
517 match ast.arena.get(ast.roots[0]) {
518 ConstraintExpr::Mul(lhs, rhs) => {
519 match ast.arena.get(*lhs) {
520 ConstraintExpr::Cell(cell) => {
521 assert_eq!(cell.col_idx, 2);
522 assert!(!cell.next_row);
523 }
524 other => panic!("Expected Cell for q, got {:?}", other),
525 }
526 match ast.arena.get(*rhs) {
527 ConstraintExpr::Add(a, b) => {
528 match ast.arena.get(*a) {
529 ConstraintExpr::Cell(cell) => {
530 assert_eq!(cell.col_idx, 0);
531 assert!(cell.next_row);
532 }
533 other => panic!("Expected Cell for next_a, got {:?}", other),
534 }
535 match ast.arena.get(*b) {
536 ConstraintExpr::Cell(cell) => {
537 assert_eq!(cell.col_idx, 1);
538 assert!(!cell.next_row);
539 }
540 other => panic!("Expected Cell for b, got {:?}", other),
541 }
542 }
543 other => panic!("Expected Add, got {:?}", other),
544 }
545 }
546 other => panic!("Expected Mul root, got {:?}", other),
547 }
548 }
549
550 #[test]
551 fn labels_round_trip_through_build() {
552 let cs = ConstraintSystem::<F>::new();
553 let a = cs.col(0);
554 let b = cs.col(1);
555
556 cs.constrain(a + b);
557 cs.constrain_named("transition", a * b);
558 cs.assert_boolean(a);
559
560 let ast = cs.build();
561
562 assert_eq!(ast.roots.len(), 3);
563 assert_eq!(ast.labels.len(), 3);
564 assert_eq!(ast.labels[0], None);
565 assert_eq!(ast.labels[1], Some("transition"));
566 assert_eq!(ast.labels[2], Some("boolean"));
567 }
568
569 #[test]
570 fn labels_preserved_through_merge() {
571 let cs1 = ConstraintSystem::<F>::new();
572 cs1.constrain_named("first", cs1.col(0));
573
574 let mut ast1 = cs1.build();
575
576 let cs2 = ConstraintSystem::<F>::new();
577
578 cs2.constrain(cs2.col(0));
579 cs2.constrain_named("second", cs2.col(1));
580
581 let ast2 = cs2.build();
582
583 ast1.merge(ast2);
584
585 assert_eq!(ast1.roots.len(), 3);
586 assert_eq!(ast1.labels.len(), 3);
587 assert_eq!(ast1.labels[0], Some("first"));
588 assert_eq!(ast1.labels[1], None);
589 assert_eq!(ast1.labels[2], Some("second"));
590 }
591
592 #[test]
593 fn labels_preserved_through_from_ast() {
594 let cs = ConstraintSystem::<F>::new();
595 cs.constrain_named("original", cs.col(0));
596
597 let ast = cs.build();
598
599 let cs2 = ConstraintSystem::from_ast(ast);
600 cs2.constrain_named("added", cs2.col(1));
601
602 let ast2 = cs2.build();
603
604 assert_eq!(ast2.labels.len(), 2);
605 assert_eq!(ast2.labels[0], Some("original"));
606 assert_eq!(ast2.labels[1], Some("added"));
607 }
608
609 #[test]
610 fn builtin_gadgets_have_labels() {
611 let cs = ConstraintSystem::<F>::new();
612
613 let a = cs.col(0);
614 let b = cs.col(1);
615
616 cs.assert_boolean(a);
617 cs.assert_zero_when(a, b);
618 cs.assert_one_hot(&[a, b]);
619
620 let ast = cs.build();
621
622 assert_eq!(ast.labels.len(), 3);
623 assert_eq!(ast.labels[0], Some("boolean"));
624 assert_eq!(ast.labels[1], Some("zero_when"));
625 assert_eq!(ast.labels[2], Some("one_hot"));
626 }
627}