use crate::instance::SolveInstance;
use crate::proof::{SolveProof, SolveResult, SolveStats, SolveStatus};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct SolverConfig {
pub max_iterations: u32,
pub learning_rate: f32,
pub momentum: f32,
pub discretize_threshold: f32,
}
impl Default for SolverConfig {
fn default() -> Self {
Self {
max_iterations: 10000,
learning_rate: 0.1,
momentum: 0.9,
discretize_threshold: 0.5,
}
}
}
impl SolverConfig {
#[inline]
pub const fn new(
max_iterations: u32,
learning_rate: f32,
momentum: f32,
discretize_threshold: f32,
) -> Self {
Self {
max_iterations,
learning_rate,
momentum,
discretize_threshold,
}
}
#[inline]
pub const fn fast() -> Self {
Self {
max_iterations: 1000,
learning_rate: 0.2,
momentum: 0.9,
discretize_threshold: 0.5,
}
}
#[inline]
pub const fn thorough() -> Self {
Self {
max_iterations: 50000,
learning_rate: 0.05,
momentum: 0.95,
discretize_threshold: 0.5,
}
}
#[inline]
pub const fn with_max_iterations(mut self, max_iterations: u32) -> Self {
self.max_iterations = max_iterations;
self
}
#[inline]
pub const fn with_learning_rate(mut self, learning_rate: f32) -> Self {
self.learning_rate = learning_rate;
self
}
#[inline]
pub const fn with_momentum(mut self, momentum: f32) -> Self {
self.momentum = momentum;
self
}
#[inline]
pub const fn with_discretize_threshold(mut self, threshold: f32) -> Self {
self.discretize_threshold = threshold;
self
}
}
#[derive(Debug, Clone)]
pub struct SolverState {
pub assignments: Vec<f32>,
pub velocities: Vec<f32>,
pub gradients: Vec<f32>,
}
impl SolverState {
pub fn new(num_vars: u32) -> Self {
let n = num_vars as usize;
let assignments: Vec<f32> = (0..n)
.map(|i| {
let phi = 1.618033988749895_f64;
let val = ((i as f64 + 1.0) * phi).fract() as f32;
0.3 + val * 0.4
})
.collect();
Self {
assignments,
velocities: vec![0.0; n],
gradients: vec![0.0; n],
}
}
pub fn with_assignments(assignments: Vec<f32>) -> Self {
let n = assignments.len();
Self {
assignments,
velocities: vec![0.0; n],
gradients: vec![0.0; n],
}
}
#[inline]
pub fn discretize(&self, threshold: f32) -> Vec<bool> {
self.assignments
.iter()
.map(|&val| val >= threshold)
.collect()
}
#[inline]
pub fn num_vars(&self) -> usize {
self.assignments.len()
}
#[inline]
pub fn reset_velocities(&mut self) {
self.velocities.fill(0.0);
}
#[inline]
pub fn clear_gradients(&mut self) {
self.gradients.fill(0.0);
}
}
#[derive(Debug, Clone)]
pub struct Solver {
config: SolverConfig,
}
impl Solver {
#[inline]
pub fn new_cpu() -> Self {
Self {
config: SolverConfig::default(),
}
}
#[inline]
pub fn with_config_cpu(config: SolverConfig) -> Self {
Self { config }
}
#[inline]
pub fn config(&self) -> &SolverConfig {
&self.config
}
pub fn solve(&self, instance: SolveInstance) -> SolveResult {
let start = std::time::Instant::now();
if instance.num_vars == 0 {
let has_empty_clause = instance.clauses.iter().any(|c| c.is_empty());
if has_empty_clause {
return SolveResult {
status: SolveStatus::Unknown,
proof: SolveProof::approximate(vec![], 0, instance.clauses.len() as u32, 0),
stats: SolveStats::new(0, start.elapsed().as_micros() as u64, 0),
};
}
return SolveResult::satisfiable(vec![]).with_stats(SolveStats::new(
0,
start.elapsed().as_micros() as u64,
0,
));
}
if instance.clauses.is_empty() {
let assignment = vec![false; instance.num_vars as usize];
return SolveResult::satisfiable(assignment).with_stats(SolveStats::new(
0,
start.elapsed().as_micros() as u64,
0,
));
}
if instance.clauses.iter().any(|c| c.is_empty()) {
return SolveResult {
status: SolveStatus::Unknown,
proof: SolveProof::approximate(
vec![false; instance.num_vars as usize],
instance.count_satisfied(&vec![false; instance.num_vars as usize]) as u32,
instance.clauses.len() as u32,
0,
),
stats: SolveStats::new(0, start.elapsed().as_micros() as u64, 0),
};
}
let mut state = SolverState::new(instance.num_vars);
let mut best_assignment: Option<Vec<bool>> = None;
let mut best_satisfied: u32 = 0;
for iter in 0..self.config.max_iterations {
self.compute_gradients(&instance, &mut state);
self.update_assignments(&mut state);
let discrete = state.discretize(self.config.discretize_threshold);
let satisfied = instance.count_satisfied(&discrete) as u32;
if satisfied > best_satisfied {
best_satisfied = satisfied;
best_assignment = Some(discrete.clone());
}
if instance.is_satisfied(&discrete) {
return SolveResult::satisfiable(discrete).with_stats(SolveStats {
iterations: iter + 1,
duration_us: start.elapsed().as_micros() as u64,
peak_memory: 0,
});
}
}
let final_discrete =
best_assignment.unwrap_or_else(|| state.discretize(self.config.discretize_threshold));
let final_satisfied = instance.count_satisfied(&final_discrete) as u32;
SolveResult {
status: SolveStatus::Unknown,
proof: SolveProof::approximate(
final_discrete,
final_satisfied,
instance.clauses.len() as u32,
self.config.max_iterations,
),
stats: SolveStats {
iterations: self.config.max_iterations,
duration_us: start.elapsed().as_micros() as u64,
peak_memory: 0,
},
}
}
fn compute_gradients(&self, instance: &SolveInstance, state: &mut SolverState) {
state.gradients.fill(0.0);
for clause in &instance.clauses {
let mut clause_unsat = 1.0f32;
for lit in &clause.literals {
let val = state.assignments[lit.var as usize];
let lit_val = if lit.negated { 1.0 - val } else { val };
clause_unsat *= 1.0 - lit_val;
}
if clause_unsat < 0.001 {
continue;
}
for lit in &clause.literals {
let var = lit.var as usize;
let mut other_product = 1.0f32;
for other_lit in &clause.literals {
if other_lit.var != lit.var {
let other_val = state.assignments[other_lit.var as usize];
let lit_val = if other_lit.negated {
1.0 - other_val
} else {
other_val
};
other_product *= 1.0 - lit_val;
}
}
let sign = if lit.negated { 1.0 } else { -1.0 };
state.gradients[var] += sign * other_product;
}
}
}
fn update_assignments(&self, state: &mut SolverState) {
for i in 0..state.assignments.len() {
state.velocities[i] = self.config.momentum * state.velocities[i]
- self.config.learning_rate * state.gradients[i];
state.assignments[i] += state.velocities[i];
state.assignments[i] = state.assignments[i].clamp(0.0, 1.0);
}
}
}
impl Default for Solver {
fn default() -> Self {
Self::new_cpu()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::instance::{Clause, Literal};
#[test]
fn test_solver_config_default() {
let config = SolverConfig::default();
assert_eq!(config.max_iterations, 10000);
assert_eq!(config.learning_rate, 0.1);
assert_eq!(config.momentum, 0.9);
assert_eq!(config.discretize_threshold, 0.5);
}
#[test]
fn test_solver_config_new() {
let config = SolverConfig::new(5000, 0.05, 0.95, 0.4);
assert_eq!(config.max_iterations, 5000);
assert_eq!(config.learning_rate, 0.05);
assert_eq!(config.momentum, 0.95);
assert_eq!(config.discretize_threshold, 0.4);
}
#[test]
fn test_solver_config_fast() {
let config = SolverConfig::fast();
assert_eq!(config.max_iterations, 1000);
assert_eq!(config.learning_rate, 0.2);
}
#[test]
fn test_solver_config_thorough() {
let config = SolverConfig::thorough();
assert_eq!(config.max_iterations, 50000);
assert_eq!(config.learning_rate, 0.05);
}
#[test]
fn test_solver_config_builders() {
let config = SolverConfig::default()
.with_max_iterations(2000)
.with_learning_rate(0.2)
.with_momentum(0.8)
.with_discretize_threshold(0.6);
assert_eq!(config.max_iterations, 2000);
assert_eq!(config.learning_rate, 0.2);
assert_eq!(config.momentum, 0.8);
assert_eq!(config.discretize_threshold, 0.6);
}
#[test]
fn test_solver_state_new() {
let state = SolverState::new(5);
assert_eq!(state.assignments.len(), 5);
assert_eq!(state.velocities.len(), 5);
assert_eq!(state.gradients.len(), 5);
assert!(state.velocities.iter().all(|&v| v == 0.0));
assert!(state.gradients.iter().all(|&g| g == 0.0));
for &val in &state.assignments {
assert!(val >= 0.3 && val <= 0.7);
}
}
#[test]
fn test_solver_state_with_assignments() {
let assignments = vec![0.1, 0.5, 0.9];
let state = SolverState::with_assignments(assignments.clone());
assert_eq!(state.assignments, assignments);
assert!(state.velocities.iter().all(|&v| v == 0.0));
}
#[test]
fn test_solver_state_discretize() {
let mut state = SolverState::new(4);
state.assignments = vec![0.2, 0.5, 0.6, 0.9];
let discrete = state.discretize(0.5);
assert_eq!(discrete, vec![false, true, true, true]);
let discrete_high = state.discretize(0.7);
assert_eq!(discrete_high, vec![false, false, false, true]);
}
#[test]
fn test_solver_state_num_vars() {
let state = SolverState::new(10);
assert_eq!(state.num_vars(), 10);
}
#[test]
fn test_solver_state_reset_velocities() {
let mut state = SolverState::new(3);
state.velocities = vec![1.0, 2.0, 3.0];
state.reset_velocities();
assert!(state.velocities.iter().all(|&v| v == 0.0));
}
#[test]
fn test_solver_state_clear_gradients() {
let mut state = SolverState::new(3);
state.gradients = vec![1.0, 2.0, 3.0];
state.clear_gradients();
assert!(state.gradients.iter().all(|&g| g == 0.0));
}
#[test]
fn test_solver_new_cpu() {
let solver = Solver::new_cpu();
assert_eq!(solver.config().max_iterations, 10000);
}
#[test]
fn test_solver_with_config_cpu() {
let config = SolverConfig::fast();
let solver = Solver::with_config_cpu(config);
assert_eq!(solver.config().max_iterations, 1000);
}
#[test]
fn test_solver_default() {
let solver = Solver::default();
assert_eq!(solver.config().max_iterations, 10000);
}
#[test]
fn test_solver_simple_sat() {
let instance = SolveInstance::new(1, vec![Clause::new(vec![Literal::positive(0)])]);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Sat));
if let Some(assignment) = result.assignment() {
assert!(assignment[0]); }
}
#[test]
fn test_solver_two_clause() {
let instance = SolveInstance::new(
2,
vec![
Clause::new(vec![Literal::positive(0), Literal::positive(1)]),
Clause::new(vec![Literal::negative(0), Literal::positive(1)]),
],
);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Sat));
if let Some(assignment) = result.assignment() {
assert!(assignment[1]); }
}
#[test]
fn test_solver_unsat() {
let instance = SolveInstance::new(
1,
vec![
Clause::new(vec![Literal::positive(0)]),
Clause::new(vec![Literal::negative(0)]),
],
);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(
result.status,
SolveStatus::Unsat | SolveStatus::Unknown
));
}
#[test]
fn test_solver_empty_instance() {
let instance = SolveInstance::new(3, vec![]);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Sat));
}
#[test]
fn test_solver_no_variables() {
let instance = SolveInstance::new(0, vec![]);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Sat));
}
#[test]
fn test_solver_unit_propagation() {
let instance = SolveInstance::new(
2,
vec![
Clause::new(vec![Literal::positive(0)]),
Clause::new(vec![Literal::negative(0), Literal::positive(1)]),
],
);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Sat));
if let Some(assignment) = result.assignment() {
assert!(assignment[0]); }
}
#[test]
fn test_solver_negative_unit() {
let instance = SolveInstance::new(1, vec![Clause::new(vec![Literal::negative(0)])]);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Sat));
if let Some(assignment) = result.assignment() {
assert!(!assignment[0]); }
}
#[test]
fn test_solver_three_vars() {
let instance = SolveInstance::new(
3,
vec![
Clause::new(vec![Literal::positive(0), Literal::positive(1)]),
Clause::new(vec![Literal::negative(1), Literal::positive(2)]),
Clause::new(vec![Literal::negative(2), Literal::positive(0)]),
],
);
let solver = Solver::new_cpu();
let result = solver.solve(instance.clone());
assert!(matches!(result.status, SolveStatus::Sat));
if let Some(assignment) = result.assignment() {
assert!(instance.is_satisfied(assignment));
}
}
#[test]
fn test_solver_all_positive() {
let instance = SolveInstance::new(
3,
vec![
Clause::new(vec![Literal::positive(0)]),
Clause::new(vec![Literal::positive(1)]),
Clause::new(vec![Literal::positive(2)]),
],
);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Sat));
if let Some(assignment) = result.assignment() {
assert!(assignment.iter().all(|&v| v));
}
}
#[test]
fn test_solver_all_negative() {
let instance = SolveInstance::new(
3,
vec![
Clause::new(vec![Literal::negative(0)]),
Clause::new(vec![Literal::negative(1)]),
Clause::new(vec![Literal::negative(2)]),
],
);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Sat));
if let Some(assignment) = result.assignment() {
assert!(assignment.iter().all(|&v| !v));
}
}
#[test]
fn test_solver_xor_like() {
let instance = SolveInstance::new(
2,
vec![
Clause::new(vec![Literal::positive(0), Literal::positive(1)]),
Clause::new(vec![Literal::negative(0), Literal::negative(1)]),
],
);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Sat));
if let Some(assignment) = result.assignment() {
assert!(assignment[0] != assignment[1]);
}
}
#[test]
fn test_solver_binary_clause() {
let instance = SolveInstance::new(
2,
vec![Clause::new(vec![
Literal::positive(0),
Literal::positive(1),
])],
);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Sat));
if let Some(assignment) = result.assignment() {
assert!(assignment[0] || assignment[1]);
}
}
#[test]
fn test_solver_ternary_clause() {
let instance = SolveInstance::new(
3,
vec![Clause::new(vec![
Literal::positive(0),
Literal::positive(1),
Literal::positive(2),
])],
);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Sat));
if let Some(assignment) = result.assignment() {
assert!(assignment[0] || assignment[1] || assignment[2]);
}
}
#[test]
fn test_solver_stats() {
let instance = SolveInstance::new(1, vec![Clause::new(vec![Literal::positive(0)])]);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(result.stats.iterations > 0);
assert!(result.stats.iterations <= solver.config().max_iterations);
}
#[test]
fn test_solver_stats_iterations_limited() {
let config = SolverConfig::default().with_max_iterations(10);
let solver = Solver::with_config_cpu(config);
let instance = SolveInstance::new(
1,
vec![
Clause::new(vec![Literal::positive(0)]),
Clause::new(vec![Literal::negative(0)]),
],
);
let result = solver.solve(instance);
assert!(result.stats.iterations <= 10);
}
#[test]
fn test_compute_gradients_single_positive() {
let instance = SolveInstance::new(1, vec![Clause::new(vec![Literal::positive(0)])]);
let solver = Solver::new_cpu();
let mut state = SolverState::with_assignments(vec![0.5]);
solver.compute_gradients(&instance, &mut state);
assert!(state.gradients[0] < 0.0);
}
#[test]
fn test_compute_gradients_single_negative() {
let instance = SolveInstance::new(1, vec![Clause::new(vec![Literal::negative(0)])]);
let solver = Solver::new_cpu();
let mut state = SolverState::with_assignments(vec![0.5]);
solver.compute_gradients(&instance, &mut state);
assert!(state.gradients[0] > 0.0);
}
#[test]
fn test_compute_gradients_satisfied_clause() {
let instance = SolveInstance::new(1, vec![Clause::new(vec![Literal::positive(0)])]);
let solver = Solver::new_cpu();
let mut state = SolverState::with_assignments(vec![1.0]);
solver.compute_gradients(&instance, &mut state);
assert!(state.gradients[0].abs() < 0.01);
}
#[test]
fn test_update_assignments_clamps() {
let solver = Solver::with_config_cpu(SolverConfig::default().with_learning_rate(10.0));
let mut state = SolverState::with_assignments(vec![0.5]);
state.gradients = vec![-1.0];
solver.update_assignments(&mut state);
assert!(state.assignments[0] >= 0.0);
assert!(state.assignments[0] <= 1.0);
}
#[test]
fn test_update_assignments_momentum() {
let solver = Solver::with_config_cpu(
SolverConfig::default()
.with_learning_rate(0.1)
.with_momentum(0.5),
);
let mut state = SolverState::with_assignments(vec![0.5]);
state.velocities = vec![0.1]; state.gradients = vec![-0.1];
solver.update_assignments(&mut state);
let expected_velocity = 0.5 * 0.1 - 0.1 * (-0.1);
assert!((state.velocities[0] - expected_velocity).abs() < 1e-6);
}
#[test]
fn test_solver_empty_clause() {
let instance = SolveInstance::new(1, vec![Clause::new(vec![])]);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Unknown));
}
#[test]
fn test_solver_large_clause() {
let literals: Vec<Literal> = (0..10).map(Literal::positive).collect();
let instance = SolveInstance::new(10, vec![Clause::new(literals)]);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Sat));
}
#[test]
fn test_solver_many_clauses() {
let clauses: Vec<Clause> = (0..20)
.map(|i| Clause::new(vec![Literal::positive(i)]))
.collect();
let instance = SolveInstance::new(20, clauses);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Sat));
if let Some(assignment) = result.assignment() {
assert!(assignment.iter().all(|&v| v));
}
}
#[test]
fn test_solver_pigeon_hole_small() {
let instance = SolveInstance::new(
2,
vec![
Clause::new(vec![Literal::positive(0)]), Clause::new(vec![Literal::positive(1)]), Clause::new(vec![Literal::negative(0), Literal::negative(1)]), ],
);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Unknown));
}
#[test]
fn test_solver_deterministic() {
let instance = SolveInstance::new(
2,
vec![
Clause::new(vec![Literal::positive(0), Literal::positive(1)]),
Clause::new(vec![Literal::negative(0), Literal::positive(1)]),
],
);
let solver = Solver::new_cpu();
let result1 = solver.solve(instance.clone());
let result2 = solver.solve(instance);
assert_eq!(result1.status, result2.status);
assert_eq!(result1.assignment(), result2.assignment());
}
#[test]
fn test_solver_with_fast_config() {
let instance = SolveInstance::new(
2,
vec![Clause::new(vec![
Literal::positive(0),
Literal::positive(1),
])],
);
let solver = Solver::with_config_cpu(SolverConfig::fast());
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Sat));
}
#[test]
fn test_solver_implication_chain() {
let instance = SolveInstance::new(
4,
vec![
Clause::new(vec![Literal::positive(0)]),
Clause::new(vec![Literal::negative(0), Literal::positive(1)]),
Clause::new(vec![Literal::negative(1), Literal::positive(2)]),
Clause::new(vec![Literal::negative(2), Literal::positive(3)]),
],
);
let solver = Solver::new_cpu();
let result = solver.solve(instance);
assert!(matches!(result.status, SolveStatus::Sat));
if let Some(assignment) = result.assignment() {
assert!(assignment.iter().all(|&v| v));
}
}
}