Skip to main content

oxiz_sat/solver/
mod.rs

1//! CDCL SAT Solver
2
3mod conflict;
4mod decide;
5pub mod heuristic;
6mod incremental;
7mod learn;
8mod propagate;
9
10pub use heuristic::{BoxedBranchingHeuristic, BranchingHeuristic};
11
12use crate::chb::CHB;
13use crate::chrono::ChronoBacktrack;
14use crate::clause::{ClauseDatabase, ClauseId};
15use crate::literal::{LBool, Lit, Var};
16use crate::lrb::LRB;
17use crate::memory_opt::{MemoryAction, MemoryOptimizer};
18#[allow(unused_imports)]
19use crate::prelude::*;
20use crate::trail::{Reason, Trail};
21use crate::vsids::VSIDS;
22use crate::watched::{WatchLists, Watcher};
23use smallvec::SmallVec;
24
25/// Binary implication graph for efficient binary clause propagation
26/// For each literal L, stores the list of literals that are implied when L is false
27/// (i.e., for binary clause (~L v M), when L is assigned false, M must be true)
28#[derive(Debug, Clone)]
29pub(super) struct BinaryImplicationGraph {
30    /// implications[lit] = list of (implied_lit, clause_id) pairs
31    implications: Vec<Vec<(Lit, ClauseId)>>,
32}
33
34impl BinaryImplicationGraph {
35    fn new(num_vars: usize) -> Self {
36        Self {
37            implications: vec![Vec::new(); num_vars * 2],
38        }
39    }
40
41    fn resize(&mut self, num_vars: usize) {
42        self.implications.resize(num_vars * 2, Vec::new());
43    }
44
45    fn add(&mut self, lit: Lit, implied: Lit, clause_id: ClauseId) {
46        self.implications[lit.code() as usize].push((implied, clause_id));
47    }
48
49    fn get(&self, lit: Lit) -> &[(Lit, ClauseId)] {
50        &self.implications[lit.code() as usize]
51    }
52
53    fn clear(&mut self) {
54        for implications in &mut self.implications {
55            implications.clear();
56        }
57    }
58}
59
60/// Result from a theory check
61#[derive(Debug, Clone)]
62pub enum TheoryCheckResult {
63    /// Theory is satisfied under current assignment
64    Sat,
65    /// Theory detected a conflict, returns conflict clause literals
66    Conflict(SmallVec<[Lit; 8]>),
67    /// Theory propagated new literals (lit, reason clause)
68    Propagated(Vec<(Lit, SmallVec<[Lit; 8]>)>),
69}
70
71/// Callback trait for theory solvers
72/// The CDCL(T) solver implements this to receive theory callbacks
73pub trait TheoryCallback {
74    /// Called when a literal is assigned
75    /// Returns a theory check result
76    fn on_assignment(&mut self, lit: Lit) -> TheoryCheckResult;
77
78    /// Called after propagation is complete to do a full theory check
79    fn final_check(&mut self) -> TheoryCheckResult;
80
81    /// Called when the decision level increases
82    fn on_new_level(&mut self, _level: u32) {}
83
84    /// Called when backtracking
85    fn on_backtrack(&mut self, level: u32);
86}
87
88/// Result of SAT solving
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90pub enum SolverResult {
91    /// Satisfiable
92    Sat,
93    /// Unsatisfiable
94    Unsat,
95    /// Unknown (e.g., timeout, resource limit)
96    Unknown,
97}
98
99/// Solver configuration
100#[derive(Clone)]
101pub struct SolverConfig {
102    /// Restart interval (number of conflicts)
103    pub restart_interval: u64,
104    /// Restart multiplier for geometric restarts
105    pub restart_multiplier: f64,
106    /// Clause deletion threshold
107    pub clause_deletion_threshold: usize,
108    /// Variable decay factor
109    pub var_decay: f64,
110    /// Clause decay factor
111    pub clause_decay: f64,
112    /// Random polarity probability (0.0 to 1.0)
113    pub random_polarity_prob: f64,
114    /// Restart strategy: "luby" or "geometric"
115    pub restart_strategy: RestartStrategy,
116    /// Enable lazy hyper-binary resolution
117    pub enable_lazy_hyper_binary: bool,
118    /// Use CHB instead of VSIDS for branching
119    pub use_chb_branching: bool,
120    /// Use LRB (Learning Rate Branching) for branching
121    pub use_lrb_branching: bool,
122    /// Enable inprocessing (periodic preprocessing during search)
123    pub enable_inprocessing: bool,
124    /// Inprocessing interval (number of conflicts between inprocessing)
125    pub inprocessing_interval: u64,
126    /// Enable chronological backtracking
127    pub enable_chronological_backtrack: bool,
128    /// Chronological backtracking threshold (max distance from assertion level)
129    pub chrono_backtrack_threshold: u32,
130    /// Optional external branching heuristic. When `Some`, called before built-in
131    /// VSIDS/LRB/CHB; returning `None` from the heuristic falls back to built-in.
132    /// Default: `None` (pure built-in strategy).
133    pub external_branching: Option<BoxedBranchingHeuristic>,
134}
135
136impl core::fmt::Debug for SolverConfig {
137    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
138        f.debug_struct("SolverConfig")
139            .field("restart_interval", &self.restart_interval)
140            .field("restart_multiplier", &self.restart_multiplier)
141            .field("clause_deletion_threshold", &self.clause_deletion_threshold)
142            .field("var_decay", &self.var_decay)
143            .field("clause_decay", &self.clause_decay)
144            .field("random_polarity_prob", &self.random_polarity_prob)
145            .field("restart_strategy", &self.restart_strategy)
146            .field("enable_lazy_hyper_binary", &self.enable_lazy_hyper_binary)
147            .field("use_chb_branching", &self.use_chb_branching)
148            .field("use_lrb_branching", &self.use_lrb_branching)
149            .field("enable_inprocessing", &self.enable_inprocessing)
150            .field("inprocessing_interval", &self.inprocessing_interval)
151            .field(
152                "enable_chronological_backtrack",
153                &self.enable_chronological_backtrack,
154            )
155            .field(
156                "chrono_backtrack_threshold",
157                &self.chrono_backtrack_threshold,
158            )
159            .field(
160                "external_branching",
161                &self
162                    .external_branching
163                    .as_ref()
164                    .map(|_| "<BranchingHeuristic>"),
165            )
166            .finish()
167    }
168}
169
170/// Restart strategy
171#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172pub enum RestartStrategy {
173    /// Luby sequence restarts
174    Luby,
175    /// Geometric restarts
176    Geometric,
177    /// Glucose-style dynamic restarts based on LBD
178    Glucose,
179    /// Local restarts based on LBD trail
180    LocalLbd,
181}
182
183impl Default for SolverConfig {
184    fn default() -> Self {
185        Self {
186            restart_interval: 100,
187            restart_multiplier: 1.5,
188            clause_deletion_threshold: 10000,
189            var_decay: 0.95,
190            clause_decay: 0.999,
191            random_polarity_prob: 0.02,
192            restart_strategy: RestartStrategy::Luby,
193            enable_lazy_hyper_binary: true,
194            use_chb_branching: false,
195            use_lrb_branching: false,
196            enable_inprocessing: false,
197            inprocessing_interval: 5000,
198            enable_chronological_backtrack: true,
199            chrono_backtrack_threshold: 100,
200            external_branching: None,
201        }
202    }
203}
204
205/// Statistics for the solver
206#[derive(Debug, Default, Clone)]
207pub struct SolverStats {
208    /// Number of decisions made
209    pub decisions: u64,
210    /// Number of propagations
211    pub propagations: u64,
212    /// Number of conflicts
213    pub conflicts: u64,
214    /// Number of restarts
215    pub restarts: u64,
216    /// Number of learned clauses
217    pub learned_clauses: u64,
218    /// Number of deleted clauses
219    pub deleted_clauses: u64,
220    /// Number of binary clauses learned
221    pub binary_clauses: u64,
222    /// Number of unit clauses learned
223    pub unit_clauses: u64,
224    /// Total LBD of learned clauses
225    pub total_lbd: u64,
226    /// Number of clause minimizations
227    pub minimizations: u64,
228    /// Literals removed by minimization
229    pub literals_removed: u64,
230    /// Number of chronological backtracks
231    pub chrono_backtracks: u64,
232    /// Number of non-chronological backtracks
233    pub non_chrono_backtracks: u64,
234}
235
236impl SolverStats {
237    /// Get average LBD of learned clauses
238    #[must_use]
239    pub fn avg_lbd(&self) -> f64 {
240        if self.learned_clauses == 0 {
241            0.0
242        } else {
243            self.total_lbd as f64 / self.learned_clauses as f64
244        }
245    }
246
247    /// Get average decisions per conflict
248    #[must_use]
249    pub fn avg_decisions_per_conflict(&self) -> f64 {
250        if self.conflicts == 0 {
251            0.0
252        } else {
253            self.decisions as f64 / self.conflicts as f64
254        }
255    }
256
257    /// Get propagations per conflict
258    #[must_use]
259    pub fn propagations_per_conflict(&self) -> f64 {
260        if self.conflicts == 0 {
261            0.0
262        } else {
263            self.propagations as f64 / self.conflicts as f64
264        }
265    }
266
267    /// Get clause deletion ratio
268    #[must_use]
269    pub fn deletion_ratio(&self) -> f64 {
270        if self.learned_clauses == 0 {
271            0.0
272        } else {
273            self.deleted_clauses as f64 / self.learned_clauses as f64
274        }
275    }
276
277    /// Get chronological backtrack ratio
278    #[must_use]
279    pub fn chrono_backtrack_ratio(&self) -> f64 {
280        let total = self.chrono_backtracks + self.non_chrono_backtracks;
281        if total == 0 {
282            0.0
283        } else {
284            self.chrono_backtracks as f64 / total as f64
285        }
286    }
287
288    /// Display formatted statistics
289    pub fn display(&self) {
290        println!("========== Solver Statistics ==========");
291        println!("Decisions:              {:>12}", self.decisions);
292        println!("Propagations:           {:>12}", self.propagations);
293        println!("Conflicts:              {:>12}", self.conflicts);
294        println!("Restarts:               {:>12}", self.restarts);
295        println!("Learned clauses:        {:>12}", self.learned_clauses);
296        println!("  - Unit clauses:       {:>12}", self.unit_clauses);
297        println!("  - Binary clauses:     {:>12}", self.binary_clauses);
298        println!("Deleted clauses:        {:>12}", self.deleted_clauses);
299        println!("Minimizations:          {:>12}", self.minimizations);
300        println!("Literals removed:       {:>12}", self.literals_removed);
301        println!("Chrono backtracks:      {:>12}", self.chrono_backtracks);
302        println!("Non-chrono backtracks:  {:>12}", self.non_chrono_backtracks);
303        println!("---------------------------------------");
304        println!("Avg LBD:                {:>12.2}", self.avg_lbd());
305        println!(
306            "Avg decisions/conflict: {:>12.2}",
307            self.avg_decisions_per_conflict()
308        );
309        println!(
310            "Propagations/conflict:  {:>12.2}",
311            self.propagations_per_conflict()
312        );
313        println!(
314            "Deletion ratio:         {:>12.2}%",
315            self.deletion_ratio() * 100.0
316        );
317        println!(
318            "Chrono backtrack ratio: {:>12.2}%",
319            self.chrono_backtrack_ratio() * 100.0
320        );
321        println!("=======================================");
322    }
323}
324
325/// CDCL SAT Solver
326#[derive(Debug)]
327pub struct Solver {
328    /// Configuration
329    pub(super) config: SolverConfig,
330    /// Number of variables
331    pub(super) num_vars: usize,
332    /// Clause database
333    pub(super) clauses: ClauseDatabase,
334    /// Assignment trail
335    pub(super) trail: Trail,
336    /// Watch lists
337    pub(super) watches: WatchLists,
338    /// VSIDS branching heuristic
339    pub(super) vsids: VSIDS,
340    /// CHB branching heuristic
341    pub(super) chb: CHB,
342    /// LRB branching heuristic
343    pub(super) lrb: LRB,
344    /// Statistics
345    pub(super) stats: SolverStats,
346    /// Learnt clause for conflict analysis
347    pub(super) learnt: SmallVec<[Lit; 16]>,
348    /// Seen flags for conflict analysis
349    pub(super) seen: Vec<bool>,
350    /// Analyze stack
351    pub(super) analyze_stack: Vec<Lit>,
352    /// Current restart threshold
353    pub(super) restart_threshold: u64,
354    /// Assertions stack for incremental solving (number of original clauses)
355    pub(super) assertion_levels: Vec<usize>,
356    /// Trail sizes at each assertion level (for proper pop backtracking)
357    pub(super) assertion_trail_sizes: Vec<usize>,
358    /// Clause IDs added at each assertion level (for proper pop)
359    pub(super) assertion_clause_ids: Vec<Vec<ClauseId>>,
360    /// Model (if sat)
361    pub(super) model: Vec<LBool>,
362    /// Whether formula is trivially unsatisfiable
363    pub(super) trivially_unsat: bool,
364    /// Phase saving: last polarity assigned to each variable
365    pub(super) phase: Vec<bool>,
366    /// Luby sequence index for restarts
367    pub(super) luby_index: u64,
368    /// Level marks for LBD computation
369    pub(super) level_marks: Vec<u32>,
370    /// Current mark counter for LBD computation
371    pub(super) lbd_mark: u32,
372    /// Learned clause IDs for deletion
373    pub(super) learned_clause_ids: Vec<ClauseId>,
374    /// Number of conflicts since last clause deletion
375    pub(super) conflicts_since_deletion: u64,
376    /// PRNG state (xorshift64)
377    pub(super) rng_state: u64,
378    /// For Glucose-style restarts: average LBD of recent conflicts
379    pub(super) recent_lbd_sum: u64,
380    /// Number of conflicts contributing to recent_lbd_sum
381    pub(super) recent_lbd_count: u64,
382    /// Binary implication graph for fast binary clause propagation
383    pub(super) binary_graph: BinaryImplicationGraph,
384    /// Global average LBD for local restarts
385    pub(super) global_lbd_sum: u64,
386    /// Number of conflicts contributing to global LBD
387    pub(super) global_lbd_count: u64,
388    /// Conflicts since last local restart
389    pub(super) conflicts_since_local_restart: u64,
390    /// Conflicts since last inprocessing
391    pub(super) conflicts_since_inprocessing: u64,
392    /// Chronological backtracking helper
393    pub(super) chrono_backtrack: ChronoBacktrack,
394    /// Clause activity bump increment (for MapleSAT-style clause bumping)
395    pub(super) clause_bump_increment: f64,
396    /// Memory optimizer with size-class pools for clause allocation
397    pub(super) memory_optimizer: MemoryOptimizer,
398}
399
400impl Default for Solver {
401    fn default() -> Self {
402        Self::new()
403    }
404}
405
406impl Solver {
407    /// Create a new solver
408    #[must_use]
409    pub fn new() -> Self {
410        Self::with_config(SolverConfig::default())
411    }
412
413    /// Create a new solver with configuration
414    #[must_use]
415    pub fn with_config(config: SolverConfig) -> Self {
416        let chrono_enabled = config.enable_chronological_backtrack;
417        let chrono_threshold = config.chrono_backtrack_threshold;
418
419        Self {
420            restart_threshold: config.restart_interval,
421            config,
422            num_vars: 0,
423            clauses: ClauseDatabase::new(),
424            trail: Trail::new(0),
425            watches: WatchLists::new(0),
426            vsids: VSIDS::new(0),
427            chb: CHB::new(0),
428            lrb: LRB::new(0),
429            stats: SolverStats::default(),
430            learnt: SmallVec::new(),
431            seen: Vec::new(),
432            analyze_stack: Vec::new(),
433            assertion_levels: vec![0],
434            assertion_trail_sizes: vec![0],
435            assertion_clause_ids: vec![Vec::new()],
436            model: Vec::new(),
437            trivially_unsat: false,
438            phase: Vec::new(),
439            luby_index: 0,
440            level_marks: Vec::new(),
441            lbd_mark: 0,
442            learned_clause_ids: Vec::new(),
443            conflicts_since_deletion: 0,
444            rng_state: 0x853c_49e6_748f_ea9b, // Random seed
445            recent_lbd_sum: 0,
446            recent_lbd_count: 0,
447            binary_graph: BinaryImplicationGraph::new(0),
448            global_lbd_sum: 0,
449            global_lbd_count: 0,
450            conflicts_since_local_restart: 0,
451            conflicts_since_inprocessing: 0,
452            chrono_backtrack: ChronoBacktrack::new(chrono_enabled, chrono_threshold),
453            clause_bump_increment: 1.0,
454            memory_optimizer: MemoryOptimizer::new(),
455        }
456    }
457
458    /// Create a new variable
459    pub fn new_var(&mut self) -> Var {
460        let var = Var::new(self.num_vars as u32);
461        self.num_vars += 1;
462        self.trail.resize(self.num_vars);
463        self.watches.resize(self.num_vars);
464        self.binary_graph.resize(self.num_vars);
465        self.vsids.insert(var);
466        self.chb.insert(var);
467        self.lrb.resize(self.num_vars);
468        self.seen.resize(self.num_vars, false);
469        self.model.resize(self.num_vars, LBool::Undef);
470        self.phase.resize(self.num_vars, false); // Default phase: negative
471        // Resize level_marks to at least num_vars (enough for decision levels)
472        if self.level_marks.len() < self.num_vars {
473            self.level_marks.resize(self.num_vars, 0);
474        }
475        var
476    }
477
478    /// Ensure we have at least n variables
479    pub fn ensure_vars(&mut self, n: usize) {
480        while self.num_vars < n {
481            self.new_var();
482        }
483    }
484
485    /// Add a clause
486    pub fn add_clause(&mut self, lits: impl IntoIterator<Item = Lit>) -> bool {
487        let mut clause_lits: SmallVec<[Lit; 8]> = lits.into_iter().collect();
488
489        // Ensure we have all variables
490        for lit in &clause_lits {
491            let var_idx = lit.var().index();
492            if var_idx >= self.num_vars {
493                self.ensure_vars(var_idx + 1);
494            }
495        }
496
497        // Remove duplicates and check for tautology
498        clause_lits.sort_by_key(|l| l.code());
499        clause_lits.dedup();
500
501        // Check for tautology (x and ~x in same clause)
502        for i in 0..clause_lits.len() {
503            for j in (i + 1)..clause_lits.len() {
504                if clause_lits[i] == clause_lits[j].negate() {
505                    return true; // Tautology - always satisfied
506                }
507            }
508        }
509
510        // Handle special cases
511        match clause_lits.len() {
512            0 => {
513                self.trivially_unsat = true;
514                return false; // Empty clause - unsat
515            }
516            1 => {
517                // Unit clause - enqueue at decision level 0
518                // Unit clauses must be assigned at level 0 to survive backtracking.
519                // After solve(), current_level may be > 0, so we must backtrack first.
520                let lit = clause_lits[0];
521
522                if self.trail.lit_value(lit).is_false() {
523                    // The literal conflicts with the current trail.
524                    // Check if the conflict is at decision level 0 (permanent constraint)
525                    // or from a previous solve (can be retried after backtrack).
526                    let var = lit.var();
527                    let level = self.trail.level(var);
528                    if level == 0 {
529                        // Conflict with a level-0 assignment - truly UNSAT
530                        self.trivially_unsat = true;
531                        return false;
532                    } else {
533                        // Conflict with higher-level assignment from previous solve.
534                        // Backtrack to root and assign the new unit literal at level 0.
535                        self.backtrack_to_root();
536                        self.trail.assign_decision(lit);
537                        return true;
538                    }
539                }
540
541                if self.trail.lit_value(lit).is_true() {
542                    // Already satisfied - check if at level 0
543                    let var = lit.var();
544                    let level = self.trail.level(var);
545                    if level == 0 {
546                        // Already assigned at level 0, nothing to do
547                        return true;
548                    }
549                    // Assigned at higher level - backtrack and reassign at level 0
550                    self.backtrack_to_root();
551                    self.trail.assign_decision(lit);
552                    return true;
553                }
554
555                // Variable is unassigned - backtrack to level 0 first to ensure
556                // the assignment is at level 0 (survives future backtracks)
557                if self.trail.decision_level() > 0 {
558                    self.backtrack_to_root();
559                }
560                self.trail.assign_decision(lit);
561                return true;
562            }
563            2 => {
564                // Binary clause - check if it conflicts with current assignment
565                let lit0 = clause_lits[0];
566                let lit1 = clause_lits[1];
567                let val0 = self.trail.lit_value(lit0);
568                let val1 = self.trail.lit_value(lit1);
569
570                // If clause is satisfied, just add it
571                if val0.is_true() || val1.is_true() {
572                    // Clause already satisfied by current assignment
573                    let clause_id = self.clauses.add_original(clause_lits.iter().copied());
574                    if let Some(current_level_clauses) = self.assertion_clause_ids.last_mut() {
575                        current_level_clauses.push(clause_id);
576                    }
577                    self.binary_graph.add(lit0.negate(), lit1, clause_id);
578                    self.binary_graph.add(lit1.negate(), lit0, clause_id);
579                    self.watches
580                        .add(lit0.negate(), Watcher::new(clause_id, lit1));
581                    self.watches
582                        .add(lit1.negate(), Watcher::new(clause_id, lit0));
583                    return true;
584                }
585
586                // If both literals are false, we have a conflict
587                if val0.is_false() && val1.is_false() {
588                    // Check if both are at level 0
589                    let level0 = self.trail.level(lit0.var());
590                    let level1 = self.trail.level(lit1.var());
591
592                    if level0 == 0 && level1 == 0 {
593                        // Conflict at level 0 - UNSAT
594                        self.trivially_unsat = true;
595                        return false;
596                    }
597
598                    // Backtrack to level 0 and add clause
599                    // The clause will be propagated on next solve()
600                    self.backtrack_to_root();
601                }
602
603                // If one literal is false and one undefined, propagate
604                // after adding the clause (via next solve())
605
606                let clause_id = self.clauses.add_original(clause_lits.iter().copied());
607                if let Some(current_level_clauses) = self.assertion_clause_ids.last_mut() {
608                    current_level_clauses.push(clause_id);
609                }
610                self.binary_graph.add(lit0.negate(), lit1, clause_id);
611                self.binary_graph.add(lit1.negate(), lit0, clause_id);
612                self.watches
613                    .add(lit0.negate(), Watcher::new(clause_id, lit1));
614                self.watches
615                    .add(lit1.negate(), Watcher::new(clause_id, lit0));
616                return true;
617            }
618            _ => {}
619        }
620
621        // Add clause (3+ literals)
622        // Check if clause is satisfied or conflicts with current assignment
623        let num_false = clause_lits
624            .iter()
625            .filter(|&l| self.trail.lit_value(*l).is_false())
626            .count();
627        let has_true = clause_lits
628            .iter()
629            .any(|l| self.trail.lit_value(*l).is_true());
630
631        if !has_true && num_false == clause_lits.len() {
632            // All literals are false - conflict
633            // Check if all at level 0
634            let all_at_zero = clause_lits.iter().all(|l| self.trail.level(l.var()) == 0);
635            if all_at_zero {
636                self.trivially_unsat = true;
637                return false;
638            }
639            // Backtrack to level 0
640            self.backtrack_to_root();
641        }
642
643        let clause_id = self.clauses.add_original(clause_lits.iter().copied());
644
645        // Track clause for incremental solving
646        if let Some(current_level_clauses) = self.assertion_clause_ids.last_mut() {
647            current_level_clauses.push(clause_id);
648        }
649
650        // Set up watches - prefer non-false literals for watching
651        let lit0 = clause_lits[0];
652        let lit1 = clause_lits[1];
653
654        self.watches
655            .add(lit0.negate(), Watcher::new(clause_id, lit1));
656        self.watches
657            .add(lit1.negate(), Watcher::new(clause_id, lit0));
658
659        true
660    }
661
662    /// Add a clause from DIMACS literals
663    pub fn add_clause_dimacs(&mut self, lits: &[i32]) -> bool {
664        self.add_clause(lits.iter().map(|&l| Lit::from_dimacs(l)))
665    }
666
667    /// Solve the SAT problem
668    pub fn solve(&mut self) -> SolverResult {
669        // Check if trivially unsatisfiable
670        if self.trivially_unsat {
671            return SolverResult::Unsat;
672        }
673
674        // Initial propagation
675        if self.propagate().is_some() {
676            return SolverResult::Unsat;
677        }
678
679        loop {
680            // Propagate
681            if let Some(conflict) = self.propagate() {
682                self.stats.conflicts += 1;
683                self.conflicts_since_inprocessing += 1;
684
685                if self.trail.decision_level() == 0 {
686                    return SolverResult::Unsat;
687                }
688
689                // Analyze conflict
690                let (backtrack_level, learnt_clause) = self.analyze(conflict);
691
692                // Backtrack with phase saving
693                self.backtrack_with_phase_saving(backtrack_level);
694
695                // Learn clause
696                if learnt_clause.len() == 1 {
697                    // Store unit learned clause in database for persistence
698                    let clause_id = self.clauses.add_learned(learnt_clause.iter().copied());
699                    self.stats.learned_clauses += 1;
700                    self.stats.unit_clauses += 1;
701                    self.learned_clause_ids.push(clause_id);
702
703                    // Track for incremental solving
704                    if let Some(current_level_clauses) = self.assertion_clause_ids.last_mut() {
705                        current_level_clauses.push(clause_id);
706                    }
707
708                    self.trail.assign_decision(learnt_clause[0]);
709                } else {
710                    // Compute LBD for the learned clause
711                    let lbd = self.compute_lbd(&learnt_clause);
712
713                    // Track recent LBD for Glucose-style and local restarts
714                    self.recent_lbd_sum += u64::from(lbd);
715                    self.recent_lbd_count += 1;
716                    self.global_lbd_sum += u64::from(lbd);
717                    self.global_lbd_count += 1;
718
719                    // Reset recent LBD tracking periodically
720                    if self.recent_lbd_count >= 5000 {
721                        self.recent_lbd_sum /= 2;
722                        self.recent_lbd_count /= 2;
723                    }
724
725                    let clause_id = self.clauses.add_learned(learnt_clause.iter().copied());
726                    self.stats.learned_clauses += 1;
727
728                    // Set LBD score for the clause
729                    if let Some(clause) = self.clauses.get_mut(clause_id) {
730                        clause.lbd = lbd;
731                    }
732
733                    // Track learned clause for potential deletion
734                    self.learned_clause_ids.push(clause_id);
735
736                    // Track clause for incremental solving
737                    if let Some(current_level_clauses) = self.assertion_clause_ids.last_mut() {
738                        current_level_clauses.push(clause_id);
739                    }
740
741                    // Watch first two literals
742                    let lit0 = learnt_clause[0];
743                    let lit1 = learnt_clause[1];
744                    self.watches
745                        .add(lit0.negate(), Watcher::new(clause_id, lit1));
746                    self.watches
747                        .add(lit1.negate(), Watcher::new(clause_id, lit0));
748
749                    // Propagate the asserting literal
750                    self.trail.assign_propagation(learnt_clause[0], clause_id);
751                }
752
753                // Decay activities
754                self.vsids.decay();
755                self.chb.decay();
756                self.lrb.decay();
757                self.lrb.on_conflict();
758                self.clauses.decay_activity(self.config.clause_decay);
759                // Increase clause bump increment (inverse of decay)
760                self.clause_bump_increment /= self.config.clause_decay;
761
762                // Track conflicts for clause deletion
763                self.conflicts_since_deletion += 1;
764
765                // Periodic clause database reduction
766                if self.conflicts_since_deletion >= self.config.clause_deletion_threshold as u64 {
767                    self.reduce_clause_database();
768                    self.conflicts_since_deletion = 0;
769
770                    // Vivification after clause database reduction (at level 0 after restart)
771                    if self.stats.restarts.is_multiple_of(10) {
772                        let saved_level = self.trail.decision_level();
773                        if saved_level == 0 {
774                            self.vivify_clauses();
775                        }
776                    }
777                }
778
779                // Check for restart
780                if self.stats.conflicts >= self.restart_threshold {
781                    self.restart();
782                }
783
784                // Periodic inprocessing
785                if self.config.enable_inprocessing
786                    && self.conflicts_since_inprocessing >= self.config.inprocessing_interval
787                {
788                    self.inprocess();
789                    self.conflicts_since_inprocessing = 0;
790                }
791            } else {
792                // No conflict - try to decide
793                if let Some(var) = self.pick_branch_var() {
794                    self.stats.decisions += 1;
795                    self.trail.new_decision_level();
796
797                    // Use phase saving with random polarity
798                    let polarity = if self.rand_bool(self.config.random_polarity_prob) {
799                        // Random polarity
800                        self.rand_bool(0.5)
801                    } else {
802                        // Saved phase
803                        self.phase[var.index()]
804                    };
805                    let lit = if polarity {
806                        Lit::pos(var)
807                    } else {
808                        Lit::neg(var)
809                    };
810                    self.trail.assign_decision(lit);
811                } else {
812                    // All variables assigned - SAT
813                    self.save_model();
814                    return SolverResult::Sat;
815                }
816            }
817        }
818    }
819
820    /// Solve with assumptions and return unsat core if UNSAT
821    ///
822    /// This is the key method for MaxSAT: it solves under assumptions and
823    /// if the result is UNSAT, returns the subset of assumptions in the core.
824    ///
825    /// # Arguments
826    /// * `assumptions` - Literals that must be true
827    ///
828    /// # Returns
829    /// * `(SolverResult, Option<Vec<Lit>>)` - Result and unsat core (if UNSAT)
830    pub fn solve_with_assumptions(
831        &mut self,
832        assumptions: &[Lit],
833    ) -> (SolverResult, Option<Vec<Lit>>) {
834        if self.trivially_unsat {
835            return (SolverResult::Unsat, Some(Vec::new()));
836        }
837
838        // Ensure all assumption variables exist
839        for &lit in assumptions {
840            while self.num_vars <= lit.var().index() {
841                self.new_var();
842            }
843        }
844
845        // Initial propagation at level 0
846        if self.propagate().is_some() {
847            return (SolverResult::Unsat, Some(Vec::new()));
848        }
849
850        // Create a new decision level for assumptions
851        let assumption_level_start = self.trail.decision_level();
852
853        // Assign assumptions as decisions
854        for (i, &lit) in assumptions.iter().enumerate() {
855            // Check if already assigned
856            let value = self.trail.lit_value(lit);
857            if value.is_true() {
858                continue; // Already satisfied
859            }
860            if value.is_false() {
861                // Conflict with assumption - extract core from conflicting assumptions
862                let core = self.extract_assumption_core(assumptions, i);
863                self.backtrack(assumption_level_start);
864                return (SolverResult::Unsat, Some(core));
865            }
866
867            // Make decision for assumption
868            self.trail.new_decision_level();
869            self.trail.assign_decision(lit);
870
871            // Propagate after each assumption
872            if let Some(_conflict) = self.propagate() {
873                // Conflict during assumption propagation
874                let core = self.analyze_assumption_conflict(assumptions);
875                self.backtrack(assumption_level_start);
876                return (SolverResult::Unsat, Some(core));
877            }
878        }
879
880        // Now solve normally
881        loop {
882            if let Some(conflict) = self.propagate() {
883                self.stats.conflicts += 1;
884
885                // Check if conflict involves assumptions
886                let backtrack_level = self.analyze_conflict_level(conflict);
887
888                if backtrack_level <= assumption_level_start {
889                    // Conflict forces backtracking past assumptions - UNSAT
890                    let core = self.analyze_assumption_conflict(assumptions);
891                    self.backtrack(assumption_level_start);
892                    return (SolverResult::Unsat, Some(core));
893                }
894
895                let (bt_level, learnt_clause) = self.analyze(conflict);
896                self.backtrack_with_phase_saving(bt_level.max(assumption_level_start + 1));
897                self.learn_clause(learnt_clause);
898
899                self.vsids.decay();
900                self.clauses.decay_activity(self.config.clause_decay);
901                self.handle_clause_deletion_and_restart_limited(assumption_level_start);
902            } else {
903                // No conflict - try to decide
904                if let Some(var) = self.pick_branch_var() {
905                    self.stats.decisions += 1;
906                    self.trail.new_decision_level();
907
908                    let polarity = if self.rand_bool(self.config.random_polarity_prob) {
909                        self.rand_bool(0.5)
910                    } else {
911                        self.phase.get(var.index()).copied().unwrap_or(false)
912                    };
913                    let lit = if polarity {
914                        Lit::pos(var)
915                    } else {
916                        Lit::neg(var)
917                    };
918                    self.trail.assign_decision(lit);
919                } else {
920                    // All variables assigned - SAT
921                    self.save_model();
922                    self.backtrack(assumption_level_start);
923                    return (SolverResult::Sat, None);
924                }
925            }
926        }
927    }
928
929    /// Solve with theory integration via callbacks
930    ///
931    /// This implements the CDCL(T) loop:
932    /// 1. BCP (Boolean Constraint Propagation)
933    /// 2. Theory propagation (via callback)
934    /// 3. On conflict: analyze and learn
935    /// 4. Decision
936    /// 5. Final theory check when all vars assigned
937    pub fn solve_with_theory<T: TheoryCallback>(&mut self, theory: &mut T) -> SolverResult {
938        if self.trivially_unsat {
939            return SolverResult::Unsat;
940        }
941
942        // Initial propagation
943        if self.propagate().is_some() {
944            return SolverResult::Unsat;
945        }
946
947        // Track how many assignments have been sent to the theory.
948        // We only send NEW assignments (not previously processed ones) to avoid
949        // duplicate theory constraints that would cause spurious UNSAT.
950        let mut theory_processed: usize = 0;
951
952        loop {
953            // Boolean propagation
954            if let Some(conflict) = self.propagate() {
955                self.stats.conflicts += 1;
956
957                if self.trail.decision_level() == 0 {
958                    return SolverResult::Unsat;
959                }
960
961                let (backtrack_level, learnt_clause) = self.analyze(conflict);
962                theory.on_backtrack(backtrack_level);
963                self.backtrack_with_phase_saving(backtrack_level);
964                // After backtrack, the trail may be shorter; update processed count
965                theory_processed = theory_processed.min(self.trail.assignments().len());
966                self.learn_clause(learnt_clause);
967
968                self.vsids.decay();
969                self.clauses.decay_activity(self.config.clause_decay);
970                self.handle_clause_deletion_and_restart();
971                continue;
972            }
973
974            // Theory propagation check after each assignment
975            loop {
976                // Get only NEW (unprocessed) assignments and notify theory
977                let assignments = self.trail.assignments().to_vec();
978                let mut theory_conflict = None;
979                let mut theory_propagations = Vec::new();
980
981                // Check only NEW assignments with theory (skip already-processed ones).
982                // Guard against stale theory_processed after backtracks/restarts.
983                let safe_start = theory_processed.min(assignments.len());
984                for &lit in &assignments[safe_start..] {
985                    match theory.on_assignment(lit) {
986                        TheoryCheckResult::Sat => {}
987                        TheoryCheckResult::Conflict(conflict_lits) => {
988                            theory_conflict = Some(conflict_lits);
989                            break;
990                        }
991                        TheoryCheckResult::Propagated(props) => {
992                            theory_propagations.extend(props);
993                        }
994                    }
995                }
996                // Update processed count
997                theory_processed = assignments.len();
998
999                // Handle theory conflict
1000                if let Some(conflict_lits) = theory_conflict {
1001                    self.stats.conflicts += 1;
1002
1003                    if self.trail.decision_level() == 0 {
1004                        return SolverResult::Unsat;
1005                    }
1006
1007                    let (backtrack_level, learnt_clause) =
1008                        self.analyze_theory_conflict(&conflict_lits);
1009
1010                    // Empty learned clause signals all-level-0 conflict = fundamental UNSAT
1011                    if learnt_clause.is_empty() {
1012                        self.trivially_unsat = true;
1013                        return SolverResult::Unsat;
1014                    }
1015
1016                    theory.on_backtrack(backtrack_level);
1017                    self.backtrack_with_phase_saving(backtrack_level);
1018                    // After backtrack, update theory_processed to trail length
1019                    theory_processed = theory_processed.min(self.trail.assignments().len());
1020                    self.learn_clause(learnt_clause);
1021
1022                    self.vsids.decay();
1023                    self.clauses.decay_activity(self.config.clause_decay);
1024                    self.handle_clause_deletion_and_restart();
1025                    continue;
1026                }
1027
1028                // Handle theory propagations
1029                let mut made_propagation = false;
1030                for (lit, reason_lits) in theory_propagations {
1031                    if !self.trail.is_assigned(lit.var()) {
1032                        // Add reason clause and propagate
1033                        let clause_id = self.add_theory_reason_clause(&reason_lits, lit);
1034                        self.trail.assign_propagation(lit, clause_id);
1035                        made_propagation = true;
1036                    }
1037                }
1038
1039                if made_propagation {
1040                    // Re-run Boolean propagation
1041                    if let Some(conflict) = self.propagate() {
1042                        self.stats.conflicts += 1;
1043
1044                        if self.trail.decision_level() == 0 {
1045                            return SolverResult::Unsat;
1046                        }
1047
1048                        let (backtrack_level, learnt_clause) = self.analyze(conflict);
1049                        theory.on_backtrack(backtrack_level);
1050                        self.backtrack_with_phase_saving(backtrack_level);
1051                        // After backtrack, the trail is shorter; update processed count
1052                        theory_processed = theory_processed.min(self.trail.assignments().len());
1053                        self.learn_clause(learnt_clause);
1054
1055                        self.vsids.decay();
1056                        self.clauses.decay_activity(self.config.clause_decay);
1057                        self.handle_clause_deletion_and_restart();
1058                    }
1059                    continue;
1060                }
1061
1062                break;
1063            }
1064
1065            // Try to decide
1066            if let Some(var) = self.pick_branch_var() {
1067                self.stats.decisions += 1;
1068                self.trail.new_decision_level();
1069                let new_level = self.trail.decision_level();
1070                theory.on_new_level(new_level);
1071
1072                let polarity = if self.rand_bool(self.config.random_polarity_prob) {
1073                    self.rand_bool(0.5)
1074                } else {
1075                    self.phase[var.index()]
1076                };
1077                let lit = if polarity {
1078                    Lit::pos(var)
1079                } else {
1080                    Lit::neg(var)
1081                };
1082                self.trail.assign_decision(lit);
1083            } else {
1084                // All variables assigned - do final theory check
1085                match theory.final_check() {
1086                    TheoryCheckResult::Sat => {
1087                        self.save_model();
1088                        return SolverResult::Sat;
1089                    }
1090                    TheoryCheckResult::Conflict(conflict_lits) => {
1091                        self.stats.conflicts += 1;
1092
1093                        if self.trail.decision_level() == 0 {
1094                            return SolverResult::Unsat;
1095                        }
1096
1097                        let (backtrack_level, learnt_clause) =
1098                            self.analyze_theory_conflict(&conflict_lits);
1099
1100                        // If all conflict literals are at level 0, analyze_theory_conflict
1101                        // returns an empty learned clause as a signal of fundamental UNSAT.
1102                        if learnt_clause.is_empty() {
1103                            self.trivially_unsat = true;
1104                            return SolverResult::Unsat;
1105                        }
1106
1107                        theory.on_backtrack(backtrack_level);
1108                        self.backtrack_with_phase_saving(backtrack_level);
1109                        // After backtrack, update theory_processed
1110                        theory_processed = theory_processed.min(self.trail.assignments().len());
1111                        self.learn_clause(learnt_clause);
1112
1113                        self.vsids.decay();
1114                        self.clauses.decay_activity(self.config.clause_decay);
1115                        self.handle_clause_deletion_and_restart();
1116                    }
1117                    TheoryCheckResult::Propagated(props) => {
1118                        // Handle late propagations
1119                        for (lit, reason_lits) in props {
1120                            if !self.trail.is_assigned(lit.var()) {
1121                                let clause_id = self.add_theory_reason_clause(&reason_lits, lit);
1122                                self.trail.assign_propagation(lit, clause_id);
1123                            }
1124                        }
1125                    }
1126                }
1127            }
1128        }
1129    }
1130
1131    /// Get the model (if sat)
1132    #[must_use]
1133    pub fn model(&self) -> &[LBool] {
1134        &self.model
1135    }
1136
1137    /// Get the value of a variable in the model
1138    #[must_use]
1139    pub fn model_value(&self, var: Var) -> LBool {
1140        self.model.get(var.index()).copied().unwrap_or(LBool::Undef)
1141    }
1142
1143    /// Get statistics
1144    #[must_use]
1145    pub fn stats(&self) -> &SolverStats {
1146        &self.stats
1147    }
1148
1149    /// Get memory optimizer statistics
1150    #[must_use]
1151    pub fn memory_opt_stats(&self) -> &crate::memory_opt::MemoryOptStats {
1152        self.memory_optimizer.stats()
1153    }
1154
1155    /// Get number of variables
1156    #[must_use]
1157    pub fn num_vars(&self) -> usize {
1158        self.num_vars
1159    }
1160
1161    /// Get number of clauses
1162    #[must_use]
1163    pub fn num_clauses(&self) -> usize {
1164        self.clauses.len()
1165    }
1166
1167    /// Push a new assertion level (for incremental solving)
1168    ///
1169    /// This saves the current state so that clauses added after this point
1170    /// can be removed with pop(). Automatically backtracks to decision level 0
1171    /// to ensure a clean state for adding new constraints.
1172    pub fn push(&mut self) {
1173        // Backtrack to level 0 to ensure clean state
1174        // This is necessary because solve() may leave assignments on the trail
1175        // Use phase-saving backtrack to properly re-insert variables into decision heaps
1176        self.backtrack_with_phase_saving(0);
1177
1178        self.assertion_levels.push(self.clauses.num_original());
1179        self.assertion_trail_sizes.push(self.trail.size());
1180        self.assertion_clause_ids.push(Vec::new());
1181    }
1182
1183    /// Pop to previous assertion level
1184    pub fn pop(&mut self) {
1185        if self.assertion_levels.len() > 1 {
1186            self.assertion_levels.pop();
1187
1188            // Get the trail size to backtrack to
1189            let trail_size = self.assertion_trail_sizes.pop().unwrap_or(0);
1190
1191            // Remove all clauses added at this assertion level
1192            if let Some(clause_ids_to_remove) = self.assertion_clause_ids.pop() {
1193                for clause_id in clause_ids_to_remove {
1194                    // Remove from clause database
1195                    self.clauses.remove(clause_id);
1196
1197                    // Remove from learned clause tracking if it's a learned clause
1198                    self.learned_clause_ids.retain(|&id| id != clause_id);
1199
1200                    // Note: Watch lists will be cleaned up naturally during propagation
1201                    // as they check if clauses are deleted before using them
1202                }
1203            }
1204
1205            // Backtrack trail to the exact size it was at push()
1206            // This properly handles unit clauses that were added after push
1207            // Note: backtrack_to_size clears values but doesn't re-insert into heaps,
1208            // so we need to manually re-insert unassigned variables.
1209            let current_size = self.trail.size();
1210            if current_size > trail_size {
1211                // Collect variables that will be unassigned
1212                let mut unassigned_vars = Vec::new();
1213                for i in trail_size..current_size {
1214                    let lit = self.trail.assignments()[i];
1215                    unassigned_vars.push(lit.var());
1216                }
1217
1218                self.trail.backtrack_to_size(trail_size);
1219
1220                // Re-insert unassigned variables into decision heaps
1221                for var in unassigned_vars {
1222                    if !self.vsids.contains(var) {
1223                        self.vsids.insert(var);
1224                    }
1225                    if !self.chb.contains(var) {
1226                        self.chb.insert(var);
1227                    }
1228                    self.lrb.unassign(var);
1229                }
1230            }
1231
1232            // Ensure we're at decision level 0 with proper heap re-insertion
1233            self.backtrack_with_phase_saving(0);
1234
1235            // Clear the trivially_unsat flag as we've removed problematic clauses
1236            self.trivially_unsat = false;
1237        }
1238    }
1239
1240    /// Backtrack to decision level 0 (for AllSAT enumeration)
1241    ///
1242    /// This is necessary after a SAT result before adding blocking clauses
1243    /// to ensure the new clauses can trigger propagation correctly.
1244    /// Uses phase-saving backtrack to properly re-insert unassigned variables
1245    /// into the decision heaps (VSIDS, CHB, LRB).
1246    pub fn backtrack_to_root(&mut self) {
1247        self.backtrack_with_phase_saving(0);
1248    }
1249
1250    /// Reset the solver
1251    pub fn reset(&mut self) {
1252        self.clauses = ClauseDatabase::new();
1253        self.trail.clear();
1254        self.watches.clear();
1255        self.vsids.clear();
1256        self.chb.clear();
1257        self.stats = SolverStats::default();
1258        self.learnt.clear();
1259        self.seen.clear();
1260        self.analyze_stack.clear();
1261        self.assertion_levels.clear();
1262        self.assertion_levels.push(0);
1263        self.assertion_trail_sizes.clear();
1264        self.assertion_trail_sizes.push(0);
1265        self.assertion_clause_ids.clear();
1266        self.assertion_clause_ids.push(Vec::new());
1267        self.model.clear();
1268        self.num_vars = 0;
1269        self.restart_threshold = self.config.restart_interval;
1270        self.trivially_unsat = false;
1271        self.phase.clear();
1272        self.luby_index = 0;
1273        self.level_marks.clear();
1274        self.lbd_mark = 0;
1275        self.learned_clause_ids.clear();
1276        self.conflicts_since_deletion = 0;
1277        self.rng_state = 0x853c_49e6_748f_ea9b;
1278        self.recent_lbd_sum = 0;
1279        self.recent_lbd_count = 0;
1280        self.binary_graph.clear();
1281        self.global_lbd_sum = 0;
1282        self.global_lbd_count = 0;
1283        self.conflicts_since_local_restart = 0;
1284    }
1285
1286    /// Get the current trail (for theory solvers)
1287    #[must_use]
1288    pub fn trail(&self) -> &Trail {
1289        &self.trail
1290    }
1291
1292    /// Get the current decision level
1293    #[must_use]
1294    pub fn decision_level(&self) -> u32 {
1295        self.trail.decision_level()
1296    }
1297
1298    /// Debug method: print all learned clauses
1299    pub fn debug_print_learned_clauses(&self) {
1300        println!(
1301            "=== Learned Clauses ({}) ===",
1302            self.learned_clause_ids.len()
1303        );
1304        for (i, &cid) in self.learned_clause_ids.iter().enumerate() {
1305            if let Some(clause) = self.clauses.get(cid)
1306                && !clause.deleted
1307            {
1308                let lits: Vec<String> = clause
1309                    .lits
1310                    .iter()
1311                    .map(|lit| {
1312                        let var = lit.var().index();
1313                        if lit.is_pos() {
1314                            format!("v{}", var)
1315                        } else {
1316                            format!("~v{}", var)
1317                        }
1318                    })
1319                    .collect();
1320                println!(
1321                    "  Learned {}: ({}), LBD={}",
1322                    i,
1323                    lits.join(" | "),
1324                    clause.lbd
1325                );
1326            }
1327        }
1328    }
1329
1330    /// Debug method: print binary implication graph entries
1331    pub fn debug_print_binary_graph(&self) {
1332        println!("=== Binary Implication Graph ===");
1333        for lit_code in 0..(self.num_vars * 2) {
1334            let lit = Lit::from_code(lit_code as u32);
1335            let implications = self.binary_graph.get(lit);
1336            if !implications.is_empty() {
1337                let lit_str = if lit.is_pos() {
1338                    format!("v{}", lit.var().index())
1339                } else {
1340                    format!("~v{}", lit.var().index())
1341                };
1342                for &(implied, _cid) in implications {
1343                    let impl_str = if implied.is_pos() {
1344                        format!("v{}", implied.var().index())
1345                    } else {
1346                        format!("~v{}", implied.var().index())
1347                    };
1348                    println!("  {} -> {}", lit_str, impl_str);
1349                }
1350            }
1351        }
1352    }
1353}
1354
1355#[cfg(test)]
1356mod tests {
1357    use super::*;
1358
1359    #[test]
1360    fn test_empty_sat() {
1361        let mut solver = Solver::new();
1362        assert_eq!(solver.solve(), SolverResult::Sat);
1363    }
1364
1365    #[test]
1366    fn test_simple_sat() {
1367        let mut solver = Solver::new();
1368        let _x = solver.new_var();
1369        let _y = solver.new_var();
1370
1371        // x or y
1372        solver.add_clause_dimacs(&[1, 2]);
1373        // not x or y
1374        solver.add_clause_dimacs(&[-1, 2]);
1375
1376        assert_eq!(solver.solve(), SolverResult::Sat);
1377        assert!(solver.model_value(Var::new(1)).is_true()); // y must be true
1378    }
1379
1380    #[test]
1381    fn test_simple_unsat() {
1382        let mut solver = Solver::new();
1383        let _x = solver.new_var();
1384
1385        // x
1386        solver.add_clause_dimacs(&[1]);
1387        // not x
1388        solver.add_clause_dimacs(&[-1]);
1389
1390        assert_eq!(solver.solve(), SolverResult::Unsat);
1391    }
1392
1393    #[test]
1394    fn test_pigeonhole_2_1() {
1395        // 2 pigeons, 1 hole - UNSAT
1396        let mut solver = Solver::new();
1397        let _p1h1 = solver.new_var(); // pigeon 1 in hole 1
1398        let _p2h1 = solver.new_var(); // pigeon 2 in hole 1
1399
1400        // Each pigeon must be in some hole
1401        solver.add_clause_dimacs(&[1]); // p1 in h1
1402        solver.add_clause_dimacs(&[2]); // p2 in h1
1403
1404        // No hole can have two pigeons
1405        solver.add_clause_dimacs(&[-1, -2]); // not (p1h1 and p2h1)
1406
1407        assert_eq!(solver.solve(), SolverResult::Unsat);
1408    }
1409
1410    #[test]
1411    fn test_3sat_random() {
1412        let mut solver = Solver::new();
1413        for _ in 0..10 {
1414            solver.new_var();
1415        }
1416
1417        // Random 3-SAT instance (likely SAT)
1418        solver.add_clause_dimacs(&[1, 2, 3]);
1419        solver.add_clause_dimacs(&[-1, 4, 5]);
1420        solver.add_clause_dimacs(&[2, -3, 6]);
1421        solver.add_clause_dimacs(&[-4, 7, 8]);
1422        solver.add_clause_dimacs(&[5, -6, 9]);
1423        solver.add_clause_dimacs(&[-7, 8, 10]);
1424        solver.add_clause_dimacs(&[1, -8, -9]);
1425        solver.add_clause_dimacs(&[-2, 3, -10]);
1426
1427        let result = solver.solve();
1428        assert_eq!(result, SolverResult::Sat);
1429    }
1430
1431    #[test]
1432    fn test_luby_sequence() {
1433        // Luby sequence: 1, 1, 2, 1, 1, 2, 4, 1, 1, 2, 1, 1, 2, 4, 8, ...
1434        assert_eq!(Solver::luby(0), 1);
1435        assert_eq!(Solver::luby(1), 1);
1436        assert_eq!(Solver::luby(2), 2);
1437        assert_eq!(Solver::luby(3), 1);
1438        assert_eq!(Solver::luby(4), 1);
1439        assert_eq!(Solver::luby(5), 2);
1440        assert_eq!(Solver::luby(6), 4);
1441        assert_eq!(Solver::luby(7), 1);
1442    }
1443
1444    #[test]
1445    fn test_phase_saving() {
1446        let mut solver = Solver::new();
1447        for _ in 0..5 {
1448            solver.new_var();
1449        }
1450
1451        // Set up a problem where phase saving helps
1452        solver.add_clause_dimacs(&[1, 2]);
1453        solver.add_clause_dimacs(&[-1, 3]);
1454        solver.add_clause_dimacs(&[-2, 4]);
1455        solver.add_clause_dimacs(&[-3, -4, 5]);
1456        solver.add_clause_dimacs(&[-5, 1]);
1457
1458        let result = solver.solve();
1459        assert_eq!(result, SolverResult::Sat);
1460    }
1461
1462    #[test]
1463    fn test_lbd_computation() {
1464        // Test that clause deletion can handle a problem that generates learned clauses
1465        let mut solver = Solver::with_config(SolverConfig {
1466            clause_deletion_threshold: 5, // Trigger deletion quickly
1467            ..SolverConfig::default()
1468        });
1469
1470        for _ in 0..20 {
1471            solver.new_var();
1472        }
1473
1474        // A harder problem to generate more conflicts and learned clauses
1475        // PHP(3,2): 3 pigeons, 2 holes - UNSAT
1476        // Variables: p_i_h (pigeon i in hole h)
1477        // p11=1, p12=2, p21=3, p22=4, p31=5, p32=6
1478
1479        // Each pigeon must be in some hole
1480        solver.add_clause_dimacs(&[1, 2]); // p1 in h1 or h2
1481        solver.add_clause_dimacs(&[3, 4]); // p2 in h1 or h2
1482        solver.add_clause_dimacs(&[5, 6]); // p3 in h1 or h2
1483
1484        // No hole can have two pigeons
1485        solver.add_clause_dimacs(&[-1, -3]); // not (p1h1 and p2h1)
1486        solver.add_clause_dimacs(&[-1, -5]); // not (p1h1 and p3h1)
1487        solver.add_clause_dimacs(&[-3, -5]); // not (p2h1 and p3h1)
1488        solver.add_clause_dimacs(&[-2, -4]); // not (p1h2 and p2h2)
1489        solver.add_clause_dimacs(&[-2, -6]); // not (p1h2 and p3h2)
1490        solver.add_clause_dimacs(&[-4, -6]); // not (p2h2 and p3h2)
1491
1492        let result = solver.solve();
1493        assert_eq!(result, SolverResult::Unsat);
1494        // Verify we had some conflicts (and thus learned clauses)
1495        assert!(solver.stats().conflicts > 0);
1496    }
1497
1498    #[test]
1499    fn test_clause_activity_decay() {
1500        let mut solver = Solver::new();
1501        for _ in 0..10 {
1502            solver.new_var();
1503        }
1504
1505        // Add some clauses
1506        solver.add_clause_dimacs(&[1, 2, 3]);
1507        solver.add_clause_dimacs(&[-1, 4, 5]);
1508        solver.add_clause_dimacs(&[-2, -3, 6]);
1509
1510        // Solve (should be SAT)
1511        let result = solver.solve();
1512        assert_eq!(result, SolverResult::Sat);
1513    }
1514
1515    #[test]
1516    fn test_clause_minimization() {
1517        // Test that clause minimization works correctly on a problem
1518        // that will generate learned clauses
1519        let mut solver = Solver::new();
1520
1521        for _ in 0..15 {
1522            solver.new_var();
1523        }
1524
1525        // A problem structure that generates conflicts and learned clauses
1526        // Graph coloring with 3 colors on 5 vertices
1527        // Vertices: 1-5, Colors: R(0-4), G(5-9), B(10-14)
1528
1529        // Each vertex has at least one color
1530        solver.add_clause_dimacs(&[1, 6, 11]); // v1: R or G or B
1531        solver.add_clause_dimacs(&[2, 7, 12]); // v2
1532        solver.add_clause_dimacs(&[3, 8, 13]); // v3
1533        solver.add_clause_dimacs(&[4, 9, 14]); // v4
1534        solver.add_clause_dimacs(&[5, 10, 15]); // v5
1535
1536        // At most one color per vertex (pairwise exclusion)
1537        solver.add_clause_dimacs(&[-1, -6]); // v1: not (R and G)
1538        solver.add_clause_dimacs(&[-1, -11]); // v1: not (R and B)
1539        solver.add_clause_dimacs(&[-6, -11]); // v1: not (G and B)
1540
1541        solver.add_clause_dimacs(&[-2, -7]);
1542        solver.add_clause_dimacs(&[-2, -12]);
1543        solver.add_clause_dimacs(&[-7, -12]);
1544
1545        solver.add_clause_dimacs(&[-3, -8]);
1546        solver.add_clause_dimacs(&[-3, -13]);
1547        solver.add_clause_dimacs(&[-8, -13]);
1548
1549        // Adjacent vertices have different colors (edges: 1-2, 2-3, 3-4, 4-5)
1550        solver.add_clause_dimacs(&[-1, -2]); // edge 1-2: not both R
1551        solver.add_clause_dimacs(&[-6, -7]); // edge 1-2: not both G
1552        solver.add_clause_dimacs(&[-11, -12]); // edge 1-2: not both B
1553
1554        solver.add_clause_dimacs(&[-2, -3]); // edge 2-3
1555        solver.add_clause_dimacs(&[-7, -8]);
1556        solver.add_clause_dimacs(&[-12, -13]);
1557
1558        let result = solver.solve();
1559        assert_eq!(result, SolverResult::Sat);
1560
1561        // The solver may or may not have conflicts/learned clauses depending on
1562        // the decision heuristic. The key is that the result is correct.
1563        // If there are learned clauses, minimization would have been applied.
1564    }
1565
1566    /// A simple theory callback that does nothing (pure SAT)
1567    struct NullTheory;
1568
1569    impl TheoryCallback for NullTheory {
1570        fn on_assignment(&mut self, _lit: Lit) -> TheoryCheckResult {
1571            TheoryCheckResult::Sat
1572        }
1573
1574        fn final_check(&mut self) -> TheoryCheckResult {
1575            TheoryCheckResult::Sat
1576        }
1577
1578        fn on_backtrack(&mut self, _level: u32) {}
1579    }
1580
1581    #[test]
1582    fn test_solve_with_theory_sat() {
1583        let mut solver = Solver::new();
1584        let mut theory = NullTheory;
1585
1586        let _x = solver.new_var();
1587        let _y = solver.new_var();
1588
1589        // x or y
1590        solver.add_clause_dimacs(&[1, 2]);
1591        // not x or y
1592        solver.add_clause_dimacs(&[-1, 2]);
1593
1594        assert_eq!(solver.solve_with_theory(&mut theory), SolverResult::Sat);
1595        assert!(solver.model_value(Var::new(1)).is_true()); // y must be true
1596    }
1597
1598    #[test]
1599    fn test_solve_with_theory_unsat() {
1600        let mut solver = Solver::new();
1601        let mut theory = NullTheory;
1602
1603        let _x = solver.new_var();
1604
1605        // x
1606        solver.add_clause_dimacs(&[1]);
1607        // not x
1608        solver.add_clause_dimacs(&[-1]);
1609
1610        assert_eq!(solver.solve_with_theory(&mut theory), SolverResult::Unsat);
1611    }
1612
1613    /// A theory that forces x0 => x1 (if x0 is true, x1 must be true)
1614    struct ImplicationTheory {
1615        /// Track if x0 is assigned true
1616        x0_true: bool,
1617    }
1618
1619    impl ImplicationTheory {
1620        fn new() -> Self {
1621            Self { x0_true: false }
1622        }
1623    }
1624
1625    impl TheoryCallback for ImplicationTheory {
1626        fn on_assignment(&mut self, lit: Lit) -> TheoryCheckResult {
1627            // If x0 becomes true, propagate x1
1628            if lit.var().index() == 0 && lit.is_pos() {
1629                self.x0_true = true;
1630                // Propagate: x1 must be true because x0 is true
1631                // The reason is: ~x0 (if x0 were false, we wouldn't need x1)
1632                let reason: SmallVec<[Lit; 8]> = smallvec::smallvec![Lit::pos(Var::new(0))];
1633                return TheoryCheckResult::Propagated(vec![(Lit::pos(Var::new(1)), reason)]);
1634            }
1635            TheoryCheckResult::Sat
1636        }
1637
1638        fn final_check(&mut self) -> TheoryCheckResult {
1639            TheoryCheckResult::Sat
1640        }
1641
1642        fn on_backtrack(&mut self, _level: u32) {
1643            self.x0_true = false;
1644        }
1645    }
1646
1647    #[test]
1648    fn test_theory_propagation() {
1649        let mut solver = Solver::new();
1650        let mut theory = ImplicationTheory::new();
1651
1652        let _x0 = solver.new_var();
1653        let _x1 = solver.new_var();
1654
1655        // Force x0 to be true
1656        solver.add_clause_dimacs(&[1]);
1657
1658        let result = solver.solve_with_theory(&mut theory);
1659        assert_eq!(result, SolverResult::Sat);
1660
1661        // x0 should be true (forced by clause)
1662        assert!(solver.model_value(Var::new(0)).is_true());
1663        // x1 should also be true (propagated by theory)
1664        assert!(solver.model_value(Var::new(1)).is_true());
1665    }
1666
1667    /// Theory that says x0 and x1 can't both be true
1668    struct MutexTheory {
1669        x0_true: Option<Lit>,
1670        x1_true: Option<Lit>,
1671    }
1672
1673    impl MutexTheory {
1674        fn new() -> Self {
1675            Self {
1676                x0_true: None,
1677                x1_true: None,
1678            }
1679        }
1680    }
1681
1682    impl TheoryCallback for MutexTheory {
1683        fn on_assignment(&mut self, lit: Lit) -> TheoryCheckResult {
1684            if lit.var().index() == 0 && lit.is_pos() {
1685                self.x0_true = Some(lit);
1686            }
1687            if lit.var().index() == 1 && lit.is_pos() {
1688                self.x1_true = Some(lit);
1689            }
1690
1691            // If both are true, conflict
1692            if self.x0_true.is_some() && self.x1_true.is_some() {
1693                // Conflict clause: ~x0 or ~x1 (at least one must be false)
1694                let conflict: SmallVec<[Lit; 8]> = smallvec::smallvec![
1695                    Lit::pos(Var::new(0)), // x0 is true (we negate in conflict)
1696                    Lit::pos(Var::new(1))  // x1 is true
1697                ];
1698                return TheoryCheckResult::Conflict(conflict);
1699            }
1700            TheoryCheckResult::Sat
1701        }
1702
1703        fn final_check(&mut self) -> TheoryCheckResult {
1704            if self.x0_true.is_some() && self.x1_true.is_some() {
1705                let conflict: SmallVec<[Lit; 8]> =
1706                    smallvec::smallvec![Lit::pos(Var::new(0)), Lit::pos(Var::new(1))];
1707                return TheoryCheckResult::Conflict(conflict);
1708            }
1709            TheoryCheckResult::Sat
1710        }
1711
1712        fn on_backtrack(&mut self, _level: u32) {
1713            self.x0_true = None;
1714            self.x1_true = None;
1715        }
1716    }
1717
1718    #[test]
1719    fn test_theory_conflict() {
1720        let mut solver = Solver::new();
1721        let mut theory = MutexTheory::new();
1722
1723        let _x0 = solver.new_var();
1724        let _x1 = solver.new_var();
1725
1726        // Force both x0 and x1 to be true (should cause theory conflict)
1727        solver.add_clause_dimacs(&[1]);
1728        solver.add_clause_dimacs(&[2]);
1729
1730        let result = solver.solve_with_theory(&mut theory);
1731        assert_eq!(result, SolverResult::Unsat);
1732    }
1733
1734    #[test]
1735    fn test_solve_with_assumptions_sat() {
1736        let mut solver = Solver::new();
1737
1738        let x0 = solver.new_var();
1739        let x1 = solver.new_var();
1740
1741        // x0 \/ x1
1742        solver.add_clause([Lit::pos(x0), Lit::pos(x1)]);
1743
1744        // Assume x0 = true
1745        let assumptions = [Lit::pos(x0)];
1746        let (result, core) = solver.solve_with_assumptions(&assumptions);
1747
1748        assert_eq!(result, SolverResult::Sat);
1749        assert!(core.is_none());
1750    }
1751
1752    #[test]
1753    fn test_solve_with_assumptions_unsat() {
1754        let mut solver = Solver::new();
1755
1756        let x0 = solver.new_var();
1757        let x1 = solver.new_var();
1758
1759        // x0 -> ~x1 (encoded as ~x0 \/ ~x1)
1760        solver.add_clause([Lit::neg(x0), Lit::neg(x1)]);
1761
1762        // Assume both x0 = true and x1 = true (should be UNSAT)
1763        let assumptions = [Lit::pos(x0), Lit::pos(x1)];
1764        let (result, core) = solver.solve_with_assumptions(&assumptions);
1765
1766        assert_eq!(result, SolverResult::Unsat);
1767        assert!(core.is_some());
1768        let core = core.expect("UNSAT result must have conflict core");
1769        // Core should contain at least one of the conflicting assumptions
1770        assert!(!core.is_empty());
1771    }
1772
1773    #[test]
1774    fn test_solve_with_assumptions_core_extraction() {
1775        let mut solver = Solver::new();
1776
1777        let x0 = solver.new_var();
1778        let x1 = solver.new_var();
1779        let x2 = solver.new_var();
1780
1781        // ~x0 (x0 must be false)
1782        solver.add_clause([Lit::neg(x0)]);
1783
1784        // Assume x0 = true, x1 = true, x2 = true
1785        // Only x0 should be in the core
1786        let assumptions = [Lit::pos(x0), Lit::pos(x1), Lit::pos(x2)];
1787        let (result, core) = solver.solve_with_assumptions(&assumptions);
1788
1789        assert_eq!(result, SolverResult::Unsat);
1790        assert!(core.is_some());
1791        let core = core.expect("UNSAT result must have conflict core");
1792        // x0 should be in the core
1793        assert!(core.contains(&Lit::pos(x0)));
1794    }
1795
1796    #[test]
1797    fn test_solve_with_assumptions_incremental() {
1798        let mut solver = Solver::new();
1799
1800        let x0 = solver.new_var();
1801        let x1 = solver.new_var();
1802
1803        // x0 \/ x1
1804        solver.add_clause([Lit::pos(x0), Lit::pos(x1)]);
1805
1806        // First: assume ~x0 (should be SAT with x1 = true)
1807        let (result1, _) = solver.solve_with_assumptions(&[Lit::neg(x0)]);
1808        assert_eq!(result1, SolverResult::Sat);
1809
1810        // Second: assume ~x0 and ~x1 (should be UNSAT)
1811        let (result2, core2) = solver.solve_with_assumptions(&[Lit::neg(x0), Lit::neg(x1)]);
1812        assert_eq!(result2, SolverResult::Unsat);
1813        assert!(core2.is_some());
1814
1815        // Third: assume x0 (should be SAT again)
1816        let (result3, _) = solver.solve_with_assumptions(&[Lit::pos(x0)]);
1817        assert_eq!(result3, SolverResult::Sat);
1818    }
1819
1820    #[test]
1821    fn test_push_pop_simple() {
1822        let mut solver = Solver::new();
1823
1824        let x0 = solver.new_var();
1825
1826        // Should be SAT (x0 can be true or false)
1827        assert_eq!(solver.solve(), SolverResult::Sat);
1828
1829        // Push and add unit clause: x0
1830        solver.push();
1831        solver.add_clause([Lit::pos(x0)]);
1832        assert_eq!(solver.solve(), SolverResult::Sat);
1833        assert!(solver.model_value(x0).is_true());
1834
1835        // Pop - should be SAT again
1836        solver.pop();
1837        let result = solver.solve();
1838        assert_eq!(
1839            result,
1840            SolverResult::Sat,
1841            "After pop, expected SAT but got {:?}. trivially_unsat={}",
1842            result,
1843            solver.trivially_unsat
1844        );
1845    }
1846
1847    #[test]
1848    fn test_push_pop_incremental() {
1849        let mut solver = Solver::new();
1850
1851        let x0 = solver.new_var();
1852        let x1 = solver.new_var();
1853        let x2 = solver.new_var();
1854
1855        // Base level: x0 \/ x1
1856        solver.add_clause([Lit::pos(x0), Lit::pos(x1)]);
1857        assert_eq!(solver.solve(), SolverResult::Sat);
1858
1859        // Push and add: ~x0
1860        solver.push();
1861        solver.add_clause([Lit::neg(x0)]);
1862        assert_eq!(solver.solve(), SolverResult::Sat);
1863        // x1 must be true
1864        assert!(solver.model_value(x1).is_true());
1865
1866        // Push again and add: ~x1 (should be UNSAT)
1867        solver.push();
1868        solver.add_clause([Lit::neg(x1)]);
1869        assert_eq!(solver.solve(), SolverResult::Unsat);
1870
1871        // Pop back one level (remove ~x1, keep ~x0)
1872        solver.pop();
1873        assert_eq!(solver.solve(), SolverResult::Sat);
1874        assert!(solver.model_value(x1).is_true());
1875
1876        // Pop back to base level (remove ~x0)
1877        solver.pop();
1878        assert_eq!(solver.solve(), SolverResult::Sat);
1879        // Either x0 or x1 can be true now
1880
1881        // Push and add different clause: x0 /\ x2
1882        solver.push();
1883        solver.add_clause([Lit::pos(x0)]);
1884        solver.add_clause([Lit::pos(x2)]);
1885        assert_eq!(solver.solve(), SolverResult::Sat);
1886        assert!(solver.model_value(x0).is_true());
1887        assert!(solver.model_value(x2).is_true());
1888
1889        // Pop and verify clauses are removed
1890        solver.pop();
1891        assert_eq!(solver.solve(), SolverResult::Sat);
1892    }
1893
1894    #[test]
1895    fn test_push_pop_with_learned_clauses() {
1896        let mut solver = Solver::new();
1897
1898        let x0 = solver.new_var();
1899        let x1 = solver.new_var();
1900        let x2 = solver.new_var();
1901
1902        // Create a formula that will cause learning
1903        // (x0 \/ x1) /\ (~x0 \/ x2) /\ (~x1 \/ x2)
1904        solver.add_clause([Lit::pos(x0), Lit::pos(x1)]);
1905        solver.add_clause([Lit::neg(x0), Lit::pos(x2)]);
1906        solver.add_clause([Lit::neg(x1), Lit::pos(x2)]);
1907
1908        assert_eq!(solver.solve(), SolverResult::Sat);
1909
1910        // Push and add conflicting clause
1911        solver.push();
1912        solver.add_clause([Lit::neg(x2)]);
1913
1914        // This should be UNSAT and cause clause learning
1915        assert_eq!(solver.solve(), SolverResult::Unsat);
1916
1917        // Pop - learned clauses from this level should be removed
1918        solver.pop();
1919
1920        // Should be SAT again
1921        assert_eq!(solver.solve(), SolverResult::Sat);
1922    }
1923}