#[allow(unused_imports)]
use crate::prelude::*;
pub type TermId = u32;
pub type TheoryId = u32;
pub type DecisionLevel = u32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Literal {
pub term: TermId,
pub polarity: bool,
}
impl Literal {
pub fn positive(term: TermId) -> Self {
Self {
term,
polarity: true,
}
}
pub fn negative(term: TermId) -> Self {
Self {
term,
polarity: false,
}
}
pub fn negate(self) -> Self {
Self {
term: self.term,
polarity: !self.polarity,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Equality {
pub lhs: TermId,
pub rhs: TermId,
}
impl Equality {
pub fn new(lhs: TermId, rhs: TermId) -> Self {
if lhs <= rhs {
Self { lhs, rhs }
} else {
Self { lhs: rhs, rhs: lhs }
}
}
}
#[derive(Debug, Clone)]
pub enum Explanation {
Given,
TheoryPropagation {
theory: TheoryId,
antecedents: Vec<Literal>,
},
EqualityPropagation {
equalities: Vec<Equality>,
support: Vec<Literal>,
},
Transitivity {
chain: Vec<Equality>,
},
Congruence {
function: TermId,
arg_equalities: Vec<Equality>,
},
}
#[derive(Debug, Clone)]
pub struct TheoryConflict {
pub theory: TheoryId,
pub literals: Vec<Literal>,
pub explanation: Explanation,
pub level: DecisionLevel,
}
#[derive(Debug, Clone)]
pub struct ConflictClause {
pub literals: Vec<Literal>,
pub uip: Option<Literal>,
pub backtrack_level: DecisionLevel,
pub theories: FxHashSet<TheoryId>,
pub activity: f64,
}
#[derive(Debug, Clone)]
pub struct ConflictAnalysis {
pub clause: ConflictClause,
pub explanation: Explanation,
pub blamed_theories: FxHashSet<TheoryId>,
}
#[derive(Debug, Clone)]
pub struct ConflictResolutionConfig {
pub enable_minimization: bool,
pub enable_uip: bool,
pub minimization_algorithm: MinimizationAlgorithm,
pub max_resolution_steps: usize,
pub track_theory_blame: bool,
pub enable_learning: bool,
}
impl Default for ConflictResolutionConfig {
fn default() -> Self {
Self {
enable_minimization: true,
enable_uip: true,
minimization_algorithm: MinimizationAlgorithm::Recursive,
max_resolution_steps: 1000,
track_theory_blame: true,
enable_learning: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MinimizationAlgorithm {
None,
Simple,
Recursive,
BinaryResolution,
}
#[derive(Debug, Clone, Default)]
pub struct ConflictResolutionStats {
pub conflicts_analyzed: u64,
pub clauses_learned: u64,
pub literals_minimized: u64,
pub uip_conflicts: u64,
pub resolution_steps: u64,
pub theory_blames: u64,
}
pub struct ConflictResolver {
config: ConflictResolutionConfig,
stats: ConflictResolutionStats,
trail: Vec<(Literal, DecisionLevel, Explanation)>,
literal_position: FxHashMap<Literal, usize>,
level_boundaries: FxHashMap<DecisionLevel, usize>,
current_level: DecisionLevel,
learned_clauses: Vec<ConflictClause>,
theory_blame: FxHashMap<TheoryId, u64>,
}
impl ConflictResolver {
pub fn new() -> Self {
Self::with_config(ConflictResolutionConfig::default())
}
pub fn with_config(config: ConflictResolutionConfig) -> Self {
let mut level_boundaries = FxHashMap::default();
level_boundaries.insert(0, 0);
Self {
config,
stats: ConflictResolutionStats::default(),
trail: Vec::new(),
literal_position: FxHashMap::default(),
level_boundaries,
current_level: 0,
learned_clauses: Vec::new(),
theory_blame: FxHashMap::default(),
}
}
pub fn stats(&self) -> &ConflictResolutionStats {
&self.stats
}
pub fn add_assignment(
&mut self,
literal: Literal,
level: DecisionLevel,
explanation: Explanation,
) {
let position = self.trail.len();
self.trail.push((literal, level, explanation));
self.literal_position.insert(literal, position);
self.level_boundaries.entry(level).or_insert(position);
}
pub fn push_decision_level(&mut self) {
self.current_level += 1;
}
pub fn backtrack(&mut self, level: DecisionLevel) -> Result<(), String> {
if level > self.current_level {
return Err("Cannot backtrack to future level".to_string());
}
let backtrack_pos = self
.level_boundaries
.get(&level)
.copied()
.unwrap_or(self.trail.len());
self.trail.truncate(backtrack_pos);
self.literal_position.clear();
for (i, &(literal, _, _)) in self.trail.iter().enumerate() {
self.literal_position.insert(literal, i);
}
self.level_boundaries.retain(|&l, _| l <= level);
self.current_level = level;
Ok(())
}
pub fn analyze_conflict(
&mut self,
conflict: TheoryConflict,
) -> Result<ConflictAnalysis, String> {
self.stats.conflicts_analyzed += 1;
if self.config.track_theory_blame {
*self.theory_blame.entry(conflict.theory).or_insert(0) += 1;
self.stats.theory_blames += 1;
}
let mut conflict_literals = conflict.literals.clone();
if self.config.enable_uip {
conflict_literals = self.find_uip(&conflict_literals, conflict.level)?;
self.stats.uip_conflicts += 1;
}
if self.config.enable_minimization {
let before_size = conflict_literals.len();
conflict_literals = self.minimize_conflict(&conflict_literals)?;
let after_size = conflict_literals.len();
self.stats.literals_minimized += (before_size - after_size) as u64;
}
let backtrack_level = self.compute_backtrack_level(&conflict_literals, conflict.level)?;
let clause = ConflictClause {
literals: conflict_literals.clone(),
uip: self.find_uip_literal(&conflict_literals),
backtrack_level,
theories: {
let mut theories = FxHashSet::default();
theories.insert(conflict.theory);
theories
},
activity: 1.0,
};
if self.config.enable_learning {
self.learned_clauses.push(clause.clone());
self.stats.clauses_learned += 1;
}
Ok(ConflictAnalysis {
clause,
explanation: conflict.explanation,
blamed_theories: {
let mut theories = FxHashSet::default();
theories.insert(conflict.theory);
theories
},
})
}
fn find_uip(
&mut self,
literals: &[Literal],
level: DecisionLevel,
) -> Result<Vec<Literal>, String> {
let mut current_clause: FxHashSet<Literal> = literals.iter().copied().collect();
let mut seen = FxHashSet::default();
let mut counter = 0;
for &lit in ¤t_clause {
if self.get_decision_level(lit) == Some(level) {
counter += 1;
}
}
for _ in 0..self.config.max_resolution_steps {
self.stats.resolution_steps += 1;
if counter <= 1 {
break; }
let resolve_lit = self.find_resolution_literal(¤t_clause, level, &seen)?;
seen.insert(resolve_lit);
let reason = self.get_reason(resolve_lit)?;
current_clause.remove(&resolve_lit);
counter -= 1;
for &lit in &reason {
if !current_clause.contains(&lit) {
current_clause.insert(lit);
if self.get_decision_level(lit) == Some(level) {
counter += 1;
}
}
}
}
Ok(current_clause.into_iter().collect())
}
fn find_resolution_literal(
&self,
clause: &FxHashSet<Literal>,
level: DecisionLevel,
seen: &FxHashSet<Literal>,
) -> Result<Literal, String> {
for &(literal, lit_level, _) in self.trail.iter().rev() {
if lit_level == level && clause.contains(&literal) && !seen.contains(&literal) {
return Ok(literal);
}
}
Err("No resolution literal found".to_string())
}
fn get_decision_level(&self, literal: Literal) -> Option<DecisionLevel> {
self.literal_position
.get(&literal)
.and_then(|&pos| self.trail.get(pos))
.map(|(_, level, _)| *level)
}
fn get_reason(&self, literal: Literal) -> Result<Vec<Literal>, String> {
let position = self
.literal_position
.get(&literal)
.ok_or("Literal not in trail")?;
let (_, _, explanation) = &self.trail[*position];
match explanation {
Explanation::TheoryPropagation { antecedents, .. } => Ok(antecedents.clone()),
Explanation::EqualityPropagation { support, .. } => Ok(support.clone()),
_ => Ok(Vec::new()),
}
}
fn minimize_conflict(&self, literals: &[Literal]) -> Result<Vec<Literal>, String> {
match self.config.minimization_algorithm {
MinimizationAlgorithm::None => Ok(literals.to_vec()),
MinimizationAlgorithm::Simple => self.minimize_simple(literals),
MinimizationAlgorithm::Recursive => self.minimize_recursive(literals),
MinimizationAlgorithm::BinaryResolution => self.minimize_binary_resolution(literals),
}
}
fn minimize_simple(&self, literals: &[Literal]) -> Result<Vec<Literal>, String> {
let mut minimal = Vec::new();
let mut seen = FxHashSet::default();
for &lit in literals {
if !seen.contains(&lit) {
seen.insert(lit);
minimal.push(lit);
}
}
Ok(minimal)
}
fn minimize_recursive(&self, literals: &[Literal]) -> Result<Vec<Literal>, String> {
let mut minimal = Vec::new();
let mut redundant = FxHashSet::default();
for &lit in literals {
if self.is_redundant(lit, literals, &mut redundant)? {
continue;
}
minimal.push(lit);
}
Ok(minimal)
}
fn is_redundant(
&self,
literal: Literal,
clause: &[Literal],
redundant: &mut FxHashSet<Literal>,
) -> Result<bool, String> {
if redundant.contains(&literal) {
return Ok(true);
}
let reason = self.get_reason(literal).ok().unwrap_or_default();
for &reason_lit in &reason {
if !clause.contains(&reason_lit)
&& !redundant.contains(&reason_lit)
&& !self.is_redundant(reason_lit, clause, redundant)?
{
return Ok(false);
}
}
redundant.insert(literal);
Ok(true)
}
fn minimize_binary_resolution(&self, literals: &[Literal]) -> Result<Vec<Literal>, String> {
self.minimize_simple(literals)
}
fn compute_backtrack_level(
&self,
literals: &[Literal],
_conflict_level: DecisionLevel,
) -> Result<DecisionLevel, String> {
let mut levels: Vec<DecisionLevel> = literals
.iter()
.filter_map(|&lit| self.get_decision_level(lit))
.collect();
levels.sort_unstable();
levels.dedup();
if levels.len() >= 2 {
Ok(levels[levels.len() - 2])
} else if !levels.is_empty() {
Ok(levels[0].saturating_sub(1))
} else {
Ok(0)
}
}
fn find_uip_literal(&self, literals: &[Literal]) -> Option<Literal> {
literals
.iter()
.max_by_key(|&&lit| self.get_decision_level(lit).unwrap_or(0))
.copied()
}
pub fn learned_clauses(&self) -> &[ConflictClause] {
&self.learned_clauses
}
pub fn theory_blame(&self) -> &FxHashMap<TheoryId, u64> {
&self.theory_blame
}
pub fn clear(&mut self) {
self.trail.clear();
self.literal_position.clear();
self.level_boundaries.clear();
self.current_level = 0;
self.learned_clauses.clear();
self.theory_blame.clear();
}
pub fn reset_stats(&mut self) {
self.stats = ConflictResolutionStats::default();
}
}
impl Default for ConflictResolver {
fn default() -> Self {
Self::new()
}
}
pub struct ExplanationGenerator {
cache: FxHashMap<Literal, Explanation>,
}
impl ExplanationGenerator {
pub fn new() -> Self {
Self {
cache: FxHashMap::default(),
}
}
pub fn add_explanation(&mut self, literal: Literal, explanation: Explanation) {
self.cache.insert(literal, explanation);
}
pub fn get_explanation(&self, literal: Literal) -> Option<&Explanation> {
self.cache.get(&literal)
}
pub fn build_chain(&self, literals: &[Literal]) -> Explanation {
let mut antecedents = Vec::new();
for &lit in literals {
if let Some(explanation) = self.cache.get(&lit)
&& let Explanation::TheoryPropagation {
antecedents: ants, ..
} = explanation
{
antecedents.extend_from_slice(ants);
}
}
Explanation::TheoryPropagation {
theory: 0,
antecedents,
}
}
pub fn clear(&mut self) {
self.cache.clear();
}
}
impl Default for ExplanationGenerator {
fn default() -> Self {
Self::new()
}
}
pub struct MultiTheoryConflictAnalyzer {
resolvers: FxHashMap<TheoryId, ConflictResolver>,
combined_stats: ConflictResolutionStats,
}
impl MultiTheoryConflictAnalyzer {
pub fn new() -> Self {
Self {
resolvers: FxHashMap::default(),
combined_stats: ConflictResolutionStats::default(),
}
}
pub fn register_theory(&mut self, theory: TheoryId, config: ConflictResolutionConfig) {
self.resolvers
.insert(theory, ConflictResolver::with_config(config));
}
pub fn analyze(&mut self, conflict: TheoryConflict) -> Result<ConflictAnalysis, String> {
let resolver = self
.resolvers
.get_mut(&conflict.theory)
.ok_or("Theory not registered")?;
let analysis = resolver.analyze_conflict(conflict)?;
self.combined_stats.conflicts_analyzed += 1;
Ok(analysis)
}
pub fn combined_stats(&self) -> &ConflictResolutionStats {
&self.combined_stats
}
pub fn get_resolver(&self, theory: TheoryId) -> Option<&ConflictResolver> {
self.resolvers.get(&theory)
}
pub fn clear(&mut self) {
for resolver in self.resolvers.values_mut() {
resolver.clear();
}
self.combined_stats = ConflictResolutionStats::default();
}
}
impl Default for MultiTheoryConflictAnalyzer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_literal_creation() {
let lit = Literal::positive(1);
assert_eq!(lit.term, 1);
assert!(lit.polarity);
}
#[test]
fn test_literal_negation() {
let lit = Literal::positive(1);
let neg = lit.negate();
assert!(!neg.polarity);
}
#[test]
fn test_resolver_creation() {
let resolver = ConflictResolver::new();
assert_eq!(resolver.stats().conflicts_analyzed, 0);
}
#[test]
fn test_add_assignment() {
let mut resolver = ConflictResolver::new();
let lit = Literal::positive(1);
resolver.add_assignment(lit, 0, Explanation::Given);
assert_eq!(resolver.trail.len(), 1);
}
#[test]
fn test_decision_level() {
let mut resolver = ConflictResolver::new();
resolver.push_decision_level();
assert_eq!(resolver.current_level, 1);
}
#[test]
fn test_backtrack() {
let mut resolver = ConflictResolver::new();
resolver.push_decision_level();
resolver.add_assignment(Literal::positive(1), 1, Explanation::Given);
resolver.backtrack(0).expect("Backtrack failed");
assert_eq!(resolver.trail.len(), 0);
}
#[test]
fn test_conflict_analysis() {
let mut resolver = ConflictResolver::new();
let conflict = TheoryConflict {
theory: 0,
literals: vec![Literal::positive(1), Literal::negative(2)],
explanation: Explanation::Given,
level: 0,
};
let analysis = resolver.analyze_conflict(conflict);
assert!(analysis.is_ok());
}
#[test]
fn test_explanation_generator() {
let mut generator = ExplanationGenerator::new();
let lit = Literal::positive(1);
generator.add_explanation(lit, Explanation::Given);
assert!(generator.get_explanation(lit).is_some());
}
#[test]
fn test_multi_theory_analyzer() {
let mut analyzer = MultiTheoryConflictAnalyzer::new();
analyzer.register_theory(0, ConflictResolutionConfig::default());
let conflict = TheoryConflict {
theory: 0,
literals: vec![Literal::positive(1)],
explanation: Explanation::Given,
level: 0,
};
let result = analyzer.analyze(conflict);
assert!(result.is_ok());
}
#[test]
fn test_simple_minimization() {
let resolver = ConflictResolver::new();
let literals = vec![
Literal::positive(1),
Literal::positive(2),
Literal::positive(1), ];
let minimized = resolver
.minimize_simple(&literals)
.expect("Minimization failed");
assert_eq!(minimized.len(), 2);
}
#[test]
fn test_backtrack_level_computation() {
let mut resolver = ConflictResolver::new();
resolver.add_assignment(Literal::positive(1), 0, Explanation::Given);
resolver.push_decision_level();
resolver.add_assignment(Literal::positive(2), 1, Explanation::Given);
resolver.push_decision_level();
resolver.add_assignment(Literal::positive(3), 2, Explanation::Given);
let literals = vec![
Literal::positive(1),
Literal::positive(2),
Literal::positive(3),
];
let level = resolver.compute_backtrack_level(&literals, 2);
assert!(level.is_ok());
}
}