1extern crate alloc;
19
20use alloc::vec;
21use alloc::vec::Vec;
22
23pub type Var = u32;
27
28#[derive(Clone, Copy, PartialEq, Eq, Debug)]
30pub struct SatLit(u32);
31
32impl SatLit {
33 pub fn new(var: Var, positive: bool) -> Self {
35 SatLit((var << 1) | (!positive as u32))
36 }
37 pub fn positive(var: Var) -> Self {
39 Self::new(var, true)
40 }
41 pub fn negative(var: Var) -> Self {
43 Self::new(var, false)
44 }
45 pub fn var(self) -> Var {
47 self.0 >> 1
48 }
49 pub fn is_negative(self) -> bool {
51 self.0 & 1 == 1
52 }
53 pub fn negate(self) -> SatLit {
55 SatLit(self.0 ^ 1)
56 }
57 fn code(self) -> usize {
59 self.0 as usize
60 }
61}
62
63#[derive(Clone, Debug, Default)]
65pub struct Cnf {
66 pub num_vars: usize,
68 pub clauses: Vec<Vec<SatLit>>,
70}
71
72impl Cnf {
73 pub fn new(num_vars: usize) -> Self {
75 Cnf {
76 num_vars,
77 clauses: Vec::new(),
78 }
79 }
80 pub fn add_clause(&mut self, lits: Vec<SatLit>) {
82 self.clauses.push(lits);
83 }
84}
85
86#[derive(Clone, Copy)]
90enum Reason {
91 Decision,
92 Unit,
93 Long(usize),
94}
95
96#[derive(Clone, Copy)]
99struct Watch {
100 cref: usize,
101 blocking: SatLit,
102}
103
104#[derive(Clone, Copy, PartialEq, Eq, Debug)]
106enum Step {
107 LearntFromConflict,
109 Decided,
111 Sat,
113 Unsat,
115}
116
117struct Solver {
121 num_vars: usize,
122 clauses: Vec<Vec<SatLit>>, watches: Vec<Vec<Watch>>, assign: Vec<Option<bool>>, level: Vec<u32>, reason: Vec<Reason>, trail: Vec<SatLit>,
128 decisions: Vec<usize>, qhead: usize,
130 activity: Vec<f64>,
131 var_inc: f64,
132 polarity: Vec<bool>, seen: Vec<bool>, ok: bool, }
136
137impl Solver {
138 fn new(cnf: &Cnf) -> Self {
140 let n = cnf.num_vars;
141 let mut s = Solver {
142 num_vars: n,
143 clauses: Vec::new(),
144 watches: vec![Vec::new(); 2 * n],
145 assign: vec![None; n],
146 level: vec![0; n],
147 reason: vec![Reason::Decision; n],
148 trail: Vec::new(),
149 decisions: Vec::new(),
150 qhead: 0,
151 activity: vec![0.0; n],
152 var_inc: 1.0,
153 polarity: vec![false; n],
154 seen: vec![false; n],
155 ok: true,
156 };
157 for clause in &cnf.clauses {
158 s.add_clause(clause);
159 }
160 s
161 }
162
163 fn lit_is_true(&self, l: SatLit) -> bool {
167 self.assign[l.var() as usize] == Some(!l.is_negative())
168 }
169 fn lit_is_false(&self, l: SatLit) -> bool {
171 self.assign[l.var() as usize] == Some(l.is_negative())
172 }
173 fn current_level(&self) -> u32 {
175 self.decisions.len() as u32
176 }
177
178 fn watch(&mut self, cref: usize, a: SatLit, b: SatLit) {
184 self.watches[a.negate().code()].push(Watch { cref, blocking: b });
185 self.watches[b.negate().code()].push(Watch { cref, blocking: a });
186 }
187
188 fn add_clause(&mut self, lits: &[SatLit]) {
194 if !self.ok {
195 return;
196 }
197 if lits.is_empty() {
198 self.ok = false;
199 return;
200 }
201 if lits.len() == 1 {
202 let l = lits[0];
203 if self.lit_is_false(l) {
204 self.ok = false;
205 } else if !self.lit_is_true(l) {
206 self.enqueue(l, Reason::Unit);
207 }
208 return;
209 }
210
211 let mut clause = lits.to_vec();
213 let mut first = None;
214 let mut second = None;
215 for (i, &l) in clause.iter().enumerate() {
216 if !self.lit_is_false(l) {
217 if first.is_none() {
218 first = Some(i);
219 } else {
220 second = Some(i);
221 break;
222 }
223 }
224 }
225 let cref = self.clauses.len();
226 match (first, second) {
227 (None, _) => self.ok = false,
229 (Some(a), None) => {
231 clause.swap(0, a);
232 self.watch(cref, clause[0], clause[1]);
233 let unit = clause[0];
234 self.clauses.push(clause);
235 if !self.lit_is_true(unit) {
236 self.enqueue(unit, Reason::Long(cref));
237 }
238 }
239 (Some(a), Some(b)) => {
241 clause.swap(0, a);
242 clause.swap(1, b);
243 self.watch(cref, clause[0], clause[1]);
244 self.clauses.push(clause);
245 }
246 }
247 }
248
249 fn enqueue(&mut self, l: SatLit, reason: Reason) {
252 let v = l.var() as usize;
253 self.assign[v] = Some(!l.is_negative());
254 self.level[v] = self.current_level();
255 self.reason[v] = reason;
256 self.trail.push(l);
257 }
258
259 fn propagate(&mut self) -> Option<usize> {
263 while self.qhead < self.trail.len() {
264 let p = self.trail[self.qhead];
265 self.qhead += 1;
266 if let Some(cref) = self.propagate_lit(p) {
267 return Some(cref);
268 }
269 }
270 None
271 }
272
273 fn propagate_lit(&mut self, p: SatLit) -> Option<usize> {
275 let fl = p.negate(); let mut ws = core::mem::take(&mut self.watches[p.code()]);
277 let mut read = 0;
278 let mut write = 0;
279 let mut conflict = None;
280
281 while read < ws.len() {
282 let w = ws[read];
283 read += 1;
284
285 if self.lit_is_true(w.blocking) {
287 ws[write] = w;
288 write += 1;
289 continue;
290 }
291
292 let cref = w.cref;
293 if self.clauses[cref][0] == fl {
294 self.clauses[cref].swap(0, 1);
295 }
296 let other = self.clauses[cref][0];
297 let kept = Watch {
298 cref,
299 blocking: other,
300 };
301
302 if other != w.blocking && self.lit_is_true(other) {
303 ws[write] = kept;
304 write += 1;
305 continue;
306 }
307
308 if let Some(repl) = self.find_replacement(cref, fl) {
310 self.watches[repl.negate().code()].push(kept);
311 continue; }
313
314 ws[write] = kept;
316 write += 1;
317 if self.lit_is_false(other) {
318 while read < ws.len() {
319 ws[write] = ws[read];
320 write += 1;
321 read += 1;
322 }
323 conflict = Some(cref);
324 break;
325 }
326 self.enqueue(other, Reason::Long(cref));
327 }
328
329 ws.truncate(write);
330 self.watches[p.code()] = ws;
331 conflict
332 }
333
334 fn find_replacement(&mut self, cref: usize, fl: SatLit) -> Option<SatLit> {
336 let len = self.clauses[cref].len();
337 for k in 2..len {
338 let ck = self.clauses[cref][k];
339 if !self.lit_is_false(ck) {
340 self.clauses[cref][1] = ck;
341 self.clauses[cref][k] = fl;
342 return Some(ck);
343 }
344 }
345 None
346 }
347
348 fn bump(&mut self, v: usize) {
353 self.activity[v] += self.var_inc;
354 if self.activity[v] > 1e100 {
355 for a in &mut self.activity {
356 *a *= 1e-100;
357 }
358 self.var_inc *= 1e-100;
359 }
360 }
361
362 fn analyze(&mut self, conflict: usize) -> (Vec<SatLit>, u32) {
365 let cur_level = self.current_level();
366 let mut learned: Vec<SatLit> = vec![SatLit(0)]; let mut touched: Vec<Var> = Vec::new();
368 let mut counter = 0usize;
369 let mut idx = self.trail.len();
370 let mut p: Option<SatLit> = None;
371 let mut confl = conflict;
372
373 loop {
374 let start = if p.is_some() { 1 } else { 0 }; for j in start..self.clauses[confl].len() {
376 let q = self.clauses[confl][j];
377 let v = q.var() as usize;
378 if !self.seen[v] && self.level[v] > 0 {
379 self.seen[v] = true;
380 touched.push(v as Var);
381 self.bump(v);
382 if self.level[v] == cur_level {
383 counter += 1;
384 } else {
385 learned.push(q);
386 }
387 }
388 }
389 loop {
391 idx -= 1;
392 if self.seen[self.trail[idx].var() as usize] {
393 break;
394 }
395 }
396 let lit = self.trail[idx];
397 self.seen[lit.var() as usize] = false;
398 counter -= 1;
399 p = Some(lit);
400 if counter == 0 {
401 break;
402 }
403 confl = match self.reason[lit.var() as usize] {
404 Reason::Long(c) => c,
405 _ => unreachable!("a resolved current-level literal must have a clause reason"),
406 };
407 }
408 learned[0] = p.unwrap().negate();
409
410 let backjump = self.assertion_level(&mut learned);
411 self.var_inc *= 1.0 / 0.95; for v in touched {
414 self.seen[v as usize] = false; }
416 (learned, backjump)
417 }
418
419 fn assertion_level(&self, learned: &mut [SatLit]) -> u32 {
422 if learned.len() == 1 {
423 return 0;
424 }
425 let mut max_i = 1;
426 let mut max_l = self.level[learned[1].var() as usize];
427 for (i, &lit) in learned.iter().enumerate().skip(2) {
428 let l = self.level[lit.var() as usize];
429 if l > max_l {
430 max_l = l;
431 max_i = i;
432 }
433 }
434 learned.swap(1, max_i);
435 max_l
436 }
437
438 fn backtrack(&mut self, level: u32) {
441 if self.current_level() <= level {
442 return;
443 }
444 let new_len = self.decisions[level as usize];
445 for i in new_len..self.trail.len() {
446 let v = self.trail[i].var() as usize;
447 self.polarity[v] = self.assign[v] == Some(true);
448 self.assign[v] = None;
449 }
450 self.trail.truncate(new_len);
451 self.decisions.truncate(level as usize);
452 self.qhead = new_len;
453 }
454
455 fn learn(&mut self, learned: Vec<SatLit>) {
457 if learned.len() == 1 {
458 self.enqueue(learned[0], Reason::Unit);
459 } else {
460 let cref = self.clauses.len();
461 self.watch(cref, learned[0], learned[1]);
462 let assert_lit = learned[0];
463 self.clauses.push(learned);
464 self.enqueue(assert_lit, Reason::Long(cref));
465 }
466 }
467
468 fn pick_branch(&self) -> Option<SatLit> {
473 let mut best: Option<usize> = None;
474 let mut best_act = -1.0;
475 for v in 0..self.num_vars {
476 if self.assign[v].is_none() && self.activity[v] > best_act {
477 best_act = self.activity[v];
478 best = Some(v);
479 }
480 }
481 best.map(|v| SatLit::new(v as Var, self.polarity[v]))
482 }
483
484 fn step(&mut self) -> Step {
488 if let Some(cref) = self.propagate() {
489 if self.decisions.is_empty() {
490 return Step::Unsat;
491 }
492 let (learned, backjump) = self.analyze(cref);
493 self.backtrack(backjump);
494 self.learn(learned);
495 Step::LearntFromConflict
496 } else {
497 match self.pick_branch() {
498 None => Step::Sat,
499 Some(lit) => {
500 self.decisions.push(self.trail.len());
501 self.enqueue(lit, Reason::Decision);
502 Step::Decided
503 }
504 }
505 }
506 }
507
508 fn search(&mut self) -> bool {
511 if !self.ok {
512 return false;
513 }
514 loop {
515 match self.step() {
516 Step::Sat => return true,
517 Step::Unsat => {
518 self.ok = false;
519 return false;
520 }
521 _ => {}
522 }
523 }
524 }
525
526 fn model(&self) -> Vec<bool> {
529 self.assign.iter().map(|a| a.unwrap_or(false)).collect()
530 }
531
532 fn block(&mut self, project: &[Var], model: &[bool]) -> bool {
536 if project.is_empty() {
537 return false;
538 }
539 let block: Vec<SatLit> = project
540 .iter()
541 .map(|&v| {
542 if model[v as usize] {
543 SatLit::negative(v)
544 } else {
545 SatLit::positive(v)
546 }
547 })
548 .collect();
549 self.backtrack(0);
550 self.add_clause(&block);
551 true
552 }
553}
554
555pub fn solve(cnf: &Cnf) -> Option<Vec<bool>> {
559 let mut s = Solver::new(cnf);
560 if s.search() { Some(s.model()) } else { None }
561}
562
563pub struct Models {
567 solver: Solver,
568 project: Vec<Var>,
569 done: bool,
570}
571
572impl Iterator for Models {
573 type Item = Vec<bool>;
574
575 fn next(&mut self) -> Option<Vec<bool>> {
576 if self.done {
577 return None;
578 }
579 if !self.solver.search() {
580 self.done = true;
581 return None;
582 }
583 let model = self.solver.model();
584 if !self.solver.block(&self.project, &model) {
585 self.done = true;
586 }
587 Some(model)
588 }
589}
590
591pub fn all_models(cnf: &Cnf, project: Vec<Var>) -> Models {
593 Models {
594 solver: Solver::new(cnf),
595 project,
596 done: false,
597 }
598}
599
600pub fn models(cnf: &Cnf, project: &[Var], limit: usize) -> Vec<Vec<bool>> {
602 all_models(cnf, project.to_vec()).take(limit).collect()
603}
604
605pub fn models_upto(cnf: &Cnf, project: &[Var], limit: usize) -> usize {
607 all_models(cnf, project.to_vec()).take(limit).count()
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613
614 #[test]
615 fn trivial_sat() {
616 let mut c = Cnf::new(2);
617 c.add_clause(vec![SatLit::positive(0), SatLit::positive(1)]);
618 assert!(solve(&c).is_some());
619 }
620
621 #[test]
622 fn unit_contradiction_unsat() {
623 let mut c = Cnf::new(1);
624 c.add_clause(vec![SatLit::positive(0)]);
625 c.add_clause(vec![SatLit::negative(0)]);
626 assert!(solve(&c).is_none());
627 }
628
629 #[test]
630 fn all_four_combos_excluded_is_unsat() {
631 let mut c = Cnf::new(2);
632 let (a, b) = (0u32, 1u32);
633 c.add_clause(vec![SatLit::positive(a), SatLit::positive(b)]);
634 c.add_clause(vec![SatLit::negative(a), SatLit::positive(b)]);
635 c.add_clause(vec![SatLit::positive(a), SatLit::negative(b)]);
636 c.add_clause(vec![SatLit::negative(a), SatLit::negative(b)]);
637 assert!(solve(&c).is_none());
638 }
639
640 #[test]
641 fn forced_chain_has_unique_model() {
642 let mut c = Cnf::new(2);
643 c.add_clause(vec![SatLit::negative(0), SatLit::positive(1)]);
644 c.add_clause(vec![SatLit::positive(0)]);
645 let m = solve(&c).unwrap();
646 assert!(m[0] && m[1]);
647 assert_eq!(models_upto(&c, &[0, 1], 5), 1);
648 }
649
650 #[test]
651 fn or_clause_has_three_models() {
652 let mut c = Cnf::new(2);
653 c.add_clause(vec![SatLit::positive(0), SatLit::positive(1)]);
654 assert_eq!(models_upto(&c, &[0, 1], 10), 3);
655 }
656
657 #[test]
658 fn lazy_models_iterator_is_incremental() {
659 let mut c = Cnf::new(2);
661 c.add_clause(vec![SatLit::positive(0), SatLit::positive(1)]);
662 let first_two: Vec<_> = all_models(&c, vec![0, 1]).take(2).collect();
663 assert_eq!(first_two.len(), 2);
664 assert_ne!(first_two[0], first_two[1]);
665 assert_eq!(all_models(&c, vec![0, 1]).count(), 3);
666 }
667
668 #[test]
669 fn larger_random_like_sat_is_solved() {
670 let mut c = Cnf::new(5);
671 let l = |v: u32, p: bool| SatLit::new(v, p);
672 c.add_clause(vec![l(0, true), l(1, true), l(2, false)]);
673 c.add_clause(vec![l(0, false), l(2, true), l(3, true)]);
674 c.add_clause(vec![l(1, false), l(3, false), l(4, true)]);
675 c.add_clause(vec![l(2, false), l(4, false)]);
676 c.add_clause(vec![l(0, true), l(4, true)]);
677 let m = solve(&c).expect("sat");
678 for clause in &c.clauses {
679 assert!(
680 clause
681 .iter()
682 .any(|&lit| m[lit.var() as usize] != lit.is_negative())
683 );
684 }
685 }
686}