use std::collections::{BTreeMap, HashSet};
use crate::term::{Literal, Term};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Rational {
pub numerator: i64,
pub denominator: i64,
}
impl Rational {
pub fn new(n: i64, d: i64) -> Self {
if d == 0 {
panic!("Rational denominator cannot be zero");
}
let g = gcd(n.abs(), d.abs()).max(1);
let sign = if d < 0 { -1 } else { 1 };
Rational {
numerator: sign * n / g,
denominator: (d.abs()) / g,
}
}
pub fn zero() -> Self {
Rational {
numerator: 0,
denominator: 1,
}
}
pub fn from_int(n: i64) -> Self {
Rational {
numerator: n,
denominator: 1,
}
}
pub fn add(&self, other: &Rational) -> Rational {
Rational::new(
self.numerator * other.denominator + other.numerator * self.denominator,
self.denominator * other.denominator,
)
}
pub fn neg(&self) -> Rational {
Rational {
numerator: -self.numerator,
denominator: self.denominator,
}
}
pub fn sub(&self, other: &Rational) -> Rational {
self.add(&other.neg())
}
pub fn mul(&self, other: &Rational) -> Rational {
Rational::new(
self.numerator * other.numerator,
self.denominator * other.denominator,
)
}
pub fn div(&self, other: &Rational) -> Option<Rational> {
if other.numerator == 0 {
return None;
}
Some(Rational::new(
self.numerator * other.denominator,
self.denominator * other.numerator,
))
}
pub fn is_negative(&self) -> bool {
self.numerator < 0
}
pub fn is_positive(&self) -> bool {
self.numerator > 0
}
pub fn is_zero(&self) -> bool {
self.numerator == 0
}
}
fn gcd(a: i64, b: i64) -> i64 {
if b == 0 {
a
} else {
gcd(b, a % b)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LinearExpr {
pub constant: Rational,
pub coefficients: BTreeMap<i64, Rational>,
}
impl LinearExpr {
pub fn constant(c: Rational) -> Self {
LinearExpr {
constant: c,
coefficients: BTreeMap::new(),
}
}
pub fn var(idx: i64) -> Self {
let mut coeffs = BTreeMap::new();
coeffs.insert(idx, Rational::from_int(1));
LinearExpr {
constant: Rational::zero(),
coefficients: coeffs,
}
}
pub fn add(&self, other: &LinearExpr) -> LinearExpr {
let mut result = self.clone();
result.constant = result.constant.add(&other.constant);
for (var, coeff) in &other.coefficients {
let entry = result
.coefficients
.entry(*var)
.or_insert(Rational::zero());
*entry = entry.add(coeff);
if entry.is_zero() {
result.coefficients.remove(var);
}
}
result
}
pub fn neg(&self) -> LinearExpr {
LinearExpr {
constant: self.constant.neg(),
coefficients: self
.coefficients
.iter()
.map(|(v, c)| (*v, c.neg()))
.collect(),
}
}
pub fn sub(&self, other: &LinearExpr) -> LinearExpr {
self.add(&other.neg())
}
pub fn scale(&self, c: &Rational) -> LinearExpr {
if c.is_zero() {
return LinearExpr::constant(Rational::zero());
}
LinearExpr {
constant: self.constant.mul(c),
coefficients: self
.coefficients
.iter()
.map(|(v, coeff)| (*v, coeff.mul(c)))
.filter(|(_, c)| !c.is_zero())
.collect(),
}
}
pub fn is_constant(&self) -> bool {
self.coefficients.is_empty()
}
pub fn get_coeff(&self, var: i64) -> Rational {
self.coefficients
.get(&var)
.cloned()
.unwrap_or(Rational::zero())
}
}
#[derive(Debug, Clone)]
pub struct Constraint {
pub expr: LinearExpr,
pub strict: bool,
}
impl Constraint {
pub fn is_satisfied_constant(&self) -> bool {
if !self.expr.is_constant() {
return true; }
let c = &self.expr.constant;
if self.strict {
c.is_negative() } else {
!c.is_positive() }
}
}
#[derive(Debug)]
pub enum LiaError {
NonLinear(String),
MalformedTerm,
NotInequality,
}
pub fn reify_linear(term: &Term) -> Result<LinearExpr, LiaError> {
if let Some(n) = extract_slit(term) {
return Ok(LinearExpr::constant(Rational::from_int(n)));
}
if let Some(i) = extract_svar(term) {
return Ok(LinearExpr::var(i));
}
if let Some(name) = extract_sname(term) {
let hash = name_to_var_index(&name);
return Ok(LinearExpr::var(hash));
}
if let Some((op, a, b)) = extract_binary_app(term) {
match op.as_str() {
"add" => {
let la = reify_linear(&a)?;
let lb = reify_linear(&b)?;
return Ok(la.add(&lb));
}
"sub" => {
let la = reify_linear(&a)?;
let lb = reify_linear(&b)?;
return Ok(la.sub(&lb));
}
"mul" => {
let la = reify_linear(&a)?;
let lb = reify_linear(&b)?;
if la.is_constant() {
return Ok(lb.scale(&la.constant));
}
if lb.is_constant() {
return Ok(la.scale(&lb.constant));
}
return Err(LiaError::NonLinear(
"multiplication of two variables is not linear".to_string(),
));
}
"div" | "mod" => {
return Err(LiaError::NonLinear(format!(
"operation '{}' is not supported in lia",
op
)));
}
_ => {
return Err(LiaError::NonLinear(format!("unknown operation '{}'", op)));
}
}
}
Err(LiaError::MalformedTerm)
}
pub fn extract_comparison(term: &Term) -> Option<(String, Term, Term)> {
if let Some((rel, lhs, rhs)) = extract_binary_app(term) {
match rel.as_str() {
"Lt" | "Le" | "Gt" | "Ge" | "lt" | "le" | "gt" | "ge" => {
return Some((rel, lhs, rhs));
}
_ => {}
}
}
None
}
pub fn goal_to_negated_constraint(
rel: &str,
lhs: &LinearExpr,
rhs: &LinearExpr,
) -> Option<Constraint> {
let diff = lhs.sub(rhs);
match rel {
"Lt" | "lt" => {
Some(Constraint {
expr: rhs.sub(lhs),
strict: false, })
}
"Le" | "le" => {
Some(Constraint {
expr: rhs.sub(lhs),
strict: true, })
}
"Gt" | "gt" => {
Some(Constraint {
expr: diff, strict: false,
})
}
"Ge" | "ge" => {
Some(Constraint {
expr: diff, strict: true,
})
}
_ => None,
}
}
pub fn fourier_motzkin_unsat(constraints: &[Constraint]) -> bool {
if constraints.is_empty() {
return false; }
let vars: Vec<i64> = constraints
.iter()
.flat_map(|c| c.expr.coefficients.keys().copied())
.collect::<HashSet<_>>()
.into_iter()
.collect();
let mut current = constraints.to_vec();
for var in vars {
current = eliminate_variable(¤t, var);
for c in ¤t {
if c.expr.is_constant() && !c.is_satisfied_constant() {
return true; }
}
}
current.iter().any(|c| !c.is_satisfied_constant())
}
fn eliminate_variable(constraints: &[Constraint], var: i64) -> Vec<Constraint> {
let mut lower: Vec<(LinearExpr, bool)> = vec![]; let mut upper: Vec<(LinearExpr, bool)> = vec![]; let mut independent: Vec<Constraint> = vec![];
for c in constraints {
let coeff = c.expr.get_coeff(var);
if coeff.is_zero() {
independent.push(c.clone());
} else {
let mut rest = c.expr.clone();
rest.coefficients.remove(&var);
if coeff.is_positive() {
let bound = rest.neg().scale(&coeff.div(&Rational::from_int(1)).unwrap());
let bound = bound.scale(
&Rational::from_int(1)
.div(&coeff)
.unwrap_or(Rational::from_int(1)),
);
upper.push((rest.neg().scale(&coeff.div(&coeff).unwrap()), c.strict));
} else {
let abs_coeff = coeff.neg();
lower.push((rest.scale(&abs_coeff.div(&abs_coeff).unwrap()), c.strict));
}
}
}
for (lo_expr, lo_strict) in &lower {
for (hi_expr, hi_strict) in &upper {
let diff = lo_expr.sub(hi_expr);
independent.push(Constraint {
expr: diff,
strict: *lo_strict || *hi_strict,
});
}
}
independent
}
fn extract_slit(term: &Term) -> Option<i64> {
if let Term::App(ctor, arg) = term {
if let Term::Global(name) = ctor.as_ref() {
if name == "SLit" {
if let Term::Lit(Literal::Int(n)) = arg.as_ref() {
return Some(*n);
}
}
}
}
None
}
fn extract_svar(term: &Term) -> Option<i64> {
if let Term::App(ctor, arg) = term {
if let Term::Global(name) = ctor.as_ref() {
if name == "SVar" {
if let Term::Lit(Literal::Int(i)) = arg.as_ref() {
return Some(*i);
}
}
}
}
None
}
fn extract_sname(term: &Term) -> Option<String> {
if let Term::App(ctor, arg) = term {
if let Term::Global(name) = ctor.as_ref() {
if name == "SName" {
if let Term::Lit(Literal::Text(s)) = arg.as_ref() {
return Some(s.clone());
}
}
}
}
None
}
fn extract_binary_app(term: &Term) -> Option<(String, Term, Term)> {
if let Term::App(outer, b) = term {
if let Term::App(sapp_outer, inner) = outer.as_ref() {
if let Term::Global(ctor) = sapp_outer.as_ref() {
if ctor == "SApp" {
if let Term::App(partial, a) = inner.as_ref() {
if let Term::App(sapp_inner, op_term) = partial.as_ref() {
if let Term::Global(ctor2) = sapp_inner.as_ref() {
if ctor2 == "SApp" {
if let Some(op) = extract_sname(op_term) {
return Some((
op,
a.as_ref().clone(),
b.as_ref().clone(),
));
}
}
}
}
}
}
}
}
}
None
}
fn name_to_var_index(name: &str) -> i64 {
let hash: i64 = name
.bytes()
.fold(0i64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as i64));
-(hash.abs() + 1_000_000)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rational_arithmetic() {
let half = Rational::new(1, 2);
let third = Rational::new(1, 3);
let sum = half.add(&third);
assert_eq!(sum, Rational::new(5, 6));
}
#[test]
fn test_linear_expr_add() {
let x = LinearExpr::var(0);
let y = LinearExpr::var(1);
let sum = x.add(&y);
assert!(!sum.is_constant());
assert_eq!(sum.get_coeff(0), Rational::from_int(1));
assert_eq!(sum.get_coeff(1), Rational::from_int(1));
}
#[test]
fn test_linear_expr_cancel() {
let x = LinearExpr::var(0);
let neg_x = x.neg();
let zero = x.add(&neg_x);
assert!(zero.is_constant());
assert!(zero.constant.is_zero());
}
#[test]
fn test_constraint_satisfied() {
let c1 = Constraint {
expr: LinearExpr::constant(Rational::from_int(-1)),
strict: false,
};
assert!(c1.is_satisfied_constant());
let c2 = Constraint {
expr: LinearExpr::constant(Rational::from_int(1)),
strict: false,
};
assert!(!c2.is_satisfied_constant());
let c3 = Constraint {
expr: LinearExpr::constant(Rational::zero()),
strict: false,
};
assert!(c3.is_satisfied_constant());
let c4 = Constraint {
expr: LinearExpr::constant(Rational::zero()),
strict: true,
};
assert!(!c4.is_satisfied_constant());
}
#[test]
fn test_fourier_motzkin_constant() {
let constraints = vec![Constraint {
expr: LinearExpr::constant(Rational::from_int(1)),
strict: false,
}];
assert!(fourier_motzkin_unsat(&constraints));
let constraints2 = vec![Constraint {
expr: LinearExpr::constant(Rational::from_int(-1)),
strict: false,
}];
assert!(!fourier_motzkin_unsat(&constraints2));
}
#[test]
fn test_x_lt_x_plus_1() {
let x = LinearExpr::var(0);
let one = LinearExpr::constant(Rational::from_int(1));
let xp1 = x.add(&one);
let constraint = Constraint {
expr: LinearExpr::constant(Rational::from_int(1)),
strict: false,
};
assert!(fourier_motzkin_unsat(&[constraint]));
}
}