logicng 0.1.0-alpha.3

A Library for Creating, Manipulating, and Solving Boolean Formulas
Documentation
use std::collections::HashSet;

use itertools::Itertools;

use crate::datastructures::Assignment;
use crate::formulas::{EncodedFormula, FormulaFactory, Variable};
use crate::solver::functions::{enumerate_models_for_formula_with_config, ModelEnumerationConfig};
use crate::solver::minisat::MiniSat;
use crate::solver::minisat_config::MiniSatConfig;
use crate::solver::minisat_config::SolverCnfMethod::{FullPgOnSolver, PgOnSolver};

fn solvers() -> [MiniSat; 5] {
    [
        MiniSat::new_with_config(MiniSatConfig::default().incremental(true)),
        MiniSat::new_with_config(MiniSatConfig::default().incremental(false)),
        MiniSat::new_with_config(MiniSatConfig::default().cnf_method(PgOnSolver)),
        MiniSat::new_with_config(MiniSatConfig::default().cnf_method(PgOnSolver).auxiliary_variables_in_models(false)),
        MiniSat::new_with_config(MiniSatConfig::default().cnf_method(FullPgOnSolver).auxiliary_variables_in_models(false)),
    ]
}

#[test]
fn test_formula_on_solver() {
    for solver in &mut solvers() {
        let f = &FormulaFactory::new();
        let mut formulas =
            vec![f.parse("A | B | C").unwrap(), f.parse("~A | ~B | ~C").unwrap(), f.parse("A | ~B").unwrap(), f.parse("A").unwrap()];
        solver.add_all(&formulas, f);
        compare_formulas(&formulas, &solver.formula_on_solver(f), f);

        formulas.push(f.parse("~A | C").unwrap());
        solver.reset();
        solver.add_all(&formulas, f);
        compare_formulas(&formulas, &solver.formula_on_solver(f), f);

        let formula = f.parse("C + D + E <= 2").unwrap();
        formulas.push(formula);
        solver.add(formula, f);
        compare_formulas(&formulas, &solver.formula_on_solver(f), f);
    }
}

#[test]
fn test_formula_on_solver_with_contradiction() {
    for solver in &mut solvers() {
        let f = &FormulaFactory::new();
        solver.add(f.parse("A").unwrap(), f);
        solver.add(f.parse("B").unwrap(), f);
        solver.add(f.parse("C & (~A | ~B)").unwrap(), f);
        assert_eq!(solver.formula_on_solver(f).as_ref(), &[f.variable("A"), f.variable("B"), f.variable("C"), f.falsum()]);

        solver.reset();
        solver.add(f.parse("A <=> B").unwrap(), f);
        solver.add(f.parse("B <=> ~A").unwrap(), f);
        let on_solver = solver.formula_on_solver(f).iter().copied().collect::<HashSet<EncodedFormula>>();
        let expected = [f.parse("A | B").unwrap(), f.parse("~A | B").unwrap(), f.parse("A | ~B").unwrap(), f.parse("~A | ~B").unwrap()]
            .iter()
            .copied()
            .collect::<HashSet<EncodedFormula>>();
        assert_eq!(on_solver.len(), expected.len());
        expected.iter().for_each(|fm| assert!(on_solver.contains(fm)));

        solver.sat();
        let on_solver = solver.formula_on_solver(f).iter().copied().collect::<HashSet<EncodedFormula>>();
        let expected = [
            f.parse("A | B").unwrap(),
            f.parse("~A | B").unwrap(),
            f.parse("A | ~B").unwrap(),
            f.parse("~A | ~B").unwrap(),
            f.parse("A").unwrap(),
            f.parse("B").unwrap(),
            f.falsum(),
        ]
        .iter()
        .copied()
        .collect::<HashSet<EncodedFormula>>();
        assert_eq!(on_solver.len(), expected.len());
        expected.iter().for_each(|fm| assert!(on_solver.contains(fm)));
    }
}

fn compare_formulas(original: &[EncodedFormula], from_solver: &[EncodedFormula], f: &FormulaFactory) {
    let vars: Box<[Variable]> = original.iter().flat_map(|formula| (*formula.variables(f)).clone()).unique().collect();
    let models1 = enumerate_models_for_formula_with_config(f.and(original), f, &ModelEnumerationConfig::default().variables(vars.clone()))
        .iter()
        .map(Assignment::from)
        .collect::<HashSet<Assignment>>();
    let models2 = enumerate_models_for_formula_with_config(f.and(from_solver), f, &ModelEnumerationConfig::default().variables(vars))
        .iter()
        .map(Assignment::from)
        .collect::<HashSet<Assignment>>();
    assert_eq!(models1.len(), models2.len());
    models1.iter().for_each(|m| assert!(models2.contains(m)));
}