1use crate::clause::{ClauseDatabase, ClauseId};
12use crate::literal::{Lit, Var};
13#[allow(unused_imports)]
14use crate::prelude::*;
15use smallvec::SmallVec;
16
17#[derive(Debug, Clone)]
19pub struct LocalSearchConfig {
20 pub max_flips: u64,
22 pub random_walk_prob: f64,
24 pub cb_exponent: f64,
26 pub random_seed: u64,
28}
29
30impl Default for LocalSearchConfig {
31 fn default() -> Self {
32 Self {
33 max_flips: 1_000_000,
34 random_walk_prob: 0.4,
35 cb_exponent: 2.3,
36 random_seed: 1234567,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum LocalSearchResult {
44 Sat,
46 Unknown,
48}
49
50#[derive(Debug, Default, Clone)]
52pub struct LocalSearchStats {
53 pub flips: u64,
55 pub min_unsat: usize,
57 pub improvements: u64,
59}
60
61pub struct LocalSearch {
66 assignment: Vec<bool>,
68 break_count: Vec<u64>,
70 make_count: Vec<u64>,
72 unsat_clauses: Vec<ClauseId>,
74 unsat_set: HashMap<ClauseId, ()>,
76 true_count: HashMap<ClauseId, usize>,
78 config: LocalSearchConfig,
80 stats: LocalSearchStats,
82 rng_state: u64,
84}
85
86impl LocalSearch {
87 #[must_use]
89 pub fn new(num_vars: usize, config: LocalSearchConfig) -> Self {
90 Self {
91 assignment: vec![false; num_vars],
92 break_count: vec![0; num_vars],
93 make_count: vec![0; num_vars],
94 unsat_clauses: Vec::new(),
95 unsat_set: HashMap::new(),
96 true_count: HashMap::new(),
97 rng_state: config.random_seed,
98 config,
99 stats: LocalSearchStats::default(),
100 }
101 }
102
103 fn rand(&mut self) -> u64 {
105 const A: u64 = 1664525;
108 const C: u64 = 1013904223;
109 self.rng_state = self.rng_state.wrapping_mul(A).wrapping_add(C);
110 self.rng_state
111 }
112
113 fn rand_float(&mut self) -> f64 {
115 (self.rand() as f64) / (u64::MAX as f64)
116 }
117
118 fn initialize_random(&mut self, num_vars: usize) {
120 self.assignment.clear();
121 self.assignment.resize(num_vars, false);
122
123 for i in 0..num_vars {
124 self.assignment[i] = self.rand().is_multiple_of(2);
125 }
126 }
127
128 fn initialize(&mut self, clauses: &ClauseDatabase, num_vars: usize) {
130 self.initialize_random(num_vars);
131 self.break_count.clear();
132 self.break_count.resize(num_vars, 0);
133 self.make_count.clear();
134 self.make_count.resize(num_vars, 0);
135 self.unsat_clauses.clear();
136 self.unsat_set.clear();
137 self.true_count.clear();
138
139 for id in clauses.iter_ids() {
141 let clause = clauses
142 .get(id)
143 .expect("id from clauses.iter_ids() is valid");
144 let true_lits = clause.lits.iter().filter(|&&lit| self.is_true(lit)).count();
145
146 self.true_count.insert(id, true_lits);
147
148 if true_lits == 0 {
149 self.unsat_clauses.push(id);
150 self.unsat_set.insert(id, ());
151 }
152 }
153
154 for id in clauses.iter_ids() {
156 let clause = clauses
157 .get(id)
158 .expect("id from clauses.iter_ids() is valid");
159 let true_lits = self.true_count[&id];
160
161 for &lit in &clause.lits {
162 let var = lit.var();
163 let var_idx = var.index();
164
165 if self.is_true(lit) {
166 if true_lits == 1 {
168 self.break_count[var_idx] += 1;
169 }
170 } else {
171 if true_lits == 0 {
173 self.make_count[var_idx] += 1;
174 }
175 }
176 }
177 }
178
179 self.stats.min_unsat = self.unsat_clauses.len();
180 }
181
182 fn is_true(&self, lit: Lit) -> bool {
184 let var_value = self.assignment[lit.var().index()];
185 if lit.is_pos() { var_value } else { !var_value }
186 }
187
188 fn flip(&mut self, var: Var, clauses: &ClauseDatabase) {
190 let var_idx = var.index();
191 self.assignment[var_idx] = !self.assignment[var_idx];
192 self.stats.flips += 1;
193
194 let pos_lit = Lit::pos(var);
196 let neg_lit = Lit::neg(var);
197
198 for id in clauses.iter_ids() {
201 let clause = clauses
202 .get(id)
203 .expect("id from clauses.iter_ids() is valid");
204 if !clause.lits.contains(&pos_lit) && !clause.lits.contains(&neg_lit) {
205 continue;
206 }
207
208 let old_true_count = self.true_count[&id];
209 let was_unsat = old_true_count == 0;
210
211 let new_true_count = clause.lits.iter().filter(|&&lit| self.is_true(lit)).count();
213
214 self.true_count.insert(id, new_true_count);
215
216 let is_unsat = new_true_count == 0;
217
218 if !was_unsat && is_unsat {
220 self.unsat_clauses.push(id);
221 self.unsat_set.insert(id, ());
222 } else if was_unsat && !is_unsat {
223 self.unsat_set.remove(&id);
224 }
225
226 for &lit in &clause.lits {
228 let lit_var = lit.var();
229 let lit_var_idx = lit_var.index();
230
231 if old_true_count == 1 && self.is_true(lit) {
233 self.break_count[lit_var_idx] -= 1;
235 }
236 if new_true_count == 1 && self.is_true(lit) {
237 self.break_count[lit_var_idx] += 1;
239 }
240
241 if old_true_count == 0 && !self.is_true(lit) {
243 self.make_count[lit_var_idx] -= 1;
245 }
246 if new_true_count == 0 && !self.is_true(lit) {
247 self.make_count[lit_var_idx] += 1;
249 }
250 }
251 }
252
253 self.unsat_clauses
255 .retain(|&id| self.unsat_set.contains_key(&id));
256
257 if self.unsat_clauses.len() < self.stats.min_unsat {
259 self.stats.min_unsat = self.unsat_clauses.len();
260 self.stats.improvements += 1;
261 }
262 }
263
264 pub fn solve_walksat(
268 &mut self,
269 clauses: &ClauseDatabase,
270 num_vars: usize,
271 ) -> (LocalSearchResult, Option<Vec<bool>>) {
272 self.initialize(clauses, num_vars);
273
274 for _ in 0..self.config.max_flips {
275 if self.unsat_clauses.is_empty() {
276 return (LocalSearchResult::Sat, Some(self.assignment.clone()));
277 }
278
279 let clause_id = {
281 let idx = (self.rand() as usize) % self.unsat_clauses.len();
282 self.unsat_clauses[idx]
283 };
284 let clause = clauses.get(clause_id).expect("clause_id is valid");
285
286 let use_random_walk = self.rand_float() < self.config.random_walk_prob;
288
289 let var_to_flip = if use_random_walk {
291 let idx = (self.rand() as usize) % clause.lits.len();
293 clause.lits[idx].var()
294 } else {
295 let mut best_var = clause.lits[0].var();
297 let mut min_break = self.break_count[best_var.index()];
298
299 for &lit in &clause.lits[1..] {
300 let var = lit.var();
301 let break_cnt = self.break_count[var.index()];
302 if break_cnt < min_break {
303 min_break = break_cnt;
304 best_var = var;
305 }
306 }
307
308 best_var
309 };
310
311 self.flip(var_to_flip, clauses);
312 }
313
314 (LocalSearchResult::Unknown, None)
315 }
316
317 pub fn solve_probsat(
321 &mut self,
322 clauses: &ClauseDatabase,
323 num_vars: usize,
324 ) -> (LocalSearchResult, Option<Vec<bool>>) {
325 self.initialize(clauses, num_vars);
326
327 for _ in 0..self.config.max_flips {
328 if self.unsat_clauses.is_empty() {
329 return (LocalSearchResult::Sat, Some(self.assignment.clone()));
330 }
331
332 let clause_id = {
334 let idx = (self.rand() as usize) % self.unsat_clauses.len();
335 self.unsat_clauses[idx]
336 };
337 let clause = clauses.get(clause_id).expect("clause_id is valid");
338
339 let mut probs: SmallVec<[f64; 8]> = SmallVec::new();
341 let mut total = 0.0;
342
343 for &lit in &clause.lits {
344 let var = lit.var();
345 let break_cnt = self.break_count[var.index()];
346 let prob = 1.0 / libm::pow(break_cnt as f64 + 1.0, self.config.cb_exponent);
348 probs.push(prob);
349 total += prob;
350 }
351
352 for prob in &mut probs {
354 *prob /= total;
355 }
356
357 let r = self.rand_float();
359 let mut cumulative = 0.0;
360 let mut selected_var = clause.lits[0].var();
361
362 for (i, &lit) in clause.lits.iter().enumerate() {
363 cumulative += probs[i];
364 if r <= cumulative {
365 selected_var = lit.var();
366 break;
367 }
368 }
369
370 self.flip(selected_var, clauses);
371 }
372
373 (LocalSearchResult::Unknown, None)
374 }
375
376 #[must_use]
378 pub fn stats(&self) -> &LocalSearchStats {
379 &self.stats
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use crate::clause::Clause;
387
388 #[test]
389 fn test_local_search_creation() {
390 let config = LocalSearchConfig::default();
391 let ls = LocalSearch::new(10, config);
392 assert_eq!(ls.assignment.len(), 10);
393 }
394
395 #[test]
396 fn test_local_search_simple_sat() {
397 let mut db = ClauseDatabase::new();
399 let c1 = Clause::new(vec![Lit::pos(Var::new(0)), Lit::pos(Var::new(1))], false);
400 let c2 = Clause::new(vec![Lit::neg(Var::new(0)), Lit::pos(Var::new(2))], false);
401
402 let id1 = db.add(c1);
403 let id2 = db.add(c2);
404
405 let config = LocalSearchConfig {
406 max_flips: 1000,
407 ..Default::default()
408 };
409
410 let mut ls = LocalSearch::new(3, config);
411 let (result, assignment) = ls.solve_walksat(&db, 3);
412
413 assert_eq!(result, LocalSearchResult::Sat);
415 assert!(assignment.is_some());
416
417 let assignment = assignment.expect("SAT result must have assignment");
419 let clause1 = db.get(id1).expect("Clause must exist in database");
420 let clause2 = db.get(id2).expect("Clause must exist in database");
421
422 let sat1 = clause1.lits.iter().any(|&lit| {
423 let var_value = assignment[lit.var().index()];
424 if lit.is_pos() { var_value } else { !var_value }
425 });
426
427 let sat2 = clause2.lits.iter().any(|&lit| {
428 let var_value = assignment[lit.var().index()];
429 if lit.is_pos() { var_value } else { !var_value }
430 });
431
432 assert!(sat1);
433 assert!(sat2);
434 }
435
436 #[test]
437 fn test_local_search_stats() {
438 let config = LocalSearchConfig {
439 max_flips: 100,
440 ..Default::default()
441 };
442 let mut ls = LocalSearch::new(5, config);
443
444 let mut db = ClauseDatabase::new();
445 db.add(Clause::new(
447 vec![Lit::pos(Var::new(0)), Lit::pos(Var::new(1))],
448 false,
449 ));
450 db.add(Clause::new(
451 vec![Lit::neg(Var::new(0)), Lit::pos(Var::new(2))],
452 false,
453 ));
454 db.add(Clause::new(
455 vec![Lit::neg(Var::new(1)), Lit::pos(Var::new(3))],
456 false,
457 ));
458
459 let (result, _) = ls.solve_walksat(&db, 5);
460 let stats = ls.stats();
461
462 assert_eq!(result, LocalSearchResult::Sat);
464 assert!(stats.min_unsat <= 3);
466 }
467
468 #[test]
469 fn test_probsat() {
470 let mut db = ClauseDatabase::new();
472 db.add(Clause::new(
473 vec![Lit::pos(Var::new(0)), Lit::pos(Var::new(1))],
474 false,
475 ));
476 db.add(Clause::new(
477 vec![Lit::neg(Var::new(0)), Lit::pos(Var::new(2))],
478 false,
479 ));
480
481 let config = LocalSearchConfig {
482 max_flips: 1000,
483 cb_exponent: 2.5,
484 ..Default::default()
485 };
486
487 let mut ls = LocalSearch::new(3, config);
488 let (result, _assignment) = ls.solve_probsat(&db, 3);
489
490 assert!(result == LocalSearchResult::Sat || result == LocalSearchResult::Unknown);
493 }
494}