1#![warn(clippy::pedantic)]
12#![warn(missing_docs)]
13#![warn(missing_debug_implementations)]
14
15use std::time::Duration;
16
17use batsat::{intmap::AsIndex, lbool, Callbacks, SolverInterface};
18use cpu_time::ProcessTime;
19use rustsat::{
20 solvers::{Solve, SolveIncremental, SolveStats, SolverResult, SolverStats},
21 types::{Cl, Clause, Lit, TernaryVal, Var},
22};
23
24pub type BasicSolver = Solver<batsat::BasicCallbacks>;
26
27#[derive(Default)]
29pub struct Solver<Cb: Callbacks> {
30 internal: batsat::Solver<Cb>,
31 n_sat: usize,
32 n_unsat: usize,
33 n_terminated: usize,
34 avg_clause_len: f32,
35 cpu_time: Duration,
36}
37
38impl<Cb: Callbacks> std::fmt::Debug for Solver<Cb> {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("Solver")
41 .field("internal", &"omitted")
42 .field("n_sat", &self.n_sat)
43 .field("n_unsat", &self.n_unsat)
44 .field("n_terminated", &self.n_terminated)
45 .field("avg_clause_len", &self.avg_clause_len)
46 .field("cpu_time", &self.cpu_time)
47 .finish()
48 }
49}
50
51impl<Cb: Callbacks> Solver<Cb> {
52 #[must_use]
54 pub fn batsat_ref(&self) -> &batsat::Solver<Cb> {
55 &self.internal
56 }
57
58 #[must_use]
60 pub fn batsat_mut(&mut self) -> &mut batsat::Solver<Cb> {
61 &mut self.internal
62 }
63
64 #[allow(clippy::cast_precision_loss)]
65 #[inline]
66 fn update_avg_clause_len(&mut self, clause: &Cl) {
67 self.avg_clause_len = (self.avg_clause_len * ((self.n_clauses()) as f32)
68 + clause.len() as f32)
69 / (self.n_clauses() + 1) as f32;
70 }
71
72 fn solve_track_stats(&mut self, assumps: &[Lit]) -> SolverResult {
73 let a = assumps
74 .iter()
75 .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos()))
76 .collect::<Vec<_>>();
77
78 let start = ProcessTime::now();
79 let ret = match self.internal.solve_limited(&a) {
80 x if x == lbool::TRUE => {
81 self.n_sat += 1;
82 SolverResult::Sat
83 }
84 x if x == lbool::FALSE => {
85 self.n_unsat += 1;
86 SolverResult::Unsat
87 }
88 x if x == lbool::UNDEF => {
89 self.n_terminated += 1;
90 SolverResult::Interrupted
91 }
92 _ => unreachable!(),
93 };
94 self.cpu_time += start.elapsed();
95 ret
96 }
97}
98
99impl<Cb: Callbacks> Extend<Clause> for Solver<Cb> {
100 fn extend<T: IntoIterator<Item = Clause>>(&mut self, iter: T) {
101 iter.into_iter()
102 .for_each(|cl| self.add_clause(cl).expect("Error adding clause in extend"));
103 }
104}
105
106impl<'a, C, Cb> Extend<&'a C> for Solver<Cb>
107where
108 C: AsRef<Cl> + ?Sized,
109 Cb: Callbacks,
110{
111 fn extend<T: IntoIterator<Item = &'a C>>(&mut self, iter: T) {
112 iter.into_iter().for_each(|cl| {
113 self.add_clause_ref(cl)
114 .expect("Error adding clause in extend");
115 });
116 }
117}
118
119impl<Cb: Callbacks> Solve for Solver<Cb> {
120 fn signature(&self) -> &'static str {
121 "BatSat 0.6.0"
122 }
123
124 fn solve(&mut self) -> anyhow::Result<SolverResult> {
125 Ok(self.solve_track_stats(&[]))
126 }
127
128 fn lit_val(&self, lit: Lit) -> anyhow::Result<TernaryVal> {
129 let l = batsat::Lit::new(batsat::Var::from_index(lit.vidx() + 1), lit.is_pos());
130
131 match self.internal.value_lit(l) {
132 x if x == lbool::TRUE => Ok(TernaryVal::True),
133 x if x == lbool::FALSE => Ok(TernaryVal::False),
134 x if x == lbool::UNDEF => Ok(TernaryVal::DontCare),
135 _ => unreachable!(),
136 }
137 }
138
139 fn add_clause_ref<C>(&mut self, clause: &C) -> anyhow::Result<()>
140 where
141 C: AsRef<Cl> + ?Sized,
142 {
143 let clause = clause.as_ref();
144 self.update_avg_clause_len(clause);
145
146 let mut c: Vec<_> = clause
147 .iter()
148 .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos()))
149 .collect();
150
151 self.internal.add_clause_reuse(&mut c);
152
153 Ok(())
154 }
155}
156
157impl<Cb: Callbacks> SolveIncremental for Solver<Cb> {
158 fn solve_assumps(&mut self, assumps: &[Lit]) -> anyhow::Result<SolverResult> {
159 Ok(self.solve_track_stats(assumps))
160 }
161
162 fn core(&mut self) -> anyhow::Result<Vec<Lit>> {
163 Ok(self
164 .internal
165 .unsat_core()
166 .iter()
167 .map(|l| Lit::new(l.var().idx() - 1, !l.sign()))
168 .collect::<Vec<_>>())
169 }
170}
171
172impl<Cb: Callbacks> SolveStats for Solver<Cb> {
173 fn stats(&self) -> SolverStats {
174 SolverStats {
175 n_sat: self.n_sat,
176 n_unsat: self.n_unsat,
177 n_terminated: self.n_terminated,
178 n_clauses: self.n_clauses(),
179 max_var: self.max_var(),
180 avg_clause_len: self.avg_clause_len,
181 cpu_solve_time: self.cpu_time,
182 }
183 }
184
185 fn n_sat_solves(&self) -> usize {
186 self.n_sat
187 }
188
189 fn n_unsat_solves(&self) -> usize {
190 self.n_unsat
191 }
192
193 fn n_terminated(&self) -> usize {
194 self.n_terminated
195 }
196
197 fn n_clauses(&self) -> usize {
198 usize::try_from(self.internal.num_clauses()).expect("more than `usize::MAX` clauses")
199 }
200
201 fn max_var(&self) -> Option<Var> {
202 let num = self.internal.num_vars();
203 if num > 0 {
204 Some(Var::new(num - 2))
206 } else {
207 None
208 }
209 }
210
211 fn avg_clause_len(&self) -> f32 {
212 self.avg_clause_len
213 }
214
215 fn cpu_solve_time(&self) -> Duration {
216 self.cpu_time
217 }
218}
219
220#[cfg(test)]
221mod test {
222 rustsat_solvertests::basic_unittests!(super::BasicSolver, false);
223}