#![allow(dead_code)]
#[allow(unused_imports)]
use crate::prelude::*;
use oxiz_core::ast::{TermId, TermManager};
pub struct ClauseLearner {
impl_graph: ImplicationGraph,
learned_db: LearnedDatabase,
minimizer: ClauseMinimizer,
config: ClauseLearningConfig,
stats: ClauseLearningStats,
}
#[derive(Debug, Clone)]
pub struct ImplicationGraph {
nodes: FxHashMap<TermId, ImplicationNode>,
predecessors: FxHashMap<TermId, Vec<TermId>>,
levels: FxHashMap<TermId, usize>,
current_level: usize,
}
#[derive(Debug, Clone)]
pub struct ImplicationNode {
pub var: TermId,
pub value: bool,
pub level: usize,
pub reason: Option<ClauseId>,
pub is_decision: bool,
}
pub type ClauseId = usize;
#[derive(Debug, Clone)]
pub struct LearnedDatabase {
clauses: Vec<LearnedClause>,
activity: Vec<f64>,
clause_map: FxHashMap<Vec<TermId>, ClauseId>,
bump_increment: f64,
decay_factor: f64,
}
#[derive(Debug, Clone)]
pub struct LearnedClause {
pub literals: Vec<TermId>,
pub asserting_lit: TermId,
pub backtrack_level: usize,
pub activity: f64,
pub locked: bool,
pub lbd: usize,
}
#[derive(Debug, Clone)]
pub struct ClauseMinimizer {
seen: FxHashSet<TermId>,
analyze_stack: Vec<TermId>,
cache: FxHashMap<TermId, bool>,
}
#[derive(Debug, Clone)]
pub struct ClauseLearningConfig {
pub enable_minimization: bool,
pub enable_recursive_minimization: bool,
pub enable_subsumption: bool,
pub enable_strengthening: bool,
pub max_learned_size: usize,
pub lbd_threshold: usize,
pub activity_decay: f64,
}
impl Default for ClauseLearningConfig {
fn default() -> Self {
Self {
enable_minimization: true,
enable_recursive_minimization: true,
enable_subsumption: true,
enable_strengthening: true,
max_learned_size: 1000,
lbd_threshold: 5,
activity_decay: 0.95,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ClauseLearningStats {
pub conflicts_analyzed: usize,
pub clauses_learned: usize,
pub literals_before_minimization: usize,
pub literals_after_minimization: usize,
pub clauses_subsumed: usize,
pub clauses_strengthened: usize,
pub uip_computations: usize,
pub db_reductions: usize,
}
impl ClauseLearner {
pub fn new(config: ClauseLearningConfig) -> Self {
Self {
impl_graph: ImplicationGraph::new(),
learned_db: LearnedDatabase::new(config.activity_decay),
minimizer: ClauseMinimizer::new(),
config,
stats: ClauseLearningStats::default(),
}
}
pub fn analyze_conflict(
&mut self,
conflict_clause: ClauseId,
_tm: &TermManager,
) -> Result<LearnedClause, String> {
self.stats.conflicts_analyzed += 1;
let conflict_lits = self.get_clause_literals(conflict_clause)?;
let (learned_lits, asserting_lit, backtrack_level) =
self.compute_first_uip(&conflict_lits)?;
self.stats.uip_computations += 1;
self.stats.literals_before_minimization += learned_lits.len();
let minimized_lits = if self.config.enable_minimization {
self.minimize_clause(&learned_lits)?
} else {
learned_lits
};
self.stats.literals_after_minimization += minimized_lits.len();
let lbd = self.compute_lbd(&minimized_lits);
let learned = LearnedClause {
literals: minimized_lits,
asserting_lit,
backtrack_level,
activity: 0.0,
locked: false,
lbd,
};
self.stats.clauses_learned += 1;
self.learned_db.add_clause(learned.clone());
Ok(learned)
}
fn compute_first_uip(
&mut self,
conflict_lits: &[TermId],
) -> Result<(Vec<TermId>, TermId, usize), String> {
let current_level = self.impl_graph.current_level;
let mut clause = conflict_lits.to_vec();
let mut seen = FxHashSet::default();
let mut counter = 0;
for &lit in &clause {
if self.impl_graph.get_level(lit) == current_level {
counter += 1;
}
seen.insert(lit);
}
let mut asserting_lit = TermId::from(0);
while counter > 1 {
let resolve_lit = clause
.iter()
.copied()
.find(|&lit| {
self.impl_graph.get_level(lit) == current_level
&& !self.impl_graph.is_decision(lit)
})
.ok_or("No literal to resolve on")?;
let reason = self
.impl_graph
.get_reason(resolve_lit)
.ok_or("No reason for propagated literal")?;
let reason_lits = self.get_clause_literals(reason)?;
clause.retain(|&lit| lit != resolve_lit);
counter -= 1;
for &reason_lit in &reason_lits {
if reason_lit != resolve_lit && !seen.contains(&reason_lit) {
clause.push(reason_lit);
seen.insert(reason_lit);
if self.impl_graph.get_level(reason_lit) == current_level {
counter += 1;
}
}
}
}
for &lit in &clause {
if self.impl_graph.get_level(lit) == current_level {
asserting_lit = lit;
break;
}
}
let mut levels: Vec<usize> = clause
.iter()
.map(|&lit| self.impl_graph.get_level(lit))
.collect();
levels.sort_unstable();
levels.dedup();
let backtrack_level = if levels.len() > 1 {
levels[levels.len() - 2]
} else {
0
};
Ok((clause, asserting_lit, backtrack_level))
}
fn minimize_clause(&mut self, clause: &[TermId]) -> Result<Vec<TermId>, String> {
if !self.config.enable_minimization {
return Ok(clause.to_vec());
}
let mut minimized = clause.to_vec();
minimized.retain(|&lit| !self.is_redundant(lit, clause));
if self.config.enable_recursive_minimization {
minimized = self.recursive_minimize(&minimized)?;
}
Ok(minimized)
}
fn is_redundant(&mut self, lit: TermId, clause: &[TermId]) -> bool {
if let Some(reason) = self.impl_graph.get_reason(lit)
&& let Ok(reason_lits) = self.get_clause_literals(reason)
{
return reason_lits
.iter()
.all(|&r_lit| r_lit == lit || clause.contains(&r_lit));
}
false
}
fn recursive_minimize(&mut self, clause: &[TermId]) -> Result<Vec<TermId>, String> {
self.minimizer.seen.clear();
self.minimizer.analyze_stack.clear();
for &lit in clause {
self.minimizer.seen.insert(lit);
}
let mut minimized = Vec::new();
for &lit in clause {
if !self.minimizer.can_remove(lit, &self.impl_graph)? {
minimized.push(lit);
}
}
Ok(minimized)
}
fn compute_lbd(&self, clause: &[TermId]) -> usize {
let mut levels = FxHashSet::default();
for &lit in clause {
let level = self.impl_graph.get_level(lit);
levels.insert(level);
}
levels.len()
}
fn get_clause_literals(&self, _clause_id: ClauseId) -> Result<Vec<TermId>, String> {
Ok(vec![])
}
pub fn subsume_clauses(&mut self) -> Result<(), String> {
if !self.config.enable_subsumption {
return Ok(());
}
let mut to_remove = Vec::new();
for i in 0..self.learned_db.clauses.len() {
for j in (i + 1)..self.learned_db.clauses.len() {
if self.learned_db.clauses[i].locked || self.learned_db.clauses[j].locked {
continue;
}
let clause_i = &self.learned_db.clauses[i].literals;
let clause_j = &self.learned_db.clauses[j].literals;
if Self::subsumes(clause_i, clause_j) {
to_remove.push(j);
self.stats.clauses_subsumed += 1;
} else if Self::subsumes(clause_j, clause_i) {
to_remove.push(i);
self.stats.clauses_subsumed += 1;
break;
}
}
}
to_remove.sort_unstable();
to_remove.dedup();
for &idx in to_remove.iter().rev() {
self.learned_db.clauses.remove(idx);
self.learned_db.activity.remove(idx);
}
Ok(())
}
fn subsumes(a: &[TermId], b: &[TermId]) -> bool {
if a.len() > b.len() {
return false;
}
let b_set: FxHashSet<TermId> = b.iter().copied().collect();
a.iter().all(|lit| b_set.contains(lit))
}
pub fn strengthen_clauses(&mut self) -> Result<(), String> {
if !self.config.enable_strengthening {
return Ok(());
}
Ok(())
}
fn can_remove_literal(&self, _lit: TermId, _clause: &[TermId]) -> bool {
false
}
pub fn reduce_database(&mut self) -> Result<(), String> {
self.stats.db_reductions += 1;
self.learned_db.reduce();
Ok(())
}
pub fn bump_clause(&mut self, clause_id: ClauseId) {
self.learned_db.bump_activity(clause_id);
}
pub fn stats(&self) -> &ClauseLearningStats {
&self.stats
}
}
impl ImplicationGraph {
pub fn new() -> Self {
Self {
nodes: FxHashMap::default(),
predecessors: FxHashMap::default(),
levels: FxHashMap::default(),
current_level: 0,
}
}
pub fn add_node(
&mut self,
var: TermId,
value: bool,
level: usize,
reason: Option<ClauseId>,
is_decision: bool,
) {
self.nodes.insert(
var,
ImplicationNode {
var,
value,
level,
reason,
is_decision,
},
);
self.levels.insert(var, level);
}
pub fn get_level(&self, var: TermId) -> usize {
self.levels.get(&var).copied().unwrap_or(0)
}
pub fn is_decision(&self, var: TermId) -> bool {
self.nodes.get(&var).is_some_and(|n| n.is_decision)
}
pub fn get_reason(&self, var: TermId) -> Option<ClauseId> {
self.nodes.get(&var).and_then(|n| n.reason)
}
pub fn set_level(&mut self, level: usize) {
self.current_level = level;
}
}
impl LearnedDatabase {
pub fn new(decay_factor: f64) -> Self {
Self {
clauses: Vec::new(),
activity: Vec::new(),
clause_map: FxHashMap::default(),
bump_increment: 1.0,
decay_factor,
}
}
pub fn add_clause(&mut self, clause: LearnedClause) {
let clause_id = self.clauses.len();
self.clause_map.insert(clause.literals.clone(), clause_id);
self.activity.push(clause.activity);
self.clauses.push(clause);
}
pub fn bump_activity(&mut self, clause_id: ClauseId) {
if clause_id < self.activity.len() {
self.activity[clause_id] += self.bump_increment;
if self.activity[clause_id] > 1e20 {
for act in &mut self.activity {
*act *= 1e-20;
}
self.bump_increment *= 1e-20;
}
}
}
pub fn decay(&mut self) {
self.bump_increment /= self.decay_factor;
}
pub fn reduce(&mut self) {
let mut sorted_indices: Vec<usize> = (0..self.clauses.len()).collect();
sorted_indices.sort_by(|&a, &b| {
self.activity[b]
.partial_cmp(&self.activity[a])
.unwrap_or(core::cmp::Ordering::Equal)
});
let keep_count = self.clauses.len() / 2;
let mut to_keep = FxHashSet::default();
for &idx in sorted_indices.iter().take(keep_count) {
to_keep.insert(idx);
}
for (idx, clause) in self.clauses.iter().enumerate() {
if clause.locked {
to_keep.insert(idx);
}
}
let mut new_clauses = Vec::new();
let mut new_activity = Vec::new();
for (idx, clause) in self.clauses.iter().enumerate() {
if to_keep.contains(&idx) {
new_clauses.push(clause.clone());
new_activity.push(self.activity[idx]);
}
}
self.clauses = new_clauses;
self.activity = new_activity;
self.clause_map.clear();
for (idx, clause) in self.clauses.iter().enumerate() {
self.clause_map.insert(clause.literals.clone(), idx);
}
}
}
impl ClauseMinimizer {
pub fn new() -> Self {
Self {
seen: FxHashSet::default(),
analyze_stack: Vec::new(),
cache: FxHashMap::default(),
}
}
fn can_remove(&mut self, _lit: TermId, _graph: &ImplicationGraph) -> Result<bool, String> {
Ok(false)
}
}
impl Default for ClauseLearner {
fn default() -> Self {
Self::new(ClauseLearningConfig::default())
}
}
impl Default for ImplicationGraph {
fn default() -> Self {
Self::new()
}
}
impl Default for ClauseMinimizer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_clause_learner() {
let learner = ClauseLearner::default();
assert_eq!(learner.stats.conflicts_analyzed, 0);
}
#[test]
fn test_implication_graph() {
let mut graph = ImplicationGraph::new();
let var = TermId::from(1);
graph.add_node(var, true, 1, None, true);
assert_eq!(graph.get_level(var), 1);
assert!(graph.is_decision(var));
}
#[test]
fn test_learned_database() {
let mut db = LearnedDatabase::new(0.95);
let clause = LearnedClause {
literals: vec![TermId::from(1), TermId::from(2)],
asserting_lit: TermId::from(1),
backtrack_level: 0,
activity: 0.0,
locked: false,
lbd: 2,
};
db.add_clause(clause);
assert_eq!(db.clauses.len(), 1);
}
#[test]
fn test_subsumption() {
let a = vec![TermId::from(1), TermId::from(2)];
let b = vec![TermId::from(1), TermId::from(2), TermId::from(3)];
assert!(ClauseLearner::subsumes(&a, &b));
assert!(!ClauseLearner::subsumes(&b, &a));
}
#[test]
fn test_lbd_computation() {
let learner = ClauseLearner::default();
let clause = vec![TermId::from(1), TermId::from(2), TermId::from(3)];
let lbd = learner.compute_lbd(&clause);
assert_eq!(lbd, 1);
}
}