Skip to main content

kizzasi_logic/
compiler.rs

1//! Constraint Compilation to Optimized Bytecode IR
2//!
3//! Compiles high-level constraint expressions into an optimized stack-based
4//! bytecode IR for fast evaluation. The compiler supports constant folding,
5//! dead code elimination, and batch constraint programs.
6//!
7//! # Example
8//!
9//! ```
10//! use kizzasi_logic::compiler::{ConstraintExpr, ConstraintProgram};
11//! use scirs2_core::ndarray::Array1;
12//!
13//! let expr = ConstraintExpr::between(0, -1.0, 1.0);
14//! let compiled = expr.compile("bound", 1);
15//! let x = Array1::from_vec(vec![0.5_f32]);
16//! assert!(compiled.evaluate(&x).unwrap());
17//! ```
18
19use crate::error::{LogicError, LogicResult};
20use scirs2_core::ndarray::Array1;
21use std::collections::HashMap;
22
23// ============================================================================
24// Opcode — stack-based bytecode instruction set
25// ============================================================================
26
27/// Bytecode instruction set for constraint expression evaluation.
28///
29/// The virtual machine maintains a stack of `f32` values.
30/// Each instruction pops its operands and pushes its result.
31#[derive(Debug, Clone, PartialEq)]
32pub enum Opcode {
33    /// Push `x[dim]` onto the stack
34    LoadDim(usize),
35    /// Push a constant onto the stack
36    LoadConst(f32),
37    /// Pop `b`, pop `a`; push `a + b`
38    Add,
39    /// Pop `b`, pop `a`; push `a - b`
40    Sub,
41    /// Pop `b`, pop `a`; push `a * b`
42    Mul,
43    /// Pop `b`, pop `a`; push `a / b` (errors on divide-by-zero)
44    Div,
45    /// Pop `a`; push `-a`
46    Neg,
47    /// Pop `a`; push `|a|`
48    Abs,
49    /// Pop `a`; push `sqrt(a)`
50    Sqrt,
51    /// Pop `b`, pop `a`; push `min(a, b)`
52    Min,
53    /// Pop `b`, pop `a`; push `max(a, b)`
54    Max,
55    /// Pop `b`, pop `a`; push `1.0` if `a <= b`, else `0.0`
56    CmpLe,
57    /// Pop `b`, pop `a`; push `1.0` if `a >= b`, else `0.0`
58    CmpGe,
59    /// Pop `b`, pop `a`; push `1.0` if both non-zero, else `0.0`
60    And,
61    /// Pop `b`, pop `a`; push `1.0` if either non-zero, else `0.0`
62    Or,
63    /// Pop `a`; push `1.0` if `a == 0.0`, else `0.0`
64    Not,
65    /// Duplicate the top of the stack
66    Dup,
67    /// Discard the top of the stack
68    Pop,
69}
70
71// ============================================================================
72// CompiledConstraint — executable bytecode program
73// ============================================================================
74
75/// A compiled constraint: a linear sequence of `Opcode`s evaluated on a
76/// stack machine. The result of evaluation is the top of the stack after
77/// all instructions have been executed.
78#[derive(Debug, Clone)]
79pub struct CompiledConstraint {
80    /// The bytecode instruction sequence
81    pub ops: Vec<Opcode>,
82    /// Human-readable name for this constraint
83    pub name: String,
84    /// Expected dimensionality of the input vector
85    pub num_dims: usize,
86}
87
88impl CompiledConstraint {
89    /// Execute the bytecode on `x` and return feasibility.
90    ///
91    /// Feasible iff the top of the stack after execution is non-zero.
92    pub fn evaluate(&self, x: &Array1<f32>) -> LogicResult<bool> {
93        let raw = self.evaluate_raw(x)?;
94        Ok(raw != 0.0)
95    }
96
97    /// Execute the bytecode on `x` and return the raw top-of-stack value.
98    pub fn evaluate_raw(&self, x: &Array1<f32>) -> LogicResult<f32> {
99        if x.len() < self.num_dims {
100            return Err(LogicError::DimensionMismatch {
101                expected: self.num_dims,
102                got: x.len(),
103            });
104        }
105
106        let mut stack: Vec<f32> = Vec::with_capacity(self.ops.len());
107
108        for op in &self.ops {
109            match op {
110                Opcode::LoadDim(dim) => {
111                    let val = x.get(*dim).copied().ok_or_else(|| {
112                        LogicError::InvalidInput(format!(
113                            "LoadDim: dimension {} out of bounds (len={})",
114                            dim,
115                            x.len()
116                        ))
117                    })?;
118                    stack.push(val);
119                }
120                Opcode::LoadConst(v) => {
121                    stack.push(*v);
122                }
123                Opcode::Add => {
124                    let b = stack_pop(&mut stack, "Add")?;
125                    let a = stack_pop(&mut stack, "Add")?;
126                    stack.push(a + b);
127                }
128                Opcode::Sub => {
129                    let b = stack_pop(&mut stack, "Sub")?;
130                    let a = stack_pop(&mut stack, "Sub")?;
131                    stack.push(a - b);
132                }
133                Opcode::Mul => {
134                    let b = stack_pop(&mut stack, "Mul")?;
135                    let a = stack_pop(&mut stack, "Mul")?;
136                    stack.push(a * b);
137                }
138                Opcode::Div => {
139                    let b = stack_pop(&mut stack, "Div")?;
140                    let a = stack_pop(&mut stack, "Div")?;
141                    if b == 0.0 {
142                        return Err(LogicError::InvalidInput(
143                            "Div: division by zero".to_string(),
144                        ));
145                    }
146                    stack.push(a / b);
147                }
148                Opcode::Neg => {
149                    let a = stack_pop(&mut stack, "Neg")?;
150                    stack.push(-a);
151                }
152                Opcode::Abs => {
153                    let a = stack_pop(&mut stack, "Abs")?;
154                    stack.push(a.abs());
155                }
156                Opcode::Sqrt => {
157                    let a = stack_pop(&mut stack, "Sqrt")?;
158                    if a < 0.0 {
159                        return Err(LogicError::InvalidInput(format!(
160                            "Sqrt: negative argument {a}"
161                        )));
162                    }
163                    stack.push(a.sqrt());
164                }
165                Opcode::Min => {
166                    let b = stack_pop(&mut stack, "Min")?;
167                    let a = stack_pop(&mut stack, "Min")?;
168                    stack.push(a.min(b));
169                }
170                Opcode::Max => {
171                    let b = stack_pop(&mut stack, "Max")?;
172                    let a = stack_pop(&mut stack, "Max")?;
173                    stack.push(a.max(b));
174                }
175                Opcode::CmpLe => {
176                    let b = stack_pop(&mut stack, "CmpLe")?;
177                    let a = stack_pop(&mut stack, "CmpLe")?;
178                    stack.push(if a <= b { 1.0 } else { 0.0 });
179                }
180                Opcode::CmpGe => {
181                    let b = stack_pop(&mut stack, "CmpGe")?;
182                    let a = stack_pop(&mut stack, "CmpGe")?;
183                    stack.push(if a >= b { 1.0 } else { 0.0 });
184                }
185                Opcode::And => {
186                    let b = stack_pop(&mut stack, "And")?;
187                    let a = stack_pop(&mut stack, "And")?;
188                    stack.push(if a != 0.0 && b != 0.0 { 1.0 } else { 0.0 });
189                }
190                Opcode::Or => {
191                    let b = stack_pop(&mut stack, "Or")?;
192                    let a = stack_pop(&mut stack, "Or")?;
193                    stack.push(if a != 0.0 || b != 0.0 { 1.0 } else { 0.0 });
194                }
195                Opcode::Not => {
196                    let a = stack_pop(&mut stack, "Not")?;
197                    stack.push(if a == 0.0 { 1.0 } else { 0.0 });
198                }
199                Opcode::Dup => {
200                    let a = stack.last().copied().ok_or_else(|| {
201                        LogicError::InvalidInput("Dup: stack underflow".to_string())
202                    })?;
203                    stack.push(a);
204                }
205                Opcode::Pop => {
206                    stack_pop(&mut stack, "Pop")?;
207                }
208            }
209        }
210
211        stack.last().copied().ok_or_else(|| {
212            LogicError::InvalidInput("evaluate_raw: stack is empty after execution".to_string())
213        })
214    }
215
216    /// Optimize the bytecode via constant folding and dead code elimination.
217    ///
218    /// Constant folding: sequences of two `LoadConst` instructions followed by
219    /// a binary arithmetic/comparison opcode are collapsed into a single
220    /// `LoadConst` with the pre-computed result.
221    pub fn optimize(&self) -> Self {
222        let folded = constant_fold(&self.ops);
223        let dce = dead_code_eliminate(&folded);
224        Self {
225            ops: dce,
226            name: self.name.clone(),
227            num_dims: self.num_dims,
228        }
229    }
230
231    /// Return the number of opcodes (before optimization).
232    pub fn complexity(&self) -> usize {
233        self.ops.len()
234    }
235}
236
237// ============================================================================
238// Internal stack helpers
239// ============================================================================
240
241#[inline]
242fn stack_pop(stack: &mut Vec<f32>, op: &str) -> LogicResult<f32> {
243    stack
244        .pop()
245        .ok_or_else(|| LogicError::InvalidInput(format!("{op}: stack underflow")))
246}
247
248// ============================================================================
249// Optimizer passes
250// ============================================================================
251
252/// Constant folding pass: collapse pairs of LoadConst + binary/unary op.
253fn constant_fold(ops: &[Opcode]) -> Vec<Opcode> {
254    let mut out: Vec<Opcode> = Vec::with_capacity(ops.len());
255
256    let mut i = 0;
257    while i < ops.len() {
258        // Attempt binary constant fold: LoadConst(a), LoadConst(b), BinaryOp → LoadConst(result)
259        if i + 2 < ops.len() {
260            if let (Opcode::LoadConst(a), Opcode::LoadConst(b)) = (&ops[i], &ops[i + 1]) {
261                let a = *a;
262                let b = *b;
263                let folded = match &ops[i + 2] {
264                    Opcode::Add => Some(a + b),
265                    Opcode::Sub => Some(a - b),
266                    Opcode::Mul => Some(a * b),
267                    Opcode::Div => {
268                        if b != 0.0 {
269                            Some(a / b)
270                        } else {
271                            None
272                        }
273                    }
274                    Opcode::Min => Some(a.min(b)),
275                    Opcode::Max => Some(a.max(b)),
276                    Opcode::CmpLe => Some(if a <= b { 1.0 } else { 0.0 }),
277                    Opcode::CmpGe => Some(if a >= b { 1.0 } else { 0.0 }),
278                    Opcode::And => Some(if a != 0.0 && b != 0.0 { 1.0 } else { 0.0 }),
279                    Opcode::Or => Some(if a != 0.0 || b != 0.0 { 1.0 } else { 0.0 }),
280                    _ => None,
281                };
282                if let Some(result) = folded {
283                    out.push(Opcode::LoadConst(result));
284                    i += 3;
285                    continue;
286                }
287            }
288        }
289
290        // Attempt unary constant fold: LoadConst(a), UnaryOp → LoadConst(result)
291        if i + 1 < ops.len() {
292            if let Opcode::LoadConst(a) = &ops[i] {
293                let a = *a;
294                let folded = match &ops[i + 1] {
295                    Opcode::Neg => Some(-a),
296                    Opcode::Abs => Some(a.abs()),
297                    Opcode::Sqrt => {
298                        if a >= 0.0 {
299                            Some(a.sqrt())
300                        } else {
301                            None
302                        }
303                    }
304                    Opcode::Not => Some(if a == 0.0 { 1.0 } else { 0.0 }),
305                    _ => None,
306                };
307                if let Some(result) = folded {
308                    out.push(Opcode::LoadConst(result));
309                    i += 2;
310                    continue;
311                }
312            }
313        }
314
315        out.push(ops[i].clone());
316        i += 1;
317    }
318
319    // Run again if any folding happened (handles nested folds)
320    if out.len() < ops.len() {
321        constant_fold(&out)
322    } else {
323        out
324    }
325}
326
327/// Dead code elimination: remove `Pop` immediately after a `LoadConst`
328/// (the value is never used).
329fn dead_code_eliminate(ops: &[Opcode]) -> Vec<Opcode> {
330    let mut out: Vec<Opcode> = Vec::with_capacity(ops.len());
331    let mut i = 0;
332    while i < ops.len() {
333        if i + 1 < ops.len() {
334            if let Opcode::LoadConst(_) = &ops[i] {
335                if let Opcode::Pop = &ops[i + 1] {
336                    // LoadConst followed by Pop — skip both
337                    i += 2;
338                    continue;
339                }
340            }
341        }
342        out.push(ops[i].clone());
343        i += 1;
344    }
345    out
346}
347
348// ============================================================================
349// ConstraintExpr — high-level AST
350// ============================================================================
351
352/// High-level constraint expression AST.
353///
354/// Build expressions using the builder helpers (`dim`, `constant`, `between`,
355/// `l2_norm_le`, `affine_le`) or compose them manually, then call
356/// [`ConstraintExpr::compile`] to produce a [`CompiledConstraint`].
357#[derive(Debug, Clone)]
358pub enum ConstraintExpr {
359    /// Access `x[i]`
360    Dim(usize),
361    /// A constant scalar
362    Const(f32),
363    /// `a + b`
364    Add(Box<ConstraintExpr>, Box<ConstraintExpr>),
365    /// `a - b`
366    Sub(Box<ConstraintExpr>, Box<ConstraintExpr>),
367    /// `a * b`
368    Mul(Box<ConstraintExpr>, Box<ConstraintExpr>),
369    /// `a / b`
370    Div(Box<ConstraintExpr>, Box<ConstraintExpr>),
371    /// `-a`
372    Neg(Box<ConstraintExpr>),
373    /// `|a|`
374    Abs(Box<ConstraintExpr>),
375    /// `sqrt(a)`
376    Sqrt(Box<ConstraintExpr>),
377    /// `a <= b` (evaluates to 1.0 or 0.0)
378    Le(Box<ConstraintExpr>, Box<ConstraintExpr>),
379    /// `a >= b` (evaluates to 1.0 or 0.0)
380    Ge(Box<ConstraintExpr>, Box<ConstraintExpr>),
381    /// `a && b`
382    And(Box<ConstraintExpr>, Box<ConstraintExpr>),
383    /// `a || b`
384    Or(Box<ConstraintExpr>, Box<ConstraintExpr>),
385    /// `!a`
386    Not(Box<ConstraintExpr>),
387}
388
389impl ConstraintExpr {
390    // ------------------------------------------------------------------
391    // Compilation
392    // ------------------------------------------------------------------
393
394    /// Compile this AST into a [`CompiledConstraint`].
395    pub fn compile(&self, name: &str, num_dims: usize) -> CompiledConstraint {
396        let mut ops = Vec::new();
397        emit(self, &mut ops);
398        CompiledConstraint {
399            ops,
400            name: name.to_string(),
401            num_dims,
402        }
403    }
404
405    // ------------------------------------------------------------------
406    // Builder helpers
407    // ------------------------------------------------------------------
408
409    /// Reference `x[i]`
410    pub fn dim(i: usize) -> Self {
411        ConstraintExpr::Dim(i)
412    }
413
414    /// A scalar constant
415    pub fn constant(v: f32) -> Self {
416        ConstraintExpr::Const(v)
417    }
418
419    /// `lo <= x[dim] <= hi`
420    pub fn between(dim: usize, lo: f32, hi: f32) -> Self {
421        let x = ConstraintExpr::Dim(dim);
422        let lo_le = ConstraintExpr::Le(Box::new(ConstraintExpr::Const(lo)), Box::new(x.clone()));
423        let hi_le = ConstraintExpr::Le(Box::new(x), Box::new(ConstraintExpr::Const(hi)));
424        ConstraintExpr::And(Box::new(lo_le), Box::new(hi_le))
425    }
426
427    /// `||x[dims]||_2 <= radius`
428    ///
429    /// Compiles to: `sqrt(sum_i(x[dims[i]]^2)) <= radius`
430    pub fn l2_norm_le(dims: &[usize], radius: f32) -> Self {
431        assert!(!dims.is_empty(), "l2_norm_le: dims must not be empty");
432
433        // Build sum of squares
434        let mut sum_sq: ConstraintExpr = ConstraintExpr::Mul(
435            Box::new(ConstraintExpr::Dim(dims[0])),
436            Box::new(ConstraintExpr::Dim(dims[0])),
437        );
438        for &d in &dims[1..] {
439            let sq = ConstraintExpr::Mul(
440                Box::new(ConstraintExpr::Dim(d)),
441                Box::new(ConstraintExpr::Dim(d)),
442            );
443            sum_sq = ConstraintExpr::Add(Box::new(sum_sq), Box::new(sq));
444        }
445
446        let norm = ConstraintExpr::Sqrt(Box::new(sum_sq));
447        ConstraintExpr::Le(Box::new(norm), Box::new(ConstraintExpr::Const(radius)))
448    }
449
450    /// `sum_i(coeffs[i].1 * x[coeffs[i].0]) <= rhs`
451    pub fn affine_le(coeffs: &[(usize, f32)], rhs: f32) -> Self {
452        assert!(!coeffs.is_empty(), "affine_le: coeffs must not be empty");
453
454        let term = |(dim, c): &(usize, f32)| -> ConstraintExpr {
455            ConstraintExpr::Mul(
456                Box::new(ConstraintExpr::Const(*c)),
457                Box::new(ConstraintExpr::Dim(*dim)),
458            )
459        };
460
461        let mut sum = term(&coeffs[0]);
462        for coeff in &coeffs[1..] {
463            sum = ConstraintExpr::Add(Box::new(sum), Box::new(term(coeff)));
464        }
465
466        ConstraintExpr::Le(Box::new(sum), Box::new(ConstraintExpr::Const(rhs)))
467    }
468}
469
470// ============================================================================
471// Code emission (AST → opcodes)
472// ============================================================================
473
474fn emit(expr: &ConstraintExpr, ops: &mut Vec<Opcode>) {
475    match expr {
476        ConstraintExpr::Dim(i) => {
477            ops.push(Opcode::LoadDim(*i));
478        }
479        ConstraintExpr::Const(v) => {
480            ops.push(Opcode::LoadConst(*v));
481        }
482        ConstraintExpr::Add(a, b) => {
483            emit(a, ops);
484            emit(b, ops);
485            ops.push(Opcode::Add);
486        }
487        ConstraintExpr::Sub(a, b) => {
488            emit(a, ops);
489            emit(b, ops);
490            ops.push(Opcode::Sub);
491        }
492        ConstraintExpr::Mul(a, b) => {
493            emit(a, ops);
494            emit(b, ops);
495            ops.push(Opcode::Mul);
496        }
497        ConstraintExpr::Div(a, b) => {
498            emit(a, ops);
499            emit(b, ops);
500            ops.push(Opcode::Div);
501        }
502        ConstraintExpr::Neg(a) => {
503            emit(a, ops);
504            ops.push(Opcode::Neg);
505        }
506        ConstraintExpr::Abs(a) => {
507            emit(a, ops);
508            ops.push(Opcode::Abs);
509        }
510        ConstraintExpr::Sqrt(a) => {
511            emit(a, ops);
512            ops.push(Opcode::Sqrt);
513        }
514        ConstraintExpr::Le(a, b) => {
515            emit(a, ops);
516            emit(b, ops);
517            ops.push(Opcode::CmpLe);
518        }
519        ConstraintExpr::Ge(a, b) => {
520            emit(a, ops);
521            emit(b, ops);
522            ops.push(Opcode::CmpGe);
523        }
524        ConstraintExpr::And(a, b) => {
525            emit(a, ops);
526            emit(b, ops);
527            ops.push(Opcode::And);
528        }
529        ConstraintExpr::Or(a, b) => {
530            emit(a, ops);
531            emit(b, ops);
532            ops.push(Opcode::Or);
533        }
534        ConstraintExpr::Not(a) => {
535            emit(a, ops);
536            ops.push(Opcode::Not);
537        }
538    }
539}
540
541// ============================================================================
542// ConstraintProgram — named collection of compiled constraints
543// ============================================================================
544
545/// A named collection of compiled constraints that can be evaluated together.
546///
547/// All constraints share the same input vector but may have different
548/// dimensionality requirements.
549pub struct ConstraintProgram {
550    constraints: HashMap<String, CompiledConstraint>,
551}
552
553impl Default for ConstraintProgram {
554    fn default() -> Self {
555        Self::new()
556    }
557}
558
559impl ConstraintProgram {
560    /// Create an empty program.
561    pub fn new() -> Self {
562        Self {
563            constraints: HashMap::new(),
564        }
565    }
566
567    /// Compile and add a constraint expression to the program.
568    pub fn add(&mut self, expr: ConstraintExpr, name: &str, num_dims: usize) {
569        let compiled = expr.compile(name, num_dims);
570        self.constraints.insert(name.to_string(), compiled);
571    }
572
573    /// Evaluate all constraints and return a map from name → feasibility.
574    pub fn evaluate_all(&self, x: &Array1<f32>) -> LogicResult<HashMap<String, bool>> {
575        let mut results = HashMap::with_capacity(self.constraints.len());
576        for (name, constraint) in &self.constraints {
577            let feasible = constraint.evaluate(x)?;
578            results.insert(name.clone(), feasible);
579        }
580        Ok(results)
581    }
582
583    /// Return the names of all violated (infeasible) constraints.
584    pub fn violated(&self, x: &Array1<f32>) -> LogicResult<Vec<String>> {
585        let all = self.evaluate_all(x)?;
586        let mut names: Vec<String> = all
587            .into_iter()
588            .filter_map(|(name, feasible)| if feasible { None } else { Some(name) })
589            .collect();
590        names.sort(); // deterministic order
591        Ok(names)
592    }
593
594    /// Return `true` iff all constraints are satisfied.
595    pub fn is_feasible(&self, x: &Array1<f32>) -> LogicResult<bool> {
596        for constraint in self.constraints.values() {
597            if !constraint.evaluate(x)? {
598                return Ok(false);
599            }
600        }
601        Ok(true)
602    }
603
604    /// Return the number of constraints in this program.
605    pub fn num_constraints(&self) -> usize {
606        self.constraints.len()
607    }
608}
609
610// ============================================================================
611// Tests
612// ============================================================================
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617    use scirs2_core::ndarray::Array1;
618
619    fn arr(values: Vec<f32>) -> Array1<f32> {
620        Array1::from_vec(values)
621    }
622
623    #[test]
624    fn test_compile_constant() {
625        let expr = ConstraintExpr::constant(3.0);
626        let compiled = expr.compile("c", 0);
627        let x: Array1<f32> = Array1::from_vec(vec![]);
628        let raw = compiled.evaluate_raw(&x).expect("evaluate_raw failed");
629        assert!((raw - 3.0).abs() < 1e-6, "expected 3.0, got {raw}");
630    }
631
632    #[test]
633    fn test_compile_load_dim() {
634        let expr = ConstraintExpr::dim(1);
635        let compiled = expr.compile("c", 2);
636        let x = arr(vec![0.0, 5.0]);
637        let raw = compiled.evaluate_raw(&x).expect("evaluate_raw failed");
638        assert!((raw - 5.0).abs() < 1e-6, "expected 5.0, got {raw}");
639    }
640
641    #[test]
642    fn test_compile_between() {
643        let expr = ConstraintExpr::between(0, -1.0, 1.0);
644        let compiled = expr.compile("bound", 1);
645
646        // Feasible: x[0] = 0.5
647        let x_ok = arr(vec![0.5]);
648        assert!(
649            compiled.evaluate(&x_ok).expect("evaluate failed"),
650            "0.5 should be in [-1, 1]"
651        );
652
653        // Infeasible: x[0] = 2.0
654        let x_bad = arr(vec![2.0]);
655        assert!(
656            !compiled.evaluate(&x_bad).expect("evaluate failed"),
657            "2.0 should not be in [-1, 1]"
658        );
659    }
660
661    #[test]
662    fn test_compile_affine_le() {
663        // 2*x[0] + 3*x[1] <= 10
664        let expr = ConstraintExpr::affine_le(&[(0, 2.0), (1, 3.0)], 10.0);
665        let compiled = expr.compile("affine", 2);
666
667        // 2*1 + 3*1 = 5 <= 10 → feasible
668        let x_ok = arr(vec![1.0, 1.0]);
669        assert!(
670            compiled.evaluate(&x_ok).expect("evaluate failed"),
671            "2+3=5 should be <= 10"
672        );
673
674        // 2*3 + 3*3 = 15 > 10 → infeasible
675        let x_bad = arr(vec![3.0, 3.0]);
676        assert!(
677            !compiled.evaluate(&x_bad).expect("evaluate failed"),
678            "6+9=15 should not be <= 10"
679        );
680    }
681
682    #[test]
683    fn test_compile_l2_norm_le() {
684        // ||(x[0], x[1])||_2 <= 1.0
685        let expr = ConstraintExpr::l2_norm_le(&[0, 1], 1.0);
686        let compiled = expr.compile("l2ball", 2);
687
688        // (0.3, 0.4): norm = 0.5 <= 1.0
689        let x_ok = arr(vec![0.3, 0.4]);
690        assert!(
691            compiled.evaluate(&x_ok).expect("evaluate failed"),
692            "norm(0.3, 0.4)=0.5 should be <= 1.0"
693        );
694
695        // (1.0, 1.0): norm = sqrt(2) > 1.0
696        let x_bad = arr(vec![1.0, 1.0]);
697        assert!(
698            !compiled.evaluate(&x_bad).expect("evaluate failed"),
699            "norm(1, 1)=sqrt(2) should not be <= 1.0"
700        );
701    }
702
703    #[test]
704    fn test_optimize_constant_folding() {
705        // Add(Const(2), Const(3)) should fold to a single LoadConst(5)
706        let expr = ConstraintExpr::Add(
707            Box::new(ConstraintExpr::Const(2.0)),
708            Box::new(ConstraintExpr::Const(3.0)),
709        );
710        let compiled = expr.compile("fold", 0);
711        let optimized = compiled.optimize();
712
713        // Unoptimized: LoadConst(2), LoadConst(3), Add = 3 ops
714        // Optimized:   LoadConst(5) = 1 op
715        assert!(
716            optimized.complexity() < compiled.complexity(),
717            "optimized ({}) should have fewer ops than original ({})",
718            optimized.complexity(),
719            compiled.complexity()
720        );
721
722        // Verify the result is still correct
723        let x: Array1<f32> = Array1::from_vec(vec![]);
724        let raw = optimized.evaluate_raw(&x).expect("evaluate_raw failed");
725        assert!(
726            (raw - 5.0).abs() < 1e-6,
727            "folded result should be 5.0, got {raw}"
728        );
729    }
730
731    #[test]
732    fn test_program_evaluate_all() {
733        let mut prog = ConstraintProgram::new();
734        prog.add(ConstraintExpr::between(0, 0.0, 1.0), "x_bound", 1);
735        prog.add(ConstraintExpr::between(1, 0.0, 1.0), "y_bound", 2);
736
737        let x = arr(vec![0.5, 0.5]);
738        let results = prog.evaluate_all(&x).expect("evaluate_all failed");
739
740        assert_eq!(results.len(), 2, "should have 2 entries");
741        assert!(results["x_bound"], "x_bound should be feasible");
742        assert!(results["y_bound"], "y_bound should be feasible");
743    }
744
745    #[test]
746    fn test_program_violated_returns_names() {
747        let mut prog = ConstraintProgram::new();
748        prog.add(ConstraintExpr::between(0, 0.0, 1.0), "x_bound", 1);
749        prog.add(ConstraintExpr::between(1, 0.0, 1.0), "y_bound", 2);
750
751        // x[0] = 2.0 violates x_bound; x[1] = 0.5 satisfies y_bound
752        let x = arr(vec![2.0, 0.5]);
753        let violated = prog.violated(&x).expect("violated failed");
754
755        assert_eq!(violated, vec!["x_bound".to_string()]);
756    }
757
758    #[test]
759    fn test_complexity_before_after_optimize() {
760        // Deep nested expression: Add(Add(Const(1), Const(2)), Const(3))
761        let expr = ConstraintExpr::Add(
762            Box::new(ConstraintExpr::Add(
763                Box::new(ConstraintExpr::Const(1.0)),
764                Box::new(ConstraintExpr::Const(2.0)),
765            )),
766            Box::new(ConstraintExpr::Const(3.0)),
767        );
768        let compiled = expr.compile("nested", 0);
769        let optimized = compiled.optimize();
770
771        assert!(
772            optimized.complexity() <= compiled.complexity(),
773            "optimized complexity {} should be <= original {}",
774            optimized.complexity(),
775            compiled.complexity()
776        );
777
778        // Also verify correctness
779        let x: Array1<f32> = Array1::from_vec(vec![]);
780        let raw = optimized.evaluate_raw(&x).expect("evaluate_raw failed");
781        assert!((raw - 6.0).abs() < 1e-6, "result should be 6.0, got {raw}");
782    }
783}