rustsat/solvers/
simulators.rs1use crate::{
6 instances::Cnf,
7 types::{Cl, Clause, Lit},
8 utils::Timer,
9};
10
11use super::Solve;
12
13#[derive(Debug, PartialEq, Eq, Default)]
14enum InternalSolverState {
15 #[default]
16 Init,
17 Sat,
18 Unsat(Vec<Lit>),
19 Unknown,
20}
21
22impl InternalSolverState {
23 fn to_external(&self) -> super::SolverState {
24 match self {
25 InternalSolverState::Init | InternalSolverState::Unknown => super::SolverState::Input,
26 InternalSolverState::Sat => super::SolverState::Sat,
27 InternalSolverState::Unsat(_) => super::SolverState::Unsat,
28 }
29 }
30}
31
32#[derive(Debug)]
43pub struct Incremental<S, Init = super::DefaultInitializer> {
44 solver: S,
45 init: std::marker::PhantomData<Init>,
46 state: InternalSolverState,
47 clauses: Cnf,
48 stats: super::SolverStats,
49}
50
51impl<S, Init> Default for Incremental<S, Init>
52where
53 Init: super::Initialize<S>,
54{
55 fn default() -> Self {
56 Self {
57 solver: Init::init(),
58 init: std::marker::PhantomData,
59 state: InternalSolverState::default(),
60 clauses: Cnf::default(),
61 stats: super::SolverStats::default(),
62 }
63 }
64}
65
66impl<S, Init> Incremental<S, Init> {
67 #[allow(clippy::cast_precision_loss)]
68 #[inline]
69 fn update_avg_clause_len(&mut self, clause: &Cl) {
70 self.stats.avg_clause_len =
71 (self.stats.avg_clause_len * ((self.stats.n_clauses - 1) as f32) + clause.len() as f32)
72 / self.stats.n_clauses as f32;
73 }
74
75 #[inline]
76 fn update_max_var(&mut self, clause: &Cl) {
77 if self.stats.max_var.is_none() {
78 self.stats.max_var = Some(crate::types::Var::new(0));
79 }
80 let max_var = self.stats.max_var.as_mut().unwrap();
81 for lit in clause {
82 *max_var = std::cmp::max(*max_var, lit.var());
83 }
84 }
85}
86
87impl<S, Init> super::Solve for Incremental<S, Init>
88where
89 S: super::Solve,
90 Init: super::Initialize<S>,
91{
92 fn signature(&self) -> &'static str {
93 self.solver.signature()
94 }
95
96 fn solve(&mut self) -> anyhow::Result<super::SolverResult> {
97 match &self.state {
98 InternalSolverState::Sat => return Ok(super::SolverResult::Sat),
99 InternalSolverState::Unsat(lits) if lits.is_empty() => {
100 return Ok(super::SolverResult::Unsat)
101 }
102 InternalSolverState::Unknown | InternalSolverState::Unsat(_) => {
103 self.solver = Init::init();
104 self.solver.add_cnf_ref(&self.clauses)?;
105 }
106 InternalSolverState::Init => (),
107 }
108 let start = Timer::now();
109 let res = self.solver.solve()?;
110 self.stats.cpu_solve_time += start.elapsed();
111 match res {
112 super::SolverResult::Sat => {
113 self.stats.n_sat += 1;
114 self.state = InternalSolverState::Sat;
115 }
116 super::SolverResult::Unsat => {
117 self.stats.n_unsat += 1;
118 self.state = InternalSolverState::Unsat(vec![]);
119 }
120 super::SolverResult::Interrupted => {
121 self.stats.n_terminated += 1;
122 self.state = InternalSolverState::Unknown;
123 }
124 }
125 Ok(res)
126 }
127
128 fn lit_val(&self, lit: Lit) -> anyhow::Result<crate::types::TernaryVal> {
129 self.solver.lit_val(lit)
130 }
131
132 fn var_val(&self, var: crate::types::Var) -> anyhow::Result<crate::types::TernaryVal> {
133 self.solver.var_val(var)
134 }
135
136 fn add_clause_ref<C>(&mut self, clause: &C) -> anyhow::Result<()>
137 where
138 C: AsRef<crate::types::Cl> + ?Sized,
139 {
140 self.stats.n_clauses += 1;
141 self.update_avg_clause_len(clause.as_ref());
142 self.update_max_var(clause.as_ref());
143 if matches!(self.state, InternalSolverState::Init) {
144 self.solver.add_clause_ref(clause)?;
145 } else {
146 self.state = InternalSolverState::Init;
147 self.solver = Init::init();
148 self.solver.add_cnf_ref(&self.clauses)?;
149 self.solver.add_clause_ref(&clause)?;
150 }
151 self.clauses
152 .add_clause(clause.as_ref().iter().copied().collect());
153 Ok(())
154 }
155
156 fn add_clause(&mut self, clause: Clause) -> anyhow::Result<()> {
157 self.stats.n_clauses += 1;
158 self.update_avg_clause_len(&clause);
159 self.update_max_var(&clause);
160 if matches!(self.state, InternalSolverState::Init) {
161 self.solver.add_clause_ref(&clause)?;
162 } else {
163 self.state = InternalSolverState::Init;
164 self.solver = Init::init();
165 self.solver.add_cnf_ref(&self.clauses)?;
166 self.solver.add_clause_ref(&clause)?;
167 }
168 self.clauses.add_clause(clause);
169 Ok(())
170 }
171
172 fn solution(&self, high_var: crate::types::Var) -> anyhow::Result<crate::types::Assignment> {
173 self.solver.solution(high_var)
174 }
175}
176
177impl<S, Init> super::SolveStats for Incremental<S, Init> {
178 fn stats(&self) -> super::SolverStats {
179 self.stats.clone()
180 }
181}
182
183impl<S, Init> super::SolveIncremental for Incremental<S, Init>
184where
185 S: super::Solve,
186 Init: super::Initialize<S>,
187{
188 fn solve_assumps(&mut self, assumps: &[Lit]) -> anyhow::Result<super::SolverResult> {
189 let start = Timer::now();
190 if !matches!(self.state, InternalSolverState::Init) {
191 self.solver = Init::init();
192 self.solver.add_cnf_ref(&self.clauses)?;
193 }
194 for lit in assumps {
195 self.solver.add_unit(*lit)?;
196 }
197 let res = self.solver.solve()?;
198 self.stats.cpu_solve_time += start.elapsed();
199 match res {
200 super::SolverResult::Sat => {
201 self.stats.n_sat += 1;
202 self.state = InternalSolverState::Sat;
203 }
204 super::SolverResult::Unsat => {
205 self.stats.n_unsat += 1;
206 self.state = InternalSolverState::Unsat(assumps.iter().map(|l| !*l).collect());
207 }
208 super::SolverResult::Interrupted => {
209 self.stats.n_terminated += 1;
210 self.state = InternalSolverState::Unknown;
211 }
212 }
213 Ok(res)
214 }
215
216 fn core(&mut self) -> anyhow::Result<Vec<Lit>> {
217 match &self.state {
218 InternalSolverState::Unsat(core) => Ok(core.clone()),
219 other => Err(super::StateError {
220 required_state: super::SolverState::Unsat,
221 actual_state: other.to_external(),
222 }
223 .into()),
224 }
225 }
226}
227
228impl<S, Init> Extend<Clause> for Incremental<S, Init>
229where
230 S: super::Solve,
231 Init: super::Initialize<S>,
232{
233 fn extend<T: IntoIterator<Item = Clause>>(&mut self, iter: T) {
234 iter.into_iter()
235 .for_each(|cl| self.add_clause(cl).expect("Error adding clause in extend"));
236 }
237}
238
239impl<'a, S, Init, C> Extend<&'a C> for Incremental<S, Init>
240where
241 S: super::Solve,
242 Init: super::Initialize<S>,
243 C: AsRef<Cl> + ?Sized,
244{
245 fn extend<T: IntoIterator<Item = &'a C>>(&mut self, iter: T) {
246 iter.into_iter().for_each(|cl| {
247 self.add_clause_ref(cl)
248 .expect("Error adding clause in extend");
249 });
250 }
251}