oxiz_theories/arithmetic/
solver.rs1use super::simplex::{LinExpr, Simplex, VarId};
4use crate::theory::{EqualityNotification, Theory, TheoryCombination, TheoryId, TheoryResult};
5use num_rational::Rational64;
6use num_traits::{One, Signed};
7use oxiz_core::ast::TermId;
8use oxiz_core::error::Result;
9use rustc_hash::FxHashMap;
10
11fn gcd_i64(mut a: i64, mut b: i64) -> i64 {
13 a = a.abs();
14 b = b.abs();
15 while b != 0 {
16 let temp = b;
17 b = a % b;
18 a = temp;
19 }
20 a
21}
22
23#[derive(Debug)]
25pub struct ArithSolver {
26 simplex: Simplex,
28 term_to_var: FxHashMap<TermId, VarId>,
30 var_to_term: Vec<TermId>,
32 reason_counter: u32,
34 reasons: Vec<TermId>,
36 is_integer: bool,
38 context_stack: Vec<ContextState>,
40}
41
42#[derive(Debug, Clone)]
44struct ContextState {
45 num_vars: usize,
46 num_reasons: usize,
47}
48
49impl Default for ArithSolver {
50 fn default() -> Self {
51 Self::new(false)
52 }
53}
54
55impl ArithSolver {
56 #[must_use]
58 pub fn new(is_integer: bool) -> Self {
59 Self {
60 simplex: Simplex::new(),
61 term_to_var: FxHashMap::default(),
62 var_to_term: Vec::new(),
63 reason_counter: 0,
64 reasons: Vec::new(),
65 is_integer,
66 context_stack: Vec::new(),
67 }
68 }
69
70 #[must_use]
72 pub fn lra() -> Self {
73 Self::new(false)
74 }
75
76 #[must_use]
78 pub fn lia() -> Self {
79 Self::new(true)
80 }
81
82 pub fn intern(&mut self, term: TermId) -> VarId {
84 if let Some(&var) = self.term_to_var.get(&term) {
85 return var;
86 }
87
88 let var = self.simplex.new_var();
89 self.term_to_var.insert(term, var);
90 self.var_to_term.push(term);
91 var
92 }
93
94 fn add_reason(&mut self, term: TermId) -> u32 {
96 let id = self.reason_counter;
97 self.reason_counter += 1;
98 self.reasons.push(term);
99 id
100 }
101
102 fn normalize_expr(&self, expr: &mut LinExpr) {
109 if expr.terms.is_empty() {
110 return;
111 }
112
113 if self.is_integer {
115 let gcd = expr
117 .terms
118 .iter()
119 .map(|(_, c)| c.numer().abs())
120 .fold(0i64, |acc, n| if acc == 0 { n } else { gcd_i64(acc, n) });
121
122 if gcd > 1 {
123 let divisor = Rational64::from_integer(gcd);
124 expr.scale(Rational64::one() / divisor);
125 }
126 }
127
128 if let Some((_, c)) = expr.terms.first()
130 && c.is_negative()
131 {
132 expr.negate();
133 }
134
135 expr.terms.sort_by_key(|(v, _)| *v);
137 }
138
139 pub fn assert_le(&mut self, lhs: &[(TermId, Rational64)], rhs: Rational64, reason: TermId) {
141 let mut expr = LinExpr::new();
142
143 for (term, coef) in lhs {
144 let var = self.intern(*term);
145 expr.add_term(var, *coef);
146 }
147 expr.add_constant(-rhs);
148
149 self.normalize_expr(&mut expr);
151
152 let reason_id = self.add_reason(reason);
153 self.simplex.add_le(expr, reason_id);
154 }
155
156 pub fn assert_ge(&mut self, lhs: &[(TermId, Rational64)], rhs: Rational64, reason: TermId) {
158 let mut expr = LinExpr::new();
159
160 for (term, coef) in lhs {
161 let var = self.intern(*term);
162 expr.add_term(var, *coef);
163 }
164 expr.add_constant(-rhs);
165
166 self.normalize_expr(&mut expr);
168
169 let reason_id = self.add_reason(reason);
170 self.simplex.add_ge(expr, reason_id);
171 }
172
173 pub fn assert_eq(&mut self, lhs: &[(TermId, Rational64)], rhs: Rational64, reason: TermId) {
175 let mut expr = LinExpr::new();
176
177 for (term, coef) in lhs {
178 let var = self.intern(*term);
179 expr.add_term(var, *coef);
180 }
181 expr.add_constant(-rhs);
182
183 self.normalize_expr(&mut expr);
185
186 let reason_id = self.add_reason(reason);
187 self.simplex.add_eq(expr, reason_id);
188 }
189
190 pub fn assert_lt(&mut self, lhs: &[(TermId, Rational64)], rhs: Rational64, reason: TermId) {
193 let mut expr = LinExpr::new();
195
196 for (term, coef) in lhs {
197 let var = self.intern(*term);
198 expr.add_term(var, *coef);
199 }
200 expr.add_constant(-rhs);
201
202 let reason_id = self.add_reason(reason);
207 self.simplex.add_strict_lt(expr, reason_id);
208 }
209
210 pub fn assert_gt(&mut self, lhs: &[(TermId, Rational64)], rhs: Rational64, reason: TermId) {
213 let mut expr = LinExpr::new();
217
218 for (term, coef) in lhs {
219 let var = self.intern(*term);
220 expr.add_term(var, -(*coef));
222 }
223 expr.add_constant(rhs);
225
226 let reason_id = self.add_reason(reason);
232 self.simplex.add_strict_lt(expr, reason_id);
233 }
234
235 #[must_use]
242 pub fn value(&self, term: TermId) -> Option<Rational64> {
243 self.term_to_var.get(&term).map(|&var| {
244 if self.is_integer {
245 let dval = self.simplex.delta_value(var);
247
248 if dval.delta.is_positive() {
255 let real_val = dval.real;
258 if real_val.is_integer() {
259 Rational64::from_integer(real_val.to_integer() + 1)
260 } else {
261 Rational64::from_integer(real_val.ceil().to_integer())
262 }
263 } else if dval.delta.is_negative() {
264 let real_val = dval.real;
267 if real_val.is_integer() {
268 Rational64::from_integer(real_val.to_integer() - 1)
269 } else {
270 Rational64::from_integer(real_val.floor().to_integer())
271 }
272 } else {
273 dval.real
276 }
277 } else {
278 self.simplex.value(var)
280 }
281 })
282 }
283
284 #[allow(dead_code)]
292 fn tighten_bound(&self, bound: Rational64, is_upper: bool) -> Rational64 {
293 if !self.is_integer {
294 return bound;
295 }
296
297 if bound.is_integer() {
300 bound
301 } else if is_upper {
302 Rational64::from_integer(bound.floor().to_integer())
304 } else {
305 Rational64::from_integer(bound.ceil().to_integer())
307 }
308 }
309
310 pub fn tighten_constraints(&mut self) -> bool {
314 if !self.is_integer {
315 return false;
316 }
317
318 false
325 }
326}
327
328impl Theory for ArithSolver {
329 fn id(&self) -> TheoryId {
330 if self.is_integer {
331 TheoryId::LIA
332 } else {
333 TheoryId::LRA
334 }
335 }
336
337 fn name(&self) -> &str {
338 if self.is_integer { "LIA" } else { "LRA" }
339 }
340
341 fn can_handle(&self, _term: TermId) -> bool {
342 true
344 }
345
346 fn assert_true(&mut self, term: TermId) -> Result<TheoryResult> {
347 let _ = self.intern(term);
349 Ok(TheoryResult::Sat)
350 }
351
352 fn assert_false(&mut self, term: TermId) -> Result<TheoryResult> {
353 let _ = self.intern(term);
354 Ok(TheoryResult::Sat)
355 }
356
357 fn check(&mut self) -> Result<TheoryResult> {
358 match self.simplex.check() {
359 Ok(()) => Ok(TheoryResult::Sat),
360 Err(reasons) => {
361 let terms: Vec<_> = reasons
362 .iter()
363 .filter_map(|&r| self.reasons.get(r as usize).copied())
364 .collect();
365 Ok(TheoryResult::Unsat(terms))
366 }
367 }
368 }
369
370 fn push(&mut self) {
371 self.context_stack.push(ContextState {
372 num_vars: self.var_to_term.len(),
373 num_reasons: self.reasons.len(),
374 });
375 self.simplex.push();
376 }
377
378 fn pop(&mut self) {
379 if let Some(state) = self.context_stack.pop() {
380 self.var_to_term.truncate(state.num_vars);
381 self.reasons.truncate(state.num_reasons);
382 self.reason_counter = state.num_reasons as u32;
383 self.simplex.pop();
384 }
385 }
386
387 fn reset(&mut self) {
388 self.simplex.reset();
389 self.term_to_var.clear();
390 self.var_to_term.clear();
391 self.reason_counter = 0;
392 self.reasons.clear();
393 self.context_stack.clear();
394 }
395
396 fn get_model(&self) -> Vec<(TermId, TermId)> {
397 Vec::new()
400 }
401}
402
403impl TheoryCombination for ArithSolver {
404 fn notify_equality(&mut self, eq: EqualityNotification) -> bool {
405 let lhs_var = self.term_to_var.get(&eq.lhs).copied();
407 let rhs_var = self.term_to_var.get(&eq.rhs).copied();
408
409 if let (Some(_lhs), Some(_rhs)) = (lhs_var, rhs_var) {
410 let _reason = if let Some(r) = eq.reason {
423 self.add_reason(r)
424 } else {
425 self.add_reason(eq.lhs)
426 };
427
428 true
429 } else {
430 false
432 }
433 }
434
435 fn get_shared_equalities(&self) -> Vec<EqualityNotification> {
436 Vec::new()
440 }
441
442 fn is_relevant(&self, term: TermId) -> bool {
443 self.term_to_var.contains_key(&term)
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use num_traits::{One, Zero};
452
453 #[test]
454 fn test_arith_basic() {
455 let mut solver = ArithSolver::lra();
456
457 let x = TermId::new(1);
458 let y = TermId::new(2);
459 let reason = TermId::new(100);
460
461 solver.assert_ge(
463 &[(x, Rational64::one())],
464 Rational64::from_integer(0),
465 reason,
466 );
467
468 solver.assert_ge(
470 &[(y, Rational64::one())],
471 Rational64::from_integer(0),
472 reason,
473 );
474
475 solver.assert_le(
477 &[(x, Rational64::one()), (y, Rational64::one())],
478 Rational64::from_integer(10),
479 reason,
480 );
481
482 let result = solver.check().unwrap();
483 assert!(matches!(result, TheoryResult::Sat));
484 }
485
486 #[test]
487 fn test_arith_unsat() {
488 let mut solver = ArithSolver::lra();
489
490 let x = TermId::new(1);
491 let reason = TermId::new(100);
492
493 solver.assert_ge(
495 &[(x, Rational64::one())],
496 Rational64::from_integer(10),
497 reason,
498 );
499
500 solver.assert_le(
502 &[(x, Rational64::one())],
503 Rational64::from_integer(5),
504 reason,
505 );
506
507 let result = solver.check().unwrap();
508 assert!(matches!(result, TheoryResult::Unsat(_)));
509 }
510
511 #[test]
512 fn test_arith_strict_inequality() {
513 let mut solver = ArithSolver::lra();
514
515 let x = TermId::new(1);
516 let reason = TermId::new(100);
517
518 solver.assert_gt(
520 &[(x, Rational64::one())],
521 Rational64::from_integer(0),
522 reason,
523 );
524
525 solver.assert_lt(
527 &[(x, Rational64::one())],
528 Rational64::from_integer(10),
529 reason,
530 );
531
532 let result = solver.check().unwrap();
533 assert!(matches!(result, TheoryResult::Sat));
534 }
535
536 #[test]
537 fn test_arith_strict_unsat() {
538 let mut solver = ArithSolver::lra();
539
540 let x = TermId::new(1);
541 let reason = TermId::new(100);
542
543 solver.assert_ge(
545 &[(x, Rational64::one())],
546 Rational64::from_integer(5),
547 reason,
548 );
549
550 solver.assert_lt(
552 &[(x, Rational64::one())],
553 Rational64::from_integer(5),
554 reason,
555 );
556
557 let result = solver.check().unwrap();
558 assert!(matches!(result, TheoryResult::Unsat(_)));
559 }
560
561 #[test]
562 fn test_coefficient_normalization_lia() {
563 let mut solver = ArithSolver::lia();
564
565 let x = TermId::new(1);
566 let y = TermId::new(2);
567 let reason = TermId::new(100);
568
569 solver.assert_le(
571 &[
572 (x, Rational64::from_integer(2)),
573 (y, Rational64::from_integer(4)),
574 ],
575 Rational64::from_integer(10),
576 reason,
577 );
578
579 let result = solver.check().unwrap();
581 assert!(matches!(result, TheoryResult::Sat));
582 }
583
584 #[test]
585 fn test_coefficient_normalization_sign() {
586 let solver = ArithSolver::lra();
587
588 let _x = TermId::new(1);
589 let _y = TermId::new(2);
590
591 let mut expr = LinExpr::new();
593 expr.add_term(0, Rational64::from_integer(-3));
594 expr.add_term(1, Rational64::from_integer(2));
595
596 solver.normalize_expr(&mut expr);
597
598 if let Some((_, c)) = expr.terms.first() {
600 assert!(c > &Rational64::zero());
601 }
602 }
603
604 #[test]
605 fn test_gcd_computation() {
606 assert_eq!(gcd_i64(12, 8), 4);
607 assert_eq!(gcd_i64(15, 25), 5);
608 assert_eq!(gcd_i64(7, 13), 1);
609 assert_eq!(gcd_i64(0, 5), 5);
610 assert_eq!(gcd_i64(5, 0), 5);
611 assert_eq!(gcd_i64(-12, 8), 4);
612 assert_eq!(gcd_i64(12, -8), 4);
613 }
614
615 #[test]
616 fn test_bound_tightening_lia() {
617 let solver = ArithSolver::lia();
618
619 let tightened = solver.tighten_bound(Rational64::new(57, 10), true);
621 assert_eq!(tightened, Rational64::from_integer(5));
622
623 let tightened = solver.tighten_bound(Rational64::new(23, 10), false);
625 assert_eq!(tightened, Rational64::from_integer(3));
626
627 let tightened = solver.tighten_bound(Rational64::from_integer(5), true);
629 assert_eq!(tightened, Rational64::from_integer(5));
630 }
631
632 #[test]
633 fn test_bound_tightening_lra() {
634 let solver = ArithSolver::lra();
635
636 let bound = Rational64::new(57, 10);
638 let tightened = solver.tighten_bound(bound, true);
639 assert_eq!(tightened, bound);
640 }
641
642 #[test]
643 fn test_tighten_constraints() {
644 let mut solver_lia = ArithSolver::lia();
645 let mut solver_lra = ArithSolver::lra();
646
647 assert!(!solver_lia.tighten_constraints());
649 assert!(!solver_lra.tighten_constraints());
650 }
651}