1use super::simplex::{LinExpr, Simplex, VarId};
4use crate::theory::{EqualityNotification, Theory, TheoryCombination, TheoryId, TheoryResult};
5use num_rational::Rational64;
6use num_traits::{One, Signed};
7use oxiz_core::ast::TermId;
8use oxiz_core::error::Result;
9use rustc_hash::FxHashMap;
10
11fn gcd_i64(mut a: i64, mut b: i64) -> i64 {
13 a = a.abs();
14 b = b.abs();
15 while b != 0 {
16 let temp = b;
17 b = a % b;
18 a = temp;
19 }
20 a
21}
22
23#[derive(Debug)]
25pub struct ArithSolver {
26 simplex: Simplex,
28 term_to_var: FxHashMap<TermId, VarId>,
30 var_to_term: Vec<TermId>,
32 reason_counter: u32,
34 reasons: Vec<TermId>,
36 is_integer: bool,
38 context_stack: Vec<ContextState>,
40}
41
42#[derive(Debug, Clone)]
44struct ContextState {
45 num_vars: usize,
46 num_reasons: usize,
47}
48
49impl Default for ArithSolver {
50 fn default() -> Self {
51 Self::new(false)
52 }
53}
54
55impl ArithSolver {
56 #[must_use]
58 pub fn new(is_integer: bool) -> Self {
59 Self {
60 simplex: Simplex::new(),
61 term_to_var: FxHashMap::default(),
62 var_to_term: Vec::new(),
63 reason_counter: 0,
64 reasons: Vec::new(),
65 is_integer,
66 context_stack: Vec::new(),
67 }
68 }
69
70 #[must_use]
72 pub fn lra() -> Self {
73 Self::new(false)
74 }
75
76 #[must_use]
78 pub fn lia() -> Self {
79 Self::new(true)
80 }
81
82 pub fn intern(&mut self, term: TermId) -> VarId {
84 if let Some(&var) = self.term_to_var.get(&term) {
85 return var;
86 }
87
88 let var = self.simplex.new_var();
89 self.term_to_var.insert(term, var);
90 self.var_to_term.push(term);
91 var
92 }
93
94 fn add_reason(&mut self, term: TermId) -> u32 {
96 let id = self.reason_counter;
97 self.reason_counter += 1;
98 self.reasons.push(term);
99 id
100 }
101
102 fn normalize_expr(&self, expr: &mut LinExpr) {
109 if expr.terms.is_empty() {
110 return;
111 }
112
113 if self.is_integer {
115 let gcd = expr
117 .terms
118 .iter()
119 .map(|(_, c)| c.numer().abs())
120 .fold(0i64, |acc, n| if acc == 0 { n } else { gcd_i64(acc, n) });
121
122 if gcd > 1 {
123 let divisor = Rational64::from_integer(gcd);
124 expr.scale(Rational64::one() / divisor);
125 }
126 }
127
128 if let Some((_, c)) = expr.terms.first()
130 && c.is_negative()
131 {
132 expr.negate();
133 }
134
135 expr.terms.sort_by_key(|(v, _)| *v);
137 }
138
139 pub fn assert_le(&mut self, lhs: &[(TermId, Rational64)], rhs: Rational64, reason: TermId) {
141 let mut expr = LinExpr::new();
142
143 for (term, coef) in lhs {
144 let var = self.intern(*term);
145 expr.add_term(var, *coef);
146 }
147 expr.add_constant(-rhs);
148
149 self.normalize_expr(&mut expr);
151
152 let reason_id = self.add_reason(reason);
153 self.simplex.add_le(expr, reason_id);
154 }
155
156 pub fn assert_ge(&mut self, lhs: &[(TermId, Rational64)], rhs: Rational64, reason: TermId) {
158 let mut expr = LinExpr::new();
159
160 for (term, coef) in lhs {
161 let var = self.intern(*term);
162 expr.add_term(var, *coef);
163 }
164 expr.add_constant(-rhs);
165
166 self.normalize_expr(&mut expr);
168
169 let reason_id = self.add_reason(reason);
170 self.simplex.add_ge(expr, reason_id);
171 }
172
173 pub fn assert_eq(&mut self, lhs: &[(TermId, Rational64)], rhs: Rational64, reason: TermId) {
181 let mut expr = LinExpr::new();
182
183 for (term, coef) in lhs {
184 let var = self.intern(*term);
185 expr.add_term(var, *coef);
186 }
187 expr.add_constant(-rhs);
188
189 if self.is_integer {
192 let coeffs: Vec<i64> = expr
194 .terms
195 .iter()
196 .filter_map(|(_, c)| {
197 if c.denom() == &1 {
198 Some(*c.numer())
199 } else {
200 None
201 }
202 })
203 .collect();
204
205 let const_term = if expr.constant.denom() == &1 {
207 -*expr.constant.numer()
208 } else {
209 if let Some(&(var, _)) = expr.terms.first() {
211 self.simplex.set_lower(var, Rational64::from_integer(1), 0);
212 self.simplex.set_upper(var, Rational64::from_integer(0), 0);
213 }
214 return;
215 };
216
217 if !coeffs.is_empty() && coeffs.len() == expr.terms.len() {
219 let g = coeffs.iter().fold(0i64, |acc, &c| gcd_i64(acc, c.abs()));
221
222 if g > 0 && const_term % g != 0 {
223 if let Some(&(var, _)) = expr.terms.first() {
226 self.simplex.set_lower(var, Rational64::from_integer(1), 0);
227 self.simplex.set_upper(var, Rational64::from_integer(0), 0);
228 }
229 return;
230 }
231 }
232 }
233
234 self.normalize_expr(&mut expr);
236
237 let reason_id = self.add_reason(reason);
238 self.simplex.add_eq(expr, reason_id);
239 }
240
241 pub fn assert_lt(&mut self, lhs: &[(TermId, Rational64)], rhs: Rational64, reason: TermId) {
245 if self.is_integer {
248 self.assert_le(lhs, rhs - Rational64::one(), reason);
250 return;
251 }
252
253 let mut expr = LinExpr::new();
256
257 for (term, coef) in lhs {
258 let var = self.intern(*term);
259 expr.add_term(var, *coef);
260 }
261 expr.add_constant(-rhs);
262
263 let reason_id = self.add_reason(reason);
268 self.simplex.add_strict_lt(expr, reason_id);
269 }
270
271 pub fn assert_gt(&mut self, lhs: &[(TermId, Rational64)], rhs: Rational64, reason: TermId) {
275 if self.is_integer {
278 self.assert_ge(lhs, rhs + Rational64::one(), reason);
280 return;
281 }
282
283 let mut expr = LinExpr::new();
288
289 for (term, coef) in lhs {
290 let var = self.intern(*term);
291 expr.add_term(var, -(*coef));
293 }
294 expr.add_constant(rhs);
296
297 let reason_id = self.add_reason(reason);
303 self.simplex.add_strict_lt(expr, reason_id);
304 }
305
306 #[must_use]
313 pub fn value(&self, term: TermId) -> Option<Rational64> {
314 self.term_to_var.get(&term).map(|&var| {
315 if self.is_integer {
316 let dval = self.simplex.delta_value(var);
318
319 if dval.delta.is_positive() {
326 let real_val = dval.real;
329 if real_val.is_integer() {
330 Rational64::from_integer(real_val.to_integer() + 1)
331 } else {
332 Rational64::from_integer(real_val.ceil().to_integer())
333 }
334 } else if dval.delta.is_negative() {
335 let real_val = dval.real;
338 if real_val.is_integer() {
339 Rational64::from_integer(real_val.to_integer() - 1)
340 } else {
341 Rational64::from_integer(real_val.floor().to_integer())
342 }
343 } else {
344 dval.real
347 }
348 } else {
349 self.simplex.value(var)
351 }
352 })
353 }
354
355 #[allow(dead_code)]
363 fn tighten_bound(&self, bound: Rational64, is_upper: bool) -> Rational64 {
364 if !self.is_integer {
365 return bound;
366 }
367
368 if bound.is_integer() {
371 bound
372 } else if is_upper {
373 Rational64::from_integer(bound.floor().to_integer())
375 } else {
376 Rational64::from_integer(bound.ceil().to_integer())
378 }
379 }
380
381 pub fn tighten_constraints(&mut self) -> bool {
385 if !self.is_integer {
386 return false;
387 }
388
389 false
396 }
397}
398
399impl Theory for ArithSolver {
400 fn id(&self) -> TheoryId {
401 if self.is_integer {
402 TheoryId::LIA
403 } else {
404 TheoryId::LRA
405 }
406 }
407
408 fn name(&self) -> &str {
409 if self.is_integer { "LIA" } else { "LRA" }
410 }
411
412 fn can_handle(&self, _term: TermId) -> bool {
413 true
415 }
416
417 fn assert_true(&mut self, term: TermId) -> Result<TheoryResult> {
418 let _ = self.intern(term);
420 Ok(TheoryResult::Sat)
421 }
422
423 fn assert_false(&mut self, term: TermId) -> Result<TheoryResult> {
424 let _ = self.intern(term);
425 Ok(TheoryResult::Sat)
426 }
427
428 fn check(&mut self) -> Result<TheoryResult> {
429 match self.simplex.check() {
430 Ok(()) => Ok(TheoryResult::Sat),
431 Err(reasons) => {
432 let terms: Vec<_> = reasons
433 .iter()
434 .filter_map(|&r| self.reasons.get(r as usize).copied())
435 .collect();
436 Ok(TheoryResult::Unsat(terms))
437 }
438 }
439 }
440
441 fn push(&mut self) {
442 self.context_stack.push(ContextState {
443 num_vars: self.var_to_term.len(),
444 num_reasons: self.reasons.len(),
445 });
446 self.simplex.push();
447 }
448
449 fn pop(&mut self) {
450 if let Some(state) = self.context_stack.pop() {
451 self.var_to_term.truncate(state.num_vars);
452 self.reasons.truncate(state.num_reasons);
453 self.reason_counter = state.num_reasons as u32;
454 self.simplex.pop();
455 }
456 }
457
458 fn reset(&mut self) {
459 self.simplex.reset();
460 self.term_to_var.clear();
461 self.var_to_term.clear();
462 self.reason_counter = 0;
463 self.reasons.clear();
464 self.context_stack.clear();
465 }
466
467 fn get_model(&self) -> Vec<(TermId, TermId)> {
468 Vec::new()
471 }
472}
473
474impl TheoryCombination for ArithSolver {
475 fn notify_equality(&mut self, eq: EqualityNotification) -> bool {
476 let lhs_var = self.term_to_var.get(&eq.lhs).copied();
478 let rhs_var = self.term_to_var.get(&eq.rhs).copied();
479
480 if let (Some(_lhs), Some(_rhs)) = (lhs_var, rhs_var) {
481 let _reason = if let Some(r) = eq.reason {
494 self.add_reason(r)
495 } else {
496 self.add_reason(eq.lhs)
497 };
498
499 true
500 } else {
501 false
503 }
504 }
505
506 fn get_shared_equalities(&self) -> Vec<EqualityNotification> {
507 Vec::new()
511 }
512
513 fn is_relevant(&self, term: TermId) -> bool {
514 self.term_to_var.contains_key(&term)
516 }
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522 use num_traits::{One, Zero};
523
524 #[test]
525 fn test_arith_basic() {
526 let mut solver = ArithSolver::lra();
527
528 let x = TermId::new(1);
529 let y = TermId::new(2);
530 let reason = TermId::new(100);
531
532 solver.assert_ge(
534 &[(x, Rational64::one())],
535 Rational64::from_integer(0),
536 reason,
537 );
538
539 solver.assert_ge(
541 &[(y, Rational64::one())],
542 Rational64::from_integer(0),
543 reason,
544 );
545
546 solver.assert_le(
548 &[(x, Rational64::one()), (y, Rational64::one())],
549 Rational64::from_integer(10),
550 reason,
551 );
552
553 let result = solver.check().unwrap();
554 assert!(matches!(result, TheoryResult::Sat));
555 }
556
557 #[test]
558 fn test_arith_unsat() {
559 let mut solver = ArithSolver::lra();
560
561 let x = TermId::new(1);
562 let reason = TermId::new(100);
563
564 solver.assert_ge(
566 &[(x, Rational64::one())],
567 Rational64::from_integer(10),
568 reason,
569 );
570
571 solver.assert_le(
573 &[(x, Rational64::one())],
574 Rational64::from_integer(5),
575 reason,
576 );
577
578 let result = solver.check().unwrap();
579 assert!(matches!(result, TheoryResult::Unsat(_)));
580 }
581
582 #[test]
583 fn test_arith_strict_inequality() {
584 let mut solver = ArithSolver::lra();
585
586 let x = TermId::new(1);
587 let reason = TermId::new(100);
588
589 solver.assert_gt(
591 &[(x, Rational64::one())],
592 Rational64::from_integer(0),
593 reason,
594 );
595
596 solver.assert_lt(
598 &[(x, Rational64::one())],
599 Rational64::from_integer(10),
600 reason,
601 );
602
603 let result = solver.check().unwrap();
604 assert!(matches!(result, TheoryResult::Sat));
605 }
606
607 #[test]
608 fn test_arith_strict_unsat() {
609 let mut solver = ArithSolver::lra();
610
611 let x = TermId::new(1);
612 let reason = TermId::new(100);
613
614 solver.assert_ge(
616 &[(x, Rational64::one())],
617 Rational64::from_integer(5),
618 reason,
619 );
620
621 solver.assert_lt(
623 &[(x, Rational64::one())],
624 Rational64::from_integer(5),
625 reason,
626 );
627
628 let result = solver.check().unwrap();
629 assert!(matches!(result, TheoryResult::Unsat(_)));
630 }
631
632 #[test]
633 fn test_coefficient_normalization_lia() {
634 let mut solver = ArithSolver::lia();
635
636 let x = TermId::new(1);
637 let y = TermId::new(2);
638 let reason = TermId::new(100);
639
640 solver.assert_le(
642 &[
643 (x, Rational64::from_integer(2)),
644 (y, Rational64::from_integer(4)),
645 ],
646 Rational64::from_integer(10),
647 reason,
648 );
649
650 let result = solver.check().unwrap();
652 assert!(matches!(result, TheoryResult::Sat));
653 }
654
655 #[test]
656 fn test_coefficient_normalization_sign() {
657 let solver = ArithSolver::lra();
658
659 let _x = TermId::new(1);
660 let _y = TermId::new(2);
661
662 let mut expr = LinExpr::new();
664 expr.add_term(0, Rational64::from_integer(-3));
665 expr.add_term(1, Rational64::from_integer(2));
666
667 solver.normalize_expr(&mut expr);
668
669 if let Some((_, c)) = expr.terms.first() {
671 assert!(c > &Rational64::zero());
672 }
673 }
674
675 #[test]
676 fn test_gcd_computation() {
677 assert_eq!(gcd_i64(12, 8), 4);
678 assert_eq!(gcd_i64(15, 25), 5);
679 assert_eq!(gcd_i64(7, 13), 1);
680 assert_eq!(gcd_i64(0, 5), 5);
681 assert_eq!(gcd_i64(5, 0), 5);
682 assert_eq!(gcd_i64(-12, 8), 4);
683 assert_eq!(gcd_i64(12, -8), 4);
684 }
685
686 #[test]
687 fn test_bound_tightening_lia() {
688 let solver = ArithSolver::lia();
689
690 let tightened = solver.tighten_bound(Rational64::new(57, 10), true);
692 assert_eq!(tightened, Rational64::from_integer(5));
693
694 let tightened = solver.tighten_bound(Rational64::new(23, 10), false);
696 assert_eq!(tightened, Rational64::from_integer(3));
697
698 let tightened = solver.tighten_bound(Rational64::from_integer(5), true);
700 assert_eq!(tightened, Rational64::from_integer(5));
701 }
702
703 #[test]
704 fn test_bound_tightening_lra() {
705 let solver = ArithSolver::lra();
706
707 let bound = Rational64::new(57, 10);
709 let tightened = solver.tighten_bound(bound, true);
710 assert_eq!(tightened, bound);
711 }
712
713 #[test]
714 fn test_tighten_constraints() {
715 let mut solver_lia = ArithSolver::lia();
716 let mut solver_lra = ArithSolver::lra();
717
718 assert!(!solver_lia.tighten_constraints());
720 assert!(!solver_lra.tighten_constraints());
721 }
722
723 #[test]
726 fn test_lia_strict_inequality_empty_interval() {
727 let mut solver = ArithSolver::lia();
728
729 let x = TermId::new(1);
730 let reason = TermId::new(100);
731
732 solver.assert_gt(
734 &[(x, Rational64::one())],
735 Rational64::from_integer(5),
736 reason,
737 );
738
739 solver.assert_lt(
741 &[(x, Rational64::one())],
742 Rational64::from_integer(6),
743 reason,
744 );
745
746 let result = solver.check().unwrap();
748 assert!(
749 matches!(result, TheoryResult::Unsat(_)),
750 "Expected UNSAT for x > 5 AND x < 6 in LIA, got {:?}",
751 result
752 );
753 }
754
755 #[test]
757 fn test_lra_strict_inequality_has_solution() {
758 let mut solver = ArithSolver::lra();
759
760 let x = TermId::new(1);
761 let reason = TermId::new(100);
762
763 solver.assert_gt(
765 &[(x, Rational64::one())],
766 Rational64::from_integer(5),
767 reason,
768 );
769
770 solver.assert_lt(
772 &[(x, Rational64::one())],
773 Rational64::from_integer(6),
774 reason,
775 );
776
777 let result = solver.check().unwrap();
779 assert!(
780 matches!(result, TheoryResult::Sat),
781 "Expected SAT for x > 5 AND x < 6 in LRA, got {:?}",
782 result
783 );
784 }
785
786 #[test]
788 fn test_lia_strict_at_boundary() {
789 let mut solver = ArithSolver::lia();
790
791 let x = TermId::new(1);
792 let reason = TermId::new(100);
793
794 solver.assert_ge(
796 &[(x, Rational64::one())],
797 Rational64::from_integer(5),
798 reason,
799 );
800
801 solver.assert_lt(
803 &[(x, Rational64::one())],
804 Rational64::from_integer(6),
805 reason,
806 );
807
808 let result = solver.check().unwrap();
810 assert!(
811 matches!(result, TheoryResult::Sat),
812 "Expected SAT for x >= 5 AND x < 6 in LIA, got {:?}",
813 result
814 );
815 }
816}