#![allow(missing_docs)]
#[allow(unused_imports)]
use crate::prelude::*;
use core::cmp::Ordering;
pub type TermId = usize;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PropagationLevel {
Boolean,
Equality,
Theory(TheoryId),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TheoryId {
Arithmetic,
BitVector,
Array,
Datatype,
String,
Uninterpreted,
}
#[derive(Debug, Clone)]
pub struct Propagation {
pub literal: TermId,
pub level: PropagationLevel,
pub reason: PropagationReason,
pub priority: u32,
}
#[derive(Debug, Clone)]
pub enum PropagationReason {
UnitClause(TermId),
BinaryClause(TermId, TermId),
Clause(Vec<TermId>),
Equality { lhs: TermId, rhs: TermId },
Theory { explanation: Vec<TermId> },
}
impl Ord for Propagation {
fn cmp(&self, other: &Self) -> Ordering {
self.priority.cmp(&other.priority)
}
}
impl PartialOrd for Propagation {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for Propagation {
fn eq(&self, other: &Self) -> bool {
self.literal == other.literal && self.priority == other.priority
}
}
impl Eq for Propagation {}
#[derive(Debug, Clone, Default)]
pub struct PropagationStats {
pub total_propagations: u64,
pub boolean_propagations: u64,
pub equality_propagations: u64,
pub theory_propagations: FxHashMap<TheoryId, u64>,
pub conflicts_detected: u64,
pub priority_queue_size_max: usize,
}
#[derive(Debug, Clone)]
pub struct PropagationConfig {
pub use_priority_queue: bool,
pub max_propagations_per_iter: usize,
pub enable_lazy_propagation: bool,
pub theory_priorities: FxHashMap<TheoryId, u32>,
}
impl Default for PropagationConfig {
fn default() -> Self {
let mut priorities = FxHashMap::default();
priorities.insert(TheoryId::Arithmetic, 10);
priorities.insert(TheoryId::BitVector, 9);
priorities.insert(TheoryId::Array, 8);
priorities.insert(TheoryId::Datatype, 7);
priorities.insert(TheoryId::String, 6);
priorities.insert(TheoryId::Uninterpreted, 5);
Self {
use_priority_queue: true,
max_propagations_per_iter: 1000,
enable_lazy_propagation: true,
theory_priorities: priorities,
}
}
}
pub struct PropagationPipeline {
config: PropagationConfig,
stats: PropagationStats,
pending: BinaryHeap<Propagation>,
propagated: FxHashSet<TermId>,
current_level: usize,
trail: Vec<Propagation>,
}
impl PropagationPipeline {
pub fn new(config: PropagationConfig) -> Self {
Self {
config,
stats: PropagationStats::default(),
pending: BinaryHeap::new(),
propagated: FxHashSet::default(),
current_level: 0,
trail: Vec::new(),
}
}
pub fn add_propagation(&mut self, propagation: Propagation) {
if self.propagated.contains(&propagation.literal) {
return;
}
self.pending.push(propagation);
if self.pending.len() > self.stats.priority_queue_size_max {
self.stats.priority_queue_size_max = self.pending.len();
}
}
pub fn propagate(&mut self) -> Result<(), Vec<TermId>> {
let mut iterations = 0;
while let Some(propagation) = self.pending.pop() {
if iterations >= self.config.max_propagations_per_iter {
self.pending.push(propagation);
break;
}
if self.propagated.contains(&propagation.literal) {
continue;
}
self.perform_propagation(propagation)?;
iterations += 1;
}
Ok(())
}
fn perform_propagation(&mut self, propagation: Propagation) -> Result<(), Vec<TermId>> {
self.stats.total_propagations += 1;
match propagation.level {
PropagationLevel::Boolean => {
self.stats.boolean_propagations += 1;
}
PropagationLevel::Equality => {
self.stats.equality_propagations += 1;
}
PropagationLevel::Theory(theory_id) => {
*self.stats.theory_propagations.entry(theory_id).or_insert(0) += 1;
}
}
if let Some(conflict) = self.check_conflict(propagation.literal)? {
self.stats.conflicts_detected += 1;
return Err(conflict);
}
self.propagated.insert(propagation.literal);
self.trail.push(propagation);
Ok(())
}
fn check_conflict(&self, literal: TermId) -> Result<Option<Vec<TermId>>, Vec<TermId>> {
let negated = self.negate(literal);
if self.propagated.contains(&negated) {
let conflict_clause = vec![literal, negated];
return Ok(Some(conflict_clause));
}
Ok(None)
}
fn negate(&self, literal: TermId) -> TermId {
if literal.is_multiple_of(2) {
literal + 1
} else {
literal - 1
}
}
pub fn backtrack(&mut self, target_level: usize) -> Result<(), String> {
let target_trail_size = self.get_trail_size_at_level(target_level);
while self.trail.len() > target_trail_size {
if let Some(prop) = self.trail.pop() {
self.propagated.remove(&prop.literal);
}
}
self.pending.clear();
self.current_level = target_level;
Ok(())
}
fn get_trail_size_at_level(&self, level: usize) -> usize {
if level == 0 { 0 } else { self.trail.len() }
}
pub fn increment_level(&mut self) {
self.current_level += 1;
}
pub fn current_level(&self) -> usize {
self.current_level
}
pub fn is_propagated(&self, literal: TermId) -> bool {
self.propagated.contains(&literal)
}
pub fn get_reason(&self, literal: TermId) -> Option<&PropagationReason> {
self.trail
.iter()
.find(|p| p.literal == literal)
.map(|p| &p.reason)
}
pub fn mk_unit_propagation(&self, literal: TermId, clause: TermId) -> Propagation {
Propagation {
literal,
level: PropagationLevel::Boolean,
reason: PropagationReason::UnitClause(clause),
priority: 100, }
}
pub fn mk_equality_propagation(
&self,
literal: TermId,
lhs: TermId,
rhs: TermId,
) -> Propagation {
Propagation {
literal,
level: PropagationLevel::Equality,
reason: PropagationReason::Equality { lhs, rhs },
priority: 50,
}
}
pub fn mk_theory_propagation(
&self,
literal: TermId,
theory: TheoryId,
explanation: Vec<TermId>,
) -> Propagation {
let priority = self
.config
.theory_priorities
.get(&theory)
.copied()
.unwrap_or(10);
Propagation {
literal,
level: PropagationLevel::Theory(theory),
reason: PropagationReason::Theory { explanation },
priority,
}
}
pub fn stats(&self) -> &PropagationStats {
&self.stats
}
pub fn reset(&mut self) {
self.pending.clear();
self.propagated.clear();
self.trail.clear();
self.current_level = 0;
}
}
pub struct PropagationWatcher {
watch_lists: FxHashMap<TermId, Vec<TermId>>,
clauses: FxHashMap<TermId, Vec<TermId>>,
}
impl PropagationWatcher {
pub fn new() -> Self {
Self {
watch_lists: FxHashMap::default(),
clauses: FxHashMap::default(),
}
}
pub fn add_clause(&mut self, clause_id: TermId, literals: Vec<TermId>) -> Result<(), String> {
if literals.len() < 2 {
return Err("Clause must have at least 2 literals for watching".to_string());
}
self.watch_lists
.entry(literals[0])
.or_default()
.push(clause_id);
self.watch_lists
.entry(literals[1])
.or_default()
.push(clause_id);
self.clauses.insert(clause_id, literals);
Ok(())
}
pub fn update_watches(
&mut self,
assigned_literal: TermId,
pipeline: &mut PropagationPipeline,
) -> Result<(), Vec<TermId>> {
let clause_ids: Vec<_> = self
.watch_lists
.get(&assigned_literal)
.cloned()
.unwrap_or_default();
for clause_id in clause_ids {
let clause = match self.clauses.get(&clause_id) {
Some(c) => c.clone(),
None => continue, };
if let Some(new_watch) = self.find_new_watch(&clause, assigned_literal, pipeline) {
if let Some(watch_list) = self.watch_lists.get_mut(&assigned_literal) {
watch_list.retain(|&id| id != clause_id);
}
self.watch_lists
.entry(new_watch)
.or_default()
.push(clause_id);
} else {
if let Some(unit_literal) = self.find_unit_literal(&clause, pipeline) {
let prop = pipeline.mk_unit_propagation(unit_literal, clause_id);
pipeline.add_propagation(prop);
} else {
return Err(clause);
}
}
}
Ok(())
}
fn find_new_watch(
&self,
clause: &[TermId],
old_watch: TermId,
pipeline: &PropagationPipeline,
) -> Option<TermId> {
clause
.iter()
.find(|&&lit| lit != old_watch && !pipeline.is_propagated(pipeline.negate(lit)))
.copied()
}
fn find_unit_literal(
&self,
clause: &[TermId],
pipeline: &PropagationPipeline,
) -> Option<TermId> {
let mut unassigned = None;
let mut unassigned_count = 0;
for &lit in clause {
if !pipeline.is_propagated(lit) && !pipeline.is_propagated(pipeline.negate(lit)) {
unassigned = Some(lit);
unassigned_count += 1;
if unassigned_count > 1 {
return None;
}
}
}
if unassigned_count == 1 {
unassigned
} else {
None
}
}
}
impl Default for PropagationWatcher {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pipeline_creation() {
let config = PropagationConfig::default();
let pipeline = PropagationPipeline::new(config);
assert_eq!(pipeline.current_level(), 0);
assert_eq!(pipeline.stats.total_propagations, 0);
}
#[test]
fn test_add_propagation() {
let config = PropagationConfig::default();
let mut pipeline = PropagationPipeline::new(config);
let prop = Propagation {
literal: 1,
level: PropagationLevel::Boolean,
reason: PropagationReason::UnitClause(42),
priority: 100,
};
pipeline.add_propagation(prop);
assert_eq!(pipeline.pending.len(), 1);
}
#[test]
fn test_unit_propagation() {
let config = PropagationConfig::default();
let mut pipeline = PropagationPipeline::new(config);
let prop = pipeline.mk_unit_propagation(1, 42);
pipeline.add_propagation(prop);
let result = pipeline.propagate();
assert!(result.is_ok());
assert_eq!(pipeline.stats.boolean_propagations, 1);
assert!(pipeline.is_propagated(1));
}
#[test]
fn test_conflict_detection() {
let config = PropagationConfig::default();
let mut pipeline = PropagationPipeline::new(config);
let prop1 = pipeline.mk_unit_propagation(2, 42);
pipeline.add_propagation(prop1);
let _ = pipeline.propagate();
let prop2 = pipeline.mk_unit_propagation(3, 43);
pipeline.add_propagation(prop2);
let result = pipeline.propagate();
assert!(result.is_err());
assert_eq!(pipeline.stats.conflicts_detected, 1);
}
#[test]
fn test_equality_propagation() {
let config = PropagationConfig::default();
let mut pipeline = PropagationPipeline::new(config);
let prop = pipeline.mk_equality_propagation(10, 5, 5);
pipeline.add_propagation(prop);
let result = pipeline.propagate();
assert!(result.is_ok());
assert_eq!(pipeline.stats.equality_propagations, 1);
}
#[test]
fn test_theory_propagation() {
let config = PropagationConfig::default();
let mut pipeline = PropagationPipeline::new(config);
let prop = pipeline.mk_theory_propagation(20, TheoryId::Arithmetic, vec![1, 2, 3]);
pipeline.add_propagation(prop);
let result = pipeline.propagate();
assert!(result.is_ok());
assert_eq!(
*pipeline
.stats
.theory_propagations
.get(&TheoryId::Arithmetic)
.expect("test operation should succeed"),
1
);
}
#[test]
fn test_priority_ordering() {
let config = PropagationConfig::default();
let mut pipeline = PropagationPipeline::new(config);
let low_priority = Propagation {
literal: 1,
level: PropagationLevel::Boolean,
reason: PropagationReason::UnitClause(1),
priority: 10,
};
let high_priority = Propagation {
literal: 2,
level: PropagationLevel::Boolean,
reason: PropagationReason::UnitClause(2),
priority: 100,
};
pipeline.add_propagation(low_priority);
pipeline.add_propagation(high_priority);
let _ = pipeline.propagate();
assert!(pipeline.is_propagated(2));
}
#[test]
fn test_backtrack() {
let config = PropagationConfig::default();
let mut pipeline = PropagationPipeline::new(config);
pipeline.increment_level();
let prop = pipeline.mk_unit_propagation(5, 50);
pipeline.add_propagation(prop);
let _ = pipeline.propagate();
assert!(pipeline.is_propagated(5));
pipeline
.backtrack(0)
.expect("test operation should succeed");
assert!(!pipeline.is_propagated(5));
assert_eq!(pipeline.current_level(), 0);
}
#[test]
fn test_watcher_creation() {
let watcher = PropagationWatcher::new();
assert_eq!(watcher.watch_lists.len(), 0);
}
#[test]
fn test_add_clause_to_watcher() {
let mut watcher = PropagationWatcher::new();
let result = watcher.add_clause(1, vec![10, 20, 30]);
assert!(result.is_ok());
assert!(watcher.watch_lists.contains_key(&10));
assert!(watcher.watch_lists.contains_key(&20));
}
#[test]
fn test_add_short_clause_fails() {
let mut watcher = PropagationWatcher::new();
let result = watcher.add_clause(1, vec![10]);
assert!(result.is_err());
}
#[test]
fn test_reset() {
let config = PropagationConfig::default();
let mut pipeline = PropagationPipeline::new(config);
let prop = pipeline.mk_unit_propagation(1, 42);
pipeline.add_propagation(prop);
let _ = pipeline.propagate();
pipeline.reset();
assert!(!pipeline.is_propagated(1));
assert_eq!(pipeline.current_level(), 0);
}
}