Skip to main content

oxiz_sat/
local_search.rs

1//! Local Search SAT Solver (ProbSAT/WalkSAT)
2//!
3//! Local search is a complementary technique to CDCL that can be very effective
4//! for satisfiable instances. It works by maintaining a complete assignment and
5//! iteratively flipping variables to reduce the number of unsatisfied clauses.
6//!
7//! This module implements:
8//! - ProbSAT: Uses a probability distribution based on break counts
9//! - WalkSAT: Uses a greedy heuristic with random walk
10
11use crate::clause::{ClauseDatabase, ClauseId};
12use crate::literal::{Lit, Var};
13#[allow(unused_imports)]
14use crate::prelude::*;
15use smallvec::SmallVec;
16
17/// Configuration for local search
18#[derive(Debug, Clone)]
19pub struct LocalSearchConfig {
20    /// Maximum number of flips before giving up
21    pub max_flips: u64,
22    /// Probability of random walk (WalkSAT only, typically 0.3-0.5)
23    pub random_walk_prob: f64,
24    /// Polynomial break value exponent (ProbSAT, typically 2.0-3.0)
25    pub cb_exponent: f64,
26    /// Random seed for reproducibility
27    pub random_seed: u64,
28}
29
30impl Default for LocalSearchConfig {
31    fn default() -> Self {
32        Self {
33            max_flips: 1_000_000,
34            random_walk_prob: 0.4,
35            cb_exponent: 2.3,
36            random_seed: 1234567,
37        }
38    }
39}
40
41/// Result of local search
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum LocalSearchResult {
44    /// Found a satisfying assignment
45    Sat,
46    /// Reached maximum flips without finding a solution
47    Unknown,
48}
49
50/// Statistics for local search
51#[derive(Debug, Default, Clone)]
52pub struct LocalSearchStats {
53    /// Number of variable flips performed
54    pub flips: u64,
55    /// Minimum number of unsatisfied clauses seen
56    pub min_unsat: usize,
57    /// Number of times the best assignment was updated
58    pub improvements: u64,
59}
60
61/// Local Search SAT Solver
62///
63/// Implements both WalkSAT and ProbSAT algorithms.
64/// Maintains a complete assignment and iteratively flips variables.
65pub struct LocalSearch {
66    /// Current variable assignment (true/false for each variable)
67    assignment: Vec<bool>,
68    /// Break count for each variable (how many satisfied clauses would become unsatisfied)
69    break_count: Vec<u64>,
70    /// Make count for each variable (how many unsatisfied clauses would become satisfied)
71    make_count: Vec<u64>,
72    /// List of currently unsatisfied clauses
73    unsat_clauses: Vec<ClauseId>,
74    /// Set of unsatisfied clauses for quick lookup
75    unsat_set: HashMap<ClauseId, ()>,
76    /// Number of true literals in each clause
77    true_count: HashMap<ClauseId, usize>,
78    /// Configuration
79    config: LocalSearchConfig,
80    /// Statistics
81    stats: LocalSearchStats,
82    /// Simple LCG random number generator state
83    rng_state: u64,
84}
85
86impl LocalSearch {
87    /// Create a new local search solver
88    #[must_use]
89    pub fn new(num_vars: usize, config: LocalSearchConfig) -> Self {
90        Self {
91            assignment: vec![false; num_vars],
92            break_count: vec![0; num_vars],
93            make_count: vec![0; num_vars],
94            unsat_clauses: Vec::new(),
95            unsat_set: HashMap::new(),
96            true_count: HashMap::new(),
97            rng_state: config.random_seed,
98            config,
99            stats: LocalSearchStats::default(),
100        }
101    }
102
103    /// Simple LCG random number generator
104    fn rand(&mut self) -> u64 {
105        // Linear Congruential Generator: Xn+1 = (a * Xn + c) mod m
106        // Using values from Numerical Recipes
107        const A: u64 = 1664525;
108        const C: u64 = 1013904223;
109        self.rng_state = self.rng_state.wrapping_mul(A).wrapping_add(C);
110        self.rng_state
111    }
112
113    /// Generate a random float between 0.0 and 1.0
114    fn rand_float(&mut self) -> f64 {
115        (self.rand() as f64) / (u64::MAX as f64)
116    }
117
118    /// Initialize with a random assignment
119    fn initialize_random(&mut self, num_vars: usize) {
120        self.assignment.clear();
121        self.assignment.resize(num_vars, false);
122
123        for i in 0..num_vars {
124            self.assignment[i] = self.rand().is_multiple_of(2);
125        }
126    }
127
128    /// Initialize data structures for search
129    fn initialize(&mut self, clauses: &ClauseDatabase, num_vars: usize) {
130        self.initialize_random(num_vars);
131        self.break_count.clear();
132        self.break_count.resize(num_vars, 0);
133        self.make_count.clear();
134        self.make_count.resize(num_vars, 0);
135        self.unsat_clauses.clear();
136        self.unsat_set.clear();
137        self.true_count.clear();
138
139        // Calculate initial true counts and unsat clauses
140        for id in clauses.iter_ids() {
141            let clause = clauses
142                .get(id)
143                .expect("id from clauses.iter_ids() is valid");
144            let true_lits = clause.lits.iter().filter(|&&lit| self.is_true(lit)).count();
145
146            self.true_count.insert(id, true_lits);
147
148            if true_lits == 0 {
149                self.unsat_clauses.push(id);
150                self.unsat_set.insert(id, ());
151            }
152        }
153
154        // Calculate break and make counts
155        for id in clauses.iter_ids() {
156            let clause = clauses
157                .get(id)
158                .expect("id from clauses.iter_ids() is valid");
159            let true_lits = self.true_count[&id];
160
161            for &lit in &clause.lits {
162                let var = lit.var();
163                let var_idx = var.index();
164
165                if self.is_true(lit) {
166                    // If this is the only true literal, flipping would break this clause
167                    if true_lits == 1 {
168                        self.break_count[var_idx] += 1;
169                    }
170                } else {
171                    // Flipping would make this clause true (if it's currently false)
172                    if true_lits == 0 {
173                        self.make_count[var_idx] += 1;
174                    }
175                }
176            }
177        }
178
179        self.stats.min_unsat = self.unsat_clauses.len();
180    }
181
182    /// Check if a literal is true under the current assignment
183    fn is_true(&self, lit: Lit) -> bool {
184        let var_value = self.assignment[lit.var().index()];
185        if lit.is_pos() { var_value } else { !var_value }
186    }
187
188    /// Flip a variable and update data structures
189    fn flip(&mut self, var: Var, clauses: &ClauseDatabase) {
190        let var_idx = var.index();
191        self.assignment[var_idx] = !self.assignment[var_idx];
192        self.stats.flips += 1;
193
194        // Update true counts and unsat status for all clauses containing this variable
195        let pos_lit = Lit::pos(var);
196        let neg_lit = Lit::neg(var);
197
198        // We need to find all clauses containing this variable
199        // Since we don't have a watch list here, we iterate all clauses
200        for id in clauses.iter_ids() {
201            let clause = clauses
202                .get(id)
203                .expect("id from clauses.iter_ids() is valid");
204            if !clause.lits.contains(&pos_lit) && !clause.lits.contains(&neg_lit) {
205                continue;
206            }
207
208            let old_true_count = self.true_count[&id];
209            let was_unsat = old_true_count == 0;
210
211            // Recalculate true count
212            let new_true_count = clause.lits.iter().filter(|&&lit| self.is_true(lit)).count();
213
214            self.true_count.insert(id, new_true_count);
215
216            let is_unsat = new_true_count == 0;
217
218            // Update unsat list
219            if !was_unsat && is_unsat {
220                self.unsat_clauses.push(id);
221                self.unsat_set.insert(id, ());
222            } else if was_unsat && !is_unsat {
223                self.unsat_set.remove(&id);
224            }
225
226            // Update break/make counts for literals in this clause
227            for &lit in &clause.lits {
228                let lit_var = lit.var();
229                let lit_var_idx = lit_var.index();
230
231                // Update break count
232                if old_true_count == 1 && self.is_true(lit) {
233                    // This literal was the only true one, flipping would break
234                    self.break_count[lit_var_idx] -= 1;
235                }
236                if new_true_count == 1 && self.is_true(lit) {
237                    // This literal is now the only true one
238                    self.break_count[lit_var_idx] += 1;
239                }
240
241                // Update make count
242                if old_true_count == 0 && !self.is_true(lit) {
243                    // Clause was unsat, flipping this literal would make it
244                    self.make_count[lit_var_idx] -= 1;
245                }
246                if new_true_count == 0 && !self.is_true(lit) {
247                    // Clause is now unsat, flipping this literal would make it
248                    self.make_count[lit_var_idx] += 1;
249                }
250            }
251        }
252
253        // Clean up unsat_clauses list
254        self.unsat_clauses
255            .retain(|&id| self.unsat_set.contains_key(&id));
256
257        // Track improvements
258        if self.unsat_clauses.len() < self.stats.min_unsat {
259            self.stats.min_unsat = self.unsat_clauses.len();
260            self.stats.improvements += 1;
261        }
262    }
263
264    /// Run WalkSAT algorithm
265    ///
266    /// Returns the result and the final assignment (if SAT)
267    pub fn solve_walksat(
268        &mut self,
269        clauses: &ClauseDatabase,
270        num_vars: usize,
271    ) -> (LocalSearchResult, Option<Vec<bool>>) {
272        self.initialize(clauses, num_vars);
273
274        for _ in 0..self.config.max_flips {
275            if self.unsat_clauses.is_empty() {
276                return (LocalSearchResult::Sat, Some(self.assignment.clone()));
277            }
278
279            // Pick a random unsatisfied clause
280            let clause_id = {
281                let idx = (self.rand() as usize) % self.unsat_clauses.len();
282                self.unsat_clauses[idx]
283            };
284            let clause = clauses.get(clause_id).expect("clause_id is valid");
285
286            // Decide whether to use random walk
287            let use_random_walk = self.rand_float() < self.config.random_walk_prob;
288
289            // Select which variable to flip
290            let var_to_flip = if use_random_walk {
291                // Pick a random variable from the clause
292                let idx = (self.rand() as usize) % clause.lits.len();
293                clause.lits[idx].var()
294            } else {
295                // Pick the variable with minimum break count
296                let mut best_var = clause.lits[0].var();
297                let mut min_break = self.break_count[best_var.index()];
298
299                for &lit in &clause.lits[1..] {
300                    let var = lit.var();
301                    let break_cnt = self.break_count[var.index()];
302                    if break_cnt < min_break {
303                        min_break = break_cnt;
304                        best_var = var;
305                    }
306                }
307
308                best_var
309            };
310
311            self.flip(var_to_flip, clauses);
312        }
313
314        (LocalSearchResult::Unknown, None)
315    }
316
317    /// Run ProbSAT algorithm
318    ///
319    /// Returns the result and the final assignment (if SAT)
320    pub fn solve_probsat(
321        &mut self,
322        clauses: &ClauseDatabase,
323        num_vars: usize,
324    ) -> (LocalSearchResult, Option<Vec<bool>>) {
325        self.initialize(clauses, num_vars);
326
327        for _ in 0..self.config.max_flips {
328            if self.unsat_clauses.is_empty() {
329                return (LocalSearchResult::Sat, Some(self.assignment.clone()));
330            }
331
332            // Pick a random unsatisfied clause
333            let clause_id = {
334                let idx = (self.rand() as usize) % self.unsat_clauses.len();
335                self.unsat_clauses[idx]
336            };
337            let clause = clauses.get(clause_id).expect("clause_id is valid");
338
339            // Calculate probabilities based on break counts
340            let mut probs: SmallVec<[f64; 8]> = SmallVec::new();
341            let mut total = 0.0;
342
343            for &lit in &clause.lits {
344                let var = lit.var();
345                let break_cnt = self.break_count[var.index()];
346                // Probability is inversely proportional to (break_count + 1)^cb
347                let prob = 1.0 / libm::pow(break_cnt as f64 + 1.0, self.config.cb_exponent);
348                probs.push(prob);
349                total += prob;
350            }
351
352            // Normalize probabilities
353            for prob in &mut probs {
354                *prob /= total;
355            }
356
357            // Select variable based on probability distribution
358            let r = self.rand_float();
359            let mut cumulative = 0.0;
360            let mut selected_var = clause.lits[0].var();
361
362            for (i, &lit) in clause.lits.iter().enumerate() {
363                cumulative += probs[i];
364                if r <= cumulative {
365                    selected_var = lit.var();
366                    break;
367                }
368            }
369
370            self.flip(selected_var, clauses);
371        }
372
373        (LocalSearchResult::Unknown, None)
374    }
375
376    /// Get statistics
377    #[must_use]
378    pub fn stats(&self) -> &LocalSearchStats {
379        &self.stats
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use crate::clause::Clause;
387
388    #[test]
389    fn test_local_search_creation() {
390        let config = LocalSearchConfig::default();
391        let ls = LocalSearch::new(10, config);
392        assert_eq!(ls.assignment.len(), 10);
393    }
394
395    #[test]
396    fn test_local_search_simple_sat() {
397        // Create a simple satisfiable formula: (x1 v x2) ^ (~x1 v x3)
398        let mut db = ClauseDatabase::new();
399        let c1 = Clause::new(vec![Lit::pos(Var::new(0)), Lit::pos(Var::new(1))], false);
400        let c2 = Clause::new(vec![Lit::neg(Var::new(0)), Lit::pos(Var::new(2))], false);
401
402        let id1 = db.add(c1);
403        let id2 = db.add(c2);
404
405        let config = LocalSearchConfig {
406            max_flips: 1000,
407            ..Default::default()
408        };
409
410        let mut ls = LocalSearch::new(3, config);
411        let (result, assignment) = ls.solve_walksat(&db, 3);
412
413        // Should find a solution
414        assert_eq!(result, LocalSearchResult::Sat);
415        assert!(assignment.is_some());
416
417        // Verify the solution satisfies all clauses
418        let assignment = assignment.expect("SAT result must have assignment");
419        let clause1 = db.get(id1).expect("Clause must exist in database");
420        let clause2 = db.get(id2).expect("Clause must exist in database");
421
422        let sat1 = clause1.lits.iter().any(|&lit| {
423            let var_value = assignment[lit.var().index()];
424            if lit.is_pos() { var_value } else { !var_value }
425        });
426
427        let sat2 = clause2.lits.iter().any(|&lit| {
428            let var_value = assignment[lit.var().index()];
429            if lit.is_pos() { var_value } else { !var_value }
430        });
431
432        assert!(sat1);
433        assert!(sat2);
434    }
435
436    #[test]
437    fn test_local_search_stats() {
438        let config = LocalSearchConfig {
439            max_flips: 100,
440            ..Default::default()
441        };
442        let mut ls = LocalSearch::new(5, config);
443
444        let mut db = ClauseDatabase::new();
445        // Add a more complex formula that requires flips
446        db.add(Clause::new(
447            vec![Lit::pos(Var::new(0)), Lit::pos(Var::new(1))],
448            false,
449        ));
450        db.add(Clause::new(
451            vec![Lit::neg(Var::new(0)), Lit::pos(Var::new(2))],
452            false,
453        ));
454        db.add(Clause::new(
455            vec![Lit::neg(Var::new(1)), Lit::pos(Var::new(3))],
456            false,
457        ));
458
459        let (result, _) = ls.solve_walksat(&db, 5);
460        let stats = ls.stats();
461
462        // Should find a solution for this easy formula
463        assert_eq!(result, LocalSearchResult::Sat);
464        // Stats should be populated
465        assert!(stats.min_unsat <= 3);
466    }
467
468    #[test]
469    fn test_probsat() {
470        // Test ProbSAT on a simple formula
471        let mut db = ClauseDatabase::new();
472        db.add(Clause::new(
473            vec![Lit::pos(Var::new(0)), Lit::pos(Var::new(1))],
474            false,
475        ));
476        db.add(Clause::new(
477            vec![Lit::neg(Var::new(0)), Lit::pos(Var::new(2))],
478            false,
479        ));
480
481        let config = LocalSearchConfig {
482            max_flips: 1000,
483            cb_exponent: 2.5,
484            ..Default::default()
485        };
486
487        let mut ls = LocalSearch::new(3, config);
488        let (result, _assignment) = ls.solve_probsat(&db, 3);
489
490        // Should find a solution (though ProbSAT is probabilistic)
491        // We just check it doesn't crash for now
492        assert!(result == LocalSearchResult::Sat || result == LocalSearchResult::Unknown);
493    }
494}