1use crate::error::{LogicError, LogicResult};
20use scirs2_core::ndarray::Array1;
21use std::collections::HashMap;
22
23#[derive(Debug, Clone, PartialEq)]
32pub enum Opcode {
33 LoadDim(usize),
35 LoadConst(f32),
37 Add,
39 Sub,
41 Mul,
43 Div,
45 Neg,
47 Abs,
49 Sqrt,
51 Min,
53 Max,
55 CmpLe,
57 CmpGe,
59 And,
61 Or,
63 Not,
65 Dup,
67 Pop,
69}
70
71#[derive(Debug, Clone)]
79pub struct CompiledConstraint {
80 pub ops: Vec<Opcode>,
82 pub name: String,
84 pub num_dims: usize,
86}
87
88impl CompiledConstraint {
89 pub fn evaluate(&self, x: &Array1<f32>) -> LogicResult<bool> {
93 let raw = self.evaluate_raw(x)?;
94 Ok(raw != 0.0)
95 }
96
97 pub fn evaluate_raw(&self, x: &Array1<f32>) -> LogicResult<f32> {
99 if x.len() < self.num_dims {
100 return Err(LogicError::DimensionMismatch {
101 expected: self.num_dims,
102 got: x.len(),
103 });
104 }
105
106 let mut stack: Vec<f32> = Vec::with_capacity(self.ops.len());
107
108 for op in &self.ops {
109 match op {
110 Opcode::LoadDim(dim) => {
111 let val = x.get(*dim).copied().ok_or_else(|| {
112 LogicError::InvalidInput(format!(
113 "LoadDim: dimension {} out of bounds (len={})",
114 dim,
115 x.len()
116 ))
117 })?;
118 stack.push(val);
119 }
120 Opcode::LoadConst(v) => {
121 stack.push(*v);
122 }
123 Opcode::Add => {
124 let b = stack_pop(&mut stack, "Add")?;
125 let a = stack_pop(&mut stack, "Add")?;
126 stack.push(a + b);
127 }
128 Opcode::Sub => {
129 let b = stack_pop(&mut stack, "Sub")?;
130 let a = stack_pop(&mut stack, "Sub")?;
131 stack.push(a - b);
132 }
133 Opcode::Mul => {
134 let b = stack_pop(&mut stack, "Mul")?;
135 let a = stack_pop(&mut stack, "Mul")?;
136 stack.push(a * b);
137 }
138 Opcode::Div => {
139 let b = stack_pop(&mut stack, "Div")?;
140 let a = stack_pop(&mut stack, "Div")?;
141 if b == 0.0 {
142 return Err(LogicError::InvalidInput(
143 "Div: division by zero".to_string(),
144 ));
145 }
146 stack.push(a / b);
147 }
148 Opcode::Neg => {
149 let a = stack_pop(&mut stack, "Neg")?;
150 stack.push(-a);
151 }
152 Opcode::Abs => {
153 let a = stack_pop(&mut stack, "Abs")?;
154 stack.push(a.abs());
155 }
156 Opcode::Sqrt => {
157 let a = stack_pop(&mut stack, "Sqrt")?;
158 if a < 0.0 {
159 return Err(LogicError::InvalidInput(format!(
160 "Sqrt: negative argument {a}"
161 )));
162 }
163 stack.push(a.sqrt());
164 }
165 Opcode::Min => {
166 let b = stack_pop(&mut stack, "Min")?;
167 let a = stack_pop(&mut stack, "Min")?;
168 stack.push(a.min(b));
169 }
170 Opcode::Max => {
171 let b = stack_pop(&mut stack, "Max")?;
172 let a = stack_pop(&mut stack, "Max")?;
173 stack.push(a.max(b));
174 }
175 Opcode::CmpLe => {
176 let b = stack_pop(&mut stack, "CmpLe")?;
177 let a = stack_pop(&mut stack, "CmpLe")?;
178 stack.push(if a <= b { 1.0 } else { 0.0 });
179 }
180 Opcode::CmpGe => {
181 let b = stack_pop(&mut stack, "CmpGe")?;
182 let a = stack_pop(&mut stack, "CmpGe")?;
183 stack.push(if a >= b { 1.0 } else { 0.0 });
184 }
185 Opcode::And => {
186 let b = stack_pop(&mut stack, "And")?;
187 let a = stack_pop(&mut stack, "And")?;
188 stack.push(if a != 0.0 && b != 0.0 { 1.0 } else { 0.0 });
189 }
190 Opcode::Or => {
191 let b = stack_pop(&mut stack, "Or")?;
192 let a = stack_pop(&mut stack, "Or")?;
193 stack.push(if a != 0.0 || b != 0.0 { 1.0 } else { 0.0 });
194 }
195 Opcode::Not => {
196 let a = stack_pop(&mut stack, "Not")?;
197 stack.push(if a == 0.0 { 1.0 } else { 0.0 });
198 }
199 Opcode::Dup => {
200 let a = stack.last().copied().ok_or_else(|| {
201 LogicError::InvalidInput("Dup: stack underflow".to_string())
202 })?;
203 stack.push(a);
204 }
205 Opcode::Pop => {
206 stack_pop(&mut stack, "Pop")?;
207 }
208 }
209 }
210
211 stack.last().copied().ok_or_else(|| {
212 LogicError::InvalidInput("evaluate_raw: stack is empty after execution".to_string())
213 })
214 }
215
216 pub fn optimize(&self) -> Self {
222 let folded = constant_fold(&self.ops);
223 let dce = dead_code_eliminate(&folded);
224 Self {
225 ops: dce,
226 name: self.name.clone(),
227 num_dims: self.num_dims,
228 }
229 }
230
231 pub fn complexity(&self) -> usize {
233 self.ops.len()
234 }
235}
236
237#[inline]
242fn stack_pop(stack: &mut Vec<f32>, op: &str) -> LogicResult<f32> {
243 stack
244 .pop()
245 .ok_or_else(|| LogicError::InvalidInput(format!("{op}: stack underflow")))
246}
247
248fn constant_fold(ops: &[Opcode]) -> Vec<Opcode> {
254 let mut out: Vec<Opcode> = Vec::with_capacity(ops.len());
255
256 let mut i = 0;
257 while i < ops.len() {
258 if i + 2 < ops.len() {
260 if let (Opcode::LoadConst(a), Opcode::LoadConst(b)) = (&ops[i], &ops[i + 1]) {
261 let a = *a;
262 let b = *b;
263 let folded = match &ops[i + 2] {
264 Opcode::Add => Some(a + b),
265 Opcode::Sub => Some(a - b),
266 Opcode::Mul => Some(a * b),
267 Opcode::Div => {
268 if b != 0.0 {
269 Some(a / b)
270 } else {
271 None
272 }
273 }
274 Opcode::Min => Some(a.min(b)),
275 Opcode::Max => Some(a.max(b)),
276 Opcode::CmpLe => Some(if a <= b { 1.0 } else { 0.0 }),
277 Opcode::CmpGe => Some(if a >= b { 1.0 } else { 0.0 }),
278 Opcode::And => Some(if a != 0.0 && b != 0.0 { 1.0 } else { 0.0 }),
279 Opcode::Or => Some(if a != 0.0 || b != 0.0 { 1.0 } else { 0.0 }),
280 _ => None,
281 };
282 if let Some(result) = folded {
283 out.push(Opcode::LoadConst(result));
284 i += 3;
285 continue;
286 }
287 }
288 }
289
290 if i + 1 < ops.len() {
292 if let Opcode::LoadConst(a) = &ops[i] {
293 let a = *a;
294 let folded = match &ops[i + 1] {
295 Opcode::Neg => Some(-a),
296 Opcode::Abs => Some(a.abs()),
297 Opcode::Sqrt => {
298 if a >= 0.0 {
299 Some(a.sqrt())
300 } else {
301 None
302 }
303 }
304 Opcode::Not => Some(if a == 0.0 { 1.0 } else { 0.0 }),
305 _ => None,
306 };
307 if let Some(result) = folded {
308 out.push(Opcode::LoadConst(result));
309 i += 2;
310 continue;
311 }
312 }
313 }
314
315 out.push(ops[i].clone());
316 i += 1;
317 }
318
319 if out.len() < ops.len() {
321 constant_fold(&out)
322 } else {
323 out
324 }
325}
326
327fn dead_code_eliminate(ops: &[Opcode]) -> Vec<Opcode> {
330 let mut out: Vec<Opcode> = Vec::with_capacity(ops.len());
331 let mut i = 0;
332 while i < ops.len() {
333 if i + 1 < ops.len() {
334 if let Opcode::LoadConst(_) = &ops[i] {
335 if let Opcode::Pop = &ops[i + 1] {
336 i += 2;
338 continue;
339 }
340 }
341 }
342 out.push(ops[i].clone());
343 i += 1;
344 }
345 out
346}
347
348#[derive(Debug, Clone)]
358pub enum ConstraintExpr {
359 Dim(usize),
361 Const(f32),
363 Add(Box<ConstraintExpr>, Box<ConstraintExpr>),
365 Sub(Box<ConstraintExpr>, Box<ConstraintExpr>),
367 Mul(Box<ConstraintExpr>, Box<ConstraintExpr>),
369 Div(Box<ConstraintExpr>, Box<ConstraintExpr>),
371 Neg(Box<ConstraintExpr>),
373 Abs(Box<ConstraintExpr>),
375 Sqrt(Box<ConstraintExpr>),
377 Le(Box<ConstraintExpr>, Box<ConstraintExpr>),
379 Ge(Box<ConstraintExpr>, Box<ConstraintExpr>),
381 And(Box<ConstraintExpr>, Box<ConstraintExpr>),
383 Or(Box<ConstraintExpr>, Box<ConstraintExpr>),
385 Not(Box<ConstraintExpr>),
387}
388
389impl ConstraintExpr {
390 pub fn compile(&self, name: &str, num_dims: usize) -> CompiledConstraint {
396 let mut ops = Vec::new();
397 emit(self, &mut ops);
398 CompiledConstraint {
399 ops,
400 name: name.to_string(),
401 num_dims,
402 }
403 }
404
405 pub fn dim(i: usize) -> Self {
411 ConstraintExpr::Dim(i)
412 }
413
414 pub fn constant(v: f32) -> Self {
416 ConstraintExpr::Const(v)
417 }
418
419 pub fn between(dim: usize, lo: f32, hi: f32) -> Self {
421 let x = ConstraintExpr::Dim(dim);
422 let lo_le = ConstraintExpr::Le(Box::new(ConstraintExpr::Const(lo)), Box::new(x.clone()));
423 let hi_le = ConstraintExpr::Le(Box::new(x), Box::new(ConstraintExpr::Const(hi)));
424 ConstraintExpr::And(Box::new(lo_le), Box::new(hi_le))
425 }
426
427 pub fn l2_norm_le(dims: &[usize], radius: f32) -> Self {
431 assert!(!dims.is_empty(), "l2_norm_le: dims must not be empty");
432
433 let mut sum_sq: ConstraintExpr = ConstraintExpr::Mul(
435 Box::new(ConstraintExpr::Dim(dims[0])),
436 Box::new(ConstraintExpr::Dim(dims[0])),
437 );
438 for &d in &dims[1..] {
439 let sq = ConstraintExpr::Mul(
440 Box::new(ConstraintExpr::Dim(d)),
441 Box::new(ConstraintExpr::Dim(d)),
442 );
443 sum_sq = ConstraintExpr::Add(Box::new(sum_sq), Box::new(sq));
444 }
445
446 let norm = ConstraintExpr::Sqrt(Box::new(sum_sq));
447 ConstraintExpr::Le(Box::new(norm), Box::new(ConstraintExpr::Const(radius)))
448 }
449
450 pub fn affine_le(coeffs: &[(usize, f32)], rhs: f32) -> Self {
452 assert!(!coeffs.is_empty(), "affine_le: coeffs must not be empty");
453
454 let term = |(dim, c): &(usize, f32)| -> ConstraintExpr {
455 ConstraintExpr::Mul(
456 Box::new(ConstraintExpr::Const(*c)),
457 Box::new(ConstraintExpr::Dim(*dim)),
458 )
459 };
460
461 let mut sum = term(&coeffs[0]);
462 for coeff in &coeffs[1..] {
463 sum = ConstraintExpr::Add(Box::new(sum), Box::new(term(coeff)));
464 }
465
466 ConstraintExpr::Le(Box::new(sum), Box::new(ConstraintExpr::Const(rhs)))
467 }
468}
469
470fn emit(expr: &ConstraintExpr, ops: &mut Vec<Opcode>) {
475 match expr {
476 ConstraintExpr::Dim(i) => {
477 ops.push(Opcode::LoadDim(*i));
478 }
479 ConstraintExpr::Const(v) => {
480 ops.push(Opcode::LoadConst(*v));
481 }
482 ConstraintExpr::Add(a, b) => {
483 emit(a, ops);
484 emit(b, ops);
485 ops.push(Opcode::Add);
486 }
487 ConstraintExpr::Sub(a, b) => {
488 emit(a, ops);
489 emit(b, ops);
490 ops.push(Opcode::Sub);
491 }
492 ConstraintExpr::Mul(a, b) => {
493 emit(a, ops);
494 emit(b, ops);
495 ops.push(Opcode::Mul);
496 }
497 ConstraintExpr::Div(a, b) => {
498 emit(a, ops);
499 emit(b, ops);
500 ops.push(Opcode::Div);
501 }
502 ConstraintExpr::Neg(a) => {
503 emit(a, ops);
504 ops.push(Opcode::Neg);
505 }
506 ConstraintExpr::Abs(a) => {
507 emit(a, ops);
508 ops.push(Opcode::Abs);
509 }
510 ConstraintExpr::Sqrt(a) => {
511 emit(a, ops);
512 ops.push(Opcode::Sqrt);
513 }
514 ConstraintExpr::Le(a, b) => {
515 emit(a, ops);
516 emit(b, ops);
517 ops.push(Opcode::CmpLe);
518 }
519 ConstraintExpr::Ge(a, b) => {
520 emit(a, ops);
521 emit(b, ops);
522 ops.push(Opcode::CmpGe);
523 }
524 ConstraintExpr::And(a, b) => {
525 emit(a, ops);
526 emit(b, ops);
527 ops.push(Opcode::And);
528 }
529 ConstraintExpr::Or(a, b) => {
530 emit(a, ops);
531 emit(b, ops);
532 ops.push(Opcode::Or);
533 }
534 ConstraintExpr::Not(a) => {
535 emit(a, ops);
536 ops.push(Opcode::Not);
537 }
538 }
539}
540
541pub struct ConstraintProgram {
550 constraints: HashMap<String, CompiledConstraint>,
551}
552
553impl Default for ConstraintProgram {
554 fn default() -> Self {
555 Self::new()
556 }
557}
558
559impl ConstraintProgram {
560 pub fn new() -> Self {
562 Self {
563 constraints: HashMap::new(),
564 }
565 }
566
567 pub fn add(&mut self, expr: ConstraintExpr, name: &str, num_dims: usize) {
569 let compiled = expr.compile(name, num_dims);
570 self.constraints.insert(name.to_string(), compiled);
571 }
572
573 pub fn evaluate_all(&self, x: &Array1<f32>) -> LogicResult<HashMap<String, bool>> {
575 let mut results = HashMap::with_capacity(self.constraints.len());
576 for (name, constraint) in &self.constraints {
577 let feasible = constraint.evaluate(x)?;
578 results.insert(name.clone(), feasible);
579 }
580 Ok(results)
581 }
582
583 pub fn violated(&self, x: &Array1<f32>) -> LogicResult<Vec<String>> {
585 let all = self.evaluate_all(x)?;
586 let mut names: Vec<String> = all
587 .into_iter()
588 .filter_map(|(name, feasible)| if feasible { None } else { Some(name) })
589 .collect();
590 names.sort(); Ok(names)
592 }
593
594 pub fn is_feasible(&self, x: &Array1<f32>) -> LogicResult<bool> {
596 for constraint in self.constraints.values() {
597 if !constraint.evaluate(x)? {
598 return Ok(false);
599 }
600 }
601 Ok(true)
602 }
603
604 pub fn num_constraints(&self) -> usize {
606 self.constraints.len()
607 }
608}
609
610#[cfg(test)]
615mod tests {
616 use super::*;
617 use scirs2_core::ndarray::Array1;
618
619 fn arr(values: Vec<f32>) -> Array1<f32> {
620 Array1::from_vec(values)
621 }
622
623 #[test]
624 fn test_compile_constant() {
625 let expr = ConstraintExpr::constant(3.0);
626 let compiled = expr.compile("c", 0);
627 let x: Array1<f32> = Array1::from_vec(vec![]);
628 let raw = compiled.evaluate_raw(&x).expect("evaluate_raw failed");
629 assert!((raw - 3.0).abs() < 1e-6, "expected 3.0, got {raw}");
630 }
631
632 #[test]
633 fn test_compile_load_dim() {
634 let expr = ConstraintExpr::dim(1);
635 let compiled = expr.compile("c", 2);
636 let x = arr(vec![0.0, 5.0]);
637 let raw = compiled.evaluate_raw(&x).expect("evaluate_raw failed");
638 assert!((raw - 5.0).abs() < 1e-6, "expected 5.0, got {raw}");
639 }
640
641 #[test]
642 fn test_compile_between() {
643 let expr = ConstraintExpr::between(0, -1.0, 1.0);
644 let compiled = expr.compile("bound", 1);
645
646 let x_ok = arr(vec![0.5]);
648 assert!(
649 compiled.evaluate(&x_ok).expect("evaluate failed"),
650 "0.5 should be in [-1, 1]"
651 );
652
653 let x_bad = arr(vec![2.0]);
655 assert!(
656 !compiled.evaluate(&x_bad).expect("evaluate failed"),
657 "2.0 should not be in [-1, 1]"
658 );
659 }
660
661 #[test]
662 fn test_compile_affine_le() {
663 let expr = ConstraintExpr::affine_le(&[(0, 2.0), (1, 3.0)], 10.0);
665 let compiled = expr.compile("affine", 2);
666
667 let x_ok = arr(vec![1.0, 1.0]);
669 assert!(
670 compiled.evaluate(&x_ok).expect("evaluate failed"),
671 "2+3=5 should be <= 10"
672 );
673
674 let x_bad = arr(vec![3.0, 3.0]);
676 assert!(
677 !compiled.evaluate(&x_bad).expect("evaluate failed"),
678 "6+9=15 should not be <= 10"
679 );
680 }
681
682 #[test]
683 fn test_compile_l2_norm_le() {
684 let expr = ConstraintExpr::l2_norm_le(&[0, 1], 1.0);
686 let compiled = expr.compile("l2ball", 2);
687
688 let x_ok = arr(vec![0.3, 0.4]);
690 assert!(
691 compiled.evaluate(&x_ok).expect("evaluate failed"),
692 "norm(0.3, 0.4)=0.5 should be <= 1.0"
693 );
694
695 let x_bad = arr(vec![1.0, 1.0]);
697 assert!(
698 !compiled.evaluate(&x_bad).expect("evaluate failed"),
699 "norm(1, 1)=sqrt(2) should not be <= 1.0"
700 );
701 }
702
703 #[test]
704 fn test_optimize_constant_folding() {
705 let expr = ConstraintExpr::Add(
707 Box::new(ConstraintExpr::Const(2.0)),
708 Box::new(ConstraintExpr::Const(3.0)),
709 );
710 let compiled = expr.compile("fold", 0);
711 let optimized = compiled.optimize();
712
713 assert!(
716 optimized.complexity() < compiled.complexity(),
717 "optimized ({}) should have fewer ops than original ({})",
718 optimized.complexity(),
719 compiled.complexity()
720 );
721
722 let x: Array1<f32> = Array1::from_vec(vec![]);
724 let raw = optimized.evaluate_raw(&x).expect("evaluate_raw failed");
725 assert!(
726 (raw - 5.0).abs() < 1e-6,
727 "folded result should be 5.0, got {raw}"
728 );
729 }
730
731 #[test]
732 fn test_program_evaluate_all() {
733 let mut prog = ConstraintProgram::new();
734 prog.add(ConstraintExpr::between(0, 0.0, 1.0), "x_bound", 1);
735 prog.add(ConstraintExpr::between(1, 0.0, 1.0), "y_bound", 2);
736
737 let x = arr(vec![0.5, 0.5]);
738 let results = prog.evaluate_all(&x).expect("evaluate_all failed");
739
740 assert_eq!(results.len(), 2, "should have 2 entries");
741 assert!(results["x_bound"], "x_bound should be feasible");
742 assert!(results["y_bound"], "y_bound should be feasible");
743 }
744
745 #[test]
746 fn test_program_violated_returns_names() {
747 let mut prog = ConstraintProgram::new();
748 prog.add(ConstraintExpr::between(0, 0.0, 1.0), "x_bound", 1);
749 prog.add(ConstraintExpr::between(1, 0.0, 1.0), "y_bound", 2);
750
751 let x = arr(vec![2.0, 0.5]);
753 let violated = prog.violated(&x).expect("violated failed");
754
755 assert_eq!(violated, vec!["x_bound".to_string()]);
756 }
757
758 #[test]
759 fn test_complexity_before_after_optimize() {
760 let expr = ConstraintExpr::Add(
762 Box::new(ConstraintExpr::Add(
763 Box::new(ConstraintExpr::Const(1.0)),
764 Box::new(ConstraintExpr::Const(2.0)),
765 )),
766 Box::new(ConstraintExpr::Const(3.0)),
767 );
768 let compiled = expr.compile("nested", 0);
769 let optimized = compiled.optimize();
770
771 assert!(
772 optimized.complexity() <= compiled.complexity(),
773 "optimized complexity {} should be <= original {}",
774 optimized.complexity(),
775 compiled.complexity()
776 );
777
778 let x: Array1<f32> = Array1::from_vec(vec![]);
780 let raw = optimized.evaluate_raw(&x).expect("evaluate_raw failed");
781 assert!((raw - 6.0).abs() < 1e-6, "result should be 6.0, got {raw}");
782 }
783}