Skip to main content

hekate_program/constraint/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate project.
3// Copyright (C) 2026 Andrei Kochergin <andrei@oumuamua.dev>
4// Copyright (C) 2026 Oumuamua Labs <info@oumuamua.dev>. All rights reserved.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! SYMBOLIC CONSTRAINTS DEFINITION
19
20use crate::ProgramCell;
21use alloc::vec;
22use alloc::vec::Vec;
23use hashbrown::HashMap;
24use hekate_core::errors::Error;
25use hekate_math::{Flat, HardwareField, TowerField};
26
27pub mod builder;
28
29/// Represents a single term in a polynomial constraint.
30/// Form: `coeff * product(cells)`
31/// Example: `5 * x_curr * y_next`
32#[derive(Clone, Debug)]
33pub struct ConstraintTerm<F> {
34    pub coeff: F,
35    pub poly_ind: Vec<ProgramCell>, // Multiplicands (variables)
36}
37
38impl<F: TowerField> ConstraintTerm<F> {
39    pub fn new(coeff: F, cells: Vec<ProgramCell>) -> Self {
40        Self {
41            coeff,
42            poly_ind: cells,
43        }
44    }
45}
46
47/// Represents a full constraint equation:
48/// sum(terms) == 0.
49#[derive(Clone, Debug)]
50pub struct Constraint<F> {
51    pub terms: Vec<ConstraintTerm<F>>,
52}
53
54impl<F: TowerField> Constraint<F> {
55    pub fn new(terms: Vec<ConstraintTerm<F>>) -> Self {
56        Self { terms }
57    }
58}
59
60/// Source of the value a boundary constraint
61/// pins `Trace(row_idx, col_idx)` to.
62#[derive(Clone, Debug, PartialEq, Eq)]
63pub enum BoundaryTarget<F> {
64    /// `instance.public_inputs[idx]`. Main-program
65    /// use only; `ChipletDef::from_air` rejects
66    /// this variant on chiplets.
67    PublicInput(usize),
68
69    /// Literal field value. Required for
70    /// chiplets, optional for main programs.
71    Constant(F),
72}
73
74/// Pins a single trace cell to a
75/// target value at a specific row.
76#[derive(Clone, Debug)]
77pub struct BoundaryConstraint<F> {
78    pub col_idx: usize,
79    pub row_idx: usize,
80    pub target: BoundaryTarget<F>,
81}
82
83impl<F> BoundaryConstraint<F> {
84    pub fn with_public_input(col_idx: usize, row_idx: usize, public_input_idx: usize) -> Self {
85        Self {
86            col_idx,
87            row_idx,
88            target: BoundaryTarget::PublicInput(public_input_idx),
89        }
90    }
91
92    pub fn with_constant(col_idx: usize, row_idx: usize, val: F) -> Self {
93        Self {
94            col_idx,
95            row_idx,
96            target: BoundaryTarget::Constant(val),
97        }
98    }
99}
100
101impl<F: TowerField> BoundaryConstraint<F> {
102    /// Read the pin value out of `target`.
103    /// `Constant` returns the literal;
104    /// `PublicInput(idx)` reads
105    /// `instance.public_inputs[idx]`.
106    pub fn resolve_target(
107        &self,
108        instance: &crate::ProgramInstance<F>,
109    ) -> hekate_core::errors::Result<F> {
110        match &self.target {
111            BoundaryTarget::Constant(v) => Ok(*v),
112            BoundaryTarget::PublicInput(idx) => {
113                instance.public_input(*idx).ok_or(Error::Protocol {
114                    protocol: "boundary",
115                    message: "public_input_idx out of bounds",
116                })
117            }
118        }
119    }
120
121    /// Constant values are load-bearing:
122    /// a malicious prover could swap them between
123    /// sessions while keeping the chiplet root valid
124    /// unless the transcript binds them.
125    pub fn absorb_into<H: hekate_crypto::Hasher>(
126        &self,
127        transcript: &mut hekate_crypto::transcript::Transcript<H>,
128    ) {
129        transcript.append_u64(b"chiplet_bnd_col", self.col_idx as u64);
130        transcript.append_u64(b"chiplet_bnd_row", self.row_idx as u64);
131
132        match &self.target {
133            BoundaryTarget::PublicInput(idx) => {
134                transcript.append_u64(b"chiplet_bnd_kind", 0);
135                transcript.append_u64(b"chiplet_bnd_pub", *idx as u64);
136            }
137            BoundaryTarget::Constant(v) => {
138                transcript.append_u64(b"chiplet_bnd_kind", 1);
139                transcript.append_field(b"chiplet_bnd_const", *v);
140            }
141        }
142    }
143}
144
145// =================================================================
146// CONSTRAINT IR
147// =================================================================
148
149/// Index into a `ConstraintArena`. Cheap to copy.
150#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
151pub struct ExprId(pub u32);
152
153/// A single node in the constraint AST-DAG.
154/// All variants are algebraic operations over GF(2^128).
155#[derive(Clone, Debug)]
156pub enum ConstraintExpr<F> {
157    /// A trace cell reference (column + current/next row).
158    Cell(ProgramCell),
159
160    /// A field constant.
161    Const(F),
162
163    /// Addition of two sub-expressions.
164    /// In GF(2^N) this is XOR.
165    Add(ExprId, ExprId),
166
167    /// Multiplication of two sub-expressions.
168    Mul(ExprId, ExprId),
169
170    /// Scalar multiplication:
171    /// coeff * expr.
172    ///
173    /// Avoids creating a Const node + Mul
174    /// pair for the common case.
175    Scale(F, ExprId),
176
177    /// Sum of N sub-expressions.
178    /// Exists specifically for Theta-style linear
179    /// combinations to avoid deep Add chains.
180    /// Evaluates to sum of children[i].
181    Sum(Vec<ExprId>),
182}
183
184/// Arena-allocated constraint DAG.
185/// Nodes reference each other by
186/// `ExprId` (index into `nodes`).
187///
188/// Cell nodes are automatically deduplicated:
189/// calling `cell()` twice with the same `ProgramCell`
190/// returns the same `ExprId`. This is mandatory, the
191/// downstream compiler maps ExprId to poly index,
192/// so duplicate cells would create duplicate polys
193/// in VirtualPoly and bloat Sumcheck evaluation.
194pub struct ConstraintArena<F> {
195    nodes: Vec<ConstraintExpr<F>>,
196
197    /// Dedup cache for Cell nodes.
198    /// Same ProgramCell -> same ExprId.
199    cell_cache: HashMap<ProgramCell, ExprId>,
200}
201
202impl<F: TowerField> Default for ConstraintArena<F> {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208impl<F: TowerField> ConstraintArena<F> {
209    pub fn new() -> Self {
210        Self {
211            nodes: Vec::new(),
212            cell_cache: HashMap::new(),
213        }
214    }
215
216    /// Allocate a new expression node.
217    /// Returns its ID.
218    pub fn alloc(&mut self, expr: ConstraintExpr<F>) -> ExprId {
219        let id = ExprId(self.nodes.len() as u32);
220        self.nodes.push(expr);
221
222        id
223    }
224
225    /// Read a node by ID.
226    pub fn get(&self, id: ExprId) -> &ConstraintExpr<F> {
227        &self.nodes[id.0 as usize]
228    }
229
230    /// Total number of nodes.
231    pub fn len(&self) -> usize {
232        self.nodes.len()
233    }
234
235    pub fn is_empty(&self) -> bool {
236        self.nodes.is_empty()
237    }
238
239    /// Shift all Cell node col_idx by `offset`.
240    /// Used when embedding a chiplet's AST into
241    /// a combined program where column indices
242    /// are offset.
243    pub fn shift_cells(&mut self, offset: usize) {
244        for node in &mut self.nodes {
245            if let ConstraintExpr::Cell(cell) = node {
246                cell.col_idx += offset;
247            }
248        }
249
250        let old_cache = core::mem::take(&mut self.cell_cache);
251        for (mut cell, id) in old_cache {
252            cell.col_idx += offset;
253            self.cell_cache.insert(cell, id);
254        }
255    }
256
257    /// Create or reuse a cell reference.
258    /// Automatically deduplicates:
259    /// same ProgramCell -> same ExprId.
260    pub fn cell(&mut self, cell: ProgramCell) -> ExprId {
261        if let Some(&id) = self.cell_cache.get(&cell) {
262            return id;
263        }
264
265        let id = self.alloc(ConstraintExpr::Cell(cell));
266        self.cell_cache.insert(cell, id);
267
268        id
269    }
270
271    /// Create a constant.
272    pub fn constant(&mut self, val: F) -> ExprId {
273        self.alloc(ConstraintExpr::Const(val))
274    }
275
276    /// a + b
277    pub fn add(&mut self, a: ExprId, b: ExprId) -> ExprId {
278        self.alloc(ConstraintExpr::Add(a, b))
279    }
280
281    /// a * b
282    pub fn mul(&mut self, a: ExprId, b: ExprId) -> ExprId {
283        self.alloc(ConstraintExpr::Mul(a, b))
284    }
285
286    /// coeff * a
287    pub fn scale(&mut self, coeff: F, a: ExprId) -> ExprId {
288        self.alloc(ConstraintExpr::Scale(coeff, a))
289    }
290
291    /// Sum of multiple expressions.
292    pub fn sum(&mut self, children: Vec<ExprId>) -> ExprId {
293        self.alloc(ConstraintExpr::Sum(children))
294    }
295}
296
297/// A full constraint system in AST form.
298/// `roots` are the top-level constraint expressions.
299/// Each root must evaluate to 0 for a valid trace.
300pub struct ConstraintAst<F> {
301    pub arena: ConstraintArena<F>,
302    pub roots: Vec<ExprId>,
303    pub labels: Vec<Option<&'static str>>,
304}
305
306impl<F: TowerField> Clone for ConstraintArena<F> {
307    fn clone(&self) -> Self {
308        Self {
309            nodes: self.nodes.clone(),
310            cell_cache: self.cell_cache.clone(),
311        }
312    }
313}
314
315impl<F: TowerField> Clone for ConstraintAst<F> {
316    fn clone(&self) -> Self {
317        Self {
318            arena: self.arena.clone(),
319            roots: self.roots.clone(),
320            labels: self.labels.clone(),
321        }
322    }
323}
324
325impl<F: TowerField> ConstraintAst<F> {
326    /// Maximum polynomial degree
327    /// across all constraint roots.
328    pub fn max_degree(&self) -> usize {
329        if self.arena.is_empty() {
330            return 0;
331        }
332
333        let n = self.arena.len();
334        let mut deg: Vec<usize> = Vec::with_capacity(n);
335
336        for i in 0..n {
337            let d = match self.arena.get(ExprId(i as u32)) {
338                ConstraintExpr::Cell(_) => 1,
339                ConstraintExpr::Const(_) => 0,
340                ConstraintExpr::Add(a, b) => deg[a.0 as usize].max(deg[b.0 as usize]),
341                ConstraintExpr::Mul(a, b) => deg[a.0 as usize] + deg[b.0 as usize],
342                ConstraintExpr::Scale(_, a) => deg[a.0 as usize],
343                ConstraintExpr::Sum(children) => children
344                    .iter()
345                    .map(|c| deg[c.0 as usize])
346                    .max()
347                    .unwrap_or(0),
348            };
349
350            deg.push(d);
351        }
352
353        self.roots
354            .iter()
355            .map(|r| deg[r.0 as usize])
356            .max()
357            .unwrap_or(0)
358    }
359
360    /// Evaluate each constraint
361    /// root at a single point.
362    pub fn evaluate(&self, current_row: &[Flat<F>], next_row: &[Flat<F>]) -> Vec<Flat<F>>
363    where
364        F: HardwareField,
365    {
366        let n = self.arena.len();
367        let mut val: Vec<Flat<F>> = Vec::with_capacity(n);
368
369        for i in 0..n {
370            let v = match self.arena.get(ExprId(i as u32)) {
371                ConstraintExpr::Cell(cell) => {
372                    if cell.next_row {
373                        next_row[cell.col_idx]
374                    } else {
375                        current_row[cell.col_idx]
376                    }
377                }
378                ConstraintExpr::Const(k) => k.to_hardware(),
379                ConstraintExpr::Add(a, b) => val[a.0 as usize] + val[b.0 as usize],
380                ConstraintExpr::Mul(a, b) => val[a.0 as usize] * val[b.0 as usize],
381                ConstraintExpr::Scale(k, a) => k.to_hardware() * val[a.0 as usize],
382                ConstraintExpr::Sum(children) => {
383                    let mut s = Flat::from_raw(F::ZERO);
384                    for c in children {
385                        s += val[c.0 as usize];
386                    }
387
388                    s
389                }
390            };
391
392            val.push(v);
393        }
394
395        self.roots.iter().map(|r| val[r.0 as usize]).collect()
396    }
397
398    /// Buffer-reuse variant of `evaluate()`.
399    /// Caller owns `buf`; reused across
400    /// rows to avoid per-row allocation.
401    pub fn evaluate_into(
402        &self,
403        current_row: &[Flat<F>],
404        next_row: &[Flat<F>],
405        buf: &mut Vec<Flat<F>>,
406    ) where
407        F: HardwareField,
408    {
409        buf.clear();
410
411        let n = self.arena.len();
412        for i in 0..n {
413            let v = match self.arena.get(ExprId(i as u32)) {
414                ConstraintExpr::Cell(cell) => {
415                    if cell.next_row {
416                        next_row[cell.col_idx]
417                    } else {
418                        current_row[cell.col_idx]
419                    }
420                }
421                ConstraintExpr::Const(k) => k.to_hardware(),
422                ConstraintExpr::Add(a, b) => buf[a.0 as usize] + buf[b.0 as usize],
423                ConstraintExpr::Mul(a, b) => buf[a.0 as usize] * buf[b.0 as usize],
424                ConstraintExpr::Scale(k, a) => k.to_hardware() * buf[a.0 as usize],
425                ConstraintExpr::Sum(children) => {
426                    let mut s = Flat::from_raw(F::ZERO);
427                    for c in children {
428                        s += buf[c.0 as usize];
429                    }
430
431                    s
432                }
433            };
434
435            buf.push(v);
436        }
437    }
438
439    /// Merge another constraint AST into this one.
440    pub fn merge(&mut self, other: ConstraintAst<F>) {
441        let mut id_map: Vec<ExprId> = Vec::with_capacity(other.arena.len());
442        for node in other.arena.nodes {
443            let new_id = match node {
444                ConstraintExpr::Cell(cell) => self.arena.cell(cell),
445                ConstraintExpr::Const(val) => self.arena.constant(val),
446                ConstraintExpr::Add(a, b) => {
447                    self.arena.add(id_map[a.0 as usize], id_map[b.0 as usize])
448                }
449                ConstraintExpr::Mul(a, b) => {
450                    self.arena.mul(id_map[a.0 as usize], id_map[b.0 as usize])
451                }
452                ConstraintExpr::Scale(coeff, inner) => {
453                    self.arena.scale(coeff, id_map[inner.0 as usize])
454                }
455                ConstraintExpr::Sum(children) => {
456                    let remapped: Vec<ExprId> =
457                        children.into_iter().map(|c| id_map[c.0 as usize]).collect();
458                    self.arena.sum(remapped)
459                }
460            };
461
462            id_map.push(new_id);
463        }
464
465        for (root, label) in other.roots.into_iter().zip(other.labels) {
466            self.roots.push(id_map[root.0 as usize]);
467            self.labels.push(label);
468        }
469    }
470
471    /// Convert AST to flat `Vec<Constraint<F>>`.
472    /// Expands the DAG into sum-of-products form.
473    pub fn to_constraints(&self) -> Vec<Constraint<F>> {
474        /// A flat term:
475        /// coefficient × product of cells.
476        type FlatTerm<F> = (F, Vec<ProgramCell>);
477
478        fn expand<F: TowerField>(
479            arena: &ConstraintArena<F>,
480            id: ExprId,
481            cache: &mut Vec<Option<Vec<FlatTerm<F>>>>,
482        ) -> Vec<FlatTerm<F>> {
483            if let Some(ref cached) = cache[id.0 as usize] {
484                return cached.clone();
485            }
486
487            let result = match arena.get(id) {
488                ConstraintExpr::Cell(cell) => {
489                    vec![(F::ONE, vec![*cell])]
490                }
491                ConstraintExpr::Const(k) => {
492                    vec![(*k, vec![])]
493                }
494                ConstraintExpr::Add(a, b) => {
495                    let mut terms = expand(arena, *a, cache);
496                    terms.extend(expand(arena, *b, cache));
497
498                    terms
499                }
500                ConstraintExpr::Mul(a, b) => {
501                    let left = expand(arena, *a, cache);
502                    let right = expand(arena, *b, cache);
503
504                    let mut terms = Vec::with_capacity(left.len() * right.len());
505                    for (lc, lp) in &left {
506                        for (rc, rp) in &right {
507                            let coeff = *lc * *rc;
508                            let mut cells = lp.clone();
509
510                            cells.extend_from_slice(rp);
511                            terms.push((coeff, cells));
512                        }
513                    }
514
515                    terms
516                }
517                ConstraintExpr::Scale(k, a) => {
518                    let inner = expand(arena, *a, cache);
519                    inner
520                        .into_iter()
521                        .map(|(c, cells)| (*k * c, cells))
522                        .collect()
523                }
524                ConstraintExpr::Sum(children) => {
525                    let mut terms = Vec::new();
526                    for child in children {
527                        terms.extend(expand(arena, *child, cache));
528                    }
529
530                    terms
531                }
532            };
533
534            cache[id.0 as usize] = Some(result.clone());
535
536            result
537        }
538
539        let n = self.arena.len();
540        let mut cache: Vec<Option<Vec<FlatTerm<F>>>> = vec![None; n];
541
542        self.roots
543            .iter()
544            .map(|root| {
545                let flat_terms = expand(&self.arena, *root, &mut cache);
546                Constraint::new(
547                    flat_terms
548                        .into_iter()
549                        .map(|(coeff, cells)| ConstraintTerm::new(coeff, cells))
550                        .collect(),
551                )
552            })
553            .collect()
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560    use crate::constraint::ConstraintExpr;
561    use crate::constraint::builder::ConstraintSystem;
562    use crate::{Air, Program};
563    use hekate_core::trace::ColumnType;
564    use hekate_math::{Block128, Flat};
565
566    type F = Block128;
567
568    #[derive(Clone)]
569    struct TestFibProgram;
570
571    impl Air<F> for TestFibProgram {
572        fn num_columns(&self) -> usize {
573            3
574        }
575
576        fn column_layout(&self) -> &[ColumnType] {
577            &[ColumnType::B32, ColumnType::B32, ColumnType::Bit]
578        }
579
580        fn constraint_ast(&self) -> ConstraintAst<F> {
581            let cs = ConstraintSystem::<F>::new();
582
583            let [a, b, q] = [cs.col(0), cs.col(1), cs.col(2)];
584            let [na, nb] = [cs.next(0), cs.next(1)];
585
586            cs.constrain(q * (na + b));
587            cs.constrain(q * (nb + a + b));
588
589            cs.build()
590        }
591    }
592
593    impl Program<F> for TestFibProgram {}
594
595    #[test]
596    fn default_constraint_ast_produces_correct_roots() {
597        let program = TestFibProgram;
598        let ast = program.constraint_ast();
599
600        // 2 constraints -> 2 roots
601        assert_eq!(ast.roots.len(), 2);
602
603        // c1 has 2 terms, c2 has 3 terms.
604        // Each term: 1 Const + 2 Cell + 2 Mul = 5 nodes
605        // c1: 2 terms * 5 + 1 Sum = 11 nodes (but cells are deduped)
606        // Verify non-empty and structurally sound.
607        assert!(!ast.arena.is_empty());
608
609        // Both roots should be Sum or product nodes
610        for &root in &ast.roots {
611            let node = ast.arena.get(root);
612            match node {
613                ConstraintExpr::Sum(children) => {
614                    assert!(!children.is_empty());
615                }
616                ConstraintExpr::Mul(_, _) => {
617                    // single-term constraint would be a Mul
618                }
619                _ => panic!("Root should be Sum or Mul, got {:?}", node),
620            }
621        }
622    }
623
624    #[test]
625    fn cell_dedup_works() {
626        let mut arena = ConstraintArena::<F>::new();
627
628        let cell_a = ProgramCell::current(0);
629        let cell_b = ProgramCell::current(0);
630        let cell_c = ProgramCell::next(0);
631
632        let id_a = arena.cell(cell_a);
633        let id_b = arena.cell(cell_b);
634        let id_c = arena.cell(cell_c);
635
636        // Same cell -> same ExprId
637        assert_eq!(id_a, id_b);
638        // Different cell -> different ExprId
639        assert_ne!(id_a, id_c);
640        // Only 2 nodes allocated, not 3
641        assert_eq!(arena.len(), 2);
642    }
643
644    #[test]
645    fn dag_sharing_reduces_node_count() {
646        let mut arena = ConstraintArena::<F>::new();
647
648        // Build a shared sub-expression: theta = a + b + c
649        let a = arena.cell(ProgramCell::current(0));
650        let b = arena.cell(ProgramCell::current(1));
651        let c = arena.cell(ProgramCell::current(2));
652        let theta = arena.sum(vec![a, b, c]);
653
654        // Use theta in two different constraints (DAG sharing)
655        let d = arena.cell(ProgramCell::current(3));
656        let expr1 = arena.mul(theta, d);
657        let expr2 = arena.mul(theta, a); // reuses both theta and a
658
659        let dag_node_count = arena.len();
660        // 4 cells + 1 sum + 2 muls = 7 nodes
661        assert_eq!(dag_node_count, 7);
662
663        // Without sharing (tree), theta would be duplicated:
664        // 4 cells + 2 sums + 2 muls = 8 nodes minimum
665        // (and the cells inside the second theta would also duplicate)
666        // Actually: 3 cells * 2 + 1 extra cell + 2 sums + 2 muls = 10 nodes
667        // DAG: 7 < 10 tree nodes
668        assert!(dag_node_count < 10);
669
670        // Verify both expressions reference the same theta
671        match arena.get(expr1) {
672            ConstraintExpr::Mul(lhs, _) => assert_eq!(*lhs, theta),
673            _ => panic!("Expected Mul"),
674        }
675
676        match arena.get(expr2) {
677            ConstraintExpr::Mul(lhs, rhs) => {
678                assert_eq!(*lhs, theta);
679                assert_eq!(*rhs, a);
680            }
681            _ => panic!("Expected Mul"),
682        }
683    }
684
685    #[test]
686    fn default_constraint_ast_node_count_matches_flat() {
687        let program = TestFibProgram;
688        let flat = program.constraints();
689        let ast = program.constraint_ast();
690
691        // Count total cells across flat constraints
692        let mut flat_cell_count = 0;
693        for c in &flat {
694            for t in &c.terms {
695                flat_cell_count += t.poly_ind.len();
696            }
697        }
698
699        // With dedup, AST cell nodes <= flat cell references
700        let ast_cell_count = ast
701            .arena
702            .nodes
703            .iter()
704            .filter(|n| matches!(n, ConstraintExpr::Cell(_)))
705            .count();
706
707        assert!(ast_cell_count <= flat_cell_count);
708
709        // Fib shares ProgramCell::current(2) across all terms.
710        // 5 total term-cells reference current(2), but dedup
711        // means only 1 Cell node for it.
712        // Unique cells: current(0), current(1), current(2), next(0), next(1) = 5
713        assert_eq!(ast_cell_count, 5);
714    }
715
716    #[test]
717    fn empty_constraint_produces_empty_ast() {
718        #[derive(Clone)]
719        struct EmptyProgram;
720
721        impl Air<F> for EmptyProgram {
722            fn num_columns(&self) -> usize {
723                0
724            }
725
726            fn column_layout(&self) -> &[ColumnType] {
727                &[]
728            }
729
730            fn constraint_ast(&self) -> ConstraintAst<F> {
731                ConstraintSystem::<F>::new().build()
732            }
733        }
734
735        impl Program<F> for EmptyProgram {}
736
737        let ast = EmptyProgram.constraint_ast();
738        assert!(ast.roots.is_empty());
739        assert!(ast.arena.is_empty());
740    }
741
742    #[test]
743    fn single_term_constraint_no_sum_wrapper() {
744        #[derive(Clone)]
745        struct SingleTermProgram;
746
747        impl Air<F> for SingleTermProgram {
748            fn num_columns(&self) -> usize {
749                2
750            }
751
752            fn column_layout(&self) -> &[ColumnType] {
753                &[ColumnType::B128, ColumnType::B128]
754            }
755
756            fn constraint_ast(&self) -> ConstraintAst<F> {
757                let cs = ConstraintSystem::<F>::new();
758                cs.constrain(cs.col(0) * cs.col(1));
759
760                cs.build()
761            }
762        }
763
764        impl Program<F> for SingleTermProgram {}
765
766        let ast = SingleTermProgram.constraint_ast();
767        assert_eq!(ast.roots.len(), 1);
768
769        // Single term: Const * cell0 * cell1 -> chain of Mul, no Sum
770        match ast.arena.get(ast.roots[0]) {
771            ConstraintExpr::Mul(_, _) => {} // correct
772            other => panic!("Expected Mul for single-term, got {:?}", other),
773        }
774    }
775
776    // =========================================================
777    // ConstraintAst method tests
778    // =========================================================
779
780    #[test]
781    fn max_degree_fibonacci() {
782        let program = TestFibProgram;
783        let ast = program.constraint_ast();
784
785        // Fib constraints: q * next_a (degree 2), q * curr_b (degree 2)
786        // Default AST adds Const(ONE) * cell * cell per term → degree 2 + Const
787        // Const has degree 0, so each term is Mul chain: 0 + 1 + 1 = 2
788        // The AST max degree should match the flat form.
789        let flat = program.constraints();
790        let flat_max = flat
791            .iter()
792            .flat_map(|c| c.terms.iter())
793            .map(|t| t.poly_ind.len())
794            .max()
795            .unwrap_or(0);
796
797        assert_eq!(ast.max_degree(), flat_max);
798    }
799
800    #[test]
801    fn max_degree_empty() {
802        let ast = ConstraintAst::<F> {
803            arena: ConstraintArena::new(),
804            roots: Vec::new(),
805            labels: Vec::new(),
806        };
807        assert_eq!(ast.max_degree(), 0);
808    }
809
810    #[test]
811    fn max_degree_builder_mul_chain() {
812        use crate::constraint::builder::ConstraintSystem;
813
814        let cs = ConstraintSystem::<F>::new();
815        let a = cs.col(0);
816        let b = cs.col(1);
817        let c = cs.col(2);
818
819        // a * b * c = degree 3
820        cs.constrain(a * b * c);
821        // a + b = degree 1
822        cs.constrain(a + b);
823
824        let ast = cs.build();
825        assert_eq!(ast.max_degree(), 3);
826    }
827
828    #[test]
829    fn to_constraints_roundtrip_fibonacci() {
830        let program = TestFibProgram;
831        let ast = program.constraint_ast();
832        let flat_from_ast = ast.to_constraints();
833        let flat_direct = program.constraints();
834
835        // Same number of constraints
836        assert_eq!(flat_from_ast.len(), flat_direct.len());
837
838        // Same number of terms per constraint
839        for (a, d) in flat_from_ast.iter().zip(flat_direct.iter()) {
840            assert_eq!(a.terms.len(), d.terms.len());
841        }
842    }
843
844    #[test]
845    fn to_constraints_from_builder() {
846        use crate::constraint::builder::ConstraintSystem;
847
848        let cs = ConstraintSystem::<F>::new();
849        let a = cs.col(0);
850        let b = cs.col(1);
851
852        // a + b = 0  →  should produce 2 flat terms: (ONE, [a]) + (ONE, [b])
853        cs.constrain(a + b);
854
855        let ast = cs.build();
856        let flat = ast.to_constraints();
857
858        assert_eq!(flat.len(), 1);
859        assert_eq!(flat[0].terms.len(), 2);
860
861        // Both terms should have exactly 1 cell each
862        for term in &flat[0].terms {
863            assert_eq!(term.coeff, F::ONE);
864            assert_eq!(term.poly_ind.len(), 1);
865        }
866    }
867
868    #[test]
869    fn evaluate_simple_constraint() {
870        use crate::constraint::builder::ConstraintSystem;
871        use hekate_math::Flat;
872
873        let cs = ConstraintSystem::<F>::new();
874        let a = cs.col(0);
875        let b = cs.col(1);
876
877        // a + b = 0
878        cs.constrain(a + b);
879
880        let ast = cs.build();
881
882        // Evaluate at a=3, b=3 -> 3 + 3 = 0 in GF(2^k) (XOR)
883        let current = vec![
884            Flat::from_raw(F::from(3u128)),
885            Flat::from_raw(F::from(3u128)),
886        ];
887        let next = vec![Flat::from_raw(F::ZERO); 2];
888
889        let evals = ast.evaluate(&current, &next);
890        assert_eq!(evals.len(), 1);
891        assert_eq!(evals[0], Flat::from_raw(F::ZERO)); // 3 XOR 3 = 0
892
893        // Evaluate at a=3, b=5 -> 3 XOR 5 = 6 ≠ 0
894        let current2 = vec![
895            Flat::from_raw(F::from(3u128)),
896            Flat::from_raw(F::from(5u128)),
897        ];
898        let evals2 = ast.evaluate(&current2, &next);
899        assert_ne!(evals2[0], Flat::from_raw(F::ZERO));
900    }
901
902    #[test]
903    fn evaluate_into_matches_evaluate() {
904        let cs = ConstraintSystem::<F>::new();
905
906        let a = cs.col(0);
907        let b = cs.col(1);
908        let na = cs.next(0);
909
910        cs.constrain(a + b);
911        cs.constrain(a * b);
912        cs.constrain(na + a);
913
914        let ast = cs.build();
915        let zero = Flat::from_raw(F::ZERO);
916
917        let current = vec![
918            Flat::from_raw(F::from(3u128)),
919            Flat::from_raw(F::from(5u128)),
920        ];
921        let next = vec![Flat::from_raw(F::from(7u128)), zero];
922
923        let expected = ast.evaluate(&current, &next);
924
925        let mut buf = Vec::new();
926        ast.evaluate_into(&current, &next, &mut buf);
927
928        for (i, root) in ast.roots.iter().enumerate() {
929            assert_eq!(buf[root.0 as usize], expected[i]);
930        }
931    }
932}