1extern crate alloc;
24
25use alloc::vec;
26use alloc::vec::Vec;
27
28pub type Var = u32;
32
33#[derive(Clone, Copy, PartialEq, Eq, Debug)]
35pub struct SatLit(u32);
36
37impl SatLit {
38 pub fn new(var: Var, positive: bool) -> Self {
40 SatLit((var << 1) | (!positive as u32))
41 }
42 pub fn positive(var: Var) -> Self {
44 Self::new(var, true)
45 }
46 pub fn negative(var: Var) -> Self {
48 Self::new(var, false)
49 }
50 pub fn var(self) -> Var {
52 self.0 >> 1
53 }
54 pub fn is_negative(self) -> bool {
56 self.0 & 1 == 1
57 }
58 pub fn negate(self) -> SatLit {
60 SatLit(self.0 ^ 1)
61 }
62 fn code(self) -> usize {
64 self.0 as usize
65 }
66}
67
68#[derive(Clone, Debug, Default)]
70pub struct Cnf {
71 pub num_vars: usize,
73 pub clauses: Vec<Vec<SatLit>>,
75}
76
77impl Cnf {
78 pub fn new(num_vars: usize) -> Self {
80 Cnf {
81 num_vars,
82 clauses: Vec::new(),
83 }
84 }
85 pub fn add_clause(&mut self, lits: Vec<SatLit>) {
87 self.clauses.push(lits);
88 }
89}
90
91#[derive(Clone, Copy)]
95enum Reason {
96 Decision,
97 Unit,
98 Long(usize),
99}
100
101#[derive(Clone, Copy)]
104struct Watch {
105 cref: usize,
106 blocking: SatLit,
107}
108
109enum Decision {
111 Propagated,
113 Sat,
115 UnsatCore(Vec<SatLit>),
118}
119
120struct Solver {
124 num_vars: usize,
125 clauses: Vec<Vec<SatLit>>, watches: Vec<Vec<Watch>>, assign: Vec<Option<bool>>, level: Vec<u32>, reason: Vec<Reason>, trail: Vec<SatLit>,
131 decisions: Vec<usize>, qhead: usize,
133 activity: Vec<f64>,
134 var_inc: f64,
135 polarity: Vec<bool>, seen: Vec<bool>, ok: bool, assumptions: Vec<SatLit>,
142}
143
144impl Solver {
145 fn new(cnf: &Cnf) -> Self {
147 let n = cnf.num_vars;
148 let mut s = Solver {
149 num_vars: n,
150 clauses: Vec::new(),
151 watches: vec![Vec::new(); 2 * n],
152 assign: vec![None; n],
153 level: vec![0; n],
154 reason: vec![Reason::Decision; n],
155 trail: Vec::new(),
156 decisions: Vec::new(),
157 qhead: 0,
158 activity: vec![0.0; n],
159 var_inc: 1.0,
160 polarity: vec![false; n],
161 seen: vec![false; n],
162 ok: true,
163 assumptions: Vec::new(),
164 };
165 for clause in &cnf.clauses {
166 s.add_clause(clause);
167 }
168 s
169 }
170
171 fn lit_is_true(&self, l: SatLit) -> bool {
175 self.assign[l.var() as usize] == Some(!l.is_negative())
176 }
177 fn lit_is_false(&self, l: SatLit) -> bool {
179 self.assign[l.var() as usize] == Some(l.is_negative())
180 }
181 fn current_level(&self) -> u32 {
183 self.decisions.len() as u32
184 }
185
186 fn watch(&mut self, cref: usize, a: SatLit, b: SatLit) {
192 self.watches[a.negate().code()].push(Watch { cref, blocking: b });
193 self.watches[b.negate().code()].push(Watch { cref, blocking: a });
194 }
195
196 fn add_clause(&mut self, lits: &[SatLit]) {
202 if !self.ok {
203 return;
204 }
205 if lits.is_empty() {
206 self.ok = false;
207 return;
208 }
209 if lits.len() == 1 {
210 let l = lits[0];
211 if self.lit_is_false(l) {
212 self.ok = false;
213 } else if !self.lit_is_true(l) {
214 self.enqueue(l, Reason::Unit);
215 }
216 return;
217 }
218
219 let mut clause = lits.to_vec();
221 let mut first = None;
222 let mut second = None;
223 for (i, &l) in clause.iter().enumerate() {
224 if !self.lit_is_false(l) {
225 if first.is_none() {
226 first = Some(i);
227 } else {
228 second = Some(i);
229 break;
230 }
231 }
232 }
233 let cref = self.clauses.len();
234 match (first, second) {
235 (None, _) => self.ok = false,
237 (Some(a), None) => {
239 clause.swap(0, a);
240 self.watch(cref, clause[0], clause[1]);
241 let unit = clause[0];
242 self.clauses.push(clause);
243 if !self.lit_is_true(unit) {
244 self.enqueue(unit, Reason::Long(cref));
245 }
246 }
247 (Some(a), Some(b)) => {
249 clause.swap(0, a);
250 clause.swap(1, b);
251 self.watch(cref, clause[0], clause[1]);
252 self.clauses.push(clause);
253 }
254 }
255 }
256
257 fn enqueue(&mut self, l: SatLit, reason: Reason) {
260 let v = l.var() as usize;
261 self.assign[v] = Some(!l.is_negative());
262 self.level[v] = self.current_level();
263 self.reason[v] = reason;
264 self.trail.push(l);
265 }
266
267 fn propagate(&mut self) -> Option<usize> {
271 while self.qhead < self.trail.len() {
272 let p = self.trail[self.qhead];
273 self.qhead += 1;
274 if let Some(cref) = self.propagate_lit(p) {
275 return Some(cref);
276 }
277 }
278 None
279 }
280
281 fn propagate_lit(&mut self, p: SatLit) -> Option<usize> {
283 let fl = p.negate(); let mut ws = core::mem::take(&mut self.watches[p.code()]);
285 let mut read = 0;
286 let mut write = 0;
287 let mut conflict = None;
288
289 while read < ws.len() {
290 let w = ws[read];
291 read += 1;
292
293 if self.lit_is_true(w.blocking) {
295 ws[write] = w;
296 write += 1;
297 continue;
298 }
299
300 let cref = w.cref;
301 if self.clauses[cref][0] == fl {
302 self.clauses[cref].swap(0, 1);
303 }
304 let other = self.clauses[cref][0];
305 let kept = Watch {
306 cref,
307 blocking: other,
308 };
309
310 if other != w.blocking && self.lit_is_true(other) {
311 ws[write] = kept;
312 write += 1;
313 continue;
314 }
315
316 if let Some(repl) = self.find_replacement(cref, fl) {
318 self.watches[repl.negate().code()].push(kept);
319 continue; }
321
322 ws[write] = kept;
324 write += 1;
325 if self.lit_is_false(other) {
326 while read < ws.len() {
327 ws[write] = ws[read];
328 write += 1;
329 read += 1;
330 }
331 conflict = Some(cref);
332 break;
333 }
334 self.enqueue(other, Reason::Long(cref));
335 }
336
337 ws.truncate(write);
338 self.watches[p.code()] = ws;
339 conflict
340 }
341
342 fn find_replacement(&mut self, cref: usize, fl: SatLit) -> Option<SatLit> {
344 let len = self.clauses[cref].len();
345 for k in 2..len {
346 let ck = self.clauses[cref][k];
347 if !self.lit_is_false(ck) {
348 self.clauses[cref][1] = ck;
349 self.clauses[cref][k] = fl;
350 return Some(ck);
351 }
352 }
353 None
354 }
355
356 fn bump(&mut self, v: usize) {
361 self.activity[v] += self.var_inc;
362 if self.activity[v] > 1e100 {
363 for a in &mut self.activity {
364 *a *= 1e-100;
365 }
366 self.var_inc *= 1e-100;
367 }
368 }
369
370 fn analyze(&mut self, conflict: usize) -> (Vec<SatLit>, u32) {
373 let cur_level = self.current_level();
374 let mut learned: Vec<SatLit> = vec![SatLit(0)]; let mut touched: Vec<Var> = Vec::new();
376 let mut counter = 0usize;
377 let mut idx = self.trail.len();
378 let mut p: Option<SatLit> = None;
379 let mut confl = conflict;
380
381 loop {
382 let start = if p.is_some() { 1 } else { 0 }; for j in start..self.clauses[confl].len() {
384 let q = self.clauses[confl][j];
385 let v = q.var() as usize;
386 if !self.seen[v] && self.level[v] > 0 {
387 self.seen[v] = true;
388 touched.push(v as Var);
389 self.bump(v);
390 if self.level[v] == cur_level {
391 counter += 1;
392 } else {
393 learned.push(q);
394 }
395 }
396 }
397 loop {
399 idx -= 1;
400 if self.seen[self.trail[idx].var() as usize] {
401 break;
402 }
403 }
404 let lit = self.trail[idx];
405 self.seen[lit.var() as usize] = false;
406 counter -= 1;
407 p = Some(lit);
408 if counter == 0 {
409 break;
410 }
411 confl = match self.reason[lit.var() as usize] {
412 Reason::Long(c) => c,
413 _ => unreachable!("a resolved current-level literal must have a clause reason"),
414 };
415 }
416 learned[0] = p.unwrap().negate();
417
418 let backjump = self.assertion_level(&mut learned);
419 self.var_inc *= 1.0 / 0.95; for v in touched {
422 self.seen[v as usize] = false; }
424 (learned, backjump)
425 }
426
427 fn assertion_level(&self, learned: &mut [SatLit]) -> u32 {
430 if learned.len() == 1 {
431 return 0;
432 }
433 let mut max_i = 1;
434 let mut max_l = self.level[learned[1].var() as usize];
435 for (i, &lit) in learned.iter().enumerate().skip(2) {
436 let l = self.level[lit.var() as usize];
437 if l > max_l {
438 max_l = l;
439 max_i = i;
440 }
441 }
442 learned.swap(1, max_i);
443 max_l
444 }
445
446 fn analyze_final(&mut self, true_lit: SatLit) -> Vec<SatLit> {
452 let mut core = vec![true_lit.negate()]; if self.current_level() == 0 {
454 return core;
456 }
457 let assn = self.assumptions.len() as u32;
458 let start = self.decisions[0]; self.seen[true_lit.var() as usize] = true;
460 let mut touched = vec![true_lit.var()];
461 let mut i = self.trail.len();
462 while i > start {
463 i -= 1;
464 let x = self.trail[i].var() as usize;
465 if !self.seen[x] {
466 continue;
467 }
468 self.seen[x] = false;
469 match self.reason[x] {
470 Reason::Decision => {
472 if self.level[x] > 0 && self.level[x] <= assn {
473 core.push(self.trail[i]);
474 }
475 }
476 Reason::Unit => {}
477 Reason::Long(cr) => {
479 for j in 1..self.clauses[cr].len() {
480 let v = self.clauses[cr][j].var();
481 if self.level[v as usize] > 0 && !self.seen[v as usize] {
482 self.seen[v as usize] = true;
483 touched.push(v);
484 }
485 }
486 }
487 }
488 }
489 for v in touched {
490 self.seen[v as usize] = false;
491 }
492 core
493 }
494
495 fn backtrack(&mut self, level: u32) {
498 if self.current_level() <= level {
499 return;
500 }
501 let new_len = self.decisions[level as usize];
502 for i in new_len..self.trail.len() {
503 let v = self.trail[i].var() as usize;
504 self.polarity[v] = self.assign[v] == Some(true);
505 self.assign[v] = None;
506 }
507 self.trail.truncate(new_len);
508 self.decisions.truncate(level as usize);
509 self.qhead = new_len;
510 }
511
512 fn learn(&mut self, learned: Vec<SatLit>) {
514 if learned.len() == 1 {
515 self.enqueue(learned[0], Reason::Unit);
516 } else {
517 let cref = self.clauses.len();
518 self.watch(cref, learned[0], learned[1]);
519 let assert_lit = learned[0];
520 self.clauses.push(learned);
521 self.enqueue(assert_lit, Reason::Long(cref));
522 }
523 }
524
525 fn pick_branch(&self) -> Option<SatLit> {
530 let mut best: Option<usize> = None;
531 let mut best_act = -1.0;
532 for v in 0..self.num_vars {
533 if self.assign[v].is_none() && self.activity[v] > best_act {
534 best_act = self.activity[v];
535 best = Some(v);
536 }
537 }
538 best.map(|v| SatLit::new(v as Var, self.polarity[v]))
539 }
540
541 fn decide(&mut self) -> Decision {
548 while (self.current_level() as usize) < self.assumptions.len() {
549 let p = self.assumptions[self.current_level() as usize];
550 if self.lit_is_true(p) {
551 self.decisions.push(self.trail.len()); } else if self.lit_is_false(p) {
553 return Decision::UnsatCore(self.analyze_final(p.negate()));
554 } else {
555 self.decisions.push(self.trail.len());
556 self.enqueue(p, Reason::Decision);
557 return Decision::Propagated;
558 }
559 }
560 match self.pick_branch() {
561 None => Decision::Sat,
562 Some(lit) => {
563 self.decisions.push(self.trail.len());
564 self.enqueue(lit, Reason::Decision);
565 Decision::Propagated
566 }
567 }
568 }
569
570 fn run(&mut self) -> Result<(), Vec<SatLit>> {
575 if !self.ok {
576 return Err(Vec::new());
577 }
578 loop {
579 if let Some(cref) = self.propagate() {
580 if self.current_level() == 0 {
581 self.ok = false;
582 return Err(Vec::new());
583 }
584 let (learned, backjump) = self.analyze(cref);
585 self.backtrack(backjump);
586 self.learn(learned);
587 } else {
588 match self.decide() {
589 Decision::Propagated => {}
590 Decision::Sat => return Ok(()),
591 Decision::UnsatCore(core) => return Err(core),
592 }
593 }
594 }
595 }
596
597 fn search(&mut self) -> bool {
600 self.run().is_ok()
601 }
602
603 fn model(&self) -> Vec<bool> {
606 self.assign.iter().map(|a| a.unwrap_or(false)).collect()
607 }
608
609 fn block(&mut self, project: &[Var], model: &[bool]) -> bool {
613 if project.is_empty() {
614 return false;
615 }
616 let block: Vec<SatLit> = project
617 .iter()
618 .map(|&v| {
619 if model[v as usize] {
620 SatLit::negative(v)
621 } else {
622 SatLit::positive(v)
623 }
624 })
625 .collect();
626 self.backtrack(0);
627 self.add_clause(&block);
628 true
629 }
630}
631
632#[derive(Clone, Debug, PartialEq, Eq)]
636pub enum Solved {
637 Sat(Vec<bool>),
639 Unsat(Vec<SatLit>),
643}
644
645pub fn solve_assuming(cnf: &Cnf, assumptions: &[SatLit]) -> Solved {
649 let mut s = Solver::new(cnf);
650 s.assumptions = assumptions.to_vec();
651 match s.run() {
652 Ok(()) => Solved::Sat(s.model()),
653 Err(core) => Solved::Unsat(core),
654 }
655}
656
657pub fn solve(cnf: &Cnf) -> Option<Vec<bool>> {
659 match solve_assuming(cnf, &[]) {
660 Solved::Sat(model) => Some(model),
661 Solved::Unsat(_) => None,
662 }
663}
664
665pub struct Models {
669 solver: Solver,
670 project: Vec<Var>,
671 done: bool,
672}
673
674impl Iterator for Models {
675 type Item = Vec<bool>;
676
677 fn next(&mut self) -> Option<Vec<bool>> {
678 if self.done {
679 return None;
680 }
681 if !self.solver.search() {
682 self.done = true;
683 return None;
684 }
685 let model = self.solver.model();
686 if !self.solver.block(&self.project, &model) {
687 self.done = true;
688 }
689 Some(model)
690 }
691}
692
693pub fn all_models(cnf: &Cnf, project: Vec<Var>) -> Models {
695 Models {
696 solver: Solver::new(cnf),
697 project,
698 done: false,
699 }
700}
701
702pub fn models(cnf: &Cnf, project: &[Var], limit: usize) -> Vec<Vec<bool>> {
704 all_models(cnf, project.to_vec()).take(limit).collect()
705}
706
707pub fn models_upto(cnf: &Cnf, project: &[Var], limit: usize) -> usize {
709 all_models(cnf, project.to_vec()).take(limit).count()
710}
711
712#[cfg(test)]
713mod tests {
714 use super::*;
715
716 #[test]
717 fn trivial_sat() {
718 let mut c = Cnf::new(2);
719 c.add_clause(vec![SatLit::positive(0), SatLit::positive(1)]);
720 assert!(solve(&c).is_some());
721 }
722
723 #[test]
724 fn unit_contradiction_unsat() {
725 let mut c = Cnf::new(1);
726 c.add_clause(vec![SatLit::positive(0)]);
727 c.add_clause(vec![SatLit::negative(0)]);
728 assert!(solve(&c).is_none());
729 }
730
731 #[test]
732 fn all_four_combos_excluded_is_unsat() {
733 let mut c = Cnf::new(2);
734 let (a, b) = (0u32, 1u32);
735 c.add_clause(vec![SatLit::positive(a), SatLit::positive(b)]);
736 c.add_clause(vec![SatLit::negative(a), SatLit::positive(b)]);
737 c.add_clause(vec![SatLit::positive(a), SatLit::negative(b)]);
738 c.add_clause(vec![SatLit::negative(a), SatLit::negative(b)]);
739 assert!(solve(&c).is_none());
740 }
741
742 #[test]
743 fn forced_chain_has_unique_model() {
744 let mut c = Cnf::new(2);
745 c.add_clause(vec![SatLit::negative(0), SatLit::positive(1)]);
746 c.add_clause(vec![SatLit::positive(0)]);
747 let m = solve(&c).unwrap();
748 assert!(m[0] && m[1]);
749 assert_eq!(models_upto(&c, &[0, 1], 5), 1);
750 }
751
752 #[test]
753 fn or_clause_has_three_models() {
754 let mut c = Cnf::new(2);
755 c.add_clause(vec![SatLit::positive(0), SatLit::positive(1)]);
756 assert_eq!(models_upto(&c, &[0, 1], 10), 3);
757 }
758
759 #[test]
760 fn lazy_models_iterator_is_incremental() {
761 let mut c = Cnf::new(2);
763 c.add_clause(vec![SatLit::positive(0), SatLit::positive(1)]);
764 let first_two: Vec<_> = all_models(&c, vec![0, 1]).take(2).collect();
765 assert_eq!(first_two.len(), 2);
766 assert_ne!(first_two[0], first_two[1]);
767 assert_eq!(all_models(&c, vec![0, 1]).count(), 3);
768 }
769
770 #[test]
771 fn assumption_forces_a_model() {
772 let mut c = Cnf::new(2);
774 c.add_clause(vec![SatLit::positive(0), SatLit::positive(1)]);
775 match solve_assuming(&c, &[SatLit::negative(0)]) {
776 Solved::Sat(m) => {
777 assert!(!m[0] && m[1]);
778 }
779 Solved::Unsat(_) => panic!("should be SAT under ¬a"),
780 }
781 }
782
783 #[test]
784 fn contradicted_assumptions_yield_a_sufficient_core() {
785 let mut c = Cnf::new(2);
787 c.add_clause(vec![SatLit::negative(0), SatLit::negative(1)]);
788 let assumptions = [SatLit::positive(0), SatLit::positive(1)];
789 match solve_assuming(&c, &assumptions) {
790 Solved::Unsat(core) => {
791 assert!(!core.is_empty());
792 assert!(core.iter().all(|l| assumptions.contains(l)));
793 let mut cc = c.clone();
795 for l in &core {
796 cc.add_clause(vec![*l]);
797 }
798 assert!(solve(&cc).is_none());
799 }
800 Solved::Sat(_) => panic!("a ∧ b violates (¬a ∨ ¬b)"),
801 }
802 }
803
804 #[test]
805 fn satisfiable_assumptions_round_trip() {
806 let mut c = Cnf::new(3);
808 c.add_clause(vec![
809 SatLit::positive(0),
810 SatLit::positive(1),
811 SatLit::positive(2),
812 ]);
813 let assumptions = [SatLit::positive(0), SatLit::negative(2)];
814 match solve_assuming(&c, &assumptions) {
815 Solved::Sat(m) => {
816 assert!(m[0] && !m[2]);
817 }
818 Solved::Unsat(_) => panic!("should be SAT"),
819 }
820 }
821
822 #[test]
823 fn larger_random_like_sat_is_solved() {
824 let mut c = Cnf::new(5);
825 let l = |v: u32, p: bool| SatLit::new(v, p);
826 c.add_clause(vec![l(0, true), l(1, true), l(2, false)]);
827 c.add_clause(vec![l(0, false), l(2, true), l(3, true)]);
828 c.add_clause(vec![l(1, false), l(3, false), l(4, true)]);
829 c.add_clause(vec![l(2, false), l(4, false)]);
830 c.add_clause(vec![l(0, true), l(4, true)]);
831 let m = solve(&c).expect("sat");
832 for clause in &c.clauses {
833 assert!(
834 clause
835 .iter()
836 .any(|&lit| m[lit.var() as usize] != lit.is_negative())
837 );
838 }
839 }
840}