Skip to main content

hekate_program/constraint/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate project.
3// Copyright (C) 2026 Andrei Kochergin <andrei@oumuamua.dev>
4// Copyright (C) 2026 Oumuamua Labs <info@oumuamua.dev>. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! SYMBOLIC CONSTRAINTS DEFINITION
19
20use crate::ProgramCell;
21use alloc::vec;
22use alloc::vec::Vec;
23use hashbrown::HashMap;
24use hekate_core::errors::Error;
25use hekate_math::{Flat, HardwareField, TowerField};
26
27pub mod builder;
28
29/// Represents a single term in a polynomial constraint.
30/// Form: `coeff * product(cells)`
31/// Example: `5 * x_curr * y_next`
32#[derive(Clone, Debug)]
33pub struct ConstraintTerm<F> {
34    pub coeff: F,
35    pub poly_ind: Vec<ProgramCell>, // Multiplicands (variables)
36}
37
38impl<F: TowerField> ConstraintTerm<F> {
39    pub fn new(coeff: F, cells: Vec<ProgramCell>) -> Self {
40        Self {
41            coeff,
42            poly_ind: cells,
43        }
44    }
45}
46
47/// Represents a full constraint equation:
48/// sum(terms) == 0.
49#[derive(Clone, Debug)]
50pub struct Constraint<F> {
51    pub terms: Vec<ConstraintTerm<F>>,
52}
53
54impl<F: TowerField> Constraint<F> {
55    pub fn new(terms: Vec<ConstraintTerm<F>>) -> Self {
56        Self { terms }
57    }
58}
59
60/// Source of the value a boundary constraint
61/// pins `Trace(row_idx, col_idx)` to.
62#[derive(Clone, Debug, PartialEq, Eq)]
63pub enum BoundaryTarget<F> {
64    /// `instance.public_inputs[idx]`. Main-program
65    /// use only; `ChipletDef::from_air` rejects
66    /// this variant on chiplets.
67    PublicInput(usize),
68
69    /// Literal field value. Required for
70    /// chiplets, optional for main programs.
71    Constant(F),
72}
73
74/// Pins a single trace cell to a
75/// target value at a specific row.
76#[derive(Clone, Debug)]
77pub struct BoundaryConstraint<F> {
78    pub col_idx: usize,
79    pub row_idx: usize,
80    pub target: BoundaryTarget<F>,
81}
82
83impl<F> BoundaryConstraint<F> {
84    pub fn with_public_input(col_idx: usize, row_idx: usize, public_input_idx: usize) -> Self {
85        Self {
86            col_idx,
87            row_idx,
88            target: BoundaryTarget::PublicInput(public_input_idx),
89        }
90    }
91
92    pub fn with_constant(col_idx: usize, row_idx: usize, val: F) -> Self {
93        Self {
94            col_idx,
95            row_idx,
96            target: BoundaryTarget::Constant(val),
97        }
98    }
99}
100
101impl<F: TowerField> BoundaryConstraint<F> {
102    /// Read the pin value out of `target`.
103    /// `Constant` returns the literal;
104    /// `PublicInput(idx)` reads
105    /// `instance.public_inputs[idx]`.
106    pub fn resolve_target(
107        &self,
108        instance: &crate::ProgramInstance<F>,
109    ) -> hekate_core::errors::Result<F> {
110        match &self.target {
111            BoundaryTarget::Constant(v) => Ok(*v),
112            BoundaryTarget::PublicInput(idx) => {
113                instance.public_input(*idx).ok_or(Error::Protocol {
114                    protocol: "boundary",
115                    message: "public_input_idx out of bounds",
116                })
117            }
118        }
119    }
120
121    /// Constant values are load-bearing:
122    /// a malicious prover could swap them between
123    /// sessions while keeping the chiplet root valid
124    /// unless the transcript binds them.
125    pub fn absorb_into<H: hekate_crypto::Hasher>(
126        &self,
127        transcript: &mut hekate_crypto::transcript::Transcript<H>,
128    ) {
129        transcript.append_u64(b"chiplet_bnd_col", self.col_idx as u64);
130        transcript.append_u64(b"chiplet_bnd_row", self.row_idx as u64);
131
132        match &self.target {
133            BoundaryTarget::PublicInput(idx) => {
134                transcript.append_u64(b"chiplet_bnd_kind", 0);
135                transcript.append_u64(b"chiplet_bnd_pub", *idx as u64);
136            }
137            BoundaryTarget::Constant(v) => {
138                transcript.append_u64(b"chiplet_bnd_kind", 1);
139                transcript.append_field(b"chiplet_bnd_const", *v);
140            }
141        }
142    }
143}
144
145// =================================================================
146// CONSTRAINT IR
147// =================================================================
148
149/// Index into a `ConstraintArena`. Cheap to copy.
150#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
151pub struct ExprId(pub u32);
152
153/// A single node in the constraint AST-DAG.
154/// All variants are algebraic operations over GF(2^128).
155#[derive(Clone, Debug)]
156pub enum ConstraintExpr<F> {
157    /// A trace cell reference (column + current/next row).
158    Cell(ProgramCell),
159
160    /// A field constant.
161    Const(F),
162
163    /// Addition of two sub-expressions.
164    /// In GF(2^N) this is XOR.
165    Add(ExprId, ExprId),
166
167    /// Multiplication of two sub-expressions.
168    Mul(ExprId, ExprId),
169
170    /// Scalar multiplication:
171    /// coeff * expr.
172    ///
173    /// Avoids creating a Const node + Mul
174    /// pair for the common case.
175    Scale(F, ExprId),
176
177    /// Sum of N sub-expressions.
178    /// Exists specifically for Theta-style linear
179    /// combinations to avoid deep Add chains.
180    /// Evaluates to sum of children[i].
181    Sum(Vec<ExprId>),
182}
183
184/// Arena-allocated constraint DAG.
185/// Nodes reference each other by
186/// `ExprId` (index into `nodes`).
187///
188/// Cell nodes are automatically deduplicated:
189/// calling `cell()` twice with the same `ProgramCell`
190/// returns the same `ExprId`. This is mandatory, the
191/// downstream compiler maps ExprId to poly index,
192/// so duplicate cells would create duplicate polys
193/// in VirtualPoly and bloat Sumcheck evaluation.
194pub struct ConstraintArena<F> {
195    nodes: Vec<ConstraintExpr<F>>,
196
197    /// Dedup cache for Cell nodes.
198    /// Same ProgramCell -> same ExprId.
199    cell_cache: HashMap<ProgramCell, ExprId>,
200}
201
202impl<F: TowerField> Default for ConstraintArena<F> {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208impl<F: TowerField> ConstraintArena<F> {
209    pub fn new() -> Self {
210        Self {
211            nodes: Vec::new(),
212            cell_cache: HashMap::new(),
213        }
214    }
215
216    /// Allocate a new expression node.
217    /// Returns its ID.
218    pub fn alloc(&mut self, expr: ConstraintExpr<F>) -> ExprId {
219        let id = ExprId(self.nodes.len() as u32);
220        self.nodes.push(expr);
221
222        id
223    }
224
225    /// Read a node by ID.
226    pub fn get(&self, id: ExprId) -> &ConstraintExpr<F> {
227        &self.nodes[id.0 as usize]
228    }
229
230    /// Total number of nodes.
231    pub fn len(&self) -> usize {
232        self.nodes.len()
233    }
234
235    pub fn is_empty(&self) -> bool {
236        self.nodes.is_empty()
237    }
238
239    /// Shift all Cell node col_idx by `offset`.
240    /// Used when embedding a chiplet's AST into
241    /// a combined program where column indices
242    /// are offset.
243    pub fn shift_cells(&mut self, offset: usize) {
244        for node in &mut self.nodes {
245            if let ConstraintExpr::Cell(cell) = node {
246                cell.col_idx += offset;
247            }
248        }
249
250        let old_cache = core::mem::take(&mut self.cell_cache);
251        for (mut cell, id) in old_cache {
252            cell.col_idx += offset;
253            self.cell_cache.insert(cell, id);
254        }
255    }
256
257    /// Create or reuse a cell reference.
258    /// Automatically deduplicates:
259    /// same ProgramCell -> same ExprId.
260    pub fn cell(&mut self, cell: ProgramCell) -> ExprId {
261        if let Some(&id) = self.cell_cache.get(&cell) {
262            return id;
263        }
264
265        let id = self.alloc(ConstraintExpr::Cell(cell));
266        self.cell_cache.insert(cell, id);
267
268        id
269    }
270
271    /// Create a constant.
272    pub fn constant(&mut self, val: F) -> ExprId {
273        self.alloc(ConstraintExpr::Const(val))
274    }
275
276    /// a + b
277    pub fn add(&mut self, a: ExprId, b: ExprId) -> ExprId {
278        self.alloc(ConstraintExpr::Add(a, b))
279    }
280
281    /// a * b
282    pub fn mul(&mut self, a: ExprId, b: ExprId) -> ExprId {
283        self.alloc(ConstraintExpr::Mul(a, b))
284    }
285
286    /// coeff * a
287    pub fn scale(&mut self, coeff: F, a: ExprId) -> ExprId {
288        self.alloc(ConstraintExpr::Scale(coeff, a))
289    }
290
291    /// Sum of multiple expressions.
292    pub fn sum(&mut self, children: Vec<ExprId>) -> ExprId {
293        self.alloc(ConstraintExpr::Sum(children))
294    }
295}
296
297/// A full constraint system in AST form.
298/// `roots` are the top-level constraint expressions.
299/// Each root must evaluate to 0 for a valid trace.
300pub struct ConstraintAst<F> {
301    pub arena: ConstraintArena<F>,
302    pub roots: Vec<ExprId>,
303    pub labels: Vec<Option<&'static str>>,
304}
305
306impl<F: TowerField> Clone for ConstraintArena<F> {
307    fn clone(&self) -> Self {
308        Self {
309            nodes: self.nodes.clone(),
310            cell_cache: self.cell_cache.clone(),
311        }
312    }
313}
314
315impl<F: TowerField> Clone for ConstraintAst<F> {
316    fn clone(&self) -> Self {
317        Self {
318            arena: self.arena.clone(),
319            roots: self.roots.clone(),
320            labels: self.labels.clone(),
321        }
322    }
323}
324
325impl<F: TowerField> ConstraintAst<F> {
326    /// Maximum polynomial degree
327    /// across all constraint roots.
328    pub fn max_degree(&self) -> usize {
329        if self.arena.is_empty() {
330            return 0;
331        }
332
333        let n = self.arena.len();
334        let mut deg: Vec<usize> = Vec::with_capacity(n);
335
336        for i in 0..n {
337            let d = match self.arena.get(ExprId(i as u32)) {
338                ConstraintExpr::Cell(_) => 1,
339                ConstraintExpr::Const(_) => 0,
340                ConstraintExpr::Add(a, b) => deg[a.0 as usize].max(deg[b.0 as usize]),
341                ConstraintExpr::Mul(a, b) => deg[a.0 as usize] + deg[b.0 as usize],
342                ConstraintExpr::Scale(_, a) => deg[a.0 as usize],
343                ConstraintExpr::Sum(children) => children
344                    .iter()
345                    .map(|c| deg[c.0 as usize])
346                    .max()
347                    .unwrap_or(0),
348            };
349
350            deg.push(d);
351        }
352
353        self.roots
354            .iter()
355            .map(|r| deg[r.0 as usize])
356            .max()
357            .unwrap_or(0)
358    }
359
360    /// Hardware-basis `Const`/`Scale` coefficients
361    /// indexed by `ExprId`; other slots hold zero.
362    pub fn precompute_hardware_consts(&self) -> Vec<Flat<F>>
363    where
364        F: HardwareField,
365    {
366        let n = self.arena.len();
367
368        let mut consts: Vec<Flat<F>> = Vec::with_capacity(n);
369        for i in 0..n {
370            let c = match self.arena.get(ExprId(i as u32)) {
371                ConstraintExpr::Const(k) => k.to_hardware(),
372                ConstraintExpr::Scale(k, _) => k.to_hardware(),
373                _ => Flat::from_raw(F::ZERO),
374            };
375
376            consts.push(c);
377        }
378
379        consts
380    }
381
382    /// Evaluate each constraint
383    /// root at a single point.
384    pub fn evaluate(&self, current_row: &[Flat<F>], next_row: &[Flat<F>]) -> Vec<Flat<F>>
385    where
386        F: HardwareField,
387    {
388        let consts = self.precompute_hardware_consts();
389        let mut buf: Vec<Flat<F>> = Vec::with_capacity(self.arena.len());
390
391        self.evaluate_into(&consts, current_row, next_row, &mut buf);
392
393        self.roots.iter().map(|r| buf[r.0 as usize]).collect()
394    }
395
396    /// `consts` must come from the same arena
397    /// `self.precompute_hardware_consts()`.
398    pub fn evaluate_into(
399        &self,
400        consts: &[Flat<F>],
401        current_row: &[Flat<F>],
402        next_row: &[Flat<F>],
403        buf: &mut Vec<Flat<F>>,
404    ) where
405        F: HardwareField,
406    {
407        buf.clear();
408
409        for (i, &coeff) in consts.iter().enumerate() {
410            let v = match self.arena.get(ExprId(i as u32)) {
411                ConstraintExpr::Cell(cell) => {
412                    if cell.next_row {
413                        next_row[cell.col_idx]
414                    } else {
415                        current_row[cell.col_idx]
416                    }
417                }
418                ConstraintExpr::Const(_) => coeff,
419                ConstraintExpr::Add(a, b) => buf[a.0 as usize] + buf[b.0 as usize],
420                ConstraintExpr::Mul(a, b) => buf[a.0 as usize] * buf[b.0 as usize],
421                ConstraintExpr::Scale(_, a) => coeff * buf[a.0 as usize],
422                ConstraintExpr::Sum(children) => {
423                    let mut s = Flat::from_raw(F::ZERO);
424                    for c in children {
425                        s += buf[c.0 as usize];
426                    }
427
428                    s
429                }
430            };
431
432            buf.push(v);
433        }
434    }
435
436    /// Merge another constraint AST into this one.
437    pub fn merge(&mut self, other: ConstraintAst<F>) {
438        let mut id_map: Vec<ExprId> = Vec::with_capacity(other.arena.len());
439        for node in other.arena.nodes {
440            let new_id = match node {
441                ConstraintExpr::Cell(cell) => self.arena.cell(cell),
442                ConstraintExpr::Const(val) => self.arena.constant(val),
443                ConstraintExpr::Add(a, b) => {
444                    self.arena.add(id_map[a.0 as usize], id_map[b.0 as usize])
445                }
446                ConstraintExpr::Mul(a, b) => {
447                    self.arena.mul(id_map[a.0 as usize], id_map[b.0 as usize])
448                }
449                ConstraintExpr::Scale(coeff, inner) => {
450                    self.arena.scale(coeff, id_map[inner.0 as usize])
451                }
452                ConstraintExpr::Sum(children) => {
453                    let remapped: Vec<ExprId> =
454                        children.into_iter().map(|c| id_map[c.0 as usize]).collect();
455                    self.arena.sum(remapped)
456                }
457            };
458
459            id_map.push(new_id);
460        }
461
462        for (root, label) in other.roots.into_iter().zip(other.labels) {
463            self.roots.push(id_map[root.0 as usize]);
464            self.labels.push(label);
465        }
466    }
467
468    /// Convert AST to flat `Vec<Constraint<F>>`.
469    /// Expands the DAG into sum-of-products form.
470    pub fn to_constraints(&self) -> Vec<Constraint<F>> {
471        /// A flat term:
472        /// coefficient × product of cells.
473        type FlatTerm<F> = (F, Vec<ProgramCell>);
474
475        fn expand<F: TowerField>(
476            arena: &ConstraintArena<F>,
477            id: ExprId,
478            cache: &mut Vec<Option<Vec<FlatTerm<F>>>>,
479        ) -> Vec<FlatTerm<F>> {
480            if let Some(ref cached) = cache[id.0 as usize] {
481                return cached.clone();
482            }
483
484            let result = match arena.get(id) {
485                ConstraintExpr::Cell(cell) => {
486                    vec![(F::ONE, vec![*cell])]
487                }
488                ConstraintExpr::Const(k) => {
489                    vec![(*k, vec![])]
490                }
491                ConstraintExpr::Add(a, b) => {
492                    let mut terms = expand(arena, *a, cache);
493                    terms.extend(expand(arena, *b, cache));
494
495                    terms
496                }
497                ConstraintExpr::Mul(a, b) => {
498                    let left = expand(arena, *a, cache);
499                    let right = expand(arena, *b, cache);
500
501                    let mut terms = Vec::with_capacity(left.len() * right.len());
502                    for (lc, lp) in &left {
503                        for (rc, rp) in &right {
504                            let coeff = *lc * *rc;
505                            let mut cells = lp.clone();
506
507                            cells.extend_from_slice(rp);
508                            terms.push((coeff, cells));
509                        }
510                    }
511
512                    terms
513                }
514                ConstraintExpr::Scale(k, a) => {
515                    let inner = expand(arena, *a, cache);
516                    inner
517                        .into_iter()
518                        .map(|(c, cells)| (*k * c, cells))
519                        .collect()
520                }
521                ConstraintExpr::Sum(children) => {
522                    let mut terms = Vec::new();
523                    for child in children {
524                        terms.extend(expand(arena, *child, cache));
525                    }
526
527                    terms
528                }
529            };
530
531            cache[id.0 as usize] = Some(result.clone());
532
533            result
534        }
535
536        let n = self.arena.len();
537        let mut cache: Vec<Option<Vec<FlatTerm<F>>>> = vec![None; n];
538
539        self.roots
540            .iter()
541            .map(|root| {
542                let flat_terms = expand(&self.arena, *root, &mut cache);
543                Constraint::new(
544                    flat_terms
545                        .into_iter()
546                        .map(|(coeff, cells)| ConstraintTerm::new(coeff, cells))
547                        .collect(),
548                )
549            })
550            .collect()
551    }
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557    use crate::constraint::ConstraintExpr;
558    use crate::constraint::builder::ConstraintSystem;
559    use crate::{Air, Program};
560    use hekate_core::trace::ColumnType;
561    use hekate_math::{Block128, Flat};
562
563    type F = Block128;
564
565    #[derive(Clone)]
566    struct TestFibProgram;
567
568    impl Air<F> for TestFibProgram {
569        fn num_columns(&self) -> usize {
570            3
571        }
572
573        fn column_layout(&self) -> &[ColumnType] {
574            &[ColumnType::B32, ColumnType::B32, ColumnType::Bit]
575        }
576
577        fn constraint_ast(&self) -> ConstraintAst<F> {
578            let cs = ConstraintSystem::<F>::new();
579
580            let [a, b, q] = [cs.col(0), cs.col(1), cs.col(2)];
581            let [na, nb] = [cs.next(0), cs.next(1)];
582
583            cs.constrain(q * (na + b));
584            cs.constrain(q * (nb + a + b));
585
586            cs.build()
587        }
588    }
589
590    impl Program<F> for TestFibProgram {}
591
592    #[test]
593    fn default_constraint_ast_produces_correct_roots() {
594        let program = TestFibProgram;
595        let ast = program.constraint_ast();
596
597        // 2 constraints -> 2 roots
598        assert_eq!(ast.roots.len(), 2);
599
600        // c1 has 2 terms, c2 has 3 terms.
601        // Each term: 1 Const + 2 Cell + 2 Mul = 5 nodes
602        // c1: 2 terms * 5 + 1 Sum = 11 nodes (but cells are deduped)
603        // Verify non-empty and structurally sound.
604        assert!(!ast.arena.is_empty());
605
606        // Both roots should be Sum or product nodes
607        for &root in &ast.roots {
608            let node = ast.arena.get(root);
609            match node {
610                ConstraintExpr::Sum(children) => {
611                    assert!(!children.is_empty());
612                }
613                ConstraintExpr::Mul(_, _) => {
614                    // single-term constraint would be a Mul
615                }
616                _ => panic!("Root should be Sum or Mul, got {:?}", node),
617            }
618        }
619    }
620
621    #[test]
622    fn cell_dedup_works() {
623        let mut arena = ConstraintArena::<F>::new();
624
625        let cell_a = ProgramCell::current(0);
626        let cell_b = ProgramCell::current(0);
627        let cell_c = ProgramCell::next(0);
628
629        let id_a = arena.cell(cell_a);
630        let id_b = arena.cell(cell_b);
631        let id_c = arena.cell(cell_c);
632
633        // Same cell -> same ExprId
634        assert_eq!(id_a, id_b);
635        // Different cell -> different ExprId
636        assert_ne!(id_a, id_c);
637        // Only 2 nodes allocated, not 3
638        assert_eq!(arena.len(), 2);
639    }
640
641    #[test]
642    fn dag_sharing_reduces_node_count() {
643        let mut arena = ConstraintArena::<F>::new();
644
645        // Build a shared sub-expression: theta = a + b + c
646        let a = arena.cell(ProgramCell::current(0));
647        let b = arena.cell(ProgramCell::current(1));
648        let c = arena.cell(ProgramCell::current(2));
649        let theta = arena.sum(vec![a, b, c]);
650
651        // Use theta in two different constraints (DAG sharing)
652        let d = arena.cell(ProgramCell::current(3));
653        let expr1 = arena.mul(theta, d);
654        let expr2 = arena.mul(theta, a); // reuses both theta and a
655
656        let dag_node_count = arena.len();
657        // 4 cells + 1 sum + 2 muls = 7 nodes
658        assert_eq!(dag_node_count, 7);
659
660        // Without sharing (tree), theta would be duplicated:
661        // 4 cells + 2 sums + 2 muls = 8 nodes minimum
662        // (and the cells inside the second theta would also duplicate)
663        // Actually: 3 cells * 2 + 1 extra cell + 2 sums + 2 muls = 10 nodes
664        // DAG: 7 < 10 tree nodes
665        assert!(dag_node_count < 10);
666
667        // Verify both expressions reference the same theta
668        match arena.get(expr1) {
669            ConstraintExpr::Mul(lhs, _) => assert_eq!(*lhs, theta),
670            _ => panic!("Expected Mul"),
671        }
672
673        match arena.get(expr2) {
674            ConstraintExpr::Mul(lhs, rhs) => {
675                assert_eq!(*lhs, theta);
676                assert_eq!(*rhs, a);
677            }
678            _ => panic!("Expected Mul"),
679        }
680    }
681
682    #[test]
683    fn default_constraint_ast_node_count_matches_flat() {
684        let program = TestFibProgram;
685        let flat = program.constraints();
686        let ast = program.constraint_ast();
687
688        // Count total cells across flat constraints
689        let mut flat_cell_count = 0;
690        for c in &flat {
691            for t in &c.terms {
692                flat_cell_count += t.poly_ind.len();
693            }
694        }
695
696        // With dedup, AST cell nodes <= flat cell references
697        let ast_cell_count = ast
698            .arena
699            .nodes
700            .iter()
701            .filter(|n| matches!(n, ConstraintExpr::Cell(_)))
702            .count();
703
704        assert!(ast_cell_count <= flat_cell_count);
705
706        // Fib shares ProgramCell::current(2) across all terms.
707        // 5 total term-cells reference current(2), but dedup
708        // means only 1 Cell node for it.
709        // Unique cells: current(0), current(1), current(2), next(0), next(1) = 5
710        assert_eq!(ast_cell_count, 5);
711    }
712
713    #[test]
714    fn empty_constraint_produces_empty_ast() {
715        #[derive(Clone)]
716        struct EmptyProgram;
717
718        impl Air<F> for EmptyProgram {
719            fn num_columns(&self) -> usize {
720                0
721            }
722
723            fn column_layout(&self) -> &[ColumnType] {
724                &[]
725            }
726
727            fn constraint_ast(&self) -> ConstraintAst<F> {
728                ConstraintSystem::<F>::new().build()
729            }
730        }
731
732        impl Program<F> for EmptyProgram {}
733
734        let ast = EmptyProgram.constraint_ast();
735        assert!(ast.roots.is_empty());
736        assert!(ast.arena.is_empty());
737    }
738
739    #[test]
740    fn single_term_constraint_no_sum_wrapper() {
741        #[derive(Clone)]
742        struct SingleTermProgram;
743
744        impl Air<F> for SingleTermProgram {
745            fn num_columns(&self) -> usize {
746                2
747            }
748
749            fn column_layout(&self) -> &[ColumnType] {
750                &[ColumnType::B128, ColumnType::B128]
751            }
752
753            fn constraint_ast(&self) -> ConstraintAst<F> {
754                let cs = ConstraintSystem::<F>::new();
755                cs.constrain(cs.col(0) * cs.col(1));
756
757                cs.build()
758            }
759        }
760
761        impl Program<F> for SingleTermProgram {}
762
763        let ast = SingleTermProgram.constraint_ast();
764        assert_eq!(ast.roots.len(), 1);
765
766        // Single term: Const * cell0 * cell1 -> chain of Mul, no Sum
767        match ast.arena.get(ast.roots[0]) {
768            ConstraintExpr::Mul(_, _) => {} // correct
769            other => panic!("Expected Mul for single-term, got {:?}", other),
770        }
771    }
772
773    // =========================================================
774    // ConstraintAst method tests
775    // =========================================================
776
777    #[test]
778    fn max_degree_fibonacci() {
779        let program = TestFibProgram;
780        let ast = program.constraint_ast();
781
782        // Fib constraints: q * next_a (degree 2), q * curr_b (degree 2)
783        // Default AST adds Const(ONE) * cell * cell per term → degree 2 + Const
784        // Const has degree 0, so each term is Mul chain: 0 + 1 + 1 = 2
785        // The AST max degree should match the flat form.
786        let flat = program.constraints();
787        let flat_max = flat
788            .iter()
789            .flat_map(|c| c.terms.iter())
790            .map(|t| t.poly_ind.len())
791            .max()
792            .unwrap_or(0);
793
794        assert_eq!(ast.max_degree(), flat_max);
795    }
796
797    #[test]
798    fn max_degree_empty() {
799        let ast = ConstraintAst::<F> {
800            arena: ConstraintArena::new(),
801            roots: Vec::new(),
802            labels: Vec::new(),
803        };
804        assert_eq!(ast.max_degree(), 0);
805    }
806
807    #[test]
808    fn max_degree_builder_mul_chain() {
809        use crate::constraint::builder::ConstraintSystem;
810
811        let cs = ConstraintSystem::<F>::new();
812        let a = cs.col(0);
813        let b = cs.col(1);
814        let c = cs.col(2);
815
816        // a * b * c = degree 3
817        cs.constrain(a * b * c);
818        // a + b = degree 1
819        cs.constrain(a + b);
820
821        let ast = cs.build();
822        assert_eq!(ast.max_degree(), 3);
823    }
824
825    #[test]
826    fn to_constraints_roundtrip_fibonacci() {
827        let program = TestFibProgram;
828        let ast = program.constraint_ast();
829        let flat_from_ast = ast.to_constraints();
830        let flat_direct = program.constraints();
831
832        // Same number of constraints
833        assert_eq!(flat_from_ast.len(), flat_direct.len());
834
835        // Same number of terms per constraint
836        for (a, d) in flat_from_ast.iter().zip(flat_direct.iter()) {
837            assert_eq!(a.terms.len(), d.terms.len());
838        }
839    }
840
841    #[test]
842    fn to_constraints_from_builder() {
843        use crate::constraint::builder::ConstraintSystem;
844
845        let cs = ConstraintSystem::<F>::new();
846        let a = cs.col(0);
847        let b = cs.col(1);
848
849        // a + b = 0  →  should produce 2 flat terms: (ONE, [a]) + (ONE, [b])
850        cs.constrain(a + b);
851
852        let ast = cs.build();
853        let flat = ast.to_constraints();
854
855        assert_eq!(flat.len(), 1);
856        assert_eq!(flat[0].terms.len(), 2);
857
858        // Both terms should have exactly 1 cell each
859        for term in &flat[0].terms {
860            assert_eq!(term.coeff, F::ONE);
861            assert_eq!(term.poly_ind.len(), 1);
862        }
863    }
864
865    #[test]
866    fn evaluate_simple_constraint() {
867        use crate::constraint::builder::ConstraintSystem;
868        use hekate_math::Flat;
869
870        let cs = ConstraintSystem::<F>::new();
871        let a = cs.col(0);
872        let b = cs.col(1);
873
874        // a + b = 0
875        cs.constrain(a + b);
876
877        let ast = cs.build();
878
879        // Evaluate at a=3, b=3 -> 3 + 3 = 0 in GF(2^k) (XOR)
880        let current = vec![
881            Flat::from_raw(F::from(3u128)),
882            Flat::from_raw(F::from(3u128)),
883        ];
884        let next = vec![Flat::from_raw(F::ZERO); 2];
885
886        let evals = ast.evaluate(&current, &next);
887        assert_eq!(evals.len(), 1);
888        assert_eq!(evals[0], Flat::from_raw(F::ZERO)); // 3 XOR 3 = 0
889
890        // Evaluate at a=3, b=5 -> 3 XOR 5 = 6 ≠ 0
891        let current2 = vec![
892            Flat::from_raw(F::from(3u128)),
893            Flat::from_raw(F::from(5u128)),
894        ];
895        let evals2 = ast.evaluate(&current2, &next);
896        assert_ne!(evals2[0], Flat::from_raw(F::ZERO));
897    }
898
899    #[test]
900    fn evaluate_into_matches_evaluate() {
901        let cs = ConstraintSystem::<F>::new();
902
903        let a = cs.col(0);
904        let b = cs.col(1);
905        let na = cs.next(0);
906
907        cs.constrain(a + b);
908        cs.constrain(a * b);
909        cs.constrain(na + a);
910
911        let ast = cs.build();
912        let zero = Flat::from_raw(F::ZERO);
913
914        let current = vec![
915            Flat::from_raw(F::from(3u128)),
916            Flat::from_raw(F::from(5u128)),
917        ];
918        let next = vec![Flat::from_raw(F::from(7u128)), zero];
919
920        let expected = ast.evaluate(&current, &next);
921
922        let consts = ast.precompute_hardware_consts();
923
924        let mut buf = Vec::new();
925        ast.evaluate_into(&consts, &current, &next, &mut buf);
926
927        for (i, root) in ast.roots.iter().enumerate() {
928            assert_eq!(buf[root.0 as usize], expected[i]);
929        }
930    }
931}