1#![warn(clippy::pedantic)]
26#![warn(missing_docs)]
27#![warn(missing_debug_implementations)]
28
29use std::time::Duration;
30
31use batsat::{intmap::AsIndex, lbool, Callbacks, SolverInterface};
32use rustsat::{
33 solvers::{
34 Solve, SolveIncremental, SolveStats, SolverResult, SolverState, SolverStats, StateError,
35 },
36 types::{Cl, Clause, Lit, TernaryVal, Var},
37 utils::Timer,
38};
39
40pub type BasicSolver = Solver<batsat::BasicCallbacks>;
42
43#[derive(Debug, PartialEq, Eq, Default)]
44enum InternalSolverState {
45 #[default]
46 Input,
47 Sat,
48 Unsat(bool),
49}
50
51impl InternalSolverState {
52 fn to_external(&self) -> SolverState {
53 match self {
54 InternalSolverState::Input => SolverState::Input,
55 InternalSolverState::Sat => SolverState::Sat,
56 InternalSolverState::Unsat(_) => SolverState::Unsat,
57 }
58 }
59}
60
61#[derive(Default)]
63pub struct Solver<Cb: Callbacks> {
64 internal: batsat::Solver<Cb>,
65 state: InternalSolverState,
66 n_sat: usize,
67 n_unsat: usize,
68 n_terminated: usize,
69 avg_clause_len: f32,
70 cpu_time: Duration,
71}
72
73impl<Cb: Callbacks> std::fmt::Debug for Solver<Cb> {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 f.debug_struct("Solver")
76 .field("internal", &"omitted")
77 .field("state", &self.state)
78 .field("n_sat", &self.n_sat)
79 .field("n_unsat", &self.n_unsat)
80 .field("n_terminated", &self.n_terminated)
81 .field("avg_clause_len", &self.avg_clause_len)
82 .field("cpu_time", &self.cpu_time)
83 .finish()
84 }
85}
86
87impl<Cb: Callbacks> Solver<Cb> {
88 #[must_use]
90 pub fn batsat_ref(&self) -> &batsat::Solver<Cb> {
91 &self.internal
92 }
93
94 #[must_use]
96 pub fn batsat_mut(&mut self) -> &mut batsat::Solver<Cb> {
97 &mut self.internal
98 }
99
100 #[allow(clippy::cast_precision_loss)]
101 #[inline]
102 fn update_avg_clause_len(&mut self, clause: &Cl) {
103 self.avg_clause_len = (self.avg_clause_len * ((self.n_clauses()) as f32)
104 + clause.len() as f32)
105 / (self.n_clauses() + 1) as f32;
106 }
107
108 fn solve_track_stats(&mut self, assumps: &[Lit]) -> SolverResult {
109 let assumps = assumps
110 .iter()
111 .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32()), l.is_pos()))
112 .collect::<Vec<_>>();
113
114 let start = Timer::now();
115 let ret = match self.internal.solve_limited(&assumps) {
116 x if x == lbool::TRUE => {
117 self.n_sat += 1;
118 self.state = InternalSolverState::Sat;
119 SolverResult::Sat
120 }
121 x if x == lbool::FALSE => {
122 self.n_unsat += 1;
123 self.state = InternalSolverState::Unsat(!assumps.is_empty());
124 SolverResult::Unsat
125 }
126 x if x == lbool::UNDEF => {
127 self.n_terminated += 1;
128 self.state = InternalSolverState::Input;
129 SolverResult::Interrupted
130 }
131 _ => unreachable!(),
132 };
133 self.cpu_time += start.elapsed();
134 ret
135 }
136}
137
138impl<Cb: Callbacks> Extend<Clause> for Solver<Cb> {
139 fn extend<T: IntoIterator<Item = Clause>>(&mut self, iter: T) {
140 iter.into_iter()
141 .for_each(|cl| self.add_clause(cl).expect("Error adding clause in extend"));
142 }
143}
144
145impl<'a, C, Cb> Extend<&'a C> for Solver<Cb>
146where
147 C: AsRef<Cl> + ?Sized,
148 Cb: Callbacks,
149{
150 fn extend<T: IntoIterator<Item = &'a C>>(&mut self, iter: T) {
151 iter.into_iter().for_each(|cl| {
152 self.add_clause_ref(cl)
153 .expect("Error adding clause in extend");
154 });
155 }
156}
157
158impl<Cb: Callbacks> Solve for Solver<Cb> {
159 fn signature(&self) -> &'static str {
160 "BatSat 0.6.0"
161 }
162
163 fn solve(&mut self) -> anyhow::Result<SolverResult> {
164 if let InternalSolverState::Sat = self.state {
166 return Ok(SolverResult::Sat);
167 }
168 if let InternalSolverState::Unsat(under_assumps) = &self.state {
169 if !under_assumps {
170 return Ok(SolverResult::Unsat);
171 }
172 }
173 Ok(self.solve_track_stats(&[]))
174 }
175
176 fn lit_val(&self, lit: Lit) -> anyhow::Result<TernaryVal> {
177 if self.state != InternalSolverState::Sat {
178 return Err(StateError {
179 required_state: SolverState::Sat,
180 actual_state: self.state.to_external(),
181 }
182 .into());
183 }
184
185 let lit = batsat::Lit::new(batsat::Var::from_index(lit.vidx()), lit.is_pos());
186
187 match self.internal.value_lit(lit) {
188 x if x == lbool::TRUE => Ok(TernaryVal::True),
189 x if x == lbool::FALSE => Ok(TernaryVal::False),
190 x if x == lbool::UNDEF => Ok(TernaryVal::DontCare),
191 _ => unreachable!(),
192 }
193 }
194
195 fn add_clause_ref<C>(&mut self, clause: &C) -> anyhow::Result<()>
196 where
197 C: AsRef<Cl> + ?Sized,
198 {
199 let clause = clause.as_ref();
200 self.update_avg_clause_len(clause);
201
202 let mut clause: Vec<_> = clause
203 .iter()
204 .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32()), l.is_pos()))
205 .collect();
206
207 self.internal.add_clause_reuse(&mut clause);
208 self.state = InternalSolverState::Input;
209
210 Ok(())
211 }
212
213 fn reserve(&mut self, max_var: Var) -> anyhow::Result<()> {
214 while self.internal.num_vars() <= max_var.idx32() {
215 self.internal.new_var_default();
216 }
217 Ok(())
218 }
219}
220
221impl<Cb: Callbacks> SolveIncremental for Solver<Cb> {
222 fn solve_assumps(&mut self, assumps: &[Lit]) -> anyhow::Result<SolverResult> {
223 Ok(self.solve_track_stats(assumps))
224 }
225
226 fn core(&mut self) -> anyhow::Result<Vec<Lit>> {
227 match &self.state {
228 InternalSolverState::Unsat(under_assumps) => {
229 if *under_assumps {
230 Ok(self
231 .internal
232 .unsat_core()
233 .iter()
234 .map(|l| Lit::new(l.var().idx(), !l.sign()))
235 .collect::<Vec<_>>())
236 } else {
237 Ok(vec![])
238 }
239 }
240 other => Err(StateError {
241 required_state: SolverState::Unsat,
242 actual_state: other.to_external(),
243 }
244 .into()),
245 }
246 }
247}
248
249impl<Cb: Callbacks> SolveStats for Solver<Cb> {
250 fn stats(&self) -> SolverStats {
251 SolverStats {
252 n_sat: self.n_sat,
253 n_unsat: self.n_unsat,
254 n_terminated: self.n_terminated,
255 n_clauses: self.n_clauses(),
256 max_var: self.max_var(),
257 avg_clause_len: self.avg_clause_len,
258 cpu_solve_time: self.cpu_time,
259 }
260 }
261
262 fn n_sat_solves(&self) -> usize {
263 self.n_sat
264 }
265
266 fn n_unsat_solves(&self) -> usize {
267 self.n_unsat
268 }
269
270 fn n_terminated(&self) -> usize {
271 self.n_terminated
272 }
273
274 fn n_clauses(&self) -> usize {
275 usize::try_from(self.internal.num_clauses()).expect("more than `usize::MAX` clauses")
276 }
277
278 fn max_var(&self) -> Option<Var> {
279 let num = self.internal.num_vars();
280 if num > 0 {
281 Some(Var::new(num - 1))
282 } else {
283 None
284 }
285 }
286
287 fn avg_clause_len(&self) -> f32 {
288 self.avg_clause_len
289 }
290
291 fn cpu_solve_time(&self) -> Duration {
292 self.cpu_time
293 }
294}
295
296#[cfg(test)]
297mod test {
298 rustsat_solvertests::basic_unittests!(
299 super::BasicSolver,
300 "BatSat [major].[minor].[patch]",
301 false
302 );
303}