1#![warn(clippy::pedantic)]
20#![warn(missing_docs)]
21#![warn(missing_debug_implementations)]
22
23use std::time::Duration;
24
25use batsat::{intmap::AsIndex, lbool, Callbacks, SolverInterface};
26use cpu_time::ProcessTime;
27use rustsat::{
28 solvers::{Solve, SolveIncremental, SolveStats, SolverResult, SolverStats},
29 types::{Cl, Clause, Lit, TernaryVal, Var},
30};
31
32pub type BasicSolver = Solver<batsat::BasicCallbacks>;
34
35#[derive(Default)]
37pub struct Solver<Cb: Callbacks> {
38 internal: batsat::Solver<Cb>,
39 n_sat: usize,
40 n_unsat: usize,
41 n_terminated: usize,
42 avg_clause_len: f32,
43 cpu_time: Duration,
44}
45
46impl<Cb: Callbacks> std::fmt::Debug for Solver<Cb> {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("Solver")
49 .field("internal", &"omitted")
50 .field("n_sat", &self.n_sat)
51 .field("n_unsat", &self.n_unsat)
52 .field("n_terminated", &self.n_terminated)
53 .field("avg_clause_len", &self.avg_clause_len)
54 .field("cpu_time", &self.cpu_time)
55 .finish()
56 }
57}
58
59impl<Cb: Callbacks> Solver<Cb> {
60 #[must_use]
62 pub fn batsat_ref(&self) -> &batsat::Solver<Cb> {
63 &self.internal
64 }
65
66 #[must_use]
68 pub fn batsat_mut(&mut self) -> &mut batsat::Solver<Cb> {
69 &mut self.internal
70 }
71
72 #[allow(clippy::cast_precision_loss)]
73 #[inline]
74 fn update_avg_clause_len(&mut self, clause: &Cl) {
75 self.avg_clause_len = (self.avg_clause_len * ((self.n_clauses()) as f32)
76 + clause.len() as f32)
77 / (self.n_clauses() + 1) as f32;
78 }
79
80 fn solve_track_stats(&mut self, assumps: &[Lit]) -> SolverResult {
81 let a = assumps
82 .iter()
83 .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos()))
84 .collect::<Vec<_>>();
85
86 let start = ProcessTime::now();
87 let ret = match self.internal.solve_limited(&a) {
88 x if x == lbool::TRUE => {
89 self.n_sat += 1;
90 SolverResult::Sat
91 }
92 x if x == lbool::FALSE => {
93 self.n_unsat += 1;
94 SolverResult::Unsat
95 }
96 x if x == lbool::UNDEF => {
97 self.n_terminated += 1;
98 SolverResult::Interrupted
99 }
100 _ => unreachable!(),
101 };
102 self.cpu_time += start.elapsed();
103 ret
104 }
105}
106
107impl<Cb: Callbacks> Extend<Clause> for Solver<Cb> {
108 fn extend<T: IntoIterator<Item = Clause>>(&mut self, iter: T) {
109 iter.into_iter()
110 .for_each(|cl| self.add_clause(cl).expect("Error adding clause in extend"));
111 }
112}
113
114impl<'a, C, Cb> Extend<&'a C> for Solver<Cb>
115where
116 C: AsRef<Cl> + ?Sized,
117 Cb: Callbacks,
118{
119 fn extend<T: IntoIterator<Item = &'a C>>(&mut self, iter: T) {
120 iter.into_iter().for_each(|cl| {
121 self.add_clause_ref(cl)
122 .expect("Error adding clause in extend");
123 });
124 }
125}
126
127impl<Cb: Callbacks> Solve for Solver<Cb> {
128 fn signature(&self) -> &'static str {
129 "BatSat 0.6.0"
130 }
131
132 fn solve(&mut self) -> anyhow::Result<SolverResult> {
133 Ok(self.solve_track_stats(&[]))
134 }
135
136 fn lit_val(&self, lit: Lit) -> anyhow::Result<TernaryVal> {
137 let l = batsat::Lit::new(batsat::Var::from_index(lit.vidx() + 1), lit.is_pos());
138
139 match self.internal.value_lit(l) {
140 x if x == lbool::TRUE => Ok(TernaryVal::True),
141 x if x == lbool::FALSE => Ok(TernaryVal::False),
142 x if x == lbool::UNDEF => Ok(TernaryVal::DontCare),
143 _ => unreachable!(),
144 }
145 }
146
147 fn add_clause_ref<C>(&mut self, clause: &C) -> anyhow::Result<()>
148 where
149 C: AsRef<Cl> + ?Sized,
150 {
151 let clause = clause.as_ref();
152 self.update_avg_clause_len(clause);
153
154 let mut c: Vec<_> = clause
155 .iter()
156 .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos()))
157 .collect();
158
159 self.internal.add_clause_reuse(&mut c);
160
161 Ok(())
162 }
163}
164
165impl<Cb: Callbacks> SolveIncremental for Solver<Cb> {
166 fn solve_assumps(&mut self, assumps: &[Lit]) -> anyhow::Result<SolverResult> {
167 Ok(self.solve_track_stats(assumps))
168 }
169
170 fn core(&mut self) -> anyhow::Result<Vec<Lit>> {
171 Ok(self
172 .internal
173 .unsat_core()
174 .iter()
175 .map(|l| Lit::new(l.var().idx() - 1, !l.sign()))
176 .collect::<Vec<_>>())
177 }
178}
179
180impl<Cb: Callbacks> SolveStats for Solver<Cb> {
181 fn stats(&self) -> SolverStats {
182 SolverStats {
183 n_sat: self.n_sat,
184 n_unsat: self.n_unsat,
185 n_terminated: self.n_terminated,
186 n_clauses: self.n_clauses(),
187 max_var: self.max_var(),
188 avg_clause_len: self.avg_clause_len,
189 cpu_solve_time: self.cpu_time,
190 }
191 }
192
193 fn n_sat_solves(&self) -> usize {
194 self.n_sat
195 }
196
197 fn n_unsat_solves(&self) -> usize {
198 self.n_unsat
199 }
200
201 fn n_terminated(&self) -> usize {
202 self.n_terminated
203 }
204
205 fn n_clauses(&self) -> usize {
206 usize::try_from(self.internal.num_clauses()).expect("more than `usize::MAX` clauses")
207 }
208
209 fn max_var(&self) -> Option<Var> {
210 let num = self.internal.num_vars();
211 if num > 0 {
212 Some(Var::new(num - 2))
214 } else {
215 None
216 }
217 }
218
219 fn avg_clause_len(&self) -> f32 {
220 self.avg_clause_len
221 }
222
223 fn cpu_solve_time(&self) -> Duration {
224 self.cpu_time
225 }
226}
227
228#[cfg(test)]
229mod test {
230 rustsat_solvertests::basic_unittests!(super::BasicSolver, false);
231}