Skip to main content

rustsat_batsat/
lib.rs

1//! # rustsat-batsat - Interface to the BatSat SAT Solver for RustSAT
2//!
3//! Interface to the [BatSat](https://github.com/c-cube/batsat) incremental SAT-Solver to be used with the [RustSAT](https://github.com/chrjabs/rustsat) library.
4//!
5//! BatSat is fully implemented in Rust which has advantages in restricted compilation scenarios like WebAssembly.
6//!
7//! # BatSat Version
8//!
9//! The version of BatSat in this crate is Version 0.6.0.
10//!
11//! ## Minimum Supported Rust Version (MSRV)
12//!
13//! Currently, the MSRV is 1.76.0, the plan is to always support an MSRV that is at least a year
14//! old.
15//!
16//! Bumps in the MSRV will _not_ be considered breaking changes. If you need a specific MSRV, make
17//! sure to pin a precise version of RustSAT.
18
19// NOTE: For some reason, batsat flipped the memory representation of the sign bit in the literal
20// representation compared to Minisat and therefore RustSAT
21// https://github.com/c-cube/batsat/commit/8563ae6e3a59478a0d414fe647d99ad9b989841f
22// For this reason we cannot transmute RustSAT literals to batsat literals and we have to recreate
23// the literals through batsat's API
24
25#![warn(clippy::pedantic)]
26#![warn(missing_docs)]
27#![warn(missing_debug_implementations)]
28
29use std::time::Duration;
30
31use batsat::{intmap::AsIndex, lbool, Callbacks, SolverInterface};
32use rustsat::{
33    solvers::{
34        Solve, SolveIncremental, SolveStats, SolverResult, SolverState, SolverStats, StateError,
35    },
36    types::{Cl, Clause, Lit, TernaryVal, Var},
37    utils::Timer,
38};
39
40/// RustSAT wrapper for [`batsat::BasicSolver`]
41pub type BasicSolver = Solver<batsat::BasicCallbacks>;
42
43#[derive(Debug, PartialEq, Eq, Default)]
44enum InternalSolverState {
45    #[default]
46    Input,
47    Sat,
48    Unsat(bool),
49}
50
51impl InternalSolverState {
52    fn to_external(&self) -> SolverState {
53        match self {
54            InternalSolverState::Input => SolverState::Input,
55            InternalSolverState::Sat => SolverState::Sat,
56            InternalSolverState::Unsat(_) => SolverState::Unsat,
57        }
58    }
59}
60
61/// RustSAT wrapper for a [`batsat::Solver`] Solver from BatSat
62#[derive(Default)]
63pub struct Solver<Cb: Callbacks> {
64    internal: batsat::Solver<Cb>,
65    state: InternalSolverState,
66    n_sat: usize,
67    n_unsat: usize,
68    n_terminated: usize,
69    avg_clause_len: f32,
70    cpu_time: Duration,
71}
72
73impl<Cb: Callbacks> std::fmt::Debug for Solver<Cb> {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        f.debug_struct("Solver")
76            .field("internal", &"omitted")
77            .field("state", &self.state)
78            .field("n_sat", &self.n_sat)
79            .field("n_unsat", &self.n_unsat)
80            .field("n_terminated", &self.n_terminated)
81            .field("avg_clause_len", &self.avg_clause_len)
82            .field("cpu_time", &self.cpu_time)
83            .finish()
84    }
85}
86
87impl<Cb: Callbacks> Solver<Cb> {
88    /// Gets a reference to the internal [`BasicSolver`]
89    #[must_use]
90    pub fn batsat_ref(&self) -> &batsat::Solver<Cb> {
91        &self.internal
92    }
93
94    /// Gets a mutable reference to the internal [`BasicSolver`]
95    #[must_use]
96    pub fn batsat_mut(&mut self) -> &mut batsat::Solver<Cb> {
97        &mut self.internal
98    }
99
100    #[allow(clippy::cast_precision_loss)]
101    #[inline]
102    fn update_avg_clause_len(&mut self, clause: &Cl) {
103        self.avg_clause_len = (self.avg_clause_len * ((self.n_clauses()) as f32)
104            + clause.len() as f32)
105            / (self.n_clauses() + 1) as f32;
106    }
107
108    fn solve_track_stats(&mut self, assumps: &[Lit]) -> SolverResult {
109        let assumps = assumps
110            .iter()
111            .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32()), l.is_pos()))
112            .collect::<Vec<_>>();
113
114        let start = Timer::now();
115        let ret = match self.internal.solve_limited(&assumps) {
116            x if x == lbool::TRUE => {
117                self.n_sat += 1;
118                self.state = InternalSolverState::Sat;
119                SolverResult::Sat
120            }
121            x if x == lbool::FALSE => {
122                self.n_unsat += 1;
123                self.state = InternalSolverState::Unsat(!assumps.is_empty());
124                SolverResult::Unsat
125            }
126            x if x == lbool::UNDEF => {
127                self.n_terminated += 1;
128                self.state = InternalSolverState::Input;
129                SolverResult::Interrupted
130            }
131            _ => unreachable!(),
132        };
133        self.cpu_time += start.elapsed();
134        ret
135    }
136}
137
138impl<Cb: Callbacks> Extend<Clause> for Solver<Cb> {
139    fn extend<T: IntoIterator<Item = Clause>>(&mut self, iter: T) {
140        iter.into_iter()
141            .for_each(|cl| self.add_clause(cl).expect("Error adding clause in extend"));
142    }
143}
144
145impl<'a, C, Cb> Extend<&'a C> for Solver<Cb>
146where
147    C: AsRef<Cl> + ?Sized,
148    Cb: Callbacks,
149{
150    fn extend<T: IntoIterator<Item = &'a C>>(&mut self, iter: T) {
151        iter.into_iter().for_each(|cl| {
152            self.add_clause_ref(cl)
153                .expect("Error adding clause in extend");
154        });
155    }
156}
157
158impl<Cb: Callbacks> Solve for Solver<Cb> {
159    fn signature(&self) -> &'static str {
160        "BatSat 0.6.0"
161    }
162
163    fn solve(&mut self) -> anyhow::Result<SolverResult> {
164        // If already solved, return state
165        if let InternalSolverState::Sat = self.state {
166            return Ok(SolverResult::Sat);
167        }
168        if let InternalSolverState::Unsat(under_assumps) = &self.state {
169            if !under_assumps {
170                return Ok(SolverResult::Unsat);
171            }
172        }
173        Ok(self.solve_track_stats(&[]))
174    }
175
176    fn lit_val(&self, lit: Lit) -> anyhow::Result<TernaryVal> {
177        if self.state != InternalSolverState::Sat {
178            return Err(StateError {
179                required_state: SolverState::Sat,
180                actual_state: self.state.to_external(),
181            }
182            .into());
183        }
184
185        let lit = batsat::Lit::new(batsat::Var::from_index(lit.vidx()), lit.is_pos());
186
187        match self.internal.value_lit(lit) {
188            x if x == lbool::TRUE => Ok(TernaryVal::True),
189            x if x == lbool::FALSE => Ok(TernaryVal::False),
190            x if x == lbool::UNDEF => Ok(TernaryVal::DontCare),
191            _ => unreachable!(),
192        }
193    }
194
195    fn add_clause_ref<C>(&mut self, clause: &C) -> anyhow::Result<()>
196    where
197        C: AsRef<Cl> + ?Sized,
198    {
199        let clause = clause.as_ref();
200        self.update_avg_clause_len(clause);
201
202        let mut clause: Vec<_> = clause
203            .iter()
204            .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32()), l.is_pos()))
205            .collect();
206
207        self.internal.add_clause_reuse(&mut clause);
208        self.state = InternalSolverState::Input;
209
210        Ok(())
211    }
212
213    fn reserve(&mut self, max_var: Var) -> anyhow::Result<()> {
214        while self.internal.num_vars() <= max_var.idx32() {
215            self.internal.new_var_default();
216        }
217        Ok(())
218    }
219}
220
221impl<Cb: Callbacks> SolveIncremental for Solver<Cb> {
222    fn solve_assumps(&mut self, assumps: &[Lit]) -> anyhow::Result<SolverResult> {
223        Ok(self.solve_track_stats(assumps))
224    }
225
226    fn core(&mut self) -> anyhow::Result<Vec<Lit>> {
227        match &self.state {
228            InternalSolverState::Unsat(under_assumps) => {
229                if *under_assumps {
230                    Ok(self
231                        .internal
232                        .unsat_core()
233                        .iter()
234                        .map(|l| Lit::new(l.var().idx(), !l.sign()))
235                        .collect::<Vec<_>>())
236                } else {
237                    Ok(vec![])
238                }
239            }
240            other => Err(StateError {
241                required_state: SolverState::Unsat,
242                actual_state: other.to_external(),
243            }
244            .into()),
245        }
246    }
247}
248
249impl<Cb: Callbacks> SolveStats for Solver<Cb> {
250    fn stats(&self) -> SolverStats {
251        SolverStats {
252            n_sat: self.n_sat,
253            n_unsat: self.n_unsat,
254            n_terminated: self.n_terminated,
255            n_clauses: self.n_clauses(),
256            max_var: self.max_var(),
257            avg_clause_len: self.avg_clause_len,
258            cpu_solve_time: self.cpu_time,
259        }
260    }
261
262    fn n_sat_solves(&self) -> usize {
263        self.n_sat
264    }
265
266    fn n_unsat_solves(&self) -> usize {
267        self.n_unsat
268    }
269
270    fn n_terminated(&self) -> usize {
271        self.n_terminated
272    }
273
274    fn n_clauses(&self) -> usize {
275        usize::try_from(self.internal.num_clauses()).expect("more than `usize::MAX` clauses")
276    }
277
278    fn max_var(&self) -> Option<Var> {
279        let num = self.internal.num_vars();
280        if num > 0 {
281            Some(Var::new(num - 1))
282        } else {
283            None
284        }
285    }
286
287    fn avg_clause_len(&self) -> f32 {
288        self.avg_clause_len
289    }
290
291    fn cpu_solve_time(&self) -> Duration {
292        self.cpu_time
293    }
294}
295
296#[cfg(test)]
297mod test {
298    rustsat_solvertests::basic_unittests!(
299        super::BasicSolver,
300        "BatSat [major].[minor].[patch]",
301        false
302    );
303}