#![allow(clippy::cast_possible_wrap)]
use std::collections::BTreeMap;
use std::fmt::Display;
use std::sync::Arc;
use CType::{GE, GT, LE, LT};
use crate::cardinality_constraints::CcEncoder;
use crate::datastructures::Assignment;
use crate::formulas::CType::EQ;
use crate::formulas::{EncodedFormula, FormulaFactory, Literal, Variable};
use crate::pseudo_booleans::PbEncoder;
use crate::solver::minisat::sat::Tristate;
use crate::solver::minisat::sat::Tristate::{False, True, Undef};
use crate::util::exceptions::panic_unexpected_formula_type;
use super::FormulaType;
#[derive(Hash, Eq, PartialEq, Debug, Clone, Copy)]
pub enum CType {
EQ,
GT,
GE,
LT,
LE,
}
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
pub struct CardinalityConstraint {
pub variables: Box<[Variable]>,
pub comparator: CType,
pub rhs: u64,
}
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
pub struct PbConstraint {
pub literals: Box<[Literal]>,
pub coefficients: Box<[i64]>,
pub comparator: CType,
pub rhs: i64,
}
impl Display for CType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
EQ => "=",
GT => ">",
GE => ">=",
LT => "<",
LE => "<=",
})
}
}
impl CardinalityConstraint {
pub(crate) fn new(variables: Box<[Variable]>, comparator: CType, rhs: u64) -> Self {
Self { variables, comparator, rhs }
}
pub fn is_amo(&self) -> bool {
self.comparator == LE && self.rhs == 1 || self.comparator == LT && self.rhs == 2
}
pub fn is_exo(&self) -> bool {
self.comparator == EQ && self.rhs == 1
}
pub fn encode(&self, f: &FormulaFactory) -> Arc<[EncodedFormula]> {
let index = f.ccs.lookup(self).expect("Cardinality Constraint must be present in FF");
if let Some(cached) = f.caches.cc_encoding.get(&index) {
return cached.clone();
}
let result: Arc<[_]> = Arc::from(CcEncoder::new(f.config.cc_config.clone()).encode(self, f));
if f.config.caches.cc_encoding {
f.caches.cc_encoding.insert(index, result.clone());
}
result
}
pub fn evaluate(&self, assignment: &Assignment) -> bool {
let lhs = self.variables.iter().map(|v| assignment.evaluate_lit(v.pos_lit())).filter(|b| *b).count();
evaluate_comparator(lhs as i64, self.comparator, self.rhs as i64)
}
pub fn restrict(&self, assignment: &Assignment, f: &FormulaFactory) -> EncodedFormula {
let mut satisfied = 0;
let mut remaining = vec![];
for &var in &*self.variables {
match assignment.restrict_lit(var.pos_lit()).formula_type() {
FormulaType::Lit(_) => remaining.push(var),
FormulaType::True => satisfied += 1,
FormulaType::False => {}
_ => panic_unexpected_formula_type(var.into(), Some(f)),
}
}
if satisfied > self.rhs {
match self.comparator {
EQ | LT | LE => f.falsum(),
GT | GE => f.verum(),
}
} else if satisfied == self.rhs && self.comparator == LT {
f.falsum()
} else {
f.cc(self.comparator, self.rhs - satisfied, remaining)
}
}
pub fn negate(&self, f: &FormulaFactory) -> EncodedFormula {
match self.comparator {
EQ => {
let lt = f.pbc(
LT,
self.rhs as i64,
self.variables.iter().map(Variable::pos_lit).collect::<Box<[_]>>(),
vec![1; self.variables.len()],
);
let gt = f.pbc(
GT,
self.rhs as i64,
self.variables.iter().map(Variable::pos_lit).collect::<Box<[_]>>(),
vec![1; self.variables.len()],
);
f.or(&[lt, gt])
}
GT => f.pbc(
LE,
self.rhs as i64,
self.variables.iter().map(Variable::pos_lit).collect::<Box<[_]>>(),
vec![1; self.variables.len()],
),
GE => f.pbc(
LT,
self.rhs as i64,
self.variables.iter().map(Variable::pos_lit).collect::<Box<[_]>>(),
vec![1; self.variables.len()],
),
LT => f.pbc(
GE,
self.rhs as i64,
self.variables.iter().map(Variable::pos_lit).collect::<Box<[_]>>(),
vec![1; self.variables.len()],
),
LE => f.pbc(
GT,
self.rhs as i64,
self.variables.iter().map(Variable::pos_lit).collect::<Box<[_]>>(),
vec![1; self.variables.len()],
),
}
}
pub fn to_string(&self, f: &FormulaFactory) -> String {
let lhs = self.variables.iter().map(|v| v.pos_lit().to_string(f)).collect::<Vec<String>>().join(" + ");
format!("{lhs} {} {}", self.comparator, self.rhs)
}
}
impl PbConstraint {
pub(crate) fn new(literals: Box<[Literal]>, coefficients: Box<[i64]>, comparator: CType, rhs: i64) -> Self {
Self { literals, coefficients, comparator, rhs }
}
pub fn max_weight(&self) -> i64 {
*self.coefficients.iter().max().expect("A pseudo-boolean constraint without literals should never be created.")
}
pub fn encode(&self, f: &FormulaFactory) -> Arc<[EncodedFormula]> {
let index = f.pbcs.lookup(self).expect("Pseudo-Boolean Constraint must be present in FF");
if let Some(cached) = f.caches.pbc_encoding.get(&index) {
return cached.clone();
}
let result: Arc<[_]> = PbEncoder::new(f.config.pb_config.clone()).encode(self, f);
if f.config.caches.pbc_encoding {
f.caches.pbc_encoding.insert(index, result.clone());
}
result
}
pub fn evaluate(&self, assignment: &Assignment) -> bool {
evaluate_comparator(self.evaluate_lhs(assignment), self.comparator, self.rhs)
}
pub fn restrict(&self, assignment: &Assignment, f: &FormulaFactory) -> EncodedFormula {
let mut new_lits = Vec::with_capacity(self.literals.len());
let mut new_coeffs = Vec::with_capacity(self.coefficients.len());
let mut lhs_fixed = 0;
let mut min_value = 0;
let mut max_value = 0;
for i in 0..self.literals.len() {
let lit = self.literals[i];
match assignment.restrict_lit(lit).formula_type() {
FormulaType::Lit(_) => {
new_lits.push(lit);
let coeff = self.coefficients[i];
new_coeffs.push(coeff);
if coeff > 0 {
max_value += coeff;
} else {
min_value += coeff;
};
}
FormulaType::True => lhs_fixed += self.coefficients[i],
FormulaType::False => {}
_ => unreachable!("The function `restrict_lit` can only produce `True`, `False`, or `Lit`."),
}
}
if new_lits.is_empty() {
return f.constant(evaluate_comparator(lhs_fixed, self.comparator, self.rhs));
}
let new_rhs = self.rhs - lhs_fixed;
if self.comparator != EQ {
let fixed = evaluate_coeffs(min_value, max_value, new_rhs, self.comparator);
if fixed == True {
return f.verum();
} else if fixed == False {
return f.falsum();
}
}
f.pbc(self.comparator, new_rhs, new_lits, new_coeffs)
}
pub fn negate(&self, f: &FormulaFactory) -> EncodedFormula {
match self.comparator {
EQ => {
let lt = f.pbc(LT, self.rhs, self.literals.clone(), self.coefficients.clone());
let gt = f.pbc(GT, self.rhs, self.literals.clone(), self.coefficients.clone());
f.or(&[lt, gt])
}
GT => f.pbc(LE, self.rhs, self.literals.clone(), self.coefficients.clone()),
GE => f.pbc(LT, self.rhs, self.literals.clone(), self.coefficients.clone()),
LT => f.pbc(GE, self.rhs, self.literals.clone(), self.coefficients.clone()),
LE => f.pbc(GT, self.rhs, self.literals.clone(), self.coefficients.clone()),
}
}
pub fn normalize(&self, f: &FormulaFactory) -> EncodedFormula {
let mut norm_ps = Vec::with_capacity(self.literals.len());
let mut norm_cs = Vec::with_capacity(self.literals.len());
let mut norm_rhs;
match self.comparator {
EQ => {
for i in 0..self.literals.len() {
norm_ps.push(self.literals[i]);
norm_cs.push(self.coefficients[i]);
}
norm_rhs = self.rhs;
let f1 = normalize_le(&mut norm_ps, &mut norm_cs, norm_rhs, f);
norm_ps.clear();
norm_cs.clear();
for i in 0..self.literals.len() {
norm_ps.push(self.literals[i]);
norm_cs.push(-self.coefficients[i]);
}
norm_rhs = -self.rhs;
let f2 = normalize_le(&mut norm_ps, &mut norm_cs, norm_rhs, f);
f.and(&[f1, f2])
}
LT | LE => {
for i in 0..self.literals.len() {
norm_ps.push(self.literals[i]);
norm_cs.push(self.coefficients[i]);
}
norm_rhs = if self.comparator == LE { self.rhs } else { self.rhs - 1 };
normalize_le(&mut norm_ps, &mut norm_cs, norm_rhs, f)
}
GT | GE => {
for i in 0..self.literals.len() {
norm_ps.push(self.literals[i]);
norm_cs.push(-self.coefficients[i]);
}
norm_rhs = if self.comparator == GE { -self.rhs } else { -self.rhs - 1 };
normalize_le(&mut norm_ps, &mut norm_cs, norm_rhs, f)
}
}
}
pub fn to_string(&self, f: &FormulaFactory) -> String {
let lhs = self
.literals
.iter()
.zip(self.coefficients.iter())
.map(|pair| if *pair.1 == 1 { pair.0.to_string(f) } else { format!("{}*{}", pair.1, pair.0.to_string(f)) })
.collect::<Vec<String>>()
.join(" + ");
format!("{lhs} {} {}", self.comparator, self.rhs)
}
fn evaluate_lhs(&self, assignment: &Assignment) -> i64 {
let mut lhs = 0;
for i in 0..self.literals.len() {
let lit = &self.literals[i];
if assignment.evaluate_lit(*lit) {
lhs += self.coefficients[i];
}
}
lhs
}
}
pub const fn evaluate_comparator(lhs: i64, comparator: CType, rhs: i64) -> bool {
match comparator {
EQ => lhs == rhs,
GT => lhs > rhs,
GE => lhs >= rhs,
LT => lhs < rhs,
LE => lhs <= rhs,
}
}
const fn evaluate_coeffs(min_value: i64, max_value: i64, rhs: i64, comparator: CType) -> Tristate {
let mut status = 0;
if rhs >= min_value {
status += 1;
}
if rhs > min_value {
status += 1;
}
if rhs >= max_value {
status += 1;
}
if rhs > max_value {
status += 1;
}
match comparator {
EQ => {
if status == 0 || status == 4 {
False
} else {
Undef
}
}
LE => {
if status >= 3 {
True
} else if status < 1 {
False
} else {
Undef
}
}
LT => {
if status > 3 {
True
} else if status <= 1 {
False
} else {
Undef
}
}
GE => {
if status <= 1 {
True
} else if status > 3 {
False
} else {
Undef
}
}
GT => {
if status < 1 {
True
} else if status >= 3 {
False
} else {
Undef
}
}
}
}
fn normalize_le(ps: &mut Vec<Literal>, cs: &mut Vec<i64>, rhs: i64, f: &FormulaFactory) -> EncodedFormula {
let mut c = rhs;
let mut new_size: usize = 0;
for i in 0..ps.len() {
if cs[i] != 0 {
ps[new_size] = ps[i];
cs[new_size] = cs[i];
new_size += 1;
}
}
ps.truncate(new_size);
cs.truncate(new_size);
let mut var2consts = BTreeMap::new();
for i in 0..ps.len() {
let x = ps[i].variable();
let consts = *var2consts.get(&x).unwrap_or(&(0, 0));
if ps[i].phase() {
var2consts.insert(x, (consts.0, consts.1 + cs[i]));
} else {
var2consts.insert(x, (consts.0 + cs[i], consts.1));
}
}
let mut csps = Vec::with_capacity(var2consts.len());
for (variable, (first, second)) in var2consts {
if first < second {
c -= first;
csps.push((second - first, variable.pos_lit()));
} else {
c -= second;
csps.push((first - second, variable.neg_lit()));
}
}
let mut sum = 0;
cs.clear();
ps.clear();
for (coeff, lit) in csps {
if coeff != 0 {
cs.push(coeff);
ps.push(lit);
sum += coeff;
}
}
let mut changed = true;
while changed {
changed = false;
if c < 0 {
return f.falsum();
}
if sum <= c {
return f.verum();
}
let mut div = c;
for e in &*cs {
div = gcd(div, *e);
}
if div != 0 && div != 1 {
for e in &mut *cs {
*e /= div;
}
c /= div;
changed = true;
}
}
f.pbc(LE, c, ps.clone(), cs.clone())
}
fn gcd(small: i64, big: i64) -> i64 {
if small == 0 {
big
} else {
gcd(big % small, small)
}
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use crate::formulas::pbc_cc::evaluate_coeffs;
use crate::formulas::CType::{EQ, GE, GT, LE, LT};
use crate::formulas::{FormulaFactory, ToFormula};
use crate::solver::minisat::sat::Tristate::{False, True, Undef};
#[test]
fn test_normalization() {
let f = &FormulaFactory::new();
let lits: Box<[_]> = vec![f.lit("a", true), f.lit("b", false), f.lit("c", true), f.lit("d", true), f.lit("b", false)].into();
let coeffs: Box<[_]> = vec![2, -3, 3, 0, 1].into();
let pb1 = f.pbc(EQ, 2, lits.clone(), coeffs.clone());
let pb2 = f.pbc(GE, 1, lits.clone(), coeffs.clone());
let pb3 = f.pbc(GT, 0, lits.clone(), coeffs.clone());
let pb4 = f.pbc(LE, 1, lits.clone(), coeffs.clone());
let pb5 = f.pbc(LT, 2, lits, coeffs);
assert_eq!("(2*a + 2*b + 3*c <= 4) & (2*~a + 2*~b + 3*~c <= 3)".to_formula(f), pb1.as_pbc(f).unwrap().normalize(f));
assert_eq!("2*~a + 2*~b + 3*~c <= 4".to_formula(f), pb2.as_pbc(f).unwrap().normalize(f));
assert_eq!("2*~a + 2*~b + 3*~c <= 4".to_formula(f), pb3.as_pbc(f).unwrap().normalize(f));
assert_eq!("2*a + 2*b + 3*c <= 3".to_formula(f), pb4.as_pbc(f).unwrap().normalize(f));
assert_eq!("2*a + 2*b + 3*c <= 3".to_formula(f), pb5.as_pbc(f).unwrap().normalize(f));
}
#[test]
fn test_normalization_trivial() {
let f = &FormulaFactory::new();
let lits: Box<[_]> = vec![f.lit("a", true), f.lit("b", false), f.lit("c", true), f.lit("d", true)].into();
let coeffs: Box<[_]> = vec![2, -2, 3, 0].into();
let pb1 = f.pbc(LE, 4, lits.clone(), coeffs.clone());
let pb2 = f.pbc(LE, 5, lits.clone(), coeffs.clone());
let pb3 = f.pbc(LE, 7, lits.clone(), coeffs.clone());
let pb4 = f.pbc(LE, 10, lits.clone(), coeffs.clone());
let pb5 = f.pbc(LE, -3, lits, coeffs);
assert_eq!("2*a + 2*b + 3*c <= 6".to_formula(f), pb1.as_pbc(f).unwrap().normalize(f));
assert_eq!(f.verum(), pb2.as_pbc(f).unwrap().normalize(f));
assert_eq!(f.verum(), pb3.as_pbc(f).unwrap().normalize(f));
assert_eq!(f.verum(), pb4.as_pbc(f).unwrap().normalize(f));
assert_eq!(f.falsum(), pb5.as_pbc(f).unwrap().normalize(f));
}
#[test]
fn test_normalization_simplifications() {
let f = &FormulaFactory::new();
let lits: Box<[_]> = vec![f.lit("a", true), f.lit("a", true), f.lit("c", true), f.lit("d", true)].into();
let coeffs: Box<[_]> = vec![2, -2, 4, 4].into();
let pb1 = f.pbc(LE, 4, lits, coeffs);
assert_eq!("c + d <= 1".to_formula(f), pb1.as_pbc(f).unwrap().normalize(f));
assert!(pb1.as_pbc(f).unwrap().normalize(f).is_cc());
let lits2: Box<[_]> = vec![f.lit("a", true), f.lit("a", false), f.lit("c", true), f.lit("d", true)].into();
let coeffs2: Box<[_]> = vec![2, 2, 4, 2].into();
let pb2 = f.pbc(LE, 4, lits2, coeffs2);
assert_eq!("2*c + d <= 1".to_formula(f), pb2.as_pbc(f).unwrap().normalize(f));
}
#[test]
#[allow(clippy::cognitive_complexity)]
fn test_evaluate_coeffs() {
assert_eq!(evaluate_coeffs(-2, 2, -3, EQ), False);
assert_eq!(evaluate_coeffs(-2, 2, 3, EQ), False);
assert_eq!(evaluate_coeffs(-2, 2, -2, EQ), Undef);
assert_eq!(evaluate_coeffs(-2, 2, 2, EQ), Undef);
assert_eq!(evaluate_coeffs(-2, 2, 0, EQ), Undef);
assert_eq!(evaluate_coeffs(-2, 2, -3, GE), True);
assert_eq!(evaluate_coeffs(-2, 2, 3, GE), False);
assert_eq!(evaluate_coeffs(-2, 2, -2, GE), True);
assert_eq!(evaluate_coeffs(-2, 2, 2, GE), Undef);
assert_eq!(evaluate_coeffs(-2, 2, 0, GE), Undef);
assert_eq!(evaluate_coeffs(-2, 2, -3, GT), True);
assert_eq!(evaluate_coeffs(-2, 2, 3, GT), False);
assert_eq!(evaluate_coeffs(-2, 2, -2, GT), Undef);
assert_eq!(evaluate_coeffs(-2, 2, 2, GT), False);
assert_eq!(evaluate_coeffs(-2, 2, 0, GT), Undef);
assert_eq!(evaluate_coeffs(-2, 2, -3, LE), False);
assert_eq!(evaluate_coeffs(-2, 2, 3, LE), True);
assert_eq!(evaluate_coeffs(-2, 2, -2, LE), Undef);
assert_eq!(evaluate_coeffs(-2, 2, 2, LE), True);
assert_eq!(evaluate_coeffs(-2, 2, 0, LE), Undef);
assert_eq!(evaluate_coeffs(-2, 2, -3, LT), False);
assert_eq!(evaluate_coeffs(-2, 2, 3, LT), True);
assert_eq!(evaluate_coeffs(-2, 2, -2, LT), False);
assert_eq!(evaluate_coeffs(-2, 2, 2, LT), Undef);
assert_eq!(evaluate_coeffs(-2, 2, 0, LT), Undef);
}
}