Skip to main content

logicaffeine_kernel/
omega.rs

1//! Omega Test: True Integer Arithmetic Decision Procedure
2//!
3//! This module implements the Omega test for linear integer arithmetic,
4//! handling the discrete nature of integers correctly.
5//!
6//! # Difference from LIA
7//!
8//! Unlike [`crate::lia`] (which uses rational arithmetic), this module
9//! handles integers with proper semantics:
10//!
11//! - `x > 1` becomes `x >= 2` (strict to non-strict for integers)
12//! - `3x <= 10` implies `x <= 3` (integer division with floor)
13//! - `2x = 5` is unsatisfiable (odd number cannot equal even expression)
14//!
15//! # Algorithm
16//!
17//! The algorithm is similar to Fourier-Motzkin elimination but with
18//! integer-aware semantics:
19//!
20//! 1. **Normalize**: Scale constraints and normalize by GCD
21//! 2. **Convert strict**: Transform `<` to `<=` using integer shift
22//! 3. **Eliminate**: Fourier-Motzkin with integer coefficient handling
23//! 4. **Check**: Verify constant constraints for contradictions
24//!
25//! # When to Use
26//!
27//! Use omega when you need exact integer semantics. Use lia when
28//! rational arithmetic is acceptable (faster but may miss integer-specific
29//! unsatisfiability).
30
31use std::collections::{BTreeMap, HashSet};
32
33use crate::term::{Literal, Term};
34
35/// Integer linear expression of the form c + a₁x₁ + a₂x₂ + ... + aₙxₙ.
36///
37/// Similar to [`crate::lia::LinearExpr`] but uses integer coefficients
38/// instead of rationals for exact integer arithmetic.
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct IntExpr {
41    /// The constant term c.
42    pub constant: i64,
43    /// Maps variable index to its integer coefficient (sparse representation).
44    pub coeffs: BTreeMap<i64, i64>,
45}
46
47impl IntExpr {
48    /// Create a constant expression
49    pub fn constant(c: i64) -> Self {
50        IntExpr {
51            constant: c,
52            coeffs: BTreeMap::new(),
53        }
54    }
55
56    /// Create a single variable expression: 1*x_idx + 0
57    pub fn var(idx: i64) -> Self {
58        let mut coeffs = BTreeMap::new();
59        coeffs.insert(idx, 1);
60        IntExpr {
61            constant: 0,
62            coeffs,
63        }
64    }
65
66    /// Add two expressions
67    pub fn add(&self, other: &Self) -> Self {
68        let mut result = self.clone();
69        result.constant += other.constant;
70        for (&v, &c) in &other.coeffs {
71            let entry = result.coeffs.entry(v).or_insert(0);
72            *entry += c;
73            if *entry == 0 {
74                result.coeffs.remove(&v);
75            }
76        }
77        result
78    }
79
80    /// Negate an expression
81    pub fn neg(&self) -> Self {
82        IntExpr {
83            constant: -self.constant,
84            coeffs: self.coeffs.iter().map(|(&v, &c)| (v, -c)).collect(),
85        }
86    }
87
88    /// Subtract two expressions
89    pub fn sub(&self, other: &Self) -> Self {
90        self.add(&other.neg())
91    }
92
93    /// Scale by an integer constant
94    pub fn scale(&self, k: i64) -> Self {
95        if k == 0 {
96            return IntExpr::constant(0);
97        }
98        IntExpr {
99            constant: self.constant * k,
100            coeffs: self
101                .coeffs
102                .iter()
103                .map(|(&v, &c)| (v, c * k))
104                .filter(|(_, c)| *c != 0)
105                .collect(),
106        }
107    }
108
109    /// Check if this is a constant expression (no variables)
110    pub fn is_constant(&self) -> bool {
111        self.coeffs.is_empty()
112    }
113
114    /// Get coefficient of a variable (0 if not present)
115    pub fn get_coeff(&self, var: i64) -> i64 {
116        self.coeffs.get(&var).copied().unwrap_or(0)
117    }
118}
119
120/// Integer constraint representing `expr <= 0` or `expr < 0`.
121///
122/// For integers, strict inequalities can be converted to non-strict:
123/// `x < k` is equivalent to `x <= k - 1`.
124#[derive(Debug, Clone)]
125pub struct IntConstraint {
126    /// The linear expression (constraint is expr OP 0).
127    pub expr: IntExpr,
128    /// If true, this is a strict inequality (`< 0`).
129    /// If false, this is a non-strict inequality (`<= 0`).
130    pub strict: bool,
131}
132
133impl IntConstraint {
134    /// Check if a constant constraint is satisfied
135    pub fn is_satisfied_constant(&self) -> bool {
136        if !self.expr.is_constant() {
137            return true; // Can't determine yet
138        }
139        let c = self.expr.constant;
140        if self.strict {
141            c < 0 // c < 0
142        } else {
143            c <= 0 // c ≤ 0
144        }
145    }
146
147    /// Normalize by GCD of all coefficients
148    pub fn normalize(&mut self) {
149        let g = self
150            .expr
151            .coeffs
152            .values()
153            .chain(std::iter::once(&self.expr.constant))
154            .filter(|&&x| x != 0)
155            .fold(0i64, |a, &b| gcd(a.abs(), b.abs()));
156
157        if g > 1 {
158            self.expr.constant /= g;
159            for v in self.expr.coeffs.values_mut() {
160                *v /= g;
161            }
162        }
163    }
164}
165
166/// GCD using Euclidean algorithm
167fn gcd(a: i64, b: i64) -> i64 {
168    if b == 0 {
169        a.max(1)
170    } else {
171        gcd(b, a % b)
172    }
173}
174
175/// Reify a Syntax term to an integer linear expression.
176///
177/// Converts the deep embedding (Syntax) into an integer linear expression.
178/// Similar to [`crate::lia::reify_linear`] but produces integer coefficients.
179///
180/// # Supported Forms
181///
182/// - `SLit n` - Integer literal becomes a constant
183/// - `SVar i` - De Bruijn variable becomes a linear variable
184/// - `SName "x"` - Named global becomes a linear variable (hashed)
185/// - `add`, `sub`, `mul` - Arithmetic operations (mul only if one operand is constant)
186///
187/// # Returns
188///
189/// `Some(expr)` on success, `None` if the term is non-linear or malformed.
190pub fn reify_int_linear(term: &Term) -> Option<IntExpr> {
191    // SLit n -> constant
192    if let Some(n) = extract_slit(term) {
193        return Some(IntExpr::constant(n));
194    }
195
196    // SVar i -> variable
197    if let Some(i) = extract_svar(term) {
198        return Some(IntExpr::var(i));
199    }
200
201    // SName "x" -> named variable (global constant treated as free variable)
202    if let Some(name) = extract_sname(term) {
203        let hash = name_to_var_index(&name);
204        return Some(IntExpr::var(hash));
205    }
206
207    // Binary operations
208    if let Some((op, a, b)) = extract_binary_app(term) {
209        match op.as_str() {
210            "add" => {
211                let la = reify_int_linear(&a)?;
212                let lb = reify_int_linear(&b)?;
213                return Some(la.add(&lb));
214            }
215            "sub" => {
216                let la = reify_int_linear(&a)?;
217                let lb = reify_int_linear(&b)?;
218                return Some(la.sub(&lb));
219            }
220            "mul" => {
221                let la = reify_int_linear(&a)?;
222                let lb = reify_int_linear(&b)?;
223                // Only linear if one side is constant
224                if la.is_constant() {
225                    return Some(lb.scale(la.constant));
226                }
227                if lb.is_constant() {
228                    return Some(la.scale(lb.constant));
229                }
230                return None; // Non-linear
231            }
232            _ => return None,
233        }
234    }
235
236    None
237}
238
239/// Extract comparison from goal: (SApp (SApp (SName "Lt"|"Le"|"Gt"|"Ge") lhs) rhs)
240pub fn extract_comparison(term: &Term) -> Option<(String, Term, Term)> {
241    if let Some((rel, lhs, rhs)) = extract_binary_app(term) {
242        match rel.as_str() {
243            "Lt" | "Le" | "Gt" | "Ge" | "lt" | "le" | "gt" | "ge" => {
244                return Some((rel, lhs, rhs));
245            }
246            _ => {}
247        }
248    }
249    None
250}
251
252/// Convert a goal to constraints for validity checking using integer semantics.
253///
254/// Key difference from lia: strict inequalities are converted for integers.
255/// - x < k becomes x <= k - 1 (since x must be an integer)
256/// - x > k becomes x >= k + 1
257///
258/// To prove a goal is valid, we check if its negation is unsatisfiable.
259pub fn goal_to_negated_constraint(rel: &str, lhs: &IntExpr, rhs: &IntExpr) -> Option<IntConstraint> {
260    // diff = lhs - rhs
261    let diff = lhs.sub(rhs);
262
263    match rel {
264        // Lt: a < b valid iff NOT(a >= b)
265        // For integers: a >= b means a - b >= 0
266        // We check if a - b >= 0 is satisfiable
267        // Constraint form for unsatisfiability check: -(a - b) <= 0, i.e., (b - a) <= 0
268        "Lt" | "lt" => Some(IntConstraint {
269            expr: rhs.sub(lhs),
270            strict: false,
271        }),
272
273        // Le: a <= b valid iff NOT(a > b)
274        // For integers: a > b means a - b >= 1 (strict to non-strict!)
275        // So negation is: a - b >= 1, i.e., a - b - 1 >= 0
276        // Constraint: -(a - b - 1) <= 0, i.e., (b - a + 1) <= 0
277        // Equivalently: (b - a) <= -1
278        "Le" | "le" => {
279            let mut expr = rhs.sub(lhs);
280            expr.constant += 1; // b - a + 1 <= 0
281            Some(IntConstraint {
282                expr,
283                strict: false,
284            })
285        }
286
287        // Gt: a > b valid iff NOT(a <= b)
288        // For integers: a <= b means a - b <= 0
289        // Constraint: (a - b) <= 0
290        "Gt" | "gt" => Some(IntConstraint {
291            expr: diff,
292            strict: false,
293        }),
294
295        // Ge: a >= b valid iff NOT(a < b)
296        // For integers: a < b means a - b <= -1 (strict to non-strict!)
297        // Constraint: (a - b) <= -1, i.e., (a - b + 1) <= 0
298        "Ge" | "ge" => {
299            let mut expr = diff;
300            expr.constant += 1; // (a - b + 1) <= 0
301            Some(IntConstraint {
302                expr,
303                strict: false,
304            })
305        }
306
307        _ => None,
308    }
309}
310
311/// Check if integer constraints are unsatisfiable using the Omega test.
312///
313/// This is the main entry point for the omega decision procedure. It uses
314/// integer-aware Fourier-Motzkin elimination to check for contradictions.
315///
316/// # Integer Semantics
317///
318/// Unlike rational Fourier-Motzkin, this procedure:
319/// - Normalizes constraints by their GCD
320/// - Handles strict inequalities by integer shift (`< k` becomes `<= k-1`)
321/// - Detects integer-specific unsatisfiability
322///
323/// # Returns
324///
325/// - `true` if no integer assignment satisfies all constraints (unsatisfiable)
326/// - `false` if the constraints may be satisfiable
327///
328/// # Usage for Validity
329///
330/// To prove a goal G is valid over integers, check if NOT(G) is unsatisfiable.
331/// If `omega_unsat(negation_constraints)` returns true, the goal is valid.
332pub fn omega_unsat(constraints: &[IntConstraint]) -> bool {
333    if constraints.is_empty() {
334        return false;
335    }
336
337    // Normalize all constraints
338    let mut current: Vec<IntConstraint> = constraints.to_vec();
339    for c in &mut current {
340        c.normalize();
341    }
342
343    // Check for immediate contradictions
344    for c in &current {
345        if c.expr.is_constant() && !c.is_satisfied_constant() {
346            return true;
347        }
348    }
349
350    // Collect all variables
351    let vars: Vec<i64> = current
352        .iter()
353        .flat_map(|c| c.expr.coeffs.keys().copied())
354        .collect::<HashSet<_>>()
355        .into_iter()
356        .collect();
357
358    // Eliminate each variable
359    for var in vars {
360        current = eliminate_variable_int(&current, var);
361
362        // Early termination: check for constant contradictions
363        for c in &current {
364            if c.expr.is_constant() && !c.is_satisfied_constant() {
365                return true;
366            }
367        }
368    }
369
370    // Check all remaining constant constraints
371    current
372        .iter()
373        .any(|c| c.expr.is_constant() && !c.is_satisfied_constant())
374}
375
376/// Eliminate a variable from constraints using integer-aware Fourier-Motzkin.
377fn eliminate_variable_int(constraints: &[IntConstraint], var: i64) -> Vec<IntConstraint> {
378    let mut lower: Vec<(IntExpr, i64)> = vec![]; // (rest, |coeff|) for lower bounds
379    let mut upper: Vec<(IntExpr, i64)> = vec![]; // (rest, coeff) for upper bounds
380    let mut independent: Vec<IntConstraint> = vec![];
381
382    for c in constraints {
383        let coeff = c.expr.get_coeff(var);
384        if coeff == 0 {
385            independent.push(c.clone());
386        } else {
387            // c.expr = coeff*var + rest <= 0
388            let mut rest = c.expr.clone();
389            rest.coeffs.remove(&var);
390
391            if coeff > 0 {
392                // coeff*var + rest <= 0
393                // var <= -rest/coeff (upper bound)
394                upper.push((rest, coeff));
395            } else {
396                // coeff*var + rest <= 0, coeff < 0
397                // |coeff|*(-var) + rest <= 0
398                // -var <= -rest/|coeff|
399                // var >= rest/|coeff| (lower bound)
400                lower.push((rest, -coeff));
401            }
402        }
403    }
404
405    // Combine lower and upper bounds
406    // If lo/a <= var <= -hi/b, then lo/a <= -hi/b
407    // Multiply out: b*lo <= -a*hi
408    // Rearrange: b*lo + a*hi <= 0
409    for (lo_rest, lo_coeff) in &lower {
410        for (hi_rest, hi_coeff) in &upper {
411            // Lower: var >= lo_rest / lo_coeff (lo_coeff is positive)
412            // Upper: var <= -hi_rest / hi_coeff (hi_coeff is positive)
413            // Combined: lo_rest / lo_coeff <= -hi_rest / hi_coeff
414            // => hi_coeff * lo_rest <= -lo_coeff * hi_rest
415            // => hi_coeff * lo_rest + lo_coeff * hi_rest <= 0
416            let new_expr = lo_rest.scale(*hi_coeff).add(&hi_rest.scale(*lo_coeff));
417
418            let mut new_constraint = IntConstraint {
419                expr: new_expr,
420                strict: false,
421            };
422            new_constraint.normalize();
423            independent.push(new_constraint);
424        }
425    }
426
427    independent
428}
429
430// =============================================================================
431// Helper functions for extracting Syntax patterns
432// =============================================================================
433
434/// Extract integer from SLit n
435fn extract_slit(term: &Term) -> Option<i64> {
436    if let Term::App(ctor, arg) = term {
437        if let Term::Global(name) = ctor.as_ref() {
438            if name == "SLit" {
439                if let Term::Lit(Literal::Int(n)) = arg.as_ref() {
440                    return Some(*n);
441                }
442            }
443        }
444    }
445    None
446}
447
448/// Extract variable index from SVar i
449fn extract_svar(term: &Term) -> Option<i64> {
450    if let Term::App(ctor, arg) = term {
451        if let Term::Global(name) = ctor.as_ref() {
452            if name == "SVar" {
453                if let Term::Lit(Literal::Int(i)) = arg.as_ref() {
454                    return Some(*i);
455                }
456            }
457        }
458    }
459    None
460}
461
462/// Extract name from SName "x"
463fn extract_sname(term: &Term) -> Option<String> {
464    if let Term::App(ctor, arg) = term {
465        if let Term::Global(name) = ctor.as_ref() {
466            if name == "SName" {
467                if let Term::Lit(Literal::Text(s)) = arg.as_ref() {
468                    return Some(s.clone());
469                }
470            }
471        }
472    }
473    None
474}
475
476/// Extract binary application: SApp (SApp (SName "op") a) b
477fn extract_binary_app(term: &Term) -> Option<(String, Term, Term)> {
478    if let Term::App(outer, b) = term {
479        if let Term::App(sapp_outer, inner) = outer.as_ref() {
480            if let Term::Global(ctor) = sapp_outer.as_ref() {
481                if ctor == "SApp" {
482                    if let Term::App(partial, a) = inner.as_ref() {
483                        if let Term::App(sapp_inner, op_term) = partial.as_ref() {
484                            if let Term::Global(ctor2) = sapp_inner.as_ref() {
485                                if ctor2 == "SApp" {
486                                    if let Some(op) = extract_sname(op_term) {
487                                        return Some((
488                                            op,
489                                            a.as_ref().clone(),
490                                            b.as_ref().clone(),
491                                        ));
492                                    }
493                                }
494                            }
495                        }
496                    }
497                }
498            }
499        }
500    }
501    None
502}
503
504/// Convert a name to a unique negative variable index
505fn name_to_var_index(name: &str) -> i64 {
506    let hash: i64 = name
507        .bytes()
508        .fold(0i64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as i64));
509    -(hash.abs() + 1_000_000)
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515
516    #[test]
517    fn test_int_expr_add() {
518        let x = IntExpr::var(0);
519        let y = IntExpr::var(1);
520        let sum = x.add(&y);
521        assert!(!sum.is_constant());
522        assert_eq!(sum.get_coeff(0), 1);
523        assert_eq!(sum.get_coeff(1), 1);
524    }
525
526    #[test]
527    fn test_int_expr_cancel() {
528        let x = IntExpr::var(0);
529        let neg_x = x.neg();
530        let zero = x.add(&neg_x);
531        assert!(zero.is_constant());
532        assert_eq!(zero.constant, 0);
533    }
534
535    #[test]
536    fn test_constraint_satisfied() {
537        // -1 <= 0 is satisfied
538        let c1 = IntConstraint {
539            expr: IntExpr::constant(-1),
540            strict: false,
541        };
542        assert!(c1.is_satisfied_constant());
543
544        // 1 <= 0 is NOT satisfied
545        let c2 = IntConstraint {
546            expr: IntExpr::constant(1),
547            strict: false,
548        };
549        assert!(!c2.is_satisfied_constant());
550
551        // 0 <= 0 is satisfied
552        let c3 = IntConstraint {
553            expr: IntExpr::constant(0),
554            strict: false,
555        };
556        assert!(c3.is_satisfied_constant());
557    }
558
559    #[test]
560    fn test_omega_constant() {
561        // 1 <= 0 is unsat
562        let constraints = vec![IntConstraint {
563            expr: IntExpr::constant(1),
564            strict: false,
565        }];
566        assert!(omega_unsat(&constraints));
567
568        // -1 <= 0 is sat
569        let constraints2 = vec![IntConstraint {
570            expr: IntExpr::constant(-1),
571            strict: false,
572        }];
573        assert!(!omega_unsat(&constraints2));
574    }
575
576    #[test]
577    fn test_x_lt_x_plus_1() {
578        // x < x + 1 is always true for integers
579        // To prove: negation x >= x + 1 is unsat
580        // x >= x + 1 means x - x >= 1 means 0 >= 1 which is false
581
582        // Negation constraint: (x+1) - x <= 0 = 1 <= 0
583        let constraint = IntConstraint {
584            expr: IntExpr::constant(1),
585            strict: false,
586        };
587        assert!(omega_unsat(&[constraint]));
588    }
589}