use crate::clause::{ClauseDatabase, ClauseId};
use crate::literal::{Lit, Var};
#[allow(unused_imports)]
use crate::prelude::*;
use crate::proof::DratProof;
#[derive(Debug, Default, Clone)]
pub struct DratInprocessingStats {
pub clauses_eliminated: usize,
pub variables_eliminated: usize,
pub subsumptions: usize,
pub blocked_clauses_eliminated: usize,
pub proof_steps: usize,
pub rounds: usize,
}
impl DratInprocessingStats {
pub fn display(&self) -> String {
format!(
"DRAT Inprocessing Stats:\n\
- Rounds: {}\n\
- Clauses eliminated: {}\n\
- Variables eliminated: {}\n\
- Subsumptions: {}\n\
- Blocked clauses: {}\n\
- Proof steps: {}",
self.rounds,
self.clauses_eliminated,
self.variables_eliminated,
self.subsumptions,
self.blocked_clauses_eliminated,
self.proof_steps,
)
}
}
#[derive(Debug, Clone)]
pub struct DratInprocessingConfig {
pub enable_subsumption: bool,
pub enable_variable_elimination: bool,
pub enable_blocked_clause_elimination: bool,
pub max_clause_size_for_elimination: usize,
pub max_resolution_size: usize,
}
impl Default for DratInprocessingConfig {
fn default() -> Self {
Self {
enable_subsumption: true,
enable_variable_elimination: true,
enable_blocked_clause_elimination: true,
max_clause_size_for_elimination: 10,
max_resolution_size: 100,
}
}
}
pub struct DratInprocessor {
config: DratInprocessingConfig,
stats: DratInprocessingStats,
eliminated_clauses: HashSet<ClauseId>,
eliminated_vars: HashSet<Var>,
}
impl DratInprocessor {
pub fn new(config: DratInprocessingConfig) -> Self {
Self {
config,
stats: DratInprocessingStats::default(),
eliminated_clauses: HashSet::new(),
eliminated_vars: HashSet::new(),
}
}
pub fn default_config() -> Self {
Self::new(DratInprocessingConfig::default())
}
pub fn inprocess(
&mut self,
db: &mut ClauseDatabase,
proof: &mut DratProof,
) -> std::io::Result<usize> {
self.stats.rounds += 1;
let mut simplifications = 0;
if self.config.enable_subsumption {
simplifications += self.eliminate_subsumed(db, proof)?;
}
if self.config.enable_blocked_clause_elimination {
simplifications += self.eliminate_blocked_clauses(db, proof)?;
}
if self.config.enable_variable_elimination {
simplifications += self.eliminate_variables(db, proof)?;
}
Ok(simplifications)
}
fn eliminate_subsumed(
&mut self,
db: &mut ClauseDatabase,
proof: &mut DratProof,
) -> std::io::Result<usize> {
let mut eliminated = 0;
let clause_ids: Vec<ClauseId> = db.iter_ids().collect();
for i in 0..clause_ids.len() {
let cid1 = clause_ids[i];
if self.eliminated_clauses.contains(&cid1) {
continue;
}
let clause1 = match db.get(cid1) {
Some(c) => c.lits.to_vec(),
None => continue,
};
for &cid2 in clause_ids.iter().skip(i + 1) {
if self.eliminated_clauses.contains(&cid2) {
continue;
}
let clause2 = match db.get(cid2) {
Some(c) => c.lits.to_vec(),
None => continue,
};
if Self::subsumes(&clause1, &clause2) {
proof.delete_clause(&clause2)?;
self.stats.proof_steps += 1;
self.eliminated_clauses.insert(cid2);
self.stats.subsumptions += 1;
eliminated += 1;
} else if Self::subsumes(&clause2, &clause1) {
proof.delete_clause(&clause1)?;
self.stats.proof_steps += 1;
self.eliminated_clauses.insert(cid1);
self.stats.subsumptions += 1;
eliminated += 1;
break; }
}
}
self.stats.clauses_eliminated += eliminated;
Ok(eliminated)
}
fn subsumes(clause1: &[Lit], clause2: &[Lit]) -> bool {
if clause1.len() > clause2.len() {
return false;
}
clause1.iter().all(|lit1| clause2.contains(lit1))
}
fn eliminate_blocked_clauses(
&mut self,
db: &mut ClauseDatabase,
proof: &mut DratProof,
) -> std::io::Result<usize> {
let mut eliminated = 0;
let clause_ids: Vec<ClauseId> = db.iter_ids().collect();
for cid in clause_ids {
if self.eliminated_clauses.contains(&cid) {
continue;
}
let clause = match db.get(cid) {
Some(c) => c.lits.to_vec(),
None => continue,
};
for &lit in &clause {
if self.is_blocked(db, &clause, lit) {
proof.delete_clause(&clause)?;
self.stats.proof_steps += 1;
self.eliminated_clauses.insert(cid);
self.stats.blocked_clauses_eliminated += 1;
eliminated += 1;
break;
}
}
}
self.stats.clauses_eliminated += eliminated;
Ok(eliminated)
}
fn is_blocked(&self, db: &ClauseDatabase, clause: &[Lit], lit: Lit) -> bool {
let neg_lit = lit.negate();
for other_cid in db.iter_ids() {
if self.eliminated_clauses.contains(&other_cid) {
continue;
}
let other_clause = match db.get(other_cid) {
Some(c) => &c.lits,
None => continue,
};
if !other_clause.contains(&neg_lit) {
continue;
}
let resolvent = Self::resolve(clause, other_clause, lit.var());
if !Self::is_tautology(&resolvent) {
return false; }
}
true }
fn resolve(clause1: &[Lit], clause2: &[Lit], var: Var) -> Vec<Lit> {
let mut resolvent = Vec::new();
for &lit in clause1 {
if lit.var() != var {
resolvent.push(lit);
}
}
for &lit in clause2 {
if lit.var() != var && !resolvent.contains(&lit) {
resolvent.push(lit);
}
}
resolvent
}
fn is_tautology(clause: &[Lit]) -> bool {
for &lit in clause {
if clause.contains(&lit.negate()) {
return true;
}
}
false
}
fn eliminate_variables(
&mut self,
db: &mut ClauseDatabase,
proof: &mut DratProof,
) -> std::io::Result<usize> {
let mut eliminated = 0;
let vars = self.find_elimination_candidates(db);
for var in vars {
if self.can_eliminate_variable(db, var) {
eliminated += self.eliminate_variable(db, proof, var)?;
self.eliminated_vars.insert(var);
self.stats.variables_eliminated += 1;
}
}
Ok(eliminated)
}
fn find_elimination_candidates(&self, db: &ClauseDatabase) -> Vec<Var> {
let mut var_counts: HashMap<Var, usize> = HashMap::new();
for cid in db.iter_ids() {
if self.eliminated_clauses.contains(&cid) {
continue;
}
let clause = match db.get(cid) {
Some(c) => &c.lits,
None => continue,
};
for &lit in clause {
*var_counts.entry(lit.var()).or_insert(0) += 1;
}
}
let mut candidates: Vec<_> = var_counts
.into_iter()
.filter(|(_, count)| *count <= 5) .map(|(var, _)| var)
.collect();
candidates.sort_unstable_by_key(|v| v.0);
candidates
}
fn can_eliminate_variable(&self, db: &ClauseDatabase, var: Var) -> bool {
if self.eliminated_vars.contains(&var) {
return false;
}
let (pos_clauses, neg_clauses) = self.collect_clauses_with_var(db, var);
let num_resolvents = pos_clauses.len() * neg_clauses.len();
if num_resolvents > self.config.max_resolution_size {
return false;
}
for pos_clause in &pos_clauses {
if pos_clause.len() > self.config.max_clause_size_for_elimination {
return false;
}
}
for neg_clause in &neg_clauses {
if neg_clause.len() > self.config.max_clause_size_for_elimination {
return false;
}
}
true
}
fn collect_clauses_with_var(
&self,
db: &ClauseDatabase,
var: Var,
) -> (Vec<Vec<Lit>>, Vec<Vec<Lit>>) {
let mut pos_clauses = Vec::new();
let mut neg_clauses = Vec::new();
for cid in db.iter_ids() {
if self.eliminated_clauses.contains(&cid) {
continue;
}
let clause = match db.get(cid) {
Some(c) => c.lits.to_vec(),
None => continue,
};
let pos_lit = Lit::pos(var);
let neg_lit = Lit::neg(var);
if clause.contains(&pos_lit) {
pos_clauses.push(clause);
} else if clause.contains(&neg_lit) {
neg_clauses.push(clause);
}
}
(pos_clauses, neg_clauses)
}
fn eliminate_variable(
&mut self,
db: &mut ClauseDatabase,
proof: &mut DratProof,
var: Var,
) -> std::io::Result<usize> {
let (pos_clauses, neg_clauses) = self.collect_clauses_with_var(db, var);
for pos_clause in &pos_clauses {
for neg_clause in &neg_clauses {
let resolvent = Self::resolve(pos_clause, neg_clause, var);
if Self::is_tautology(&resolvent) {
continue;
}
proof.add_clause(&resolvent)?;
self.stats.proof_steps += 1;
}
}
for clause in pos_clauses.iter().chain(neg_clauses.iter()) {
proof.delete_clause(clause)?;
self.stats.proof_steps += 1;
}
let eliminated = pos_clauses.len() + neg_clauses.len();
self.stats.clauses_eliminated += eliminated;
Ok(eliminated)
}
pub fn stats(&self) -> &DratInprocessingStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = DratInprocessingStats::default();
}
pub fn clear(&mut self) {
self.eliminated_clauses.clear();
self.eliminated_vars.clear();
self.stats = DratInprocessingStats::default();
}
}
impl Default for DratInprocessor {
fn default() -> Self {
Self::default_config()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_drat_inprocessor_creation() {
let inprocessor = DratInprocessor::default();
assert_eq!(inprocessor.stats().rounds, 0);
}
#[test]
fn test_subsumes() {
let v0 = Var(0);
let v1 = Var(1);
let clause1 = vec![Lit::pos(v0)];
let clause2 = vec![Lit::pos(v0), Lit::pos(v1)];
assert!(DratInprocessor::subsumes(&clause1, &clause2));
assert!(!DratInprocessor::subsumes(&clause2, &clause1));
}
#[test]
fn test_subsumes_equal() {
let v0 = Var(0);
let v1 = Var(1);
let clause1 = vec![Lit::pos(v0), Lit::pos(v1)];
let clause2 = vec![Lit::pos(v0), Lit::pos(v1)];
assert!(DratInprocessor::subsumes(&clause1, &clause2));
assert!(DratInprocessor::subsumes(&clause2, &clause1));
}
#[test]
fn test_resolve() {
let v0 = Var(0);
let v1 = Var(1);
let clause1 = vec![Lit::pos(v0), Lit::pos(v1)];
let clause2 = vec![Lit::neg(v0)];
let resolvent = DratInprocessor::resolve(&clause1, &clause2, v0);
assert_eq!(resolvent.len(), 1);
assert!(resolvent.contains(&Lit::pos(v1)));
}
#[test]
fn test_is_tautology() {
let v0 = Var(0);
let v1 = Var(1);
let tautology = vec![Lit::pos(v0), Lit::neg(v0), Lit::pos(v1)];
let non_tautology = vec![Lit::pos(v0), Lit::pos(v1)];
assert!(DratInprocessor::is_tautology(&tautology));
assert!(!DratInprocessor::is_tautology(&non_tautology));
}
#[test]
fn test_config_default() {
let config = DratInprocessingConfig::default();
assert!(config.enable_subsumption);
assert!(config.enable_variable_elimination);
assert!(config.enable_blocked_clause_elimination);
}
#[test]
fn test_stats_display() {
let stats = DratInprocessingStats {
clauses_eliminated: 10,
variables_eliminated: 2,
subsumptions: 5,
blocked_clauses_eliminated: 3,
proof_steps: 20,
rounds: 1,
};
let display = stats.display();
assert!(display.contains("10"));
assert!(display.contains("2"));
assert!(display.contains("5"));
}
#[test]
fn test_clear() {
let mut inprocessor = DratInprocessor::default();
inprocessor.stats.rounds = 5;
inprocessor.eliminated_vars.insert(Var(0));
inprocessor.clear();
assert_eq!(inprocessor.stats.rounds, 0);
assert!(inprocessor.eliminated_vars.is_empty());
}
}