Skip to main content

hekate_program/constraint/
builder.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate project.
3// Copyright (C) 2026 Andrei Kochergin <andrei@oumuamua.dev>
4// Copyright (C) 2026 Oumuamua Labs <info@oumuamua.dev>. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Constraint Builder DSL.
19//!
20//! Provides `ConstraintSystem` and `Expr` for writing
21//! algebraic constraints with operator overloading.
22//!
23//! # Example
24//!
25//! ```ignore
26//! let cs = ConstraintSystem::<F>::new();
27//! let [a, b, q] = [cs.col(0), cs.col(1), cs.col(2)];
28//! let [na, nb] = [cs.next(0), cs.next(1)];
29//!
30//! cs.constrain(q * (na + b));       // next_a = b
31//! cs.constrain(q * (nb + a + b));   // next_b = a + b
32//!
33//! let ast = cs.build();
34//! ```
35//!
36//! `Expr` is `Copy` — same expression reused in
37//! multiple constraints shares the underlying DAG node.
38//! Cell references are auto-deduplicated.
39//!
40//! `Sub` delegates to `Add` because in GF(2^k),
41//! subtraction is addition (XOR). This is correct
42//! for binary tower fields only.
43
44use crate::ProgramCell;
45use crate::constraint::{ConstraintArena, ConstraintAst, ExprId};
46use alloc::vec::Vec;
47use core::cell::RefCell;
48use core::ops::{Add, Mul, Sub};
49use hekate_math::TowerField;
50
51// =================================================================
52// CONSTRAINT SYSTEM
53// =================================================================
54
55/// Builder context for algebraic constraints.
56pub struct ConstraintSystem<F: TowerField> {
57    inner: RefCell<Inner<F>>,
58}
59
60struct Inner<F: TowerField> {
61    arena: ConstraintArena<F>,
62    roots: Vec<ExprId>,
63    labels: Vec<Option<&'static str>>,
64}
65
66impl<F: TowerField> ConstraintSystem<F> {
67    /// Create a new constraint system.
68    pub fn new() -> Self {
69        Self {
70            inner: RefCell::new(Inner {
71                arena: ConstraintArena::new(),
72                roots: Vec::new(),
73                labels: Vec::new(),
74            }),
75        }
76    }
77
78    /// Create a builder from an existing AST.
79    pub fn from_ast(ast: ConstraintAst<F>) -> Self {
80        Self {
81            inner: RefCell::new(Inner {
82                arena: ast.arena,
83                roots: ast.roots,
84                labels: ast.labels,
85            }),
86        }
87    }
88
89    // ===========================================================
90    // Cell References
91    // ===========================================================
92
93    /// Reference to a column in the current row.
94    pub fn col(&self, idx: usize) -> Expr<'_, F> {
95        let id = self
96            .inner
97            .borrow_mut()
98            .arena
99            .cell(ProgramCell::current(idx));
100        Expr { id, cs: self }
101    }
102
103    /// Reference to a column in the next row.
104    pub fn next(&self, idx: usize) -> Expr<'_, F> {
105        let id = self.inner.borrow_mut().arena.cell(ProgramCell::next(idx));
106        Expr { id, cs: self }
107    }
108
109    // ===========================================================
110    // Constants
111    // ===========================================================
112
113    /// Field constant.
114    pub fn constant(&self, val: F) -> Expr<'_, F> {
115        let id = self.inner.borrow_mut().arena.constant(val);
116        Expr { id, cs: self }
117    }
118
119    /// The multiplicative identity.
120    pub fn one(&self) -> Expr<'_, F> {
121        self.constant(F::ONE)
122    }
123
124    // ===========================================================
125    // Arithmetic
126    // ===========================================================
127
128    /// Scalar multiplication:
129    /// `coeff * expr`.
130    ///
131    /// Use this for powers of 2, coefficients, etc.
132    /// Orphan rules prevent implementing `F * Expr`
133    /// via operator overloading.
134    pub fn scale(&self, coeff: F, expr: Expr<'_, F>) -> Expr<'_, F> {
135        let id = self.inner.borrow_mut().arena.scale(coeff, expr.id);
136        Expr { id, cs: self }
137    }
138
139    /// N-ary sum. More efficient than chaining binary `+`
140    /// for linear combinations (avoids deep Add chains).
141    pub fn sum(&self, children: &[Expr<'_, F>]) -> Expr<'_, F> {
142        let ids: Vec<ExprId> = children.iter().map(|e| e.id).collect();
143        let id = self.inner.borrow_mut().arena.sum(ids);
144
145        Expr { id, cs: self }
146    }
147
148    // ===========================================================
149    // Constraint Registration
150    // ===========================================================
151
152    /// Register a constraint:
153    /// `expr = 0`.
154    ///
155    /// The expression must evaluate to zero
156    /// for every valid row in the execution trace.
157    pub fn constrain(&self, expr: Expr<'_, F>) {
158        let mut inner = self.inner.borrow_mut();
159        inner.roots.push(expr.id);
160        inner.labels.push(None);
161    }
162
163    pub fn constrain_named(&self, label: &'static str, expr: Expr<'_, F>) {
164        let mut inner = self.inner.borrow_mut();
165        inner.roots.push(expr.id);
166        inner.labels.push(Some(label));
167    }
168
169    // ===========================================================
170    // Built-in Gadgets
171    // ===========================================================
172
173    /// Assert that `s` is boolean:
174    /// `s * (s + 1) = 0`.
175    ///
176    /// In GF(2^k), `s + 1` equals `s - 1`.
177    /// Enforces `s ∈ {0, 1}`.
178    pub fn assert_boolean(&self, s: Expr<'_, F>) {
179        // s * s + s = 0  (expanded form of s*(s+1)=0)
180        let sq = s * s;
181        let expr = sq + s;
182
183        self.constrain_named("boolean", expr);
184    }
185
186    /// Assert that `body = 0` whenever `sel = 1`.
187    ///
188    /// Registers `sel * body = 0`.
189    /// When `sel = 0`, the constraint
190    /// is trivially satisfied.
191    pub fn assert_zero_when(&self, sel: Expr<'_, F>, body: Expr<'_, F>) {
192        self.constrain_named("zero_when", sel * body);
193    }
194
195    /// Assert that exactly
196    /// one selector is active.
197    /// Enforces:
198    /// sum(selectors) = 1.
199    ///
200    /// In GF(2^k):
201    /// sum(s_i) + 1 = 0.
202    pub fn assert_one_hot(&self, selectors: &[Expr<'_, F>]) {
203        let s = self.sum(selectors);
204        let one = self.one();
205
206        self.constrain_named("one_hot", s + one);
207    }
208
209    /// Emit the `s_send · s_recv = 0` mutex root
210    /// plus boolean checks on both selectors.
211    pub fn assert_paired_bus_mutex(&self, s_send: usize, s_recv: usize) {
212        let send = self.col(s_send);
213        let recv = self.col(s_recv);
214
215        self.assert_boolean(send);
216        self.assert_boolean(recv);
217
218        self.constrain_named("paired_bus_mutex", send * recv);
219    }
220
221    // ===========================================================
222    // Compile
223    // ===========================================================
224
225    /// Consume the builder and
226    /// produce a `ConstraintAst`.
227    ///
228    /// All `Expr` handles must be
229    /// dropped before calling this
230    /// (enforced by the borrow checker,
231    /// `build` takes `self`).
232    pub fn build(self) -> ConstraintAst<F> {
233        let inner = self.inner.into_inner();
234        ConstraintAst {
235            arena: inner.arena,
236            roots: inner.roots,
237            labels: inner.labels,
238        }
239    }
240}
241
242impl<F: TowerField> Default for ConstraintSystem<F> {
243    fn default() -> Self {
244        Self::new()
245    }
246}
247
248// =================================================================
249// EXPRESSION HANDLE
250// =================================================================
251
252/// Lightweight handle to a DAG node in a `ConstraintSystem`.
253///
254/// Supports `+`, `*`, `-` via operator overloading.
255/// `Sub` delegates to `Add` (correct for GF(2^k) only).
256#[derive(Clone, Copy)]
257pub struct Expr<'a, F: TowerField> {
258    pub(crate) id: ExprId,
259    pub(crate) cs: &'a ConstraintSystem<F>,
260}
261
262// a + b
263impl<'a, F: TowerField> Add for Expr<'a, F> {
264    type Output = Expr<'a, F>;
265
266    fn add(self, rhs: Self) -> Self::Output {
267        let id = self.cs.inner.borrow_mut().arena.add(self.id, rhs.id);
268        Expr { id, cs: self.cs }
269    }
270}
271
272// a * b
273impl<'a, F: TowerField> Mul for Expr<'a, F> {
274    type Output = Expr<'a, F>;
275
276    fn mul(self, rhs: Self) -> Self::Output {
277        let id = self.cs.inner.borrow_mut().arena.mul(self.id, rhs.id);
278        Expr { id, cs: self.cs }
279    }
280}
281
282// a - b  (in GF(2^k), subtraction = addition)
283impl<'a, F: TowerField> Sub for Expr<'a, F> {
284    type Output = Expr<'a, F>;
285
286    fn sub(self, rhs: Self) -> Self::Output {
287        // In characteristic 2:
288        // -1 = 1, so a - b = a + b = a XOR b.
289        // This is correct for binary tower fields only.
290        let id = self.cs.inner.borrow_mut().arena.add(self.id, rhs.id);
291        Expr { id, cs: self.cs }
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use crate::constraint::ConstraintExpr;
299    use hekate_math::Block128;
300
301    type F = Block128;
302
303    #[test]
304    fn basic_fibonacci_builder() {
305        let cs = ConstraintSystem::<F>::new();
306
307        let a = cs.col(0);
308        let b = cs.col(1);
309        let q = cs.col(2);
310        let na = cs.next(0);
311        let nb = cs.next(1);
312
313        // q * (next_a + b) = 0
314        cs.constrain(q * (na + b));
315        // q * (next_b + a + b) = 0
316        cs.constrain(q * (nb + a + b));
317
318        let ast = cs.build();
319
320        assert_eq!(ast.roots.len(), 2);
321        assert!(!ast.arena.is_empty());
322
323        // Both roots should be Mul (selector * sum)
324        for &root in &ast.roots {
325            match ast.arena.get(root) {
326                ConstraintExpr::Mul(_, _) => {}
327                other => panic!("Expected Mul root, got {:?}", other),
328            }
329        }
330    }
331
332    #[test]
333    fn cell_dedup_through_builder() {
334        let cs = ConstraintSystem::<F>::new();
335
336        let a1 = cs.col(0);
337        let a2 = cs.col(0);
338        let b = cs.col(1);
339
340        // Same column → same ExprId
341        assert_eq!(a1.id, a2.id);
342        // Different column → different ExprId
343        assert_ne!(a1.id, b.id);
344    }
345
346    #[test]
347    fn sub_equals_add_in_char2() {
348        let cs = ConstraintSystem::<F>::new();
349
350        let a = cs.col(0);
351        let b = cs.col(1);
352
353        let sum = a + b;
354        let diff = a - b;
355
356        // In GF(2^k), a + b == a - b
357        // Both should produce Add nodes with same children
358        let ast_sum = cs.inner.borrow();
359        match (ast_sum.arena.get(sum.id), ast_sum.arena.get(diff.id)) {
360            (ConstraintExpr::Add(la, ra), ConstraintExpr::Add(lb, rb)) => {
361                assert_eq!(la, lb);
362                assert_eq!(ra, rb);
363            }
364            _ => panic!("Expected Add nodes for both + and -"),
365        }
366    }
367
368    #[test]
369    fn assert_boolean_structure() {
370        let cs = ConstraintSystem::<F>::new();
371        let s = cs.col(5);
372
373        cs.assert_boolean(s);
374
375        let ast = cs.build();
376        assert_eq!(ast.roots.len(), 1);
377
378        // Root should be Add(Mul(s,s), s) = s² + s
379        match ast.arena.get(ast.roots[0]) {
380            ConstraintExpr::Add(lhs, rhs) => {
381                // lhs = s * s
382                match ast.arena.get(*lhs) {
383                    ConstraintExpr::Mul(a, b) => {
384                        assert_eq!(a, b); // same cell
385                    }
386                    other => panic!("Expected Mul for s², got {:?}", other),
387                }
388
389                // rhs = s
390                match ast.arena.get(*rhs) {
391                    ConstraintExpr::Cell(cell) => {
392                        assert_eq!(cell.col_idx, 5);
393                        assert!(!cell.next_row);
394                    }
395                    other => panic!("Expected Cell for s, got {:?}", other),
396                }
397            }
398            other => panic!("Expected Add for s²+s, got {:?}", other),
399        }
400    }
401
402    #[test]
403    fn assert_zero_when_structure() {
404        let cs = ConstraintSystem::<F>::new();
405        let sel = cs.col(0);
406        let body = cs.col(1) + cs.col(2);
407
408        cs.assert_zero_when(sel, body);
409
410        let ast = cs.build();
411        assert_eq!(ast.roots.len(), 1);
412
413        // Root = Mul(sel, Add(col1, col2))
414        match ast.arena.get(ast.roots[0]) {
415            ConstraintExpr::Mul(_, _) => {}
416            other => panic!("Expected Mul, got {:?}", other),
417        }
418    }
419
420    #[test]
421    fn scale_produces_scale_node() {
422        let cs = ConstraintSystem::<F>::new();
423        let a = cs.col(0);
424        let scaled = cs.scale(F::from(8u128), a);
425
426        // Save ids before build() consumes cs
427        let a_id = a.id;
428        let scaled_id = scaled.id;
429
430        let ast = cs.build();
431        match ast.arena.get(scaled_id) {
432            ConstraintExpr::Scale(coeff, inner) => {
433                assert_eq!(*coeff, F::from(8u128));
434                assert_eq!(*inner, a_id);
435            }
436            other => panic!("Expected Scale, got {:?}", other),
437        }
438    }
439
440    #[test]
441    fn sum_produces_sum_node() {
442        let cs = ConstraintSystem::<F>::new();
443        let a = cs.col(0);
444        let b = cs.col(1);
445        let c = cs.col(2);
446        let s = cs.sum(&[a, b, c]);
447
448        // Save ids before build() consumes cs
449        let (a_id, b_id, c_id) = (a.id, b.id, c.id);
450
451        let s_id = s.id;
452        let ast = cs.build();
453
454        match ast.arena.get(s_id) {
455            ConstraintExpr::Sum(children) => {
456                assert_eq!(children.len(), 3);
457                assert_eq!(children[0], a_id);
458                assert_eq!(children[1], b_id);
459                assert_eq!(children[2], c_id);
460            }
461            other => panic!("Expected Sum, got {:?}", other),
462        }
463    }
464
465    #[test]
466    fn dag_sharing_via_expr_reuse() {
467        let cs = ConstraintSystem::<F>::new();
468
469        let a = cs.col(0);
470        let b = cs.col(1);
471        let c = cs.col(2);
472
473        // Shared sub-expression
474        let theta = cs.sum(&[a, b, c]);
475
476        // Use theta in two constraints
477        let d = cs.col(3);
478        cs.constrain(theta * d);
479        cs.constrain(theta * a);
480
481        let ast = cs.build();
482        assert_eq!(ast.roots.len(), 2);
483
484        // Both roots reference the same theta node
485        match (ast.arena.get(ast.roots[0]), ast.arena.get(ast.roots[1])) {
486            (ConstraintExpr::Mul(lhs0, _), ConstraintExpr::Mul(lhs1, _)) => {
487                assert_eq!(lhs0, lhs1); // same theta ExprId
488            }
489            _ => panic!("Expected Mul roots"),
490        }
491    }
492
493    #[test]
494    fn empty_system_produces_empty_ast() {
495        let cs = ConstraintSystem::<F>::new();
496        let ast = cs.build();
497        assert!(ast.roots.is_empty());
498        assert!(ast.arena.is_empty());
499    }
500
501    #[test]
502    fn builder_matches_manual_structure() {
503        // Build via builder: q * (next_a + b)
504        let cs = ConstraintSystem::<F>::new();
505        let _a = cs.col(0);
506        let b = cs.col(1);
507        let q = cs.col(2);
508        let na = cs.next(0);
509
510        cs.constrain(q * (na + b));
511
512        let ast = cs.build();
513
514        // Verify: Mul(Cell(2,curr), Add(Cell(0,next), Cell(1,curr)))
515        assert_eq!(ast.roots.len(), 1);
516
517        match ast.arena.get(ast.roots[0]) {
518            ConstraintExpr::Mul(lhs, rhs) => {
519                match ast.arena.get(*lhs) {
520                    ConstraintExpr::Cell(cell) => {
521                        assert_eq!(cell.col_idx, 2);
522                        assert!(!cell.next_row);
523                    }
524                    other => panic!("Expected Cell for q, got {:?}", other),
525                }
526                match ast.arena.get(*rhs) {
527                    ConstraintExpr::Add(a, b) => {
528                        match ast.arena.get(*a) {
529                            ConstraintExpr::Cell(cell) => {
530                                assert_eq!(cell.col_idx, 0);
531                                assert!(cell.next_row);
532                            }
533                            other => panic!("Expected Cell for next_a, got {:?}", other),
534                        }
535                        match ast.arena.get(*b) {
536                            ConstraintExpr::Cell(cell) => {
537                                assert_eq!(cell.col_idx, 1);
538                                assert!(!cell.next_row);
539                            }
540                            other => panic!("Expected Cell for b, got {:?}", other),
541                        }
542                    }
543                    other => panic!("Expected Add, got {:?}", other),
544                }
545            }
546            other => panic!("Expected Mul root, got {:?}", other),
547        }
548    }
549
550    #[test]
551    fn labels_round_trip_through_build() {
552        let cs = ConstraintSystem::<F>::new();
553        let a = cs.col(0);
554        let b = cs.col(1);
555
556        cs.constrain(a + b);
557        cs.constrain_named("transition", a * b);
558        cs.assert_boolean(a);
559
560        let ast = cs.build();
561
562        assert_eq!(ast.roots.len(), 3);
563        assert_eq!(ast.labels.len(), 3);
564        assert_eq!(ast.labels[0], None);
565        assert_eq!(ast.labels[1], Some("transition"));
566        assert_eq!(ast.labels[2], Some("boolean"));
567    }
568
569    #[test]
570    fn labels_preserved_through_merge() {
571        let cs1 = ConstraintSystem::<F>::new();
572        cs1.constrain_named("first", cs1.col(0));
573
574        let mut ast1 = cs1.build();
575
576        let cs2 = ConstraintSystem::<F>::new();
577
578        cs2.constrain(cs2.col(0));
579        cs2.constrain_named("second", cs2.col(1));
580
581        let ast2 = cs2.build();
582
583        ast1.merge(ast2);
584
585        assert_eq!(ast1.roots.len(), 3);
586        assert_eq!(ast1.labels.len(), 3);
587        assert_eq!(ast1.labels[0], Some("first"));
588        assert_eq!(ast1.labels[1], None);
589        assert_eq!(ast1.labels[2], Some("second"));
590    }
591
592    #[test]
593    fn labels_preserved_through_from_ast() {
594        let cs = ConstraintSystem::<F>::new();
595        cs.constrain_named("original", cs.col(0));
596
597        let ast = cs.build();
598
599        let cs2 = ConstraintSystem::from_ast(ast);
600        cs2.constrain_named("added", cs2.col(1));
601
602        let ast2 = cs2.build();
603
604        assert_eq!(ast2.labels.len(), 2);
605        assert_eq!(ast2.labels[0], Some("original"));
606        assert_eq!(ast2.labels[1], Some("added"));
607    }
608
609    #[test]
610    fn builtin_gadgets_have_labels() {
611        let cs = ConstraintSystem::<F>::new();
612
613        let a = cs.col(0);
614        let b = cs.col(1);
615
616        cs.assert_boolean(a);
617        cs.assert_zero_when(a, b);
618        cs.assert_one_hot(&[a, b]);
619
620        let ast = cs.build();
621
622        assert_eq!(ast.labels.len(), 3);
623        assert_eq!(ast.labels[0], Some("boolean"));
624        assert_eq!(ast.labels[1], Some("zero_when"));
625        assert_eq!(ast.labels[2], Some("one_hot"));
626    }
627}