#[allow(unused_imports)]
use crate::prelude::*;
use crate::*;
pub fn check_trail_consistency(solver: &Solver) -> Result<(), String> {
if solver.trail_size() < 0 {
return Err("Trail size is negative".to_string());
}
for i in 0..solver.trail_size() as usize {
let entry = solver.get_trail_entry(i)?;
if entry.decision_level > solver.decision_level() as usize {
return Err(format!(
"Trail entry {} has decision level {} > current level {}",
i,
entry.decision_level,
solver.decision_level()
));
}
if entry.decision_level < 0 {
return Err(format!(
"Trail entry {} has negative decision level {}",
i, entry.decision_level
));
}
}
let mut last_level = 0;
for i in 0..solver.trail_size() as usize {
let entry = solver.get_trail_entry(i)?;
if entry.decision_level < last_level {
return Err(format!(
"Trail entry {} has decision level {} < previous {}",
i, entry.decision_level, last_level
));
}
last_level = entry.decision_level;
}
Ok(())
}
pub fn check_decision_level_consistency(solver: &Solver) -> Result<(), String> {
let level = solver.decision_level();
if level < 0 {
return Err(format!("Decision level is negative: {}", level));
}
if level > solver.trail_size() {
return Err(format!(
"Decision level {} exceeds trail size {}",
level,
solver.trail_size()
));
}
if level > 0 {
let has_entries_at_level = (0..solver.trail_size() as usize)
.any(|i| solver.get_trail_entry(i).ok().map(|e| e.decision_level) == Some(level as usize));
if !has_entries_at_level {
return Err(format!(
"Decision level {} has no trail entries",
level
));
}
}
Ok(())
}
pub fn check_clause_database_consistency(solver: &Solver) -> Result<(), String> {
let num_clauses = solver.num_clauses();
if num_clauses < 0 {
return Err(format!("Number of clauses is negative: {}", num_clauses));
}
for i in 0..num_clauses as usize {
let clause = solver.get_clause(i)?;
if clause.literals.is_empty() {
if !matches!(solver.status(), SolverStatus::Unsat) {
return Err(format!("Empty clause {} in non-UNSAT state", i));
}
}
for j in 0..clause.literals.len() {
for k in (j + 1)..clause.literals.len() {
let lit_j = clause.literals[j];
let lit_k = clause.literals[k];
if lit_j.var() == lit_k.var() && lit_j.is_negated() != lit_k.is_negated() {
return Err(format!(
"Clause {} contains tautology: {:?} and {:?}",
i, lit_j, lit_k
));
}
}
}
for j in 0..clause.literals.len() {
for k in (j + 1)..clause.literals.len() {
if clause.literals[j] == clause.literals[k] {
return Err(format!(
"Clause {} contains duplicate literal: {:?}",
i, clause.literals[j]
));
}
}
}
if clause.is_learned {
if clause.lbd == 0 {
return Err(format!("Learned clause {} has zero LBD", i));
}
if clause.lbd as usize > clause.literals.len() {
return Err(format!(
"Learned clause {} has LBD {} > clause length {}",
i,
clause.lbd,
clause.literals.len()
));
}
}
}
Ok(())
}
pub fn check_variable_assignment_consistency(solver: &Solver) -> Result<(), String> {
let num_vars = solver.num_variables();
for var_id in 0..num_vars {
let assignment = solver.get_assignment(var_id)?;
match assignment {
Assignment::Unassigned => {
}
Assignment::True | Assignment::False => {
let mut count = 0;
for i in 0..solver.trail_size() as usize {
let entry = solver.get_trail_entry(i)?;
if entry.var_id == var_id {
count += 1;
let expected = matches!(assignment, Assignment::True);
if entry.value != expected {
return Err(format!(
"Variable {} has assignment {:?} but trail entry has {}",
var_id, assignment, entry.value
));
}
}
}
if count == 0 {
return Err(format!(
"Variable {} is assigned {:?} but not on trail",
var_id, assignment
));
}
if count > 1 {
return Err(format!(
"Variable {} appears {} times on trail",
var_id, count
));
}
}
}
}
Ok(())
}
pub fn check_theory_solver_consistency(solver: &Solver) -> Result<(), String> {
if let Some(arith_solver) = solver.get_arith_theory() {
for var_id in 0..arith_solver.num_vars() {
let lower = arith_solver.get_lower_bound(var_id)?;
let upper = arith_solver.get_upper_bound(var_id)?;
if let (Some(lb), Some(ub)) = (lower, upper) {
if lb > ub {
return Err(format!(
"Arithmetic variable {} has lower bound {} > upper bound {}",
var_id, lb, ub
));
}
}
if let Some(value) = arith_solver.get_value(var_id)? {
if let Some(lb) = lower {
if value < lb {
return Err(format!(
"Arithmetic variable {} has value {} < lower bound {}",
var_id, value, lb
));
}
}
if let Some(ub) = upper {
if value > ub {
return Err(format!(
"Arithmetic variable {} has value {} > upper bound {}",
var_id, value, ub
));
}
}
}
}
if arith_solver.uses_tableau() {
arith_solver.check_tableau_consistency()?;
}
}
if let Some(eq_solver) = solver.get_equality_theory() {
for node_id in 0..eq_solver.num_nodes() {
let rep1 = eq_solver.find(node_id)?;
let rep2 = eq_solver.find(rep1)?;
if rep1 != rep2 {
return Err(format!(
"Equality node {} has inconsistent representative chain: {} -> {}",
node_id, rep1, rep2
));
}
}
eq_solver.check_congruence_closure()?;
}
Ok(())
}
pub fn check_model_validity(solver: &Solver, tm: &TermManager) -> Result<(), String> {
if !matches!(solver.status(), SolverStatus::Sat) {
return Ok(()); }
let model = solver.get_model(tm);
for i in 0..solver.num_clauses() as usize {
let clause = solver.get_clause(i)?;
if clause.is_learned {
continue; }
let mut satisfied = false;
for &lit in &clause.literals {
let var_value = model.eval(lit.to_term(tm), tm)?;
let lit_satisfied = if lit.is_negated() {
var_value == tm.mk_bool(false)
} else {
var_value == tm.mk_bool(true)
};
if lit_satisfied {
satisfied = true;
break;
}
}
if !satisfied {
return Err(format!("Clause {} is not satisfied by model", i));
}
}
solver.check_theory_model_validity(tm)?;
Ok(())
}
pub fn check_all_invariants(solver: &Solver, tm: &TermManager) -> Result<(), String> {
check_trail_consistency(solver)?;
check_decision_level_consistency(solver)?;
check_clause_database_consistency(solver)?;
check_variable_assignment_consistency(solver)?;
check_theory_solver_consistency(solver)?;
check_model_validity(solver, tm)?;
Ok(())
}
#[cfg(test)]
mod invariant_tests {
use super::*;
#[test]
fn test_empty_solver_invariants() {
let solver = Solver::new();
let tm = TermManager::new();
assert!(check_all_invariants(&solver, &tm).is_ok());
}
#[test]
fn test_simple_sat_invariants() {
let mut solver = Solver::new();
let mut tm = TermManager::new();
let p = tm.mk_var("p", tm.sorts.bool_sort);
solver.assert(p, &mut tm);
let result = solver.check(&mut tm);
assert!(matches!(result, SolverResult::Sat));
assert!(check_all_invariants(&solver, &tm).is_ok());
}
#[test]
fn test_simple_unsat_invariants() {
let mut solver = Solver::new();
let mut tm = TermManager::new();
let p = tm.mk_var("p", tm.sorts.bool_sort);
solver.assert(p, &mut tm);
solver.assert(tm.mk_not(p), &mut tm);
let result = solver.check(&mut tm);
assert!(matches!(result, SolverResult::Unsat));
}
#[test]
fn test_backtrack_invariants() {
let mut solver = Solver::new();
let mut tm = TermManager::new();
let p = tm.mk_var("p", tm.sorts.bool_sort);
solver.push();
solver.assert(p, &mut tm);
assert!(check_all_invariants(&solver, &tm).is_ok());
solver.pop();
assert!(check_all_invariants(&solver, &tm).is_ok());
}
}