use std::collections::HashMap;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use crate::error::{SatError, SatResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Literal {
pub variable: u32,
pub negated: bool,
}
impl Literal {
pub fn positive(variable: u32) -> Self {
Self {
variable,
negated: false,
}
}
pub fn negative(variable: u32) -> Self {
Self {
variable,
negated: true,
}
}
pub fn negate(self) -> Self {
Self {
variable: self.variable,
negated: !self.negated,
}
}
pub fn to_dimacs(self) -> i32 {
if self.negated {
-(self.variable as i32)
} else {
self.variable as i32
}
}
pub fn from_dimacs(dimacs: i32) -> Self {
if dimacs > 0 {
Self::positive(dimacs as u32)
} else {
Self::negative((-dimacs) as u32)
}
}
}
#[derive(Debug, Clone)]
pub struct Clause {
pub literals: Vec<Literal>,
}
impl Clause {
pub fn new(literals: Vec<Literal>) -> Self {
Self { literals }
}
pub fn unit(literal: Literal) -> Self {
Self::new(vec![literal])
}
pub fn binary(lit1: Literal, lit2: Literal) -> Self {
Self::new(vec![lit1, lit2])
}
pub fn to_dimacs(&self) -> Vec<i32> {
self.literals.iter().map(|lit| lit.to_dimacs()).collect()
}
pub fn is_empty(&self) -> bool {
self.literals.is_empty()
}
pub fn is_unit(&self) -> bool {
self.literals.len() == 1
}
}
#[derive(Debug, Clone)]
pub struct SatProblem {
pub num_variables: u32,
pub clauses: Vec<Clause>,
pub assumptions: Vec<Literal>,
pub metadata: ProblemMetadata,
}
#[derive(Debug, Clone, Default)]
pub struct ProblemMetadata {
pub description: String,
pub variable_names: HashMap<u32, String>,
pub clause_origins: HashMap<usize, String>,
}
impl SatProblem {
pub fn new(num_variables: u32) -> Self {
Self {
num_variables,
clauses: Vec::new(),
assumptions: Vec::new(),
metadata: ProblemMetadata::default(),
}
}
pub fn add_clause(&mut self, clause: Clause) {
self.clauses.push(clause);
}
pub fn add_clauses(&mut self, clauses: Vec<Clause>) {
self.clauses.extend(clauses);
}
pub fn add_assumption(&mut self, assumption: Literal) {
self.assumptions.push(assumption);
}
pub fn set_variable_name(&mut self, variable: u32, name: String) {
self.metadata.variable_names.insert(variable, name);
}
pub fn get_variable_name(&self, variable: u32) -> Option<&str> {
self.metadata
.variable_names
.get(&variable)
.map(|s| s.as_str())
}
pub fn get_clauses(&self) -> &Vec<Clause> {
&self.clauses
}
pub fn stats(&self) -> ProblemStats {
ProblemStats {
variables: self.num_variables,
clauses: self.clauses.len(),
literals: self.clauses.iter().map(|c| c.literals.len()).sum(),
assumptions: self.assumptions.len(),
}
}
}
#[derive(Debug, Clone)]
pub struct ProblemStats {
pub variables: u32,
pub clauses: usize,
pub literals: usize,
pub assumptions: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SatSolution {
Satisfiable(SatModel),
Unsatisfiable,
Unknown,
}
impl SatSolution {
pub fn is_satisfiable(&self) -> bool {
matches!(self, SatSolution::Satisfiable(_))
}
pub fn is_unsatisfiable(&self) -> bool {
matches!(self, SatSolution::Unsatisfiable)
}
pub fn is_unknown(&self) -> bool {
matches!(self, SatSolution::Unknown)
}
pub fn model(&self) -> Option<&SatModel> {
match self {
SatSolution::Satisfiable(model) => Some(model),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SatModel {
pub assignments: HashMap<u32, bool>,
}
impl SatModel {
pub fn new() -> Self {
Self {
assignments: HashMap::new(),
}
}
pub fn assign(&mut self, variable: u32, value: bool) {
self.assignments.insert(variable, value);
}
pub fn get(&self, variable: u32) -> Option<bool> {
self.assignments.get(&variable).copied()
}
pub fn is_assigned(&self, variable: u32) -> bool {
self.assignments.contains_key(&variable)
}
pub fn true_variables(&self) -> Vec<u32> {
self.assignments
.iter()
.filter_map(|(&var, &val)| if val { Some(var) } else { None })
.collect()
}
pub fn false_variables(&self) -> Vec<u32> {
self.assignments
.iter()
.filter_map(|(&var, &val)| if !val { Some(var) } else { None })
.collect()
}
pub fn from_vec(assignments: Vec<bool>) -> Self {
let mut model = Self::new();
for (i, &value) in assignments.iter().enumerate() {
model.assign((i + 1) as u32, value);
}
model
}
pub fn to_vec(&self, max_variable: u32) -> Vec<bool> {
(1..=max_variable)
.map(|var| self.get(var).unwrap_or(false))
.collect()
}
}
impl Default for SatModel {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct SolverStats {
pub solve_time: Duration,
pub decisions: u64,
pub conflicts: u64,
pub restarts: u64,
pub memory_usage: u64,
pub learned_clauses: u64,
pub propagations: u64,
}
#[async_trait]
pub trait SatSolver: Send + Sync {
async fn solve(&mut self, problem: &SatProblem) -> SatResult<SatSolution>;
async fn solve_with_timeout(
&mut self,
problem: &SatProblem,
timeout: Duration,
) -> SatResult<SatSolution>;
async fn add_clauses(&mut self, clauses: &[Clause]) -> SatResult<()>;
async fn solve_under_assumptions(&mut self, assumptions: &[Literal]) -> SatResult<SatSolution>;
fn get_stats(&self) -> SolverStats;
fn reset(&mut self);
fn name(&self) -> &'static str;
}
#[derive(Debug, Clone)]
pub struct SolverConfig {
pub timeout: Duration,
pub preprocessing: bool,
pub random_seed: Option<u64>,
pub memory_limit: Option<u64>,
pub conflict_limit: Option<u64>,
pub decision_limit: Option<u64>,
}
impl Default for SolverConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(30),
preprocessing: true,
random_seed: None,
memory_limit: Some(1024 * 1024 * 1024), conflict_limit: None,
decision_limit: None,
}
}
}
#[cfg(feature = "varisat-solver")]
mod varisat_solver;
#[cfg(feature = "varisat-solver")]
pub use varisat_solver::VarisatSolver;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SolverBackend {
Mock,
Varisat,
Cadical,
}
impl Default for SolverBackend {
fn default() -> Self {
if cfg!(feature = "varisat-solver") {
Self::Varisat
} else if cfg!(feature = "cadical-solver") {
Self::Cadical
} else {
Self::Mock
}
}
}
pub async fn create_solver(config: SolverConfig) -> SatResult<Box<dyn SatSolver>> {
create_solver_with_backend(SolverBackend::default(), config).await
}
pub async fn create_solver_with_backend(
backend: SolverBackend,
config: SolverConfig,
) -> SatResult<Box<dyn SatSolver>> {
match backend {
SolverBackend::Mock => Ok(Box::new(MockSolver::new(config))),
SolverBackend::Varisat => {
#[cfg(feature = "varisat-solver")]
{
let solver = VarisatSolver::new(config)?;
Ok(Box::new(solver))
}
#[cfg(not(feature = "varisat-solver"))]
{
Err(SatError::solver_failure(
"Varisat solver not available (feature not enabled)".to_string(),
))
}
}
SolverBackend::Cadical => {
#[cfg(feature = "cadical-solver")]
{
Err(SatError::solver_failure(
"Cadical solver not yet implemented".to_string(),
))
}
#[cfg(not(feature = "cadical-solver"))]
{
Err(SatError::solver_failure(
"Cadical solver not available (feature not enabled)".to_string(),
))
}
}
}
}
#[derive(Debug)]
#[allow(dead_code)] pub struct BasicSolver {
config: SolverConfig,
stats: SolverStats,
}
impl BasicSolver {
pub fn new(config: SolverConfig) -> SatResult<Self> {
Ok(Self {
config,
stats: SolverStats::default(),
})
}
fn dpll_solve(&self, problem: &SatProblem) -> SatResult<SatSolution> {
let mut assignments = HashMap::new();
let mut clauses = problem.clauses.clone();
for assumption in &problem.assumptions {
clauses.push(Clause::unit(*assumption));
}
self.dpll_recursive(&mut clauses, &mut assignments, problem.num_variables)
}
#[allow(clippy::only_used_in_recursion)]
fn dpll_recursive(
&self,
clauses: &mut Vec<Clause>,
assignments: &mut HashMap<u32, bool>,
num_vars: u32,
) -> SatResult<SatSolution> {
clauses.retain(|clause| {
!clause.literals.iter().any(|lit| {
assignments
.get(&lit.variable)
.is_some_and(|&val| val != lit.negated)
})
});
for clause in clauses.iter_mut() {
clause.literals.retain(|lit| {
assignments
.get(&lit.variable)
.is_none_or(|&val| val != lit.negated)
});
}
if clauses.iter().any(|c| c.literals.is_empty()) {
return Ok(SatSolution::Unsatisfiable);
}
if clauses.is_empty() {
let mut model = SatModel::new();
for var in 1..=num_vars {
let value = assignments.get(&var).copied().unwrap_or(false);
model.assign(var, value);
}
return Ok(SatSolution::Satisfiable(model));
}
loop {
let mut found_unit = false;
for clause in &*clauses {
if clause.literals.len() == 1 {
let lit = clause.literals[0];
if let std::collections::hash_map::Entry::Vacant(e) =
assignments.entry(lit.variable)
{
e.insert(!lit.negated);
found_unit = true;
break;
}
}
}
if !found_unit {
break;
}
clauses.retain(|clause| {
!clause.literals.iter().any(|lit| {
assignments
.get(&lit.variable)
.is_some_and(|&val| val != lit.negated)
})
});
for clause in clauses.iter_mut() {
clause.literals.retain(|lit| {
assignments
.get(&lit.variable)
.is_none_or(|&val| val != lit.negated)
});
}
if clauses.iter().any(|c| c.literals.is_empty()) {
return Ok(SatSolution::Unsatisfiable);
}
}
let unassigned_var = (1..=num_vars).find(|&var| !assignments.contains_key(&var));
let var = match unassigned_var {
Some(v) => v,
None => {
let mut model = SatModel::new();
for var in 1..=num_vars {
let value = assignments.get(&var).copied().unwrap_or(false);
model.assign(var, value);
}
return Ok(SatSolution::Satisfiable(model));
}
};
let mut pos_assignments = assignments.clone();
let mut pos_clauses = clauses.clone();
pos_assignments.insert(var, true);
if let SatSolution::Satisfiable(model) =
self.dpll_recursive(&mut pos_clauses, &mut pos_assignments, num_vars)?
{
return Ok(SatSolution::Satisfiable(model));
}
assignments.insert(var, false);
self.dpll_recursive(clauses, assignments, num_vars)
}
}
#[async_trait]
impl SatSolver for BasicSolver {
async fn solve(&mut self, problem: &SatProblem) -> SatResult<SatSolution> {
let start = Instant::now();
let result = self.dpll_solve(problem);
self.stats.solve_time = start.elapsed();
result
}
async fn solve_with_timeout(
&mut self,
problem: &SatProblem,
timeout: Duration,
) -> SatResult<SatSolution> {
tokio::time::timeout(timeout, self.solve(problem))
.await
.map_err(|_| SatError::timeout(timeout))?
}
async fn add_clauses(&mut self, _clauses: &[Clause]) -> SatResult<()> {
Err(SatError::solver_failure(
"Incremental solving not supported by BasicSolver".to_string(),
))
}
async fn solve_under_assumptions(&mut self, assumptions: &[Literal]) -> SatResult<SatSolution> {
let max_var = assumptions
.iter()
.map(|lit| lit.variable)
.max()
.unwrap_or(0);
let mut problem = SatProblem::new(max_var);
for &assumption in assumptions {
problem.add_clause(Clause::unit(assumption));
}
self.solve(&problem).await
}
fn get_stats(&self) -> SolverStats {
self.stats.clone()
}
fn reset(&mut self) {
self.stats = SolverStats::default();
}
fn name(&self) -> &'static str {
"BasicDPLL"
}
}
#[derive(Debug)]
struct MockSolver {
stats: SolverStats,
}
impl MockSolver {
fn new(_config: SolverConfig) -> Self {
Self {
stats: SolverStats::default(),
}
}
fn should_assign_variable_true(
&self,
var: u32,
clauses: &[Clause],
assignments: &std::collections::HashMap<u32, bool>,
) -> bool {
let mut true_satisfaction_score = 0;
let mut false_satisfaction_score = 0;
for clause in clauses {
let already_satisfied = clause.literals.iter().any(|lit| {
assignments
.get(&lit.variable)
.is_some_and(|&val| val != lit.negated)
});
if already_satisfied {
continue;
}
for literal in &clause.literals {
if literal.variable == var {
if literal.negated {
false_satisfaction_score += 1;
} else {
true_satisfaction_score += 1;
}
break;
}
}
}
if true_satisfaction_score > false_satisfaction_score {
true
} else if false_satisfaction_score > true_satisfaction_score {
false
} else {
var <= 2
}
}
}
#[async_trait]
impl SatSolver for MockSolver {
async fn solve(&mut self, problem: &SatProblem) -> SatResult<SatSolution> {
let start = Instant::now();
tokio::time::sleep(Duration::from_millis(1)).await;
let mut assignments = std::collections::HashMap::new();
for clause in &problem.clauses {
if clause.literals.len() == 1 {
let lit = clause.literals[0];
assignments.insert(lit.variable, !lit.negated);
}
}
for var in 1..=problem.num_variables {
if !assignments.contains_key(&var) {
let should_assign_true =
self.should_assign_variable_true(var, &problem.clauses, &assignments);
assignments.insert(var, should_assign_true);
}
}
let mut model = SatModel::new();
for var in 1..=problem.num_variables {
let value = assignments.get(&var).copied().unwrap_or(false);
model.assign(var, value);
}
self.stats.solve_time = start.elapsed();
self.stats.decisions = assignments.len() as u64;
Ok(SatSolution::Satisfiable(model))
}
async fn solve_with_timeout(
&mut self,
problem: &SatProblem,
timeout: Duration,
) -> SatResult<SatSolution> {
tokio::time::timeout(timeout, self.solve(problem))
.await
.map_err(|_| SatError::timeout(timeout))?
}
async fn add_clauses(&mut self, _clauses: &[Clause]) -> SatResult<()> {
Ok(())
}
async fn solve_under_assumptions(
&mut self,
_assumptions: &[Literal],
) -> SatResult<SatSolution> {
Err(SatError::solver_failure(
"Assumptions not yet implemented in mock solver".to_string(),
))
}
fn get_stats(&self) -> SolverStats {
self.stats.clone()
}
fn reset(&mut self) {
self.stats = SolverStats::default();
}
fn name(&self) -> &'static str {
"MockSolver"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_literal_operations() {
let pos = Literal::positive(42);
let neg = Literal::negative(42);
assert_eq!(pos.variable, 42);
assert!(!pos.negated);
assert_eq!(pos.to_dimacs(), 42);
assert_eq!(neg.variable, 42);
assert!(neg.negated);
assert_eq!(neg.to_dimacs(), -42);
assert_eq!(pos.negate(), neg);
assert_eq!(neg.negate(), pos);
}
#[test]
fn test_literal_dimacs_conversion() {
assert_eq!(Literal::from_dimacs(42), Literal::positive(42));
assert_eq!(Literal::from_dimacs(-42), Literal::negative(42));
let lit = Literal::positive(123);
assert_eq!(Literal::from_dimacs(lit.to_dimacs()), lit);
}
#[test]
fn test_clause_creation() {
let clause = Clause::binary(Literal::positive(1), Literal::negative(2));
assert_eq!(clause.literals.len(), 2);
assert_eq!(clause.to_dimacs(), vec![1, -2]);
let unit = Clause::unit(Literal::positive(5));
assert!(unit.is_unit());
assert!(!unit.is_empty());
}
#[test]
fn test_sat_problem() {
let mut problem = SatProblem::new(3);
problem.add_clause(Clause::binary(Literal::positive(1), Literal::positive(2)));
problem.add_assumption(Literal::negative(3));
problem.set_variable_name(1, "package_a".to_string());
let stats = problem.stats();
assert_eq!(stats.variables, 3);
assert_eq!(stats.clauses, 1);
assert_eq!(stats.assumptions, 1);
assert_eq!(problem.get_variable_name(1), Some("package_a"));
assert_eq!(problem.get_variable_name(2), None);
}
#[test]
fn test_sat_model() {
let mut model = SatModel::new();
model.assign(1, true);
model.assign(2, false);
model.assign(3, true);
assert_eq!(model.get(1), Some(true));
assert_eq!(model.get(2), Some(false));
assert_eq!(model.get(4), None);
let mut true_vars = model.true_variables();
true_vars.sort();
assert_eq!(true_vars, vec![1, 3]);
let mut false_vars = model.false_variables();
false_vars.sort();
assert_eq!(false_vars, vec![2]);
let vec = model.to_vec(3);
assert_eq!(vec, vec![true, false, true]);
let model2 = SatModel::from_vec(vec);
assert_eq!(model.assignments, model2.assignments);
}
#[tokio::test]
async fn test_mock_solver() {
let config = SolverConfig::default();
let mut solver = MockSolver::new(config);
let mut problem = SatProblem::new(2);
problem.add_clause(Clause::unit(Literal::positive(1)));
let result = solver.solve(&problem).await;
assert!(result.is_ok());
match result.unwrap() {
SatSolution::Satisfiable(model) => {
assert_eq!(model.get(1), Some(true));
assert!(model.get(2).is_some());
}
_ => panic!("Expected satisfiable result"),
}
let stats = solver.get_stats();
assert!(stats.decisions > 0);
assert!(stats.solve_time.as_nanos() > 0);
}
#[test]
fn test_solver_config_defaults() {
let config = SolverConfig::default();
assert_eq!(config.timeout, Duration::from_secs(30));
assert!(config.preprocessing);
assert!(config.random_seed.is_none());
}
}