Skip to main content

elenchus_solver/
sat.rs

1//! A compact, single-threaded CDCL SAT solver in `no_std`, replicating the core
2//! algorithm of varisat (jix/varisat) in a readable, lazy style.
3//!
4//! The solver is a small **state machine**: `Solver::step` performs exactly one
5//! transition and returns a `Step` saying which way the search went
6//! (propagated / hit a conflict / made a decision / SAT / UNSAT). `Solver::search`
7//! drives steps to a terminal state. Model enumeration is a lazy [`Models`]
8//! iterator that solves **incrementally** — each `next()` adds a blocking clause
9//! and continues from the existing state rather than re-solving from scratch.
10//!
11//! Pieces mirror varisat's modules: the trail + decision levels
12//! (`prop/assignment.rs`), two-watched-literal propagation (`prop/long.rs`),
13//! 1-UIP conflict analysis with clause learning (`analyze_conflict.rs`),
14//! non-chronological backjumping, and VSIDS decisions with phase saving.
15//! Its infrastructure (proof/DRAT logging, clause-DB GC, assumptions, restarts,
16//! the `partial_ref` context, multithreading) is intentionally omitted.
17
18extern crate alloc;
19
20use alloc::vec;
21use alloc::vec::Vec;
22
23// --- literals & formulas ---------------------------------------------------
24
25/// A boolean variable, identified by a dense index.
26pub type Var = u32;
27
28/// A literal: a variable plus a sign, packed as `var << 1 | negative`.
29#[derive(Clone, Copy, PartialEq, Eq, Debug)]
30pub struct SatLit(u32);
31
32impl SatLit {
33    /// A literal for `var`, positive (true) or negative (`NOT var`).
34    pub fn new(var: Var, positive: bool) -> Self {
35        SatLit((var << 1) | (!positive as u32))
36    }
37    /// The positive literal `var`.
38    pub fn positive(var: Var) -> Self {
39        Self::new(var, true)
40    }
41    /// The negative literal `NOT var`.
42    pub fn negative(var: Var) -> Self {
43        Self::new(var, false)
44    }
45    /// The underlying variable.
46    pub fn var(self) -> Var {
47        self.0 >> 1
48    }
49    /// Whether this is the negative polarity.
50    pub fn is_negative(self) -> bool {
51        self.0 & 1 == 1
52    }
53    /// The same variable with the opposite sign.
54    pub fn negate(self) -> SatLit {
55        SatLit(self.0 ^ 1)
56    }
57    /// The packed code, used directly as an index into the watch lists.
58    fn code(self) -> usize {
59        self.0 as usize
60    }
61}
62
63/// A CNF formula over `num_vars` variables.
64#[derive(Clone, Debug, Default)]
65pub struct Cnf {
66    /// Number of variables; every [`Var`] used must be `< num_vars`.
67    pub num_vars: usize,
68    /// The clauses, each a disjunction of literals (the formula is their AND).
69    pub clauses: Vec<Vec<SatLit>>,
70}
71
72impl Cnf {
73    /// An empty formula over `num_vars` variables.
74    pub fn new(num_vars: usize) -> Self {
75        Cnf {
76            num_vars,
77            clauses: Vec::new(),
78        }
79    }
80    /// Append one clause (a disjunction of literals).
81    pub fn add_clause(&mut self, lits: Vec<SatLit>) {
82        self.clauses.push(lits);
83    }
84}
85
86// --- internal state --------------------------------------------------------
87
88/// Why a variable was assigned — needed for conflict analysis and backtracking.
89#[derive(Clone, Copy)]
90enum Reason {
91    Decision,
92    Unit,
93    Long(usize),
94}
95
96/// One watched-literal entry: a clause plus a cached "other" literal so a true
97/// blocking literal lets us skip the clause entirely.
98#[derive(Clone, Copy)]
99struct Watch {
100    cref: usize,
101    blocking: SatLit,
102}
103
104/// The outcome of a single CDCL transition. Makes the search direction explicit.
105#[derive(Clone, Copy, PartialEq, Eq, Debug)]
106enum Step {
107    /// A clause was falsified; it was analyzed, learned, and backjumped.
108    LearntFromConflict,
109    /// A new decision literal was assigned.
110    Decided,
111    /// Every variable is assigned — the formula is satisfied.
112    Sat,
113    /// A conflict at decision level 0 — the formula is unsatisfiable.
114    Unsat,
115}
116
117/// The full CDCL search state: the assignment trail with decision levels, the
118/// clause database with two-watched-literal indices, VSIDS activities with phase
119/// saving, and a reusable `seen` scratch buffer for conflict analysis.
120struct Solver {
121    num_vars: usize,
122    clauses: Vec<Vec<SatLit>>, // originals + learned + blocking
123    watches: Vec<Vec<Watch>>, // indexed by literal code; a clause watching `w` lives in watches[!w]
124    assign: Vec<Option<bool>>, // per var
125    level: Vec<u32>,          // per var (valid when assigned)
126    reason: Vec<Reason>,      // per var (valid when assigned)
127    trail: Vec<SatLit>,
128    decisions: Vec<usize>, // trail index where each decision level starts
129    qhead: usize,
130    activity: Vec<f64>,
131    var_inc: f64,
132    polarity: Vec<bool>, // phase saving
133    seen: Vec<bool>,     // reusable scratch for analyze (invariant: all-false between calls)
134    ok: bool,            // false once the formula is known UNSAT
135}
136
137impl Solver {
138    /// Build a solver and load every clause of `cnf` under the empty assignment.
139    fn new(cnf: &Cnf) -> Self {
140        let n = cnf.num_vars;
141        let mut s = Solver {
142            num_vars: n,
143            clauses: Vec::new(),
144            watches: vec![Vec::new(); 2 * n],
145            assign: vec![None; n],
146            level: vec![0; n],
147            reason: vec![Reason::Decision; n],
148            trail: Vec::new(),
149            decisions: Vec::new(),
150            qhead: 0,
151            activity: vec![0.0; n],
152            var_inc: 1.0,
153            polarity: vec![false; n],
154            seen: vec![false; n],
155            ok: true,
156        };
157        for clause in &cnf.clauses {
158            s.add_clause(clause);
159        }
160        s
161    }
162
163    // -- assignment queries --
164
165    /// Is `l` currently assigned true? (Unassigned counts as neither true nor false.)
166    fn lit_is_true(&self, l: SatLit) -> bool {
167        self.assign[l.var() as usize] == Some(!l.is_negative())
168    }
169    /// Is `l` currently assigned false?
170    fn lit_is_false(&self, l: SatLit) -> bool {
171        self.assign[l.var() as usize] == Some(l.is_negative())
172    }
173    /// The current decision level (= number of open decisions).
174    fn current_level(&self) -> u32 {
175        self.decisions.len() as u32
176    }
177
178    // -- clause loading --
179
180    /// Register clause `cref` to be watched by literals `a` and `b`. A clause
181    /// watching a literal is stored under that literal's *negation's* code, so
182    /// it is revisited exactly when the watched literal becomes false.
183    fn watch(&mut self, cref: usize, a: SatLit, b: SatLit) {
184        self.watches[a.negate().code()].push(Watch { cref, blocking: b });
185        self.watches[b.negate().code()].push(Watch { cref, blocking: a });
186    }
187
188    /// Attach a clause under the *current* assignment. Both watched literals must
189    /// be non-false, or the clause is unit/conflicting and is handled directly.
190    /// This is what makes incremental clause addition (blocking clauses added
191    /// mid-search, at level 0) correct — naively watching `lits[0..2]` would break
192    /// the invariant when one is already false.
193    fn add_clause(&mut self, lits: &[SatLit]) {
194        if !self.ok {
195            return;
196        }
197        if lits.is_empty() {
198            self.ok = false;
199            return;
200        }
201        if lits.len() == 1 {
202            let l = lits[0];
203            if self.lit_is_false(l) {
204                self.ok = false;
205            } else if !self.lit_is_true(l) {
206                self.enqueue(l, Reason::Unit);
207            }
208            return;
209        }
210
211        // Find up to two non-false literals to watch.
212        let mut clause = lits.to_vec();
213        let mut first = None;
214        let mut second = None;
215        for (i, &l) in clause.iter().enumerate() {
216            if !self.lit_is_false(l) {
217                if first.is_none() {
218                    first = Some(i);
219                } else {
220                    second = Some(i);
221                    break;
222                }
223            }
224        }
225        let cref = self.clauses.len();
226        match (first, second) {
227            // Every literal is false under the current assignment → conflict.
228            (None, _) => self.ok = false,
229            // Exactly one non-false literal → the clause is unit; assert it.
230            (Some(a), None) => {
231                clause.swap(0, a);
232                self.watch(cref, clause[0], clause[1]);
233                let unit = clause[0];
234                self.clauses.push(clause);
235                if !self.lit_is_true(unit) {
236                    self.enqueue(unit, Reason::Long(cref));
237                }
238            }
239            // Two non-false literals → watch them (moved to positions 0 and 1).
240            (Some(a), Some(b)) => {
241                clause.swap(0, a);
242                clause.swap(1, b);
243                self.watch(cref, clause[0], clause[1]);
244                self.clauses.push(clause);
245            }
246        }
247    }
248
249    /// Assign `l` true at the current level with the given `reason`, and push it
250    /// onto the trail for propagation.
251    fn enqueue(&mut self, l: SatLit, reason: Reason) {
252        let v = l.var() as usize;
253        self.assign[v] = Some(!l.is_negative());
254        self.level[v] = self.current_level();
255        self.reason[v] = reason;
256        self.trail.push(l);
257    }
258
259    // -- propagation (two-watched-literal) --
260
261    /// Unit-propagate to a fixpoint. Returns the conflicting clause, if any.
262    fn propagate(&mut self) -> Option<usize> {
263        while self.qhead < self.trail.len() {
264            let p = self.trail[self.qhead];
265            self.qhead += 1;
266            if let Some(cref) = self.propagate_lit(p) {
267                return Some(cref);
268            }
269        }
270        None
271    }
272
273    /// Process the clauses watching `!p` after `p` became true.
274    fn propagate_lit(&mut self, p: SatLit) -> Option<usize> {
275        let fl = p.negate(); // the watched literal that just became false
276        let mut ws = core::mem::take(&mut self.watches[p.code()]);
277        let mut read = 0;
278        let mut write = 0;
279        let mut conflict = None;
280
281        while read < ws.len() {
282            let w = ws[read];
283            read += 1;
284
285            // A satisfied clause (true blocking literal) needs no inspection.
286            if self.lit_is_true(w.blocking) {
287                ws[write] = w;
288                write += 1;
289                continue;
290            }
291
292            let cref = w.cref;
293            if self.clauses[cref][0] == fl {
294                self.clauses[cref].swap(0, 1);
295            }
296            let other = self.clauses[cref][0];
297            let kept = Watch {
298                cref,
299                blocking: other,
300            };
301
302            if other != w.blocking && self.lit_is_true(other) {
303                ws[write] = kept;
304                write += 1;
305                continue;
306            }
307
308            // Try to move the watch to a non-false unwatched literal.
309            if let Some(repl) = self.find_replacement(cref, fl) {
310                self.watches[repl.negate().code()].push(kept);
311                continue; // watch left this list
312            }
313
314            // No replacement: keep watching `fl`; the clause is unit or conflicting.
315            ws[write] = kept;
316            write += 1;
317            if self.lit_is_false(other) {
318                while read < ws.len() {
319                    ws[write] = ws[read];
320                    write += 1;
321                    read += 1;
322                }
323                conflict = Some(cref);
324                break;
325            }
326            self.enqueue(other, Reason::Long(cref));
327        }
328
329        ws.truncate(write);
330        self.watches[p.code()] = ws;
331        conflict
332    }
333
334    /// Find a non-false literal in `clause[2..]`, swap it into the watched slot.
335    fn find_replacement(&mut self, cref: usize, fl: SatLit) -> Option<SatLit> {
336        let len = self.clauses[cref].len();
337        for k in 2..len {
338            let ck = self.clauses[cref][k];
339            if !self.lit_is_false(ck) {
340                self.clauses[cref][1] = ck;
341                self.clauses[cref][k] = fl;
342                return Some(ck);
343            }
344        }
345        None
346    }
347
348    // -- conflict analysis (1-UIP) --
349
350    /// VSIDS: raise variable `v`'s activity, rescaling all activities if it would
351    /// overflow `f64`'s comfortable range.
352    fn bump(&mut self, v: usize) {
353        self.activity[v] += self.var_inc;
354        if self.activity[v] > 1e100 {
355            for a in &mut self.activity {
356                *a *= 1e-100;
357            }
358            self.var_inc *= 1e-100;
359        }
360    }
361
362    /// Learn an asserting clause from `conflict` and return (clause, backjump level).
363    /// Uses the reusable `seen` buffer and restores it to all-false on exit.
364    fn analyze(&mut self, conflict: usize) -> (Vec<SatLit>, u32) {
365        let cur_level = self.current_level();
366        let mut learned: Vec<SatLit> = vec![SatLit(0)]; // slot 0 = asserting literal
367        let mut touched: Vec<Var> = Vec::new();
368        let mut counter = 0usize;
369        let mut idx = self.trail.len();
370        let mut p: Option<SatLit> = None;
371        let mut confl = conflict;
372
373        loop {
374            let start = if p.is_some() { 1 } else { 0 }; // a reason clause has p at index 0
375            for j in start..self.clauses[confl].len() {
376                let q = self.clauses[confl][j];
377                let v = q.var() as usize;
378                if !self.seen[v] && self.level[v] > 0 {
379                    self.seen[v] = true;
380                    touched.push(v as Var);
381                    self.bump(v);
382                    if self.level[v] == cur_level {
383                        counter += 1;
384                    } else {
385                        learned.push(q);
386                    }
387                }
388            }
389            // The most recently assigned `seen` literal on the trail.
390            loop {
391                idx -= 1;
392                if self.seen[self.trail[idx].var() as usize] {
393                    break;
394                }
395            }
396            let lit = self.trail[idx];
397            self.seen[lit.var() as usize] = false;
398            counter -= 1;
399            p = Some(lit);
400            if counter == 0 {
401                break;
402            }
403            confl = match self.reason[lit.var() as usize] {
404                Reason::Long(c) => c,
405                _ => unreachable!("a resolved current-level literal must have a clause reason"),
406            };
407        }
408        learned[0] = p.unwrap().negate();
409
410        let backjump = self.assertion_level(&mut learned);
411        self.var_inc *= 1.0 / 0.95; // VSIDS decay
412
413        for v in touched {
414            self.seen[v as usize] = false; // restore the scratch buffer
415        }
416        (learned, backjump)
417    }
418
419    /// Move the highest-level non-asserting literal to index 1 and return its
420    /// level (the level to backjump to), or 0 for a unit clause.
421    fn assertion_level(&self, learned: &mut [SatLit]) -> u32 {
422        if learned.len() == 1 {
423            return 0;
424        }
425        let mut max_i = 1;
426        let mut max_l = self.level[learned[1].var() as usize];
427        for (i, &lit) in learned.iter().enumerate().skip(2) {
428            let l = self.level[lit.var() as usize];
429            if l > max_l {
430                max_l = l;
431                max_i = i;
432            }
433        }
434        learned.swap(1, max_i);
435        max_l
436    }
437
438    /// Undo assignments above `level`, saving each unset variable's phase for
439    /// later reuse, and rewind the propagation queue to that level.
440    fn backtrack(&mut self, level: u32) {
441        if self.current_level() <= level {
442            return;
443        }
444        let new_len = self.decisions[level as usize];
445        for i in new_len..self.trail.len() {
446            let v = self.trail[i].var() as usize;
447            self.polarity[v] = self.assign[v] == Some(true);
448            self.assign[v] = None;
449        }
450        self.trail.truncate(new_len);
451        self.decisions.truncate(level as usize);
452        self.qhead = new_len;
453    }
454
455    /// Install a freshly learned clause and enqueue its asserting literal.
456    fn learn(&mut self, learned: Vec<SatLit>) {
457        if learned.len() == 1 {
458            self.enqueue(learned[0], Reason::Unit);
459        } else {
460            let cref = self.clauses.len();
461            self.watch(cref, learned[0], learned[1]);
462            let assert_lit = learned[0];
463            self.clauses.push(learned);
464            self.enqueue(assert_lit, Reason::Long(cref));
465        }
466    }
467
468    // -- decisions --
469
470    /// Choose the next decision: the unassigned variable with the highest VSIDS
471    /// activity, using its saved phase. `None` means all variables are assigned.
472    fn pick_branch(&self) -> Option<SatLit> {
473        let mut best: Option<usize> = None;
474        let mut best_act = -1.0;
475        for v in 0..self.num_vars {
476            if self.assign[v].is_none() && self.activity[v] > best_act {
477                best_act = self.activity[v];
478                best = Some(v);
479            }
480        }
481        best.map(|v| SatLit::new(v as Var, self.polarity[v]))
482    }
483
484    // -- the state machine --
485
486    /// Perform one CDCL transition.
487    fn step(&mut self) -> Step {
488        if let Some(cref) = self.propagate() {
489            if self.decisions.is_empty() {
490                return Step::Unsat;
491            }
492            let (learned, backjump) = self.analyze(cref);
493            self.backtrack(backjump);
494            self.learn(learned);
495            Step::LearntFromConflict
496        } else {
497            match self.pick_branch() {
498                None => Step::Sat,
499                Some(lit) => {
500                    self.decisions.push(self.trail.len());
501                    self.enqueue(lit, Reason::Decision);
502                    Step::Decided
503                }
504            }
505        }
506    }
507
508    /// Drive steps until SAT or UNSAT. Re-entrant: after [`Solver::block`] adds a
509    /// clause and resets to level 0, calling this again continues the search.
510    fn search(&mut self) -> bool {
511        if !self.ok {
512            return false;
513        }
514        loop {
515            match self.step() {
516                Step::Sat => return true,
517                Step::Unsat => {
518                    self.ok = false;
519                    return false;
520                }
521                _ => {}
522            }
523        }
524    }
525
526    /// Snapshot the assignment as `var -> bool` (any still-unassigned variable,
527    /// possible when it is unconstrained, defaults to false).
528    fn model(&self) -> Vec<bool> {
529        self.assign.iter().map(|a| a.unwrap_or(false)).collect()
530    }
531
532    /// Forbid the current `model`'s projection, then reset to level 0 so the next
533    /// [`Solver::search`] finds a different model. Returns `false` if the
534    /// projection is empty (there is only one model to report).
535    fn block(&mut self, project: &[Var], model: &[bool]) -> bool {
536        if project.is_empty() {
537            return false;
538        }
539        let block: Vec<SatLit> = project
540            .iter()
541            .map(|&v| {
542                if model[v as usize] {
543                    SatLit::negative(v)
544                } else {
545                    SatLit::positive(v)
546                }
547            })
548            .collect();
549        self.backtrack(0);
550        self.add_clause(&block);
551        true
552    }
553}
554
555// --- public API ------------------------------------------------------------
556
557/// Solve a CNF. Returns a full model (`var -> bool`) or `None` if unsatisfiable.
558pub fn solve(cnf: &Cnf) -> Option<Vec<bool>> {
559    let mut s = Solver::new(cnf);
560    if s.search() { Some(s.model()) } else { None }
561}
562
563/// A lazy iterator over the models of a CNF, distinct on the `project` variables.
564/// Solving is **incremental**: each step adds a blocking clause and continues
565/// from the existing solver state instead of restarting from scratch.
566pub struct Models {
567    solver: Solver,
568    project: Vec<Var>,
569    done: bool,
570}
571
572impl Iterator for Models {
573    type Item = Vec<bool>;
574
575    fn next(&mut self) -> Option<Vec<bool>> {
576        if self.done {
577            return None;
578        }
579        if !self.solver.search() {
580            self.done = true;
581            return None;
582        }
583        let model = self.solver.model();
584        if !self.solver.block(&self.project, &model) {
585            self.done = true;
586        }
587        Some(model)
588    }
589}
590
591/// Lazily enumerate all models of `cnf`, distinct over `project`.
592pub fn all_models(cnf: &Cnf, project: Vec<Var>) -> Models {
593    Models {
594        solver: Solver::new(cnf),
595        project,
596        done: false,
597    }
598}
599
600/// Up to `limit` models, distinct over `project` (eagerly collected).
601pub fn models(cnf: &Cnf, project: &[Var], limit: usize) -> Vec<Vec<bool>> {
602    all_models(cnf, project.to_vec()).take(limit).collect()
603}
604
605/// Count distinct models projected onto `project`, up to `limit`.
606pub fn models_upto(cnf: &Cnf, project: &[Var], limit: usize) -> usize {
607    all_models(cnf, project.to_vec()).take(limit).count()
608}
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613
614    #[test]
615    fn trivial_sat() {
616        let mut c = Cnf::new(2);
617        c.add_clause(vec![SatLit::positive(0), SatLit::positive(1)]);
618        assert!(solve(&c).is_some());
619    }
620
621    #[test]
622    fn unit_contradiction_unsat() {
623        let mut c = Cnf::new(1);
624        c.add_clause(vec![SatLit::positive(0)]);
625        c.add_clause(vec![SatLit::negative(0)]);
626        assert!(solve(&c).is_none());
627    }
628
629    #[test]
630    fn all_four_combos_excluded_is_unsat() {
631        let mut c = Cnf::new(2);
632        let (a, b) = (0u32, 1u32);
633        c.add_clause(vec![SatLit::positive(a), SatLit::positive(b)]);
634        c.add_clause(vec![SatLit::negative(a), SatLit::positive(b)]);
635        c.add_clause(vec![SatLit::positive(a), SatLit::negative(b)]);
636        c.add_clause(vec![SatLit::negative(a), SatLit::negative(b)]);
637        assert!(solve(&c).is_none());
638    }
639
640    #[test]
641    fn forced_chain_has_unique_model() {
642        let mut c = Cnf::new(2);
643        c.add_clause(vec![SatLit::negative(0), SatLit::positive(1)]);
644        c.add_clause(vec![SatLit::positive(0)]);
645        let m = solve(&c).unwrap();
646        assert!(m[0] && m[1]);
647        assert_eq!(models_upto(&c, &[0, 1], 5), 1);
648    }
649
650    #[test]
651    fn or_clause_has_three_models() {
652        let mut c = Cnf::new(2);
653        c.add_clause(vec![SatLit::positive(0), SatLit::positive(1)]);
654        assert_eq!(models_upto(&c, &[0, 1], 10), 3);
655    }
656
657    #[test]
658    fn lazy_models_iterator_is_incremental() {
659        // (a∨b) has 3 models; the iterator yields them lazily one at a time.
660        let mut c = Cnf::new(2);
661        c.add_clause(vec![SatLit::positive(0), SatLit::positive(1)]);
662        let first_two: Vec<_> = all_models(&c, vec![0, 1]).take(2).collect();
663        assert_eq!(first_two.len(), 2);
664        assert_ne!(first_two[0], first_two[1]);
665        assert_eq!(all_models(&c, vec![0, 1]).count(), 3);
666    }
667
668    #[test]
669    fn larger_random_like_sat_is_solved() {
670        let mut c = Cnf::new(5);
671        let l = |v: u32, p: bool| SatLit::new(v, p);
672        c.add_clause(vec![l(0, true), l(1, true), l(2, false)]);
673        c.add_clause(vec![l(0, false), l(2, true), l(3, true)]);
674        c.add_clause(vec![l(1, false), l(3, false), l(4, true)]);
675        c.add_clause(vec![l(2, false), l(4, false)]);
676        c.add_clause(vec![l(0, true), l(4, true)]);
677        let m = solve(&c).expect("sat");
678        for clause in &c.clauses {
679            assert!(
680                clause
681                    .iter()
682                    .any(|&lit| m[lit.var() as usize] != lit.is_negative())
683            );
684        }
685    }
686}