extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
pub type Var = u32;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct SatLit(u32);
impl SatLit {
pub fn new(var: Var, positive: bool) -> Self {
SatLit((var << 1) | (!positive as u32))
}
pub fn positive(var: Var) -> Self {
Self::new(var, true)
}
pub fn negative(var: Var) -> Self {
Self::new(var, false)
}
pub fn var(self) -> Var {
self.0 >> 1
}
pub fn is_negative(self) -> bool {
self.0 & 1 == 1
}
pub fn negate(self) -> SatLit {
SatLit(self.0 ^ 1)
}
fn code(self) -> usize {
self.0 as usize
}
}
#[derive(Clone, Debug, Default)]
pub struct Cnf {
pub num_vars: usize,
pub clauses: Vec<Vec<SatLit>>,
}
impl Cnf {
pub fn new(num_vars: usize) -> Self {
Cnf {
num_vars,
clauses: Vec::new(),
}
}
pub fn add_clause(&mut self, lits: Vec<SatLit>) {
self.clauses.push(lits);
}
}
#[derive(Clone, Copy)]
enum Reason {
Decision,
Unit,
Long(usize),
}
#[derive(Clone, Copy)]
struct Watch {
cref: usize,
blocking: SatLit,
}
enum Decision {
Propagated,
Sat,
UnsatCore(Vec<SatLit>),
}
struct Solver {
num_vars: usize,
clauses: Vec<Vec<SatLit>>, watches: Vec<Vec<Watch>>, assign: Vec<Option<bool>>, level: Vec<u32>, reason: Vec<Reason>, trail: Vec<SatLit>,
decisions: Vec<usize>, qhead: usize,
activity: Vec<f64>,
var_inc: f64,
polarity: Vec<bool>, seen: Vec<bool>, ok: bool, assumptions: Vec<SatLit>,
}
impl Solver {
fn new(cnf: &Cnf) -> Self {
let n = cnf.num_vars;
let mut s = Solver {
num_vars: n,
clauses: Vec::new(),
watches: vec![Vec::new(); 2 * n],
assign: vec![None; n],
level: vec![0; n],
reason: vec![Reason::Decision; n],
trail: Vec::new(),
decisions: Vec::new(),
qhead: 0,
activity: vec![0.0; n],
var_inc: 1.0,
polarity: vec![false; n],
seen: vec![false; n],
ok: true,
assumptions: Vec::new(),
};
for clause in &cnf.clauses {
s.add_clause(clause);
}
s
}
fn lit_is_true(&self, l: SatLit) -> bool {
self.assign[l.var() as usize] == Some(!l.is_negative())
}
fn lit_is_false(&self, l: SatLit) -> bool {
self.assign[l.var() as usize] == Some(l.is_negative())
}
fn current_level(&self) -> u32 {
self.decisions.len() as u32
}
fn watch(&mut self, cref: usize, a: SatLit, b: SatLit) {
self.watches[a.negate().code()].push(Watch { cref, blocking: b });
self.watches[b.negate().code()].push(Watch { cref, blocking: a });
}
fn add_clause(&mut self, lits: &[SatLit]) {
if !self.ok {
return;
}
if lits.is_empty() {
self.ok = false;
return;
}
if lits.len() == 1 {
let l = lits[0];
if self.lit_is_false(l) {
self.ok = false;
} else if !self.lit_is_true(l) {
self.enqueue(l, Reason::Unit);
}
return;
}
let mut clause = lits.to_vec();
let mut first = None;
let mut second = None;
for (i, &l) in clause.iter().enumerate() {
if !self.lit_is_false(l) {
if first.is_none() {
first = Some(i);
} else {
second = Some(i);
break;
}
}
}
let cref = self.clauses.len();
match (first, second) {
(None, _) => self.ok = false,
(Some(a), None) => {
clause.swap(0, a);
self.watch(cref, clause[0], clause[1]);
let unit = clause[0];
self.clauses.push(clause);
if !self.lit_is_true(unit) {
self.enqueue(unit, Reason::Long(cref));
}
}
(Some(a), Some(b)) => {
clause.swap(0, a);
clause.swap(1, b);
self.watch(cref, clause[0], clause[1]);
self.clauses.push(clause);
}
}
}
fn enqueue(&mut self, l: SatLit, reason: Reason) {
let v = l.var() as usize;
self.assign[v] = Some(!l.is_negative());
self.level[v] = self.current_level();
self.reason[v] = reason;
self.trail.push(l);
}
fn propagate(&mut self) -> Option<usize> {
while self.qhead < self.trail.len() {
let p = self.trail[self.qhead];
self.qhead += 1;
if let Some(cref) = self.propagate_lit(p) {
return Some(cref);
}
}
None
}
fn propagate_lit(&mut self, p: SatLit) -> Option<usize> {
let fl = p.negate(); let mut ws = core::mem::take(&mut self.watches[p.code()]);
let mut read = 0;
let mut write = 0;
let mut conflict = None;
while read < ws.len() {
let w = ws[read];
read += 1;
if self.lit_is_true(w.blocking) {
ws[write] = w;
write += 1;
continue;
}
let cref = w.cref;
if self.clauses[cref][0] == fl {
self.clauses[cref].swap(0, 1);
}
let other = self.clauses[cref][0];
let kept = Watch {
cref,
blocking: other,
};
if other != w.blocking && self.lit_is_true(other) {
ws[write] = kept;
write += 1;
continue;
}
if let Some(repl) = self.find_replacement(cref, fl) {
self.watches[repl.negate().code()].push(kept);
continue; }
ws[write] = kept;
write += 1;
if self.lit_is_false(other) {
while read < ws.len() {
ws[write] = ws[read];
write += 1;
read += 1;
}
conflict = Some(cref);
break;
}
self.enqueue(other, Reason::Long(cref));
}
ws.truncate(write);
self.watches[p.code()] = ws;
conflict
}
fn find_replacement(&mut self, cref: usize, fl: SatLit) -> Option<SatLit> {
let len = self.clauses[cref].len();
for k in 2..len {
let ck = self.clauses[cref][k];
if !self.lit_is_false(ck) {
self.clauses[cref][1] = ck;
self.clauses[cref][k] = fl;
return Some(ck);
}
}
None
}
fn bump(&mut self, v: usize) {
self.activity[v] += self.var_inc;
if self.activity[v] > 1e100 {
for a in &mut self.activity {
*a *= 1e-100;
}
self.var_inc *= 1e-100;
}
}
fn analyze(&mut self, conflict: usize) -> (Vec<SatLit>, u32) {
let cur_level = self.current_level();
let mut learned: Vec<SatLit> = vec![SatLit(0)]; let mut touched: Vec<Var> = Vec::new();
let mut counter = 0usize;
let mut idx = self.trail.len();
let mut p: Option<SatLit> = None;
let mut confl = conflict;
loop {
let start = if p.is_some() { 1 } else { 0 }; for j in start..self.clauses[confl].len() {
let q = self.clauses[confl][j];
let v = q.var() as usize;
if !self.seen[v] && self.level[v] > 0 {
self.seen[v] = true;
touched.push(v as Var);
self.bump(v);
if self.level[v] == cur_level {
counter += 1;
} else {
learned.push(q);
}
}
}
loop {
idx -= 1;
if self.seen[self.trail[idx].var() as usize] {
break;
}
}
let lit = self.trail[idx];
self.seen[lit.var() as usize] = false;
counter -= 1;
p = Some(lit);
if counter == 0 {
break;
}
confl = match self.reason[lit.var() as usize] {
Reason::Long(c) => c,
_ => unreachable!("a resolved current-level literal must have a clause reason"),
};
}
learned[0] = p.unwrap().negate();
let backjump = self.assertion_level(&mut learned);
self.var_inc *= 1.0 / 0.95;
for v in touched {
self.seen[v as usize] = false; }
(learned, backjump)
}
fn assertion_level(&self, learned: &mut [SatLit]) -> u32 {
if learned.len() == 1 {
return 0;
}
let mut max_i = 1;
let mut max_l = self.level[learned[1].var() as usize];
for (i, &lit) in learned.iter().enumerate().skip(2) {
let l = self.level[lit.var() as usize];
if l > max_l {
max_l = l;
max_i = i;
}
}
learned.swap(1, max_i);
max_l
}
fn analyze_final(&mut self, true_lit: SatLit) -> Vec<SatLit> {
let mut core = vec![true_lit.negate()]; if self.current_level() == 0 {
return core;
}
let assn = self.assumptions.len() as u32;
let start = self.decisions[0]; self.seen[true_lit.var() as usize] = true;
let mut touched = vec![true_lit.var()];
let mut i = self.trail.len();
while i > start {
i -= 1;
let x = self.trail[i].var() as usize;
if !self.seen[x] {
continue;
}
self.seen[x] = false;
match self.reason[x] {
Reason::Decision => {
if self.level[x] > 0 && self.level[x] <= assn {
core.push(self.trail[i]);
}
}
Reason::Unit => {}
Reason::Long(cr) => {
for j in 1..self.clauses[cr].len() {
let v = self.clauses[cr][j].var();
if self.level[v as usize] > 0 && !self.seen[v as usize] {
self.seen[v as usize] = true;
touched.push(v);
}
}
}
}
}
for v in touched {
self.seen[v as usize] = false;
}
core
}
fn backtrack(&mut self, level: u32) {
if self.current_level() <= level {
return;
}
let new_len = self.decisions[level as usize];
for i in new_len..self.trail.len() {
let v = self.trail[i].var() as usize;
self.polarity[v] = self.assign[v] == Some(true);
self.assign[v] = None;
}
self.trail.truncate(new_len);
self.decisions.truncate(level as usize);
self.qhead = new_len;
}
fn learn(&mut self, learned: Vec<SatLit>) {
if learned.len() == 1 {
self.enqueue(learned[0], Reason::Unit);
} else {
let cref = self.clauses.len();
self.watch(cref, learned[0], learned[1]);
let assert_lit = learned[0];
self.clauses.push(learned);
self.enqueue(assert_lit, Reason::Long(cref));
}
}
fn pick_branch(&self) -> Option<SatLit> {
let mut best: Option<usize> = None;
let mut best_act = -1.0;
for v in 0..self.num_vars {
if self.assign[v].is_none() && self.activity[v] > best_act {
best_act = self.activity[v];
best = Some(v);
}
}
best.map(|v| SatLit::new(v as Var, self.polarity[v]))
}
fn decide(&mut self) -> Decision {
while (self.current_level() as usize) < self.assumptions.len() {
let p = self.assumptions[self.current_level() as usize];
if self.lit_is_true(p) {
self.decisions.push(self.trail.len()); } else if self.lit_is_false(p) {
return Decision::UnsatCore(self.analyze_final(p.negate()));
} else {
self.decisions.push(self.trail.len());
self.enqueue(p, Reason::Decision);
return Decision::Propagated;
}
}
match self.pick_branch() {
None => Decision::Sat,
Some(lit) => {
self.decisions.push(self.trail.len());
self.enqueue(lit, Reason::Decision);
Decision::Propagated
}
}
}
fn run(&mut self) -> Result<(), Vec<SatLit>> {
if !self.ok {
return Err(Vec::new());
}
loop {
if let Some(cref) = self.propagate() {
if self.current_level() == 0 {
self.ok = false;
return Err(Vec::new());
}
let (learned, backjump) = self.analyze(cref);
self.backtrack(backjump);
self.learn(learned);
} else {
match self.decide() {
Decision::Propagated => {}
Decision::Sat => return Ok(()),
Decision::UnsatCore(core) => return Err(core),
}
}
}
}
fn search(&mut self) -> bool {
self.run().is_ok()
}
fn model(&self) -> Vec<bool> {
self.assign.iter().map(|a| a.unwrap_or(false)).collect()
}
fn block(&mut self, project: &[Var], model: &[bool]) -> bool {
if project.is_empty() {
return false;
}
let block: Vec<SatLit> = project
.iter()
.map(|&v| {
if model[v as usize] {
SatLit::negative(v)
} else {
SatLit::positive(v)
}
})
.collect();
self.backtrack(0);
self.add_clause(&block);
true
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Solved {
Sat(Vec<bool>),
Unsat(Vec<SatLit>),
}
pub fn solve_assuming(cnf: &Cnf, assumptions: &[SatLit]) -> Solved {
let mut s = Solver::new(cnf);
s.assumptions = assumptions.to_vec();
match s.run() {
Ok(()) => Solved::Sat(s.model()),
Err(core) => Solved::Unsat(core),
}
}
pub fn solve(cnf: &Cnf) -> Option<Vec<bool>> {
match solve_assuming(cnf, &[]) {
Solved::Sat(model) => Some(model),
Solved::Unsat(_) => None,
}
}
pub struct Models {
solver: Solver,
project: Vec<Var>,
done: bool,
}
impl Iterator for Models {
type Item = Vec<bool>;
fn next(&mut self) -> Option<Vec<bool>> {
if self.done {
return None;
}
if !self.solver.search() {
self.done = true;
return None;
}
let model = self.solver.model();
if !self.solver.block(&self.project, &model) {
self.done = true;
}
Some(model)
}
}
pub fn all_models(cnf: &Cnf, project: Vec<Var>) -> Models {
Models {
solver: Solver::new(cnf),
project,
done: false,
}
}
pub fn models(cnf: &Cnf, project: &[Var], limit: usize) -> Vec<Vec<bool>> {
all_models(cnf, project.to_vec()).take(limit).collect()
}
pub fn models_upto(cnf: &Cnf, project: &[Var], limit: usize) -> usize {
all_models(cnf, project.to_vec()).take(limit).count()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn trivial_sat() {
let mut c = Cnf::new(2);
c.add_clause(vec![SatLit::positive(0), SatLit::positive(1)]);
assert!(solve(&c).is_some());
}
#[test]
fn unit_contradiction_unsat() {
let mut c = Cnf::new(1);
c.add_clause(vec![SatLit::positive(0)]);
c.add_clause(vec![SatLit::negative(0)]);
assert!(solve(&c).is_none());
}
#[test]
fn all_four_combos_excluded_is_unsat() {
let mut c = Cnf::new(2);
let (a, b) = (0u32, 1u32);
c.add_clause(vec![SatLit::positive(a), SatLit::positive(b)]);
c.add_clause(vec![SatLit::negative(a), SatLit::positive(b)]);
c.add_clause(vec![SatLit::positive(a), SatLit::negative(b)]);
c.add_clause(vec![SatLit::negative(a), SatLit::negative(b)]);
assert!(solve(&c).is_none());
}
#[test]
fn forced_chain_has_unique_model() {
let mut c = Cnf::new(2);
c.add_clause(vec![SatLit::negative(0), SatLit::positive(1)]);
c.add_clause(vec![SatLit::positive(0)]);
let m = solve(&c).unwrap();
assert!(m[0] && m[1]);
assert_eq!(models_upto(&c, &[0, 1], 5), 1);
}
#[test]
fn or_clause_has_three_models() {
let mut c = Cnf::new(2);
c.add_clause(vec![SatLit::positive(0), SatLit::positive(1)]);
assert_eq!(models_upto(&c, &[0, 1], 10), 3);
}
#[test]
fn lazy_models_iterator_is_incremental() {
let mut c = Cnf::new(2);
c.add_clause(vec![SatLit::positive(0), SatLit::positive(1)]);
let first_two: Vec<_> = all_models(&c, vec![0, 1]).take(2).collect();
assert_eq!(first_two.len(), 2);
assert_ne!(first_two[0], first_two[1]);
assert_eq!(all_models(&c, vec![0, 1]).count(), 3);
}
#[test]
fn assumption_forces_a_model() {
let mut c = Cnf::new(2);
c.add_clause(vec![SatLit::positive(0), SatLit::positive(1)]);
match solve_assuming(&c, &[SatLit::negative(0)]) {
Solved::Sat(m) => {
assert!(!m[0] && m[1]);
}
Solved::Unsat(_) => panic!("should be SAT under ¬a"),
}
}
#[test]
fn contradicted_assumptions_yield_a_sufficient_core() {
let mut c = Cnf::new(2);
c.add_clause(vec![SatLit::negative(0), SatLit::negative(1)]);
let assumptions = [SatLit::positive(0), SatLit::positive(1)];
match solve_assuming(&c, &assumptions) {
Solved::Unsat(core) => {
assert!(!core.is_empty());
assert!(core.iter().all(|l| assumptions.contains(l)));
let mut cc = c.clone();
for l in &core {
cc.add_clause(vec![*l]);
}
assert!(solve(&cc).is_none());
}
Solved::Sat(_) => panic!("a ∧ b violates (¬a ∨ ¬b)"),
}
}
#[test]
fn satisfiable_assumptions_round_trip() {
let mut c = Cnf::new(3);
c.add_clause(vec![
SatLit::positive(0),
SatLit::positive(1),
SatLit::positive(2),
]);
let assumptions = [SatLit::positive(0), SatLit::negative(2)];
match solve_assuming(&c, &assumptions) {
Solved::Sat(m) => {
assert!(m[0] && !m[2]);
}
Solved::Unsat(_) => panic!("should be SAT"),
}
}
#[test]
fn larger_random_like_sat_is_solved() {
let mut c = Cnf::new(5);
let l = |v: u32, p: bool| SatLit::new(v, p);
c.add_clause(vec![l(0, true), l(1, true), l(2, false)]);
c.add_clause(vec![l(0, false), l(2, true), l(3, true)]);
c.add_clause(vec![l(1, false), l(3, false), l(4, true)]);
c.add_clause(vec![l(2, false), l(4, false)]);
c.add_clause(vec![l(0, true), l(4, true)]);
let m = solve(&c).expect("sat");
for clause in &c.clauses {
assert!(
clause
.iter()
.any(|&lit| m[lit.var() as usize] != lit.is_negative())
);
}
}
}