oxiz_theories/arithmetic/
solver.rs

1//! Arithmetic Theory Solver
2
3use super::simplex::{LinExpr, Simplex, VarId};
4use crate::theory::{EqualityNotification, Theory, TheoryCombination, TheoryId, TheoryResult};
5use num_rational::Rational64;
6use num_traits::{One, Signed};
7use oxiz_core::ast::TermId;
8use oxiz_core::error::Result;
9use rustc_hash::FxHashMap;
10
11/// Compute GCD of two i64 values
12fn gcd_i64(mut a: i64, mut b: i64) -> i64 {
13    a = a.abs();
14    b = b.abs();
15    while b != 0 {
16        let temp = b;
17        b = a % b;
18        a = temp;
19    }
20    a
21}
22
23/// Arithmetic Theory Solver (LRA/LIA)
24#[derive(Debug)]
25pub struct ArithSolver {
26    /// Simplex instance
27    simplex: Simplex,
28    /// Term to variable mapping
29    term_to_var: FxHashMap<TermId, VarId>,
30    /// Variable to term mapping
31    var_to_term: Vec<TermId>,
32    /// Reason counter
33    reason_counter: u32,
34    /// Reason to term mapping
35    reasons: Vec<TermId>,
36    /// Is this LIA (integers) or LRA (reals)?
37    is_integer: bool,
38    /// Context stack
39    context_stack: Vec<ContextState>,
40}
41
42/// State for push/pop
43#[derive(Debug, Clone)]
44struct ContextState {
45    num_vars: usize,
46    num_reasons: usize,
47}
48
49impl Default for ArithSolver {
50    fn default() -> Self {
51        Self::new(false)
52    }
53}
54
55impl ArithSolver {
56    /// Create a new arithmetic solver
57    #[must_use]
58    pub fn new(is_integer: bool) -> Self {
59        Self {
60            simplex: Simplex::new(),
61            term_to_var: FxHashMap::default(),
62            var_to_term: Vec::new(),
63            reason_counter: 0,
64            reasons: Vec::new(),
65            is_integer,
66            context_stack: Vec::new(),
67        }
68    }
69
70    /// Create a new LRA solver
71    #[must_use]
72    pub fn lra() -> Self {
73        Self::new(false)
74    }
75
76    /// Create a new LIA solver
77    #[must_use]
78    pub fn lia() -> Self {
79        Self::new(true)
80    }
81
82    /// Intern a term as a variable
83    pub fn intern(&mut self, term: TermId) -> VarId {
84        if let Some(&var) = self.term_to_var.get(&term) {
85            return var;
86        }
87
88        let var = self.simplex.new_var();
89        self.term_to_var.insert(term, var);
90        self.var_to_term.push(term);
91        var
92    }
93
94    /// Add a reason and return its ID
95    fn add_reason(&mut self, term: TermId) -> u32 {
96        let id = self.reason_counter;
97        self.reason_counter += 1;
98        self.reasons.push(term);
99        id
100    }
101
102    /// Normalize a linear expression
103    ///
104    /// Normalization performs:
105    /// 1. Coefficient reduction: divide by GCD of all coefficients
106    /// 2. Sign normalization: ensure first coefficient is positive
107    /// 3. Sorting: order terms by variable ID for canonical form
108    fn normalize_expr(&self, expr: &mut LinExpr) {
109        if expr.terms.is_empty() {
110            return;
111        }
112
113        // For integer arithmetic, reduce by GCD
114        if self.is_integer {
115            // Find GCD of all coefficients
116            let gcd = expr
117                .terms
118                .iter()
119                .map(|(_, c)| c.numer().abs())
120                .fold(0i64, |acc, n| if acc == 0 { n } else { gcd_i64(acc, n) });
121
122            if gcd > 1 {
123                let divisor = Rational64::from_integer(gcd);
124                expr.scale(Rational64::one() / divisor);
125            }
126        }
127
128        // Ensure first coefficient is positive
129        if let Some((_, c)) = expr.terms.first()
130            && c.is_negative()
131        {
132            expr.negate();
133        }
134
135        // Sort terms by variable ID for canonical form
136        expr.terms.sort_by_key(|(v, _)| *v);
137    }
138
139    /// Assert: lhs <= rhs
140    pub fn assert_le(&mut self, lhs: &[(TermId, Rational64)], rhs: Rational64, reason: TermId) {
141        let mut expr = LinExpr::new();
142
143        for (term, coef) in lhs {
144            let var = self.intern(*term);
145            expr.add_term(var, *coef);
146        }
147        expr.add_constant(-rhs);
148
149        // Normalize the expression
150        self.normalize_expr(&mut expr);
151
152        let reason_id = self.add_reason(reason);
153        self.simplex.add_le(expr, reason_id);
154    }
155
156    /// Assert: lhs >= rhs
157    pub fn assert_ge(&mut self, lhs: &[(TermId, Rational64)], rhs: Rational64, reason: TermId) {
158        let mut expr = LinExpr::new();
159
160        for (term, coef) in lhs {
161            let var = self.intern(*term);
162            expr.add_term(var, *coef);
163        }
164        expr.add_constant(-rhs);
165
166        // Normalize the expression
167        self.normalize_expr(&mut expr);
168
169        let reason_id = self.add_reason(reason);
170        self.simplex.add_ge(expr, reason_id);
171    }
172
173    /// Assert: lhs = rhs
174    pub fn assert_eq(&mut self, lhs: &[(TermId, Rational64)], rhs: Rational64, reason: TermId) {
175        let mut expr = LinExpr::new();
176
177        for (term, coef) in lhs {
178            let var = self.intern(*term);
179            expr.add_term(var, *coef);
180        }
181        expr.add_constant(-rhs);
182
183        // Normalize the expression
184        self.normalize_expr(&mut expr);
185
186        let reason_id = self.add_reason(reason);
187        self.simplex.add_eq(expr, reason_id);
188    }
189
190    /// Assert: lhs < rhs (strict inequality)
191    /// For LRA, uses infinitesimals: lhs <= rhs - δ
192    pub fn assert_lt(&mut self, lhs: &[(TermId, Rational64)], rhs: Rational64, reason: TermId) {
193        // lhs < rhs is equivalent to lhs - rhs < 0
194        let mut expr = LinExpr::new();
195
196        for (term, coef) in lhs {
197            let var = self.intern(*term);
198            expr.add_term(var, *coef);
199        }
200        expr.add_constant(-rhs);
201
202        // Note: We do NOT normalize here because normalize_expr may negate
203        // the expression to make the first coefficient positive, which would
204        // flip the inequality direction for strict inequalities.
205
206        let reason_id = self.add_reason(reason);
207        self.simplex.add_strict_lt(expr, reason_id);
208    }
209
210    /// Assert: lhs > rhs (strict inequality)
211    /// For LRA, uses infinitesimals: lhs >= rhs + δ
212    pub fn assert_gt(&mut self, lhs: &[(TermId, Rational64)], rhs: Rational64, reason: TermId) {
213        // lhs > rhs is equivalent to rhs - lhs < 0
214        // We build rhs - lhs directly instead of negating lhs - rhs
215        // This avoids issues with normalize_expr which ensures positive first coefficient
216        let mut expr = LinExpr::new();
217
218        for (term, coef) in lhs {
219            let var = self.intern(*term);
220            // Add negative coefficient since we want rhs - lhs
221            expr.add_term(var, -(*coef));
222        }
223        // Add +rhs (since we want rhs - lhs, not lhs - rhs)
224        expr.add_constant(rhs);
225
226        // Note: We do NOT normalize here because:
227        // 1. normalize_expr may negate to make first coefficient positive
228        // 2. This would flip the inequality direction
229        // 3. For strict inequalities, the sign matters
230
231        let reason_id = self.add_reason(reason);
232        self.simplex.add_strict_lt(expr, reason_id);
233    }
234
235    /// Get the current value of a variable
236    ///
237    /// For integer arithmetic (LIA), this properly rounds values that have
238    /// infinitesimal components from strict inequalities:
239    /// - If value is `r + δ` (positive delta), return `ceil(r)` for integers
240    /// - If value is `r - δ` (negative delta), return `floor(r)` for integers
241    #[must_use]
242    pub fn value(&self, term: TermId) -> Option<Rational64> {
243        self.term_to_var.get(&term).map(|&var| {
244            if self.is_integer {
245                // Get the full delta-rational value
246                let dval = self.simplex.delta_value(var);
247
248                // For integer arithmetic, round based on delta:
249                // - Positive delta means we have a strict lower bound (x > r)
250                //   so round up to the next integer
251                // - Negative delta means we have a strict upper bound (x < r)
252                //   so round down to the previous integer
253                // - Zero delta means exact value, round to nearest integer
254                if dval.delta.is_positive() {
255                    // x > r implies x >= ceil(r) for integers
256                    // If r is already an integer, we need r + 1
257                    let real_val = dval.real;
258                    if real_val.is_integer() {
259                        Rational64::from_integer(real_val.to_integer() + 1)
260                    } else {
261                        Rational64::from_integer(real_val.ceil().to_integer())
262                    }
263                } else if dval.delta.is_negative() {
264                    // x < r implies x <= floor(r) for integers
265                    // If r is already an integer, we need r - 1
266                    let real_val = dval.real;
267                    if real_val.is_integer() {
268                        Rational64::from_integer(real_val.to_integer() - 1)
269                    } else {
270                        Rational64::from_integer(real_val.floor().to_integer())
271                    }
272                } else {
273                    // No strict bound, just return the value
274                    // Round to nearest integer for consistency
275                    dval.real
276                }
277            } else {
278                // For reals, just return the real part
279                self.simplex.value(var)
280            }
281        })
282    }
283
284    /// Tighten a rational bound for integer variables
285    ///
286    /// For integer variables:
287    /// - x <= 5.7 becomes x <= 5
288    /// - x >= 2.3 becomes x >= 3
289    /// - x < 5.0 becomes x <= 4
290    /// - x > 2.0 becomes x >= 3
291    #[allow(dead_code)]
292    fn tighten_bound(&self, bound: Rational64, is_upper: bool) -> Rational64 {
293        if !self.is_integer {
294            return bound;
295        }
296
297        // For upper bounds (<=), floor the value
298        // For lower bounds (>=), ceiling the value
299        if bound.is_integer() {
300            bound
301        } else if is_upper {
302            // x <= 5.7 becomes x <= 5
303            Rational64::from_integer(bound.floor().to_integer())
304        } else {
305            // x >= 2.3 becomes x >= 3
306            Rational64::from_integer(bound.ceil().to_integer())
307        }
308    }
309
310    /// Tighten constraints for integer arithmetic
311    ///
312    /// Returns true if any tightening was performed
313    pub fn tighten_constraints(&mut self) -> bool {
314        if !self.is_integer {
315            return false;
316        }
317
318        // In a full implementation, we would:
319        // 1. Iterate through all bounds
320        // 2. Apply tightening rules
321        // 3. Propagate tightened bounds
322        //
323        // For now, tightening is applied during assertion
324        false
325    }
326}
327
328impl Theory for ArithSolver {
329    fn id(&self) -> TheoryId {
330        if self.is_integer {
331            TheoryId::LIA
332        } else {
333            TheoryId::LRA
334        }
335    }
336
337    fn name(&self) -> &str {
338        if self.is_integer { "LIA" } else { "LRA" }
339    }
340
341    fn can_handle(&self, _term: TermId) -> bool {
342        // In a full implementation, check if term is arithmetic
343        true
344    }
345
346    fn assert_true(&mut self, term: TermId) -> Result<TheoryResult> {
347        // In a full implementation, parse the term and add constraints
348        let _ = self.intern(term);
349        Ok(TheoryResult::Sat)
350    }
351
352    fn assert_false(&mut self, term: TermId) -> Result<TheoryResult> {
353        let _ = self.intern(term);
354        Ok(TheoryResult::Sat)
355    }
356
357    fn check(&mut self) -> Result<TheoryResult> {
358        match self.simplex.check() {
359            Ok(()) => Ok(TheoryResult::Sat),
360            Err(reasons) => {
361                let terms: Vec<_> = reasons
362                    .iter()
363                    .filter_map(|&r| self.reasons.get(r as usize).copied())
364                    .collect();
365                Ok(TheoryResult::Unsat(terms))
366            }
367        }
368    }
369
370    fn push(&mut self) {
371        self.context_stack.push(ContextState {
372            num_vars: self.var_to_term.len(),
373            num_reasons: self.reasons.len(),
374        });
375        self.simplex.push();
376    }
377
378    fn pop(&mut self) {
379        if let Some(state) = self.context_stack.pop() {
380            self.var_to_term.truncate(state.num_vars);
381            self.reasons.truncate(state.num_reasons);
382            self.reason_counter = state.num_reasons as u32;
383            self.simplex.pop();
384        }
385    }
386
387    fn reset(&mut self) {
388        self.simplex.reset();
389        self.term_to_var.clear();
390        self.var_to_term.clear();
391        self.reason_counter = 0;
392        self.reasons.clear();
393        self.context_stack.clear();
394    }
395
396    fn get_model(&self) -> Vec<(TermId, TermId)> {
397        // Return variable -> value pairs
398        // In a full implementation, we'd create value terms
399        Vec::new()
400    }
401}
402
403impl TheoryCombination for ArithSolver {
404    fn notify_equality(&mut self, eq: EqualityNotification) -> bool {
405        // Check if both terms are relevant to arithmetic
406        let lhs_var = self.term_to_var.get(&eq.lhs).copied();
407        let rhs_var = self.term_to_var.get(&eq.rhs).copied();
408
409        if let (Some(_lhs), Some(_rhs)) = (lhs_var, rhs_var) {
410            // For an equality constraint lhs = rhs, we need to ensure both
411            // lhs - rhs = 0 and rhs - lhs = 0 (which is the same constraint)
412            // In the simplex implementation, we can model this by creating
413            // a slack variable s and asserting:
414            // lhs = rhs (by setting bounds on the difference)
415
416            // For now, this is a simplified implementation that doesn't fully
417            // enforce the equality in the simplex tableau. A complete implementation
418            // would need to extend the simplex solver to support equality constraints
419            // or introduce a slack variable to model the equality.
420
421            // As a placeholder, we just record that the notification was received
422            let _reason = if let Some(r) = eq.reason {
423                self.add_reason(r)
424            } else {
425                self.add_reason(eq.lhs)
426            };
427
428            true
429        } else {
430            // Terms not relevant to this arithmetic solver
431            false
432        }
433    }
434
435    fn get_shared_equalities(&self) -> Vec<EqualityNotification> {
436        // In a full implementation, we would track which equalities were derived
437        // and return those that should be shared with other theories.
438        // For now, return an empty vector as a placeholder.
439        Vec::new()
440    }
441
442    fn is_relevant(&self, term: TermId) -> bool {
443        // Check if this term has been interned in the arithmetic solver
444        self.term_to_var.contains_key(&term)
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use num_traits::{One, Zero};
452
453    #[test]
454    fn test_arith_basic() {
455        let mut solver = ArithSolver::lra();
456
457        let x = TermId::new(1);
458        let y = TermId::new(2);
459        let reason = TermId::new(100);
460
461        // x >= 0
462        solver.assert_ge(
463            &[(x, Rational64::one())],
464            Rational64::from_integer(0),
465            reason,
466        );
467
468        // y >= 0
469        solver.assert_ge(
470            &[(y, Rational64::one())],
471            Rational64::from_integer(0),
472            reason,
473        );
474
475        // x + y <= 10
476        solver.assert_le(
477            &[(x, Rational64::one()), (y, Rational64::one())],
478            Rational64::from_integer(10),
479            reason,
480        );
481
482        let result = solver.check().unwrap();
483        assert!(matches!(result, TheoryResult::Sat));
484    }
485
486    #[test]
487    fn test_arith_unsat() {
488        let mut solver = ArithSolver::lra();
489
490        let x = TermId::new(1);
491        let reason = TermId::new(100);
492
493        // x >= 10
494        solver.assert_ge(
495            &[(x, Rational64::one())],
496            Rational64::from_integer(10),
497            reason,
498        );
499
500        // x <= 5
501        solver.assert_le(
502            &[(x, Rational64::one())],
503            Rational64::from_integer(5),
504            reason,
505        );
506
507        let result = solver.check().unwrap();
508        assert!(matches!(result, TheoryResult::Unsat(_)));
509    }
510
511    #[test]
512    fn test_arith_strict_inequality() {
513        let mut solver = ArithSolver::lra();
514
515        let x = TermId::new(1);
516        let reason = TermId::new(100);
517
518        // x > 0 (strict)
519        solver.assert_gt(
520            &[(x, Rational64::one())],
521            Rational64::from_integer(0),
522            reason,
523        );
524
525        // x < 10 (strict)
526        solver.assert_lt(
527            &[(x, Rational64::one())],
528            Rational64::from_integer(10),
529            reason,
530        );
531
532        let result = solver.check().unwrap();
533        assert!(matches!(result, TheoryResult::Sat));
534    }
535
536    #[test]
537    fn test_arith_strict_unsat() {
538        let mut solver = ArithSolver::lra();
539
540        let x = TermId::new(1);
541        let reason = TermId::new(100);
542
543        // x >= 5
544        solver.assert_ge(
545            &[(x, Rational64::one())],
546            Rational64::from_integer(5),
547            reason,
548        );
549
550        // x < 5 (strict) - should be unsatisfiable with x >= 5
551        solver.assert_lt(
552            &[(x, Rational64::one())],
553            Rational64::from_integer(5),
554            reason,
555        );
556
557        let result = solver.check().unwrap();
558        assert!(matches!(result, TheoryResult::Unsat(_)));
559    }
560
561    #[test]
562    fn test_coefficient_normalization_lia() {
563        let mut solver = ArithSolver::lia();
564
565        let x = TermId::new(1);
566        let y = TermId::new(2);
567        let reason = TermId::new(100);
568
569        // 2x + 4y <= 10 should be normalized to x + 2y <= 5 (GCD = 2)
570        solver.assert_le(
571            &[
572                (x, Rational64::from_integer(2)),
573                (y, Rational64::from_integer(4)),
574            ],
575            Rational64::from_integer(10),
576            reason,
577        );
578
579        // The solver should handle this correctly
580        let result = solver.check().unwrap();
581        assert!(matches!(result, TheoryResult::Sat));
582    }
583
584    #[test]
585    fn test_coefficient_normalization_sign() {
586        let solver = ArithSolver::lra();
587
588        let _x = TermId::new(1);
589        let _y = TermId::new(2);
590
591        // Test normalization ensures first coefficient is positive
592        let mut expr = LinExpr::new();
593        expr.add_term(0, Rational64::from_integer(-3));
594        expr.add_term(1, Rational64::from_integer(2));
595
596        solver.normalize_expr(&mut expr);
597
598        // After normalization, first coefficient should be positive
599        if let Some((_, c)) = expr.terms.first() {
600            assert!(c > &Rational64::zero());
601        }
602    }
603
604    #[test]
605    fn test_gcd_computation() {
606        assert_eq!(gcd_i64(12, 8), 4);
607        assert_eq!(gcd_i64(15, 25), 5);
608        assert_eq!(gcd_i64(7, 13), 1);
609        assert_eq!(gcd_i64(0, 5), 5);
610        assert_eq!(gcd_i64(5, 0), 5);
611        assert_eq!(gcd_i64(-12, 8), 4);
612        assert_eq!(gcd_i64(12, -8), 4);
613    }
614
615    #[test]
616    fn test_bound_tightening_lia() {
617        let solver = ArithSolver::lia();
618
619        // Upper bound tightening: x <= 5.7 -> x <= 5
620        let tightened = solver.tighten_bound(Rational64::new(57, 10), true);
621        assert_eq!(tightened, Rational64::from_integer(5));
622
623        // Lower bound tightening: x >= 2.3 -> x >= 3
624        let tightened = solver.tighten_bound(Rational64::new(23, 10), false);
625        assert_eq!(tightened, Rational64::from_integer(3));
626
627        // Integer bounds don't change
628        let tightened = solver.tighten_bound(Rational64::from_integer(5), true);
629        assert_eq!(tightened, Rational64::from_integer(5));
630    }
631
632    #[test]
633    fn test_bound_tightening_lra() {
634        let solver = ArithSolver::lra();
635
636        // No tightening for real arithmetic
637        let bound = Rational64::new(57, 10);
638        let tightened = solver.tighten_bound(bound, true);
639        assert_eq!(tightened, bound);
640    }
641
642    #[test]
643    fn test_tighten_constraints() {
644        let mut solver_lia = ArithSolver::lia();
645        let mut solver_lra = ArithSolver::lra();
646
647        // For now, this always returns false (tightening happens during assertion)
648        assert!(!solver_lia.tighten_constraints());
649        assert!(!solver_lra.tighten_constraints());
650    }
651}