1#![deny(missing_docs)]
2
3use rsat::cdcl::{Solver, SolverOptions};
6use solhop_types::dimacs::Dimacs;
7use solhop_types::{Lit, Solution, Var};
8
9pub fn solve(dimacs: Dimacs) -> (Solution, usize) {
11 match dimacs {
12 Dimacs::Cnf { n_vars, clauses } => {
13 let n_clauses = clauses.len();
14 let mut sat_solver = Solver::new(SolverOptions::default());
15 let _vars = (0..n_vars)
16 .map(|_| sat_solver.new_var())
17 .collect::<Vec<_>>();
18
19 let mut ref_vars = vec![];
20 let mut cost = clauses.len();
21
22 for clause in clauses {
23 let mut clause = clause;
24 let ref_var = sat_solver.new_var();
25 ref_vars.push(ref_var);
26 clause.push(ref_var.pos_lit());
27 sat_solver.add_clause(clause);
28 }
29
30 if n_clauses == 0 {
31 return (sat_solver.solve(vec![]), 0);
32 }
33
34 let totalizer_lits = gen_totalizer(&ref_vars, &mut sat_solver);
35
36 let mut last_best = None;
37
38 loop {
39 let sol = sat_solver.solve(vec![]);
40 match sol {
41 Solution::Unsat => break,
42 Solution::Best(_) => break,
43 Solution::Sat(model) => last_best = Some((model, cost)),
44 Solution::Unknown => break,
45 }
46 if cost == 0 {
47 break;
48 }
49 cost -= 1;
50 sat_solver.add_clause(vec![!totalizer_lits[cost]]);
51 }
52
53 match last_best {
54 Some((model, cost)) => (
55 Solution::Best(model[0..n_vars].iter().copied().collect()),
56 n_clauses - cost,
57 ),
58 None => (Solution::Unknown, 0),
59 }
60 }
61 _ => panic!("not implemented for wcnf yet!"),
62 }
63}
64
65fn gen_totalizer(vars: &[Var], solver: &mut Solver) -> Vec<Lit> {
66 debug_assert!(vars.len() >= 1);
67 let mut output: Vec<Vec<Lit>> = vars.into_iter().map(|&v| vec![v.pos_lit()]).collect();
68 loop {
69 output = totalizer_single_level(output, solver);
70 if output.len() == 1 {
71 break output[0].clone();
72 }
73 }
74}
75
76fn totalizer_single_level(input: Vec<Vec<Lit>>, solver: &mut Solver) -> Vec<Vec<Lit>> {
77 let mut output = vec![];
78 let mut input_iter = input.into_iter();
79 loop {
80 if let Some(first) = input_iter.next() {
81 if let Some(second) = input_iter.next() {
82 let a = first.len();
83 let b = second.len();
84 let parent_lits: Vec<_> = (0..a + b).map(|_| solver.new_var().pos_lit()).collect();
85 for i in 0..a {
86 solver.add_clause(vec![!first[i], parent_lits[i]]);
87 }
88 for j in 0..b {
89 solver.add_clause(vec![!second[j], parent_lits[j]]);
90 }
91 for i in 0..a {
92 for j in 0..b {
93 solver.add_clause(vec![!first[i], !second[j], parent_lits[i + j + 1]]);
94 }
95 }
96 for i in 1..parent_lits.len() {
97 solver.add_clause(vec![!parent_lits[i], parent_lits[i - 1]]);
98 }
99 output.push(parent_lits);
100 } else {
101 output.push(first);
102 break;
103 }
104 } else {
105 break;
106 }
107 }
108 output
109}