Skip to main content

oxiz_sat/solver/
mod.rs

1//! CDCL SAT Solver
2
3mod conflict;
4mod decide;
5pub mod heuristic;
6mod learn;
7mod propagate;
8
9pub use heuristic::{BoxedBranchingHeuristic, BranchingHeuristic};
10
11use crate::chb::CHB;
12use crate::chrono::ChronoBacktrack;
13use crate::clause::{ClauseDatabase, ClauseId};
14use crate::literal::{LBool, Lit, Var};
15use crate::lrb::LRB;
16use crate::memory_opt::{MemoryAction, MemoryOptimizer};
17#[allow(unused_imports)]
18use crate::prelude::*;
19use crate::trail::{Reason, Trail};
20use crate::vsids::VSIDS;
21use crate::watched::{WatchLists, Watcher};
22use smallvec::SmallVec;
23
24/// Binary implication graph for efficient binary clause propagation
25/// For each literal L, stores the list of literals that are implied when L is false
26/// (i.e., for binary clause (~L v M), when L is assigned false, M must be true)
27#[derive(Debug, Clone)]
28pub(super) struct BinaryImplicationGraph {
29    /// implications[lit] = list of (implied_lit, clause_id) pairs
30    implications: Vec<Vec<(Lit, ClauseId)>>,
31}
32
33impl BinaryImplicationGraph {
34    fn new(num_vars: usize) -> Self {
35        Self {
36            implications: vec![Vec::new(); num_vars * 2],
37        }
38    }
39
40    fn resize(&mut self, num_vars: usize) {
41        self.implications.resize(num_vars * 2, Vec::new());
42    }
43
44    fn add(&mut self, lit: Lit, implied: Lit, clause_id: ClauseId) {
45        self.implications[lit.code() as usize].push((implied, clause_id));
46    }
47
48    fn get(&self, lit: Lit) -> &[(Lit, ClauseId)] {
49        &self.implications[lit.code() as usize]
50    }
51
52    fn clear(&mut self) {
53        for implications in &mut self.implications {
54            implications.clear();
55        }
56    }
57}
58
59/// Result from a theory check
60#[derive(Debug, Clone)]
61pub enum TheoryCheckResult {
62    /// Theory is satisfied under current assignment
63    Sat,
64    /// Theory detected a conflict, returns conflict clause literals
65    Conflict(SmallVec<[Lit; 8]>),
66    /// Theory propagated new literals (lit, reason clause)
67    Propagated(Vec<(Lit, SmallVec<[Lit; 8]>)>),
68}
69
70/// Callback trait for theory solvers
71/// The CDCL(T) solver implements this to receive theory callbacks
72pub trait TheoryCallback {
73    /// Called when a literal is assigned
74    /// Returns a theory check result
75    fn on_assignment(&mut self, lit: Lit) -> TheoryCheckResult;
76
77    /// Called after propagation is complete to do a full theory check
78    fn final_check(&mut self) -> TheoryCheckResult;
79
80    /// Called when the decision level increases
81    fn on_new_level(&mut self, _level: u32) {}
82
83    /// Called when backtracking
84    fn on_backtrack(&mut self, level: u32);
85}
86
87/// Result of SAT solving
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub enum SolverResult {
90    /// Satisfiable
91    Sat,
92    /// Unsatisfiable
93    Unsat,
94    /// Unknown (e.g., timeout, resource limit)
95    Unknown,
96}
97
98/// Solver configuration
99#[derive(Clone)]
100pub struct SolverConfig {
101    /// Restart interval (number of conflicts)
102    pub restart_interval: u64,
103    /// Restart multiplier for geometric restarts
104    pub restart_multiplier: f64,
105    /// Clause deletion threshold
106    pub clause_deletion_threshold: usize,
107    /// Variable decay factor
108    pub var_decay: f64,
109    /// Clause decay factor
110    pub clause_decay: f64,
111    /// Random polarity probability (0.0 to 1.0)
112    pub random_polarity_prob: f64,
113    /// Restart strategy: "luby" or "geometric"
114    pub restart_strategy: RestartStrategy,
115    /// Enable lazy hyper-binary resolution
116    pub enable_lazy_hyper_binary: bool,
117    /// Use CHB instead of VSIDS for branching
118    pub use_chb_branching: bool,
119    /// Use LRB (Learning Rate Branching) for branching
120    pub use_lrb_branching: bool,
121    /// Enable inprocessing (periodic preprocessing during search)
122    pub enable_inprocessing: bool,
123    /// Inprocessing interval (number of conflicts between inprocessing)
124    pub inprocessing_interval: u64,
125    /// Enable chronological backtracking
126    pub enable_chronological_backtrack: bool,
127    /// Chronological backtracking threshold (max distance from assertion level)
128    pub chrono_backtrack_threshold: u32,
129    /// Optional external branching heuristic. When `Some`, called before built-in
130    /// VSIDS/LRB/CHB; returning `None` from the heuristic falls back to built-in.
131    /// Default: `None` (pure built-in strategy).
132    pub external_branching: Option<BoxedBranchingHeuristic>,
133}
134
135impl core::fmt::Debug for SolverConfig {
136    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
137        f.debug_struct("SolverConfig")
138            .field("restart_interval", &self.restart_interval)
139            .field("restart_multiplier", &self.restart_multiplier)
140            .field("clause_deletion_threshold", &self.clause_deletion_threshold)
141            .field("var_decay", &self.var_decay)
142            .field("clause_decay", &self.clause_decay)
143            .field("random_polarity_prob", &self.random_polarity_prob)
144            .field("restart_strategy", &self.restart_strategy)
145            .field("enable_lazy_hyper_binary", &self.enable_lazy_hyper_binary)
146            .field("use_chb_branching", &self.use_chb_branching)
147            .field("use_lrb_branching", &self.use_lrb_branching)
148            .field("enable_inprocessing", &self.enable_inprocessing)
149            .field("inprocessing_interval", &self.inprocessing_interval)
150            .field(
151                "enable_chronological_backtrack",
152                &self.enable_chronological_backtrack,
153            )
154            .field(
155                "chrono_backtrack_threshold",
156                &self.chrono_backtrack_threshold,
157            )
158            .field(
159                "external_branching",
160                &self
161                    .external_branching
162                    .as_ref()
163                    .map(|_| "<BranchingHeuristic>"),
164            )
165            .finish()
166    }
167}
168
169/// Restart strategy
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum RestartStrategy {
172    /// Luby sequence restarts
173    Luby,
174    /// Geometric restarts
175    Geometric,
176    /// Glucose-style dynamic restarts based on LBD
177    Glucose,
178    /// Local restarts based on LBD trail
179    LocalLbd,
180}
181
182impl Default for SolverConfig {
183    fn default() -> Self {
184        Self {
185            restart_interval: 100,
186            restart_multiplier: 1.5,
187            clause_deletion_threshold: 10000,
188            var_decay: 0.95,
189            clause_decay: 0.999,
190            random_polarity_prob: 0.02,
191            restart_strategy: RestartStrategy::Luby,
192            enable_lazy_hyper_binary: true,
193            use_chb_branching: false,
194            use_lrb_branching: false,
195            enable_inprocessing: false,
196            inprocessing_interval: 5000,
197            enable_chronological_backtrack: true,
198            chrono_backtrack_threshold: 100,
199            external_branching: None,
200        }
201    }
202}
203
204/// Statistics for the solver
205#[derive(Debug, Default, Clone)]
206pub struct SolverStats {
207    /// Number of decisions made
208    pub decisions: u64,
209    /// Number of propagations
210    pub propagations: u64,
211    /// Number of conflicts
212    pub conflicts: u64,
213    /// Number of restarts
214    pub restarts: u64,
215    /// Number of learned clauses
216    pub learned_clauses: u64,
217    /// Number of deleted clauses
218    pub deleted_clauses: u64,
219    /// Number of binary clauses learned
220    pub binary_clauses: u64,
221    /// Number of unit clauses learned
222    pub unit_clauses: u64,
223    /// Total LBD of learned clauses
224    pub total_lbd: u64,
225    /// Number of clause minimizations
226    pub minimizations: u64,
227    /// Literals removed by minimization
228    pub literals_removed: u64,
229    /// Number of chronological backtracks
230    pub chrono_backtracks: u64,
231    /// Number of non-chronological backtracks
232    pub non_chrono_backtracks: u64,
233}
234
235impl SolverStats {
236    /// Get average LBD of learned clauses
237    #[must_use]
238    pub fn avg_lbd(&self) -> f64 {
239        if self.learned_clauses == 0 {
240            0.0
241        } else {
242            self.total_lbd as f64 / self.learned_clauses as f64
243        }
244    }
245
246    /// Get average decisions per conflict
247    #[must_use]
248    pub fn avg_decisions_per_conflict(&self) -> f64 {
249        if self.conflicts == 0 {
250            0.0
251        } else {
252            self.decisions as f64 / self.conflicts as f64
253        }
254    }
255
256    /// Get propagations per conflict
257    #[must_use]
258    pub fn propagations_per_conflict(&self) -> f64 {
259        if self.conflicts == 0 {
260            0.0
261        } else {
262            self.propagations as f64 / self.conflicts as f64
263        }
264    }
265
266    /// Get clause deletion ratio
267    #[must_use]
268    pub fn deletion_ratio(&self) -> f64 {
269        if self.learned_clauses == 0 {
270            0.0
271        } else {
272            self.deleted_clauses as f64 / self.learned_clauses as f64
273        }
274    }
275
276    /// Get chronological backtrack ratio
277    #[must_use]
278    pub fn chrono_backtrack_ratio(&self) -> f64 {
279        let total = self.chrono_backtracks + self.non_chrono_backtracks;
280        if total == 0 {
281            0.0
282        } else {
283            self.chrono_backtracks as f64 / total as f64
284        }
285    }
286
287    /// Display formatted statistics
288    pub fn display(&self) {
289        println!("========== Solver Statistics ==========");
290        println!("Decisions:              {:>12}", self.decisions);
291        println!("Propagations:           {:>12}", self.propagations);
292        println!("Conflicts:              {:>12}", self.conflicts);
293        println!("Restarts:               {:>12}", self.restarts);
294        println!("Learned clauses:        {:>12}", self.learned_clauses);
295        println!("  - Unit clauses:       {:>12}", self.unit_clauses);
296        println!("  - Binary clauses:     {:>12}", self.binary_clauses);
297        println!("Deleted clauses:        {:>12}", self.deleted_clauses);
298        println!("Minimizations:          {:>12}", self.minimizations);
299        println!("Literals removed:       {:>12}", self.literals_removed);
300        println!("Chrono backtracks:      {:>12}", self.chrono_backtracks);
301        println!("Non-chrono backtracks:  {:>12}", self.non_chrono_backtracks);
302        println!("---------------------------------------");
303        println!("Avg LBD:                {:>12.2}", self.avg_lbd());
304        println!(
305            "Avg decisions/conflict: {:>12.2}",
306            self.avg_decisions_per_conflict()
307        );
308        println!(
309            "Propagations/conflict:  {:>12.2}",
310            self.propagations_per_conflict()
311        );
312        println!(
313            "Deletion ratio:         {:>12.2}%",
314            self.deletion_ratio() * 100.0
315        );
316        println!(
317            "Chrono backtrack ratio: {:>12.2}%",
318            self.chrono_backtrack_ratio() * 100.0
319        );
320        println!("=======================================");
321    }
322}
323
324/// CDCL SAT Solver
325#[derive(Debug)]
326pub struct Solver {
327    /// Configuration
328    pub(super) config: SolverConfig,
329    /// Number of variables
330    pub(super) num_vars: usize,
331    /// Clause database
332    pub(super) clauses: ClauseDatabase,
333    /// Assignment trail
334    pub(super) trail: Trail,
335    /// Watch lists
336    pub(super) watches: WatchLists,
337    /// VSIDS branching heuristic
338    pub(super) vsids: VSIDS,
339    /// CHB branching heuristic
340    pub(super) chb: CHB,
341    /// LRB branching heuristic
342    pub(super) lrb: LRB,
343    /// Statistics
344    pub(super) stats: SolverStats,
345    /// Learnt clause for conflict analysis
346    pub(super) learnt: SmallVec<[Lit; 16]>,
347    /// Seen flags for conflict analysis
348    pub(super) seen: Vec<bool>,
349    /// Analyze stack
350    pub(super) analyze_stack: Vec<Lit>,
351    /// Current restart threshold
352    pub(super) restart_threshold: u64,
353    /// Assertions stack for incremental solving (number of original clauses)
354    pub(super) assertion_levels: Vec<usize>,
355    /// Trail sizes at each assertion level (for proper pop backtracking)
356    pub(super) assertion_trail_sizes: Vec<usize>,
357    /// Clause IDs added at each assertion level (for proper pop)
358    pub(super) assertion_clause_ids: Vec<Vec<ClauseId>>,
359    /// Model (if sat)
360    pub(super) model: Vec<LBool>,
361    /// Whether formula is trivially unsatisfiable
362    pub(super) trivially_unsat: bool,
363    /// Phase saving: last polarity assigned to each variable
364    pub(super) phase: Vec<bool>,
365    /// Luby sequence index for restarts
366    pub(super) luby_index: u64,
367    /// Level marks for LBD computation
368    pub(super) level_marks: Vec<u32>,
369    /// Current mark counter for LBD computation
370    pub(super) lbd_mark: u32,
371    /// Learned clause IDs for deletion
372    pub(super) learned_clause_ids: Vec<ClauseId>,
373    /// Number of conflicts since last clause deletion
374    pub(super) conflicts_since_deletion: u64,
375    /// PRNG state (xorshift64)
376    pub(super) rng_state: u64,
377    /// For Glucose-style restarts: average LBD of recent conflicts
378    pub(super) recent_lbd_sum: u64,
379    /// Number of conflicts contributing to recent_lbd_sum
380    pub(super) recent_lbd_count: u64,
381    /// Binary implication graph for fast binary clause propagation
382    pub(super) binary_graph: BinaryImplicationGraph,
383    /// Global average LBD for local restarts
384    pub(super) global_lbd_sum: u64,
385    /// Number of conflicts contributing to global LBD
386    pub(super) global_lbd_count: u64,
387    /// Conflicts since last local restart
388    pub(super) conflicts_since_local_restart: u64,
389    /// Conflicts since last inprocessing
390    pub(super) conflicts_since_inprocessing: u64,
391    /// Chronological backtracking helper
392    pub(super) chrono_backtrack: ChronoBacktrack,
393    /// Clause activity bump increment (for MapleSAT-style clause bumping)
394    pub(super) clause_bump_increment: f64,
395    /// Memory optimizer with size-class pools for clause allocation
396    pub(super) memory_optimizer: MemoryOptimizer,
397}
398
399impl Default for Solver {
400    fn default() -> Self {
401        Self::new()
402    }
403}
404
405impl Solver {
406    /// Create a new solver
407    #[must_use]
408    pub fn new() -> Self {
409        Self::with_config(SolverConfig::default())
410    }
411
412    /// Create a new solver with configuration
413    #[must_use]
414    pub fn with_config(config: SolverConfig) -> Self {
415        let chrono_enabled = config.enable_chronological_backtrack;
416        let chrono_threshold = config.chrono_backtrack_threshold;
417
418        Self {
419            restart_threshold: config.restart_interval,
420            config,
421            num_vars: 0,
422            clauses: ClauseDatabase::new(),
423            trail: Trail::new(0),
424            watches: WatchLists::new(0),
425            vsids: VSIDS::new(0),
426            chb: CHB::new(0),
427            lrb: LRB::new(0),
428            stats: SolverStats::default(),
429            learnt: SmallVec::new(),
430            seen: Vec::new(),
431            analyze_stack: Vec::new(),
432            assertion_levels: vec![0],
433            assertion_trail_sizes: vec![0],
434            assertion_clause_ids: vec![Vec::new()],
435            model: Vec::new(),
436            trivially_unsat: false,
437            phase: Vec::new(),
438            luby_index: 0,
439            level_marks: Vec::new(),
440            lbd_mark: 0,
441            learned_clause_ids: Vec::new(),
442            conflicts_since_deletion: 0,
443            rng_state: 0x853c_49e6_748f_ea9b, // Random seed
444            recent_lbd_sum: 0,
445            recent_lbd_count: 0,
446            binary_graph: BinaryImplicationGraph::new(0),
447            global_lbd_sum: 0,
448            global_lbd_count: 0,
449            conflicts_since_local_restart: 0,
450            conflicts_since_inprocessing: 0,
451            chrono_backtrack: ChronoBacktrack::new(chrono_enabled, chrono_threshold),
452            clause_bump_increment: 1.0,
453            memory_optimizer: MemoryOptimizer::new(),
454        }
455    }
456
457    /// Create a new variable
458    pub fn new_var(&mut self) -> Var {
459        let var = Var::new(self.num_vars as u32);
460        self.num_vars += 1;
461        self.trail.resize(self.num_vars);
462        self.watches.resize(self.num_vars);
463        self.binary_graph.resize(self.num_vars);
464        self.vsids.insert(var);
465        self.chb.insert(var);
466        self.lrb.resize(self.num_vars);
467        self.seen.resize(self.num_vars, false);
468        self.model.resize(self.num_vars, LBool::Undef);
469        self.phase.resize(self.num_vars, false); // Default phase: negative
470        // Resize level_marks to at least num_vars (enough for decision levels)
471        if self.level_marks.len() < self.num_vars {
472            self.level_marks.resize(self.num_vars, 0);
473        }
474        var
475    }
476
477    /// Ensure we have at least n variables
478    pub fn ensure_vars(&mut self, n: usize) {
479        while self.num_vars < n {
480            self.new_var();
481        }
482    }
483
484    /// Add a clause
485    pub fn add_clause(&mut self, lits: impl IntoIterator<Item = Lit>) -> bool {
486        let mut clause_lits: SmallVec<[Lit; 8]> = lits.into_iter().collect();
487
488        // Ensure we have all variables
489        for lit in &clause_lits {
490            let var_idx = lit.var().index();
491            if var_idx >= self.num_vars {
492                self.ensure_vars(var_idx + 1);
493            }
494        }
495
496        // Remove duplicates and check for tautology
497        clause_lits.sort_by_key(|l| l.code());
498        clause_lits.dedup();
499
500        // Check for tautology (x and ~x in same clause)
501        for i in 0..clause_lits.len() {
502            for j in (i + 1)..clause_lits.len() {
503                if clause_lits[i] == clause_lits[j].negate() {
504                    return true; // Tautology - always satisfied
505                }
506            }
507        }
508
509        // Handle special cases
510        match clause_lits.len() {
511            0 => {
512                self.trivially_unsat = true;
513                return false; // Empty clause - unsat
514            }
515            1 => {
516                // Unit clause - enqueue at decision level 0
517                // Unit clauses must be assigned at level 0 to survive backtracking.
518                // After solve(), current_level may be > 0, so we must backtrack first.
519                let lit = clause_lits[0];
520
521                if self.trail.lit_value(lit).is_false() {
522                    // The literal conflicts with the current trail.
523                    // Check if the conflict is at decision level 0 (permanent constraint)
524                    // or from a previous solve (can be retried after backtrack).
525                    let var = lit.var();
526                    let level = self.trail.level(var);
527                    if level == 0 {
528                        // Conflict with a level-0 assignment - truly UNSAT
529                        self.trivially_unsat = true;
530                        return false;
531                    } else {
532                        // Conflict with higher-level assignment from previous solve.
533                        // Backtrack to root and assign the new unit literal at level 0.
534                        self.backtrack_to_root();
535                        self.trail.assign_decision(lit);
536                        return true;
537                    }
538                }
539
540                if self.trail.lit_value(lit).is_true() {
541                    // Already satisfied - check if at level 0
542                    let var = lit.var();
543                    let level = self.trail.level(var);
544                    if level == 0 {
545                        // Already assigned at level 0, nothing to do
546                        return true;
547                    }
548                    // Assigned at higher level - backtrack and reassign at level 0
549                    self.backtrack_to_root();
550                    self.trail.assign_decision(lit);
551                    return true;
552                }
553
554                // Variable is unassigned - backtrack to level 0 first to ensure
555                // the assignment is at level 0 (survives future backtracks)
556                if self.trail.decision_level() > 0 {
557                    self.backtrack_to_root();
558                }
559                self.trail.assign_decision(lit);
560                return true;
561            }
562            2 => {
563                // Binary clause - check if it conflicts with current assignment
564                let lit0 = clause_lits[0];
565                let lit1 = clause_lits[1];
566                let val0 = self.trail.lit_value(lit0);
567                let val1 = self.trail.lit_value(lit1);
568
569                // If clause is satisfied, just add it
570                if val0.is_true() || val1.is_true() {
571                    // Clause already satisfied by current assignment
572                    let clause_id = self.clauses.add_original(clause_lits.iter().copied());
573                    if let Some(current_level_clauses) = self.assertion_clause_ids.last_mut() {
574                        current_level_clauses.push(clause_id);
575                    }
576                    self.binary_graph.add(lit0.negate(), lit1, clause_id);
577                    self.binary_graph.add(lit1.negate(), lit0, clause_id);
578                    self.watches
579                        .add(lit0.negate(), Watcher::new(clause_id, lit1));
580                    self.watches
581                        .add(lit1.negate(), Watcher::new(clause_id, lit0));
582                    return true;
583                }
584
585                // If both literals are false, we have a conflict
586                if val0.is_false() && val1.is_false() {
587                    // Check if both are at level 0
588                    let level0 = self.trail.level(lit0.var());
589                    let level1 = self.trail.level(lit1.var());
590
591                    if level0 == 0 && level1 == 0 {
592                        // Conflict at level 0 - UNSAT
593                        self.trivially_unsat = true;
594                        return false;
595                    }
596
597                    // Backtrack to level 0 and add clause
598                    // The clause will be propagated on next solve()
599                    self.backtrack_to_root();
600                }
601
602                // If one literal is false and one undefined, propagate
603                // after adding the clause (via next solve())
604
605                let clause_id = self.clauses.add_original(clause_lits.iter().copied());
606                if let Some(current_level_clauses) = self.assertion_clause_ids.last_mut() {
607                    current_level_clauses.push(clause_id);
608                }
609                self.binary_graph.add(lit0.negate(), lit1, clause_id);
610                self.binary_graph.add(lit1.negate(), lit0, clause_id);
611                self.watches
612                    .add(lit0.negate(), Watcher::new(clause_id, lit1));
613                self.watches
614                    .add(lit1.negate(), Watcher::new(clause_id, lit0));
615                return true;
616            }
617            _ => {}
618        }
619
620        // Add clause (3+ literals)
621        // Check if clause is satisfied or conflicts with current assignment
622        let num_false = clause_lits
623            .iter()
624            .filter(|&l| self.trail.lit_value(*l).is_false())
625            .count();
626        let has_true = clause_lits
627            .iter()
628            .any(|l| self.trail.lit_value(*l).is_true());
629
630        if !has_true && num_false == clause_lits.len() {
631            // All literals are false - conflict
632            // Check if all at level 0
633            let all_at_zero = clause_lits.iter().all(|l| self.trail.level(l.var()) == 0);
634            if all_at_zero {
635                self.trivially_unsat = true;
636                return false;
637            }
638            // Backtrack to level 0
639            self.backtrack_to_root();
640        }
641
642        let clause_id = self.clauses.add_original(clause_lits.iter().copied());
643
644        // Track clause for incremental solving
645        if let Some(current_level_clauses) = self.assertion_clause_ids.last_mut() {
646            current_level_clauses.push(clause_id);
647        }
648
649        // Set up watches - prefer non-false literals for watching
650        let lit0 = clause_lits[0];
651        let lit1 = clause_lits[1];
652
653        self.watches
654            .add(lit0.negate(), Watcher::new(clause_id, lit1));
655        self.watches
656            .add(lit1.negate(), Watcher::new(clause_id, lit0));
657
658        true
659    }
660
661    /// Add a clause from DIMACS literals
662    pub fn add_clause_dimacs(&mut self, lits: &[i32]) -> bool {
663        self.add_clause(lits.iter().map(|&l| Lit::from_dimacs(l)))
664    }
665
666    /// Solve the SAT problem
667    pub fn solve(&mut self) -> SolverResult {
668        // Check if trivially unsatisfiable
669        if self.trivially_unsat {
670            return SolverResult::Unsat;
671        }
672
673        // Initial propagation
674        if self.propagate().is_some() {
675            return SolverResult::Unsat;
676        }
677
678        loop {
679            // Propagate
680            if let Some(conflict) = self.propagate() {
681                self.stats.conflicts += 1;
682                self.conflicts_since_inprocessing += 1;
683
684                if self.trail.decision_level() == 0 {
685                    return SolverResult::Unsat;
686                }
687
688                // Analyze conflict
689                let (backtrack_level, learnt_clause) = self.analyze(conflict);
690
691                // Backtrack with phase saving
692                self.backtrack_with_phase_saving(backtrack_level);
693
694                // Learn clause
695                if learnt_clause.len() == 1 {
696                    // Store unit learned clause in database for persistence
697                    let clause_id = self.clauses.add_learned(learnt_clause.iter().copied());
698                    self.stats.learned_clauses += 1;
699                    self.stats.unit_clauses += 1;
700                    self.learned_clause_ids.push(clause_id);
701
702                    // Track for incremental solving
703                    if let Some(current_level_clauses) = self.assertion_clause_ids.last_mut() {
704                        current_level_clauses.push(clause_id);
705                    }
706
707                    self.trail.assign_decision(learnt_clause[0]);
708                } else {
709                    // Compute LBD for the learned clause
710                    let lbd = self.compute_lbd(&learnt_clause);
711
712                    // Track recent LBD for Glucose-style and local restarts
713                    self.recent_lbd_sum += u64::from(lbd);
714                    self.recent_lbd_count += 1;
715                    self.global_lbd_sum += u64::from(lbd);
716                    self.global_lbd_count += 1;
717
718                    // Reset recent LBD tracking periodically
719                    if self.recent_lbd_count >= 5000 {
720                        self.recent_lbd_sum /= 2;
721                        self.recent_lbd_count /= 2;
722                    }
723
724                    let clause_id = self.clauses.add_learned(learnt_clause.iter().copied());
725                    self.stats.learned_clauses += 1;
726
727                    // Set LBD score for the clause
728                    if let Some(clause) = self.clauses.get_mut(clause_id) {
729                        clause.lbd = lbd;
730                    }
731
732                    // Track learned clause for potential deletion
733                    self.learned_clause_ids.push(clause_id);
734
735                    // Track clause for incremental solving
736                    if let Some(current_level_clauses) = self.assertion_clause_ids.last_mut() {
737                        current_level_clauses.push(clause_id);
738                    }
739
740                    // Watch first two literals
741                    let lit0 = learnt_clause[0];
742                    let lit1 = learnt_clause[1];
743                    self.watches
744                        .add(lit0.negate(), Watcher::new(clause_id, lit1));
745                    self.watches
746                        .add(lit1.negate(), Watcher::new(clause_id, lit0));
747
748                    // Propagate the asserting literal
749                    self.trail.assign_propagation(learnt_clause[0], clause_id);
750                }
751
752                // Decay activities
753                self.vsids.decay();
754                self.chb.decay();
755                self.lrb.decay();
756                self.lrb.on_conflict();
757                self.clauses.decay_activity(self.config.clause_decay);
758                // Increase clause bump increment (inverse of decay)
759                self.clause_bump_increment /= self.config.clause_decay;
760
761                // Track conflicts for clause deletion
762                self.conflicts_since_deletion += 1;
763
764                // Periodic clause database reduction
765                if self.conflicts_since_deletion >= self.config.clause_deletion_threshold as u64 {
766                    self.reduce_clause_database();
767                    self.conflicts_since_deletion = 0;
768
769                    // Vivification after clause database reduction (at level 0 after restart)
770                    if self.stats.restarts.is_multiple_of(10) {
771                        let saved_level = self.trail.decision_level();
772                        if saved_level == 0 {
773                            self.vivify_clauses();
774                        }
775                    }
776                }
777
778                // Check for restart
779                if self.stats.conflicts >= self.restart_threshold {
780                    self.restart();
781                }
782
783                // Periodic inprocessing
784                if self.config.enable_inprocessing
785                    && self.conflicts_since_inprocessing >= self.config.inprocessing_interval
786                {
787                    self.inprocess();
788                    self.conflicts_since_inprocessing = 0;
789                }
790            } else {
791                // No conflict - try to decide
792                if let Some(var) = self.pick_branch_var() {
793                    self.stats.decisions += 1;
794                    self.trail.new_decision_level();
795
796                    // Use phase saving with random polarity
797                    let polarity = if self.rand_bool(self.config.random_polarity_prob) {
798                        // Random polarity
799                        self.rand_bool(0.5)
800                    } else {
801                        // Saved phase
802                        self.phase[var.index()]
803                    };
804                    let lit = if polarity {
805                        Lit::pos(var)
806                    } else {
807                        Lit::neg(var)
808                    };
809                    self.trail.assign_decision(lit);
810                } else {
811                    // All variables assigned - SAT
812                    self.save_model();
813                    return SolverResult::Sat;
814                }
815            }
816        }
817    }
818
819    /// Solve with assumptions and return unsat core if UNSAT
820    ///
821    /// This is the key method for MaxSAT: it solves under assumptions and
822    /// if the result is UNSAT, returns the subset of assumptions in the core.
823    ///
824    /// # Arguments
825    /// * `assumptions` - Literals that must be true
826    ///
827    /// # Returns
828    /// * `(SolverResult, Option<Vec<Lit>>)` - Result and unsat core (if UNSAT)
829    pub fn solve_with_assumptions(
830        &mut self,
831        assumptions: &[Lit],
832    ) -> (SolverResult, Option<Vec<Lit>>) {
833        if self.trivially_unsat {
834            return (SolverResult::Unsat, Some(Vec::new()));
835        }
836
837        // Ensure all assumption variables exist
838        for &lit in assumptions {
839            while self.num_vars <= lit.var().index() {
840                self.new_var();
841            }
842        }
843
844        // Initial propagation at level 0
845        if self.propagate().is_some() {
846            return (SolverResult::Unsat, Some(Vec::new()));
847        }
848
849        // Create a new decision level for assumptions
850        let assumption_level_start = self.trail.decision_level();
851
852        // Assign assumptions as decisions
853        for (i, &lit) in assumptions.iter().enumerate() {
854            // Check if already assigned
855            let value = self.trail.lit_value(lit);
856            if value.is_true() {
857                continue; // Already satisfied
858            }
859            if value.is_false() {
860                // Conflict with assumption - extract core from conflicting assumptions
861                let core = self.extract_assumption_core(assumptions, i);
862                self.backtrack(assumption_level_start);
863                return (SolverResult::Unsat, Some(core));
864            }
865
866            // Make decision for assumption
867            self.trail.new_decision_level();
868            self.trail.assign_decision(lit);
869
870            // Propagate after each assumption
871            if let Some(_conflict) = self.propagate() {
872                // Conflict during assumption propagation
873                let core = self.analyze_assumption_conflict(assumptions);
874                self.backtrack(assumption_level_start);
875                return (SolverResult::Unsat, Some(core));
876            }
877        }
878
879        // Now solve normally
880        loop {
881            if let Some(conflict) = self.propagate() {
882                self.stats.conflicts += 1;
883
884                // Check if conflict involves assumptions
885                let backtrack_level = self.analyze_conflict_level(conflict);
886
887                if backtrack_level <= assumption_level_start {
888                    // Conflict forces backtracking past assumptions - UNSAT
889                    let core = self.analyze_assumption_conflict(assumptions);
890                    self.backtrack(assumption_level_start);
891                    return (SolverResult::Unsat, Some(core));
892                }
893
894                let (bt_level, learnt_clause) = self.analyze(conflict);
895                self.backtrack_with_phase_saving(bt_level.max(assumption_level_start + 1));
896                self.learn_clause(learnt_clause);
897
898                self.vsids.decay();
899                self.clauses.decay_activity(self.config.clause_decay);
900                self.handle_clause_deletion_and_restart_limited(assumption_level_start);
901            } else {
902                // No conflict - try to decide
903                if let Some(var) = self.pick_branch_var() {
904                    self.stats.decisions += 1;
905                    self.trail.new_decision_level();
906
907                    let polarity = if self.rand_bool(self.config.random_polarity_prob) {
908                        self.rand_bool(0.5)
909                    } else {
910                        self.phase.get(var.index()).copied().unwrap_or(false)
911                    };
912                    let lit = if polarity {
913                        Lit::pos(var)
914                    } else {
915                        Lit::neg(var)
916                    };
917                    self.trail.assign_decision(lit);
918                } else {
919                    // All variables assigned - SAT
920                    self.save_model();
921                    self.backtrack(assumption_level_start);
922                    return (SolverResult::Sat, None);
923                }
924            }
925        }
926    }
927
928    /// Solve with theory integration via callbacks
929    ///
930    /// This implements the CDCL(T) loop:
931    /// 1. BCP (Boolean Constraint Propagation)
932    /// 2. Theory propagation (via callback)
933    /// 3. On conflict: analyze and learn
934    /// 4. Decision
935    /// 5. Final theory check when all vars assigned
936    pub fn solve_with_theory<T: TheoryCallback>(&mut self, theory: &mut T) -> SolverResult {
937        if self.trivially_unsat {
938            return SolverResult::Unsat;
939        }
940
941        // Initial propagation
942        if self.propagate().is_some() {
943            return SolverResult::Unsat;
944        }
945
946        // Track how many assignments have been sent to the theory.
947        // We only send NEW assignments (not previously processed ones) to avoid
948        // duplicate theory constraints that would cause spurious UNSAT.
949        let mut theory_processed: usize = 0;
950
951        loop {
952            // Boolean propagation
953            if let Some(conflict) = self.propagate() {
954                self.stats.conflicts += 1;
955
956                if self.trail.decision_level() == 0 {
957                    return SolverResult::Unsat;
958                }
959
960                let (backtrack_level, learnt_clause) = self.analyze(conflict);
961                theory.on_backtrack(backtrack_level);
962                self.backtrack_with_phase_saving(backtrack_level);
963                // After backtrack, the trail may be shorter; update processed count
964                theory_processed = theory_processed.min(self.trail.assignments().len());
965                self.learn_clause(learnt_clause);
966
967                self.vsids.decay();
968                self.clauses.decay_activity(self.config.clause_decay);
969                self.handle_clause_deletion_and_restart();
970                continue;
971            }
972
973            // Theory propagation check after each assignment
974            loop {
975                // Get only NEW (unprocessed) assignments and notify theory
976                let assignments = self.trail.assignments().to_vec();
977                let mut theory_conflict = None;
978                let mut theory_propagations = Vec::new();
979
980                // Check only NEW assignments with theory (skip already-processed ones).
981                // Guard against stale theory_processed after backtracks/restarts.
982                let safe_start = theory_processed.min(assignments.len());
983                for &lit in &assignments[safe_start..] {
984                    match theory.on_assignment(lit) {
985                        TheoryCheckResult::Sat => {}
986                        TheoryCheckResult::Conflict(conflict_lits) => {
987                            theory_conflict = Some(conflict_lits);
988                            break;
989                        }
990                        TheoryCheckResult::Propagated(props) => {
991                            theory_propagations.extend(props);
992                        }
993                    }
994                }
995                // Update processed count
996                theory_processed = assignments.len();
997
998                // Handle theory conflict
999                if let Some(conflict_lits) = theory_conflict {
1000                    self.stats.conflicts += 1;
1001
1002                    if self.trail.decision_level() == 0 {
1003                        return SolverResult::Unsat;
1004                    }
1005
1006                    let (backtrack_level, learnt_clause) =
1007                        self.analyze_theory_conflict(&conflict_lits);
1008
1009                    // Empty learned clause signals all-level-0 conflict = fundamental UNSAT
1010                    if learnt_clause.is_empty() {
1011                        self.trivially_unsat = true;
1012                        return SolverResult::Unsat;
1013                    }
1014
1015                    theory.on_backtrack(backtrack_level);
1016                    self.backtrack_with_phase_saving(backtrack_level);
1017                    // After backtrack, update theory_processed to trail length
1018                    theory_processed = theory_processed.min(self.trail.assignments().len());
1019                    self.learn_clause(learnt_clause);
1020
1021                    self.vsids.decay();
1022                    self.clauses.decay_activity(self.config.clause_decay);
1023                    self.handle_clause_deletion_and_restart();
1024                    continue;
1025                }
1026
1027                // Handle theory propagations
1028                let mut made_propagation = false;
1029                for (lit, reason_lits) in theory_propagations {
1030                    if !self.trail.is_assigned(lit.var()) {
1031                        // Add reason clause and propagate
1032                        let clause_id = self.add_theory_reason_clause(&reason_lits, lit);
1033                        self.trail.assign_propagation(lit, clause_id);
1034                        made_propagation = true;
1035                    }
1036                }
1037
1038                if made_propagation {
1039                    // Re-run Boolean propagation
1040                    if let Some(conflict) = self.propagate() {
1041                        self.stats.conflicts += 1;
1042
1043                        if self.trail.decision_level() == 0 {
1044                            return SolverResult::Unsat;
1045                        }
1046
1047                        let (backtrack_level, learnt_clause) = self.analyze(conflict);
1048                        theory.on_backtrack(backtrack_level);
1049                        self.backtrack_with_phase_saving(backtrack_level);
1050                        // After backtrack, the trail is shorter; update processed count
1051                        theory_processed = theory_processed.min(self.trail.assignments().len());
1052                        self.learn_clause(learnt_clause);
1053
1054                        self.vsids.decay();
1055                        self.clauses.decay_activity(self.config.clause_decay);
1056                        self.handle_clause_deletion_and_restart();
1057                    }
1058                    continue;
1059                }
1060
1061                break;
1062            }
1063
1064            // Try to decide
1065            if let Some(var) = self.pick_branch_var() {
1066                self.stats.decisions += 1;
1067                self.trail.new_decision_level();
1068                let new_level = self.trail.decision_level();
1069                theory.on_new_level(new_level);
1070
1071                let polarity = if self.rand_bool(self.config.random_polarity_prob) {
1072                    self.rand_bool(0.5)
1073                } else {
1074                    self.phase[var.index()]
1075                };
1076                let lit = if polarity {
1077                    Lit::pos(var)
1078                } else {
1079                    Lit::neg(var)
1080                };
1081                self.trail.assign_decision(lit);
1082            } else {
1083                // All variables assigned - do final theory check
1084                match theory.final_check() {
1085                    TheoryCheckResult::Sat => {
1086                        self.save_model();
1087                        return SolverResult::Sat;
1088                    }
1089                    TheoryCheckResult::Conflict(conflict_lits) => {
1090                        self.stats.conflicts += 1;
1091
1092                        if self.trail.decision_level() == 0 {
1093                            return SolverResult::Unsat;
1094                        }
1095
1096                        let (backtrack_level, learnt_clause) =
1097                            self.analyze_theory_conflict(&conflict_lits);
1098
1099                        // If all conflict literals are at level 0, analyze_theory_conflict
1100                        // returns an empty learned clause as a signal of fundamental UNSAT.
1101                        if learnt_clause.is_empty() {
1102                            self.trivially_unsat = true;
1103                            return SolverResult::Unsat;
1104                        }
1105
1106                        theory.on_backtrack(backtrack_level);
1107                        self.backtrack_with_phase_saving(backtrack_level);
1108                        // After backtrack, update theory_processed
1109                        theory_processed = theory_processed.min(self.trail.assignments().len());
1110                        self.learn_clause(learnt_clause);
1111
1112                        self.vsids.decay();
1113                        self.clauses.decay_activity(self.config.clause_decay);
1114                        self.handle_clause_deletion_and_restart();
1115                    }
1116                    TheoryCheckResult::Propagated(props) => {
1117                        // Handle late propagations
1118                        for (lit, reason_lits) in props {
1119                            if !self.trail.is_assigned(lit.var()) {
1120                                let clause_id = self.add_theory_reason_clause(&reason_lits, lit);
1121                                self.trail.assign_propagation(lit, clause_id);
1122                            }
1123                        }
1124                    }
1125                }
1126            }
1127        }
1128    }
1129
1130    /// Get the model (if sat)
1131    #[must_use]
1132    pub fn model(&self) -> &[LBool] {
1133        &self.model
1134    }
1135
1136    /// Get the value of a variable in the model
1137    #[must_use]
1138    pub fn model_value(&self, var: Var) -> LBool {
1139        self.model.get(var.index()).copied().unwrap_or(LBool::Undef)
1140    }
1141
1142    /// Get statistics
1143    #[must_use]
1144    pub fn stats(&self) -> &SolverStats {
1145        &self.stats
1146    }
1147
1148    /// Get memory optimizer statistics
1149    #[must_use]
1150    pub fn memory_opt_stats(&self) -> &crate::memory_opt::MemoryOptStats {
1151        self.memory_optimizer.stats()
1152    }
1153
1154    /// Get number of variables
1155    #[must_use]
1156    pub fn num_vars(&self) -> usize {
1157        self.num_vars
1158    }
1159
1160    /// Get number of clauses
1161    #[must_use]
1162    pub fn num_clauses(&self) -> usize {
1163        self.clauses.len()
1164    }
1165
1166    /// Push a new assertion level (for incremental solving)
1167    ///
1168    /// This saves the current state so that clauses added after this point
1169    /// can be removed with pop(). Automatically backtracks to decision level 0
1170    /// to ensure a clean state for adding new constraints.
1171    pub fn push(&mut self) {
1172        // Backtrack to level 0 to ensure clean state
1173        // This is necessary because solve() may leave assignments on the trail
1174        // Use phase-saving backtrack to properly re-insert variables into decision heaps
1175        self.backtrack_with_phase_saving(0);
1176
1177        self.assertion_levels.push(self.clauses.num_original());
1178        self.assertion_trail_sizes.push(self.trail.size());
1179        self.assertion_clause_ids.push(Vec::new());
1180    }
1181
1182    /// Pop to previous assertion level
1183    pub fn pop(&mut self) {
1184        if self.assertion_levels.len() > 1 {
1185            self.assertion_levels.pop();
1186
1187            // Get the trail size to backtrack to
1188            let trail_size = self.assertion_trail_sizes.pop().unwrap_or(0);
1189
1190            // Remove all clauses added at this assertion level
1191            if let Some(clause_ids_to_remove) = self.assertion_clause_ids.pop() {
1192                for clause_id in clause_ids_to_remove {
1193                    // Remove from clause database
1194                    self.clauses.remove(clause_id);
1195
1196                    // Remove from learned clause tracking if it's a learned clause
1197                    self.learned_clause_ids.retain(|&id| id != clause_id);
1198
1199                    // Note: Watch lists will be cleaned up naturally during propagation
1200                    // as they check if clauses are deleted before using them
1201                }
1202            }
1203
1204            // Backtrack trail to the exact size it was at push()
1205            // This properly handles unit clauses that were added after push
1206            // Note: backtrack_to_size clears values but doesn't re-insert into heaps,
1207            // so we need to manually re-insert unassigned variables.
1208            let current_size = self.trail.size();
1209            if current_size > trail_size {
1210                // Collect variables that will be unassigned
1211                let mut unassigned_vars = Vec::new();
1212                for i in trail_size..current_size {
1213                    let lit = self.trail.assignments()[i];
1214                    unassigned_vars.push(lit.var());
1215                }
1216
1217                self.trail.backtrack_to_size(trail_size);
1218
1219                // Re-insert unassigned variables into decision heaps
1220                for var in unassigned_vars {
1221                    if !self.vsids.contains(var) {
1222                        self.vsids.insert(var);
1223                    }
1224                    if !self.chb.contains(var) {
1225                        self.chb.insert(var);
1226                    }
1227                    self.lrb.unassign(var);
1228                }
1229            }
1230
1231            // Ensure we're at decision level 0 with proper heap re-insertion
1232            self.backtrack_with_phase_saving(0);
1233
1234            // Clear the trivially_unsat flag as we've removed problematic clauses
1235            self.trivially_unsat = false;
1236        }
1237    }
1238
1239    /// Backtrack to decision level 0 (for AllSAT enumeration)
1240    ///
1241    /// This is necessary after a SAT result before adding blocking clauses
1242    /// to ensure the new clauses can trigger propagation correctly.
1243    /// Uses phase-saving backtrack to properly re-insert unassigned variables
1244    /// into the decision heaps (VSIDS, CHB, LRB).
1245    pub fn backtrack_to_root(&mut self) {
1246        self.backtrack_with_phase_saving(0);
1247    }
1248
1249    /// Reset the solver
1250    pub fn reset(&mut self) {
1251        self.clauses = ClauseDatabase::new();
1252        self.trail.clear();
1253        self.watches.clear();
1254        self.vsids.clear();
1255        self.chb.clear();
1256        self.stats = SolverStats::default();
1257        self.learnt.clear();
1258        self.seen.clear();
1259        self.analyze_stack.clear();
1260        self.assertion_levels.clear();
1261        self.assertion_levels.push(0);
1262        self.assertion_trail_sizes.clear();
1263        self.assertion_trail_sizes.push(0);
1264        self.assertion_clause_ids.clear();
1265        self.assertion_clause_ids.push(Vec::new());
1266        self.model.clear();
1267        self.num_vars = 0;
1268        self.restart_threshold = self.config.restart_interval;
1269        self.trivially_unsat = false;
1270        self.phase.clear();
1271        self.luby_index = 0;
1272        self.level_marks.clear();
1273        self.lbd_mark = 0;
1274        self.learned_clause_ids.clear();
1275        self.conflicts_since_deletion = 0;
1276        self.rng_state = 0x853c_49e6_748f_ea9b;
1277        self.recent_lbd_sum = 0;
1278        self.recent_lbd_count = 0;
1279        self.binary_graph.clear();
1280        self.global_lbd_sum = 0;
1281        self.global_lbd_count = 0;
1282        self.conflicts_since_local_restart = 0;
1283    }
1284
1285    /// Get the current trail (for theory solvers)
1286    #[must_use]
1287    pub fn trail(&self) -> &Trail {
1288        &self.trail
1289    }
1290
1291    /// Get the current decision level
1292    #[must_use]
1293    pub fn decision_level(&self) -> u32 {
1294        self.trail.decision_level()
1295    }
1296
1297    /// Debug method: print all learned clauses
1298    pub fn debug_print_learned_clauses(&self) {
1299        println!(
1300            "=== Learned Clauses ({}) ===",
1301            self.learned_clause_ids.len()
1302        );
1303        for (i, &cid) in self.learned_clause_ids.iter().enumerate() {
1304            if let Some(clause) = self.clauses.get(cid)
1305                && !clause.deleted
1306            {
1307                let lits: Vec<String> = clause
1308                    .lits
1309                    .iter()
1310                    .map(|lit| {
1311                        let var = lit.var().index();
1312                        if lit.is_pos() {
1313                            format!("v{}", var)
1314                        } else {
1315                            format!("~v{}", var)
1316                        }
1317                    })
1318                    .collect();
1319                println!(
1320                    "  Learned {}: ({}), LBD={}",
1321                    i,
1322                    lits.join(" | "),
1323                    clause.lbd
1324                );
1325            }
1326        }
1327    }
1328
1329    /// Debug method: print binary implication graph entries
1330    pub fn debug_print_binary_graph(&self) {
1331        println!("=== Binary Implication Graph ===");
1332        for lit_code in 0..(self.num_vars * 2) {
1333            let lit = Lit::from_code(lit_code as u32);
1334            let implications = self.binary_graph.get(lit);
1335            if !implications.is_empty() {
1336                let lit_str = if lit.is_pos() {
1337                    format!("v{}", lit.var().index())
1338                } else {
1339                    format!("~v{}", lit.var().index())
1340                };
1341                for &(implied, _cid) in implications {
1342                    let impl_str = if implied.is_pos() {
1343                        format!("v{}", implied.var().index())
1344                    } else {
1345                        format!("~v{}", implied.var().index())
1346                    };
1347                    println!("  {} -> {}", lit_str, impl_str);
1348                }
1349            }
1350        }
1351    }
1352}
1353
1354#[cfg(test)]
1355mod tests {
1356    use super::*;
1357
1358    #[test]
1359    fn test_empty_sat() {
1360        let mut solver = Solver::new();
1361        assert_eq!(solver.solve(), SolverResult::Sat);
1362    }
1363
1364    #[test]
1365    fn test_simple_sat() {
1366        let mut solver = Solver::new();
1367        let _x = solver.new_var();
1368        let _y = solver.new_var();
1369
1370        // x or y
1371        solver.add_clause_dimacs(&[1, 2]);
1372        // not x or y
1373        solver.add_clause_dimacs(&[-1, 2]);
1374
1375        assert_eq!(solver.solve(), SolverResult::Sat);
1376        assert!(solver.model_value(Var::new(1)).is_true()); // y must be true
1377    }
1378
1379    #[test]
1380    fn test_simple_unsat() {
1381        let mut solver = Solver::new();
1382        let _x = solver.new_var();
1383
1384        // x
1385        solver.add_clause_dimacs(&[1]);
1386        // not x
1387        solver.add_clause_dimacs(&[-1]);
1388
1389        assert_eq!(solver.solve(), SolverResult::Unsat);
1390    }
1391
1392    #[test]
1393    fn test_pigeonhole_2_1() {
1394        // 2 pigeons, 1 hole - UNSAT
1395        let mut solver = Solver::new();
1396        let _p1h1 = solver.new_var(); // pigeon 1 in hole 1
1397        let _p2h1 = solver.new_var(); // pigeon 2 in hole 1
1398
1399        // Each pigeon must be in some hole
1400        solver.add_clause_dimacs(&[1]); // p1 in h1
1401        solver.add_clause_dimacs(&[2]); // p2 in h1
1402
1403        // No hole can have two pigeons
1404        solver.add_clause_dimacs(&[-1, -2]); // not (p1h1 and p2h1)
1405
1406        assert_eq!(solver.solve(), SolverResult::Unsat);
1407    }
1408
1409    #[test]
1410    fn test_3sat_random() {
1411        let mut solver = Solver::new();
1412        for _ in 0..10 {
1413            solver.new_var();
1414        }
1415
1416        // Random 3-SAT instance (likely SAT)
1417        solver.add_clause_dimacs(&[1, 2, 3]);
1418        solver.add_clause_dimacs(&[-1, 4, 5]);
1419        solver.add_clause_dimacs(&[2, -3, 6]);
1420        solver.add_clause_dimacs(&[-4, 7, 8]);
1421        solver.add_clause_dimacs(&[5, -6, 9]);
1422        solver.add_clause_dimacs(&[-7, 8, 10]);
1423        solver.add_clause_dimacs(&[1, -8, -9]);
1424        solver.add_clause_dimacs(&[-2, 3, -10]);
1425
1426        let result = solver.solve();
1427        assert_eq!(result, SolverResult::Sat);
1428    }
1429
1430    #[test]
1431    fn test_luby_sequence() {
1432        // Luby sequence: 1, 1, 2, 1, 1, 2, 4, 1, 1, 2, 1, 1, 2, 4, 8, ...
1433        assert_eq!(Solver::luby(0), 1);
1434        assert_eq!(Solver::luby(1), 1);
1435        assert_eq!(Solver::luby(2), 2);
1436        assert_eq!(Solver::luby(3), 1);
1437        assert_eq!(Solver::luby(4), 1);
1438        assert_eq!(Solver::luby(5), 2);
1439        assert_eq!(Solver::luby(6), 4);
1440        assert_eq!(Solver::luby(7), 1);
1441    }
1442
1443    #[test]
1444    fn test_phase_saving() {
1445        let mut solver = Solver::new();
1446        for _ in 0..5 {
1447            solver.new_var();
1448        }
1449
1450        // Set up a problem where phase saving helps
1451        solver.add_clause_dimacs(&[1, 2]);
1452        solver.add_clause_dimacs(&[-1, 3]);
1453        solver.add_clause_dimacs(&[-2, 4]);
1454        solver.add_clause_dimacs(&[-3, -4, 5]);
1455        solver.add_clause_dimacs(&[-5, 1]);
1456
1457        let result = solver.solve();
1458        assert_eq!(result, SolverResult::Sat);
1459    }
1460
1461    #[test]
1462    fn test_lbd_computation() {
1463        // Test that clause deletion can handle a problem that generates learned clauses
1464        let mut solver = Solver::with_config(SolverConfig {
1465            clause_deletion_threshold: 5, // Trigger deletion quickly
1466            ..SolverConfig::default()
1467        });
1468
1469        for _ in 0..20 {
1470            solver.new_var();
1471        }
1472
1473        // A harder problem to generate more conflicts and learned clauses
1474        // PHP(3,2): 3 pigeons, 2 holes - UNSAT
1475        // Variables: p_i_h (pigeon i in hole h)
1476        // p11=1, p12=2, p21=3, p22=4, p31=5, p32=6
1477
1478        // Each pigeon must be in some hole
1479        solver.add_clause_dimacs(&[1, 2]); // p1 in h1 or h2
1480        solver.add_clause_dimacs(&[3, 4]); // p2 in h1 or h2
1481        solver.add_clause_dimacs(&[5, 6]); // p3 in h1 or h2
1482
1483        // No hole can have two pigeons
1484        solver.add_clause_dimacs(&[-1, -3]); // not (p1h1 and p2h1)
1485        solver.add_clause_dimacs(&[-1, -5]); // not (p1h1 and p3h1)
1486        solver.add_clause_dimacs(&[-3, -5]); // not (p2h1 and p3h1)
1487        solver.add_clause_dimacs(&[-2, -4]); // not (p1h2 and p2h2)
1488        solver.add_clause_dimacs(&[-2, -6]); // not (p1h2 and p3h2)
1489        solver.add_clause_dimacs(&[-4, -6]); // not (p2h2 and p3h2)
1490
1491        let result = solver.solve();
1492        assert_eq!(result, SolverResult::Unsat);
1493        // Verify we had some conflicts (and thus learned clauses)
1494        assert!(solver.stats().conflicts > 0);
1495    }
1496
1497    #[test]
1498    fn test_clause_activity_decay() {
1499        let mut solver = Solver::new();
1500        for _ in 0..10 {
1501            solver.new_var();
1502        }
1503
1504        // Add some clauses
1505        solver.add_clause_dimacs(&[1, 2, 3]);
1506        solver.add_clause_dimacs(&[-1, 4, 5]);
1507        solver.add_clause_dimacs(&[-2, -3, 6]);
1508
1509        // Solve (should be SAT)
1510        let result = solver.solve();
1511        assert_eq!(result, SolverResult::Sat);
1512    }
1513
1514    #[test]
1515    fn test_clause_minimization() {
1516        // Test that clause minimization works correctly on a problem
1517        // that will generate learned clauses
1518        let mut solver = Solver::new();
1519
1520        for _ in 0..15 {
1521            solver.new_var();
1522        }
1523
1524        // A problem structure that generates conflicts and learned clauses
1525        // Graph coloring with 3 colors on 5 vertices
1526        // Vertices: 1-5, Colors: R(0-4), G(5-9), B(10-14)
1527
1528        // Each vertex has at least one color
1529        solver.add_clause_dimacs(&[1, 6, 11]); // v1: R or G or B
1530        solver.add_clause_dimacs(&[2, 7, 12]); // v2
1531        solver.add_clause_dimacs(&[3, 8, 13]); // v3
1532        solver.add_clause_dimacs(&[4, 9, 14]); // v4
1533        solver.add_clause_dimacs(&[5, 10, 15]); // v5
1534
1535        // At most one color per vertex (pairwise exclusion)
1536        solver.add_clause_dimacs(&[-1, -6]); // v1: not (R and G)
1537        solver.add_clause_dimacs(&[-1, -11]); // v1: not (R and B)
1538        solver.add_clause_dimacs(&[-6, -11]); // v1: not (G and B)
1539
1540        solver.add_clause_dimacs(&[-2, -7]);
1541        solver.add_clause_dimacs(&[-2, -12]);
1542        solver.add_clause_dimacs(&[-7, -12]);
1543
1544        solver.add_clause_dimacs(&[-3, -8]);
1545        solver.add_clause_dimacs(&[-3, -13]);
1546        solver.add_clause_dimacs(&[-8, -13]);
1547
1548        // Adjacent vertices have different colors (edges: 1-2, 2-3, 3-4, 4-5)
1549        solver.add_clause_dimacs(&[-1, -2]); // edge 1-2: not both R
1550        solver.add_clause_dimacs(&[-6, -7]); // edge 1-2: not both G
1551        solver.add_clause_dimacs(&[-11, -12]); // edge 1-2: not both B
1552
1553        solver.add_clause_dimacs(&[-2, -3]); // edge 2-3
1554        solver.add_clause_dimacs(&[-7, -8]);
1555        solver.add_clause_dimacs(&[-12, -13]);
1556
1557        let result = solver.solve();
1558        assert_eq!(result, SolverResult::Sat);
1559
1560        // The solver may or may not have conflicts/learned clauses depending on
1561        // the decision heuristic. The key is that the result is correct.
1562        // If there are learned clauses, minimization would have been applied.
1563    }
1564
1565    /// A simple theory callback that does nothing (pure SAT)
1566    struct NullTheory;
1567
1568    impl TheoryCallback for NullTheory {
1569        fn on_assignment(&mut self, _lit: Lit) -> TheoryCheckResult {
1570            TheoryCheckResult::Sat
1571        }
1572
1573        fn final_check(&mut self) -> TheoryCheckResult {
1574            TheoryCheckResult::Sat
1575        }
1576
1577        fn on_backtrack(&mut self, _level: u32) {}
1578    }
1579
1580    #[test]
1581    fn test_solve_with_theory_sat() {
1582        let mut solver = Solver::new();
1583        let mut theory = NullTheory;
1584
1585        let _x = solver.new_var();
1586        let _y = solver.new_var();
1587
1588        // x or y
1589        solver.add_clause_dimacs(&[1, 2]);
1590        // not x or y
1591        solver.add_clause_dimacs(&[-1, 2]);
1592
1593        assert_eq!(solver.solve_with_theory(&mut theory), SolverResult::Sat);
1594        assert!(solver.model_value(Var::new(1)).is_true()); // y must be true
1595    }
1596
1597    #[test]
1598    fn test_solve_with_theory_unsat() {
1599        let mut solver = Solver::new();
1600        let mut theory = NullTheory;
1601
1602        let _x = solver.new_var();
1603
1604        // x
1605        solver.add_clause_dimacs(&[1]);
1606        // not x
1607        solver.add_clause_dimacs(&[-1]);
1608
1609        assert_eq!(solver.solve_with_theory(&mut theory), SolverResult::Unsat);
1610    }
1611
1612    /// A theory that forces x0 => x1 (if x0 is true, x1 must be true)
1613    struct ImplicationTheory {
1614        /// Track if x0 is assigned true
1615        x0_true: bool,
1616    }
1617
1618    impl ImplicationTheory {
1619        fn new() -> Self {
1620            Self { x0_true: false }
1621        }
1622    }
1623
1624    impl TheoryCallback for ImplicationTheory {
1625        fn on_assignment(&mut self, lit: Lit) -> TheoryCheckResult {
1626            // If x0 becomes true, propagate x1
1627            if lit.var().index() == 0 && lit.is_pos() {
1628                self.x0_true = true;
1629                // Propagate: x1 must be true because x0 is true
1630                // The reason is: ~x0 (if x0 were false, we wouldn't need x1)
1631                let reason: SmallVec<[Lit; 8]> = smallvec::smallvec![Lit::pos(Var::new(0))];
1632                return TheoryCheckResult::Propagated(vec![(Lit::pos(Var::new(1)), reason)]);
1633            }
1634            TheoryCheckResult::Sat
1635        }
1636
1637        fn final_check(&mut self) -> TheoryCheckResult {
1638            TheoryCheckResult::Sat
1639        }
1640
1641        fn on_backtrack(&mut self, _level: u32) {
1642            self.x0_true = false;
1643        }
1644    }
1645
1646    #[test]
1647    fn test_theory_propagation() {
1648        let mut solver = Solver::new();
1649        let mut theory = ImplicationTheory::new();
1650
1651        let _x0 = solver.new_var();
1652        let _x1 = solver.new_var();
1653
1654        // Force x0 to be true
1655        solver.add_clause_dimacs(&[1]);
1656
1657        let result = solver.solve_with_theory(&mut theory);
1658        assert_eq!(result, SolverResult::Sat);
1659
1660        // x0 should be true (forced by clause)
1661        assert!(solver.model_value(Var::new(0)).is_true());
1662        // x1 should also be true (propagated by theory)
1663        assert!(solver.model_value(Var::new(1)).is_true());
1664    }
1665
1666    /// Theory that says x0 and x1 can't both be true
1667    struct MutexTheory {
1668        x0_true: Option<Lit>,
1669        x1_true: Option<Lit>,
1670    }
1671
1672    impl MutexTheory {
1673        fn new() -> Self {
1674            Self {
1675                x0_true: None,
1676                x1_true: None,
1677            }
1678        }
1679    }
1680
1681    impl TheoryCallback for MutexTheory {
1682        fn on_assignment(&mut self, lit: Lit) -> TheoryCheckResult {
1683            if lit.var().index() == 0 && lit.is_pos() {
1684                self.x0_true = Some(lit);
1685            }
1686            if lit.var().index() == 1 && lit.is_pos() {
1687                self.x1_true = Some(lit);
1688            }
1689
1690            // If both are true, conflict
1691            if self.x0_true.is_some() && self.x1_true.is_some() {
1692                // Conflict clause: ~x0 or ~x1 (at least one must be false)
1693                let conflict: SmallVec<[Lit; 8]> = smallvec::smallvec![
1694                    Lit::pos(Var::new(0)), // x0 is true (we negate in conflict)
1695                    Lit::pos(Var::new(1))  // x1 is true
1696                ];
1697                return TheoryCheckResult::Conflict(conflict);
1698            }
1699            TheoryCheckResult::Sat
1700        }
1701
1702        fn final_check(&mut self) -> TheoryCheckResult {
1703            if self.x0_true.is_some() && self.x1_true.is_some() {
1704                let conflict: SmallVec<[Lit; 8]> =
1705                    smallvec::smallvec![Lit::pos(Var::new(0)), Lit::pos(Var::new(1))];
1706                return TheoryCheckResult::Conflict(conflict);
1707            }
1708            TheoryCheckResult::Sat
1709        }
1710
1711        fn on_backtrack(&mut self, _level: u32) {
1712            self.x0_true = None;
1713            self.x1_true = None;
1714        }
1715    }
1716
1717    #[test]
1718    fn test_theory_conflict() {
1719        let mut solver = Solver::new();
1720        let mut theory = MutexTheory::new();
1721
1722        let _x0 = solver.new_var();
1723        let _x1 = solver.new_var();
1724
1725        // Force both x0 and x1 to be true (should cause theory conflict)
1726        solver.add_clause_dimacs(&[1]);
1727        solver.add_clause_dimacs(&[2]);
1728
1729        let result = solver.solve_with_theory(&mut theory);
1730        assert_eq!(result, SolverResult::Unsat);
1731    }
1732
1733    #[test]
1734    fn test_solve_with_assumptions_sat() {
1735        let mut solver = Solver::new();
1736
1737        let x0 = solver.new_var();
1738        let x1 = solver.new_var();
1739
1740        // x0 \/ x1
1741        solver.add_clause([Lit::pos(x0), Lit::pos(x1)]);
1742
1743        // Assume x0 = true
1744        let assumptions = [Lit::pos(x0)];
1745        let (result, core) = solver.solve_with_assumptions(&assumptions);
1746
1747        assert_eq!(result, SolverResult::Sat);
1748        assert!(core.is_none());
1749    }
1750
1751    #[test]
1752    fn test_solve_with_assumptions_unsat() {
1753        let mut solver = Solver::new();
1754
1755        let x0 = solver.new_var();
1756        let x1 = solver.new_var();
1757
1758        // x0 -> ~x1 (encoded as ~x0 \/ ~x1)
1759        solver.add_clause([Lit::neg(x0), Lit::neg(x1)]);
1760
1761        // Assume both x0 = true and x1 = true (should be UNSAT)
1762        let assumptions = [Lit::pos(x0), Lit::pos(x1)];
1763        let (result, core) = solver.solve_with_assumptions(&assumptions);
1764
1765        assert_eq!(result, SolverResult::Unsat);
1766        assert!(core.is_some());
1767        let core = core.expect("UNSAT result must have conflict core");
1768        // Core should contain at least one of the conflicting assumptions
1769        assert!(!core.is_empty());
1770    }
1771
1772    #[test]
1773    fn test_solve_with_assumptions_core_extraction() {
1774        let mut solver = Solver::new();
1775
1776        let x0 = solver.new_var();
1777        let x1 = solver.new_var();
1778        let x2 = solver.new_var();
1779
1780        // ~x0 (x0 must be false)
1781        solver.add_clause([Lit::neg(x0)]);
1782
1783        // Assume x0 = true, x1 = true, x2 = true
1784        // Only x0 should be in the core
1785        let assumptions = [Lit::pos(x0), Lit::pos(x1), Lit::pos(x2)];
1786        let (result, core) = solver.solve_with_assumptions(&assumptions);
1787
1788        assert_eq!(result, SolverResult::Unsat);
1789        assert!(core.is_some());
1790        let core = core.expect("UNSAT result must have conflict core");
1791        // x0 should be in the core
1792        assert!(core.contains(&Lit::pos(x0)));
1793    }
1794
1795    #[test]
1796    fn test_solve_with_assumptions_incremental() {
1797        let mut solver = Solver::new();
1798
1799        let x0 = solver.new_var();
1800        let x1 = solver.new_var();
1801
1802        // x0 \/ x1
1803        solver.add_clause([Lit::pos(x0), Lit::pos(x1)]);
1804
1805        // First: assume ~x0 (should be SAT with x1 = true)
1806        let (result1, _) = solver.solve_with_assumptions(&[Lit::neg(x0)]);
1807        assert_eq!(result1, SolverResult::Sat);
1808
1809        // Second: assume ~x0 and ~x1 (should be UNSAT)
1810        let (result2, core2) = solver.solve_with_assumptions(&[Lit::neg(x0), Lit::neg(x1)]);
1811        assert_eq!(result2, SolverResult::Unsat);
1812        assert!(core2.is_some());
1813
1814        // Third: assume x0 (should be SAT again)
1815        let (result3, _) = solver.solve_with_assumptions(&[Lit::pos(x0)]);
1816        assert_eq!(result3, SolverResult::Sat);
1817    }
1818
1819    #[test]
1820    fn test_push_pop_simple() {
1821        let mut solver = Solver::new();
1822
1823        let x0 = solver.new_var();
1824
1825        // Should be SAT (x0 can be true or false)
1826        assert_eq!(solver.solve(), SolverResult::Sat);
1827
1828        // Push and add unit clause: x0
1829        solver.push();
1830        solver.add_clause([Lit::pos(x0)]);
1831        assert_eq!(solver.solve(), SolverResult::Sat);
1832        assert!(solver.model_value(x0).is_true());
1833
1834        // Pop - should be SAT again
1835        solver.pop();
1836        let result = solver.solve();
1837        assert_eq!(
1838            result,
1839            SolverResult::Sat,
1840            "After pop, expected SAT but got {:?}. trivially_unsat={}",
1841            result,
1842            solver.trivially_unsat
1843        );
1844    }
1845
1846    #[test]
1847    fn test_push_pop_incremental() {
1848        let mut solver = Solver::new();
1849
1850        let x0 = solver.new_var();
1851        let x1 = solver.new_var();
1852        let x2 = solver.new_var();
1853
1854        // Base level: x0 \/ x1
1855        solver.add_clause([Lit::pos(x0), Lit::pos(x1)]);
1856        assert_eq!(solver.solve(), SolverResult::Sat);
1857
1858        // Push and add: ~x0
1859        solver.push();
1860        solver.add_clause([Lit::neg(x0)]);
1861        assert_eq!(solver.solve(), SolverResult::Sat);
1862        // x1 must be true
1863        assert!(solver.model_value(x1).is_true());
1864
1865        // Push again and add: ~x1 (should be UNSAT)
1866        solver.push();
1867        solver.add_clause([Lit::neg(x1)]);
1868        assert_eq!(solver.solve(), SolverResult::Unsat);
1869
1870        // Pop back one level (remove ~x1, keep ~x0)
1871        solver.pop();
1872        assert_eq!(solver.solve(), SolverResult::Sat);
1873        assert!(solver.model_value(x1).is_true());
1874
1875        // Pop back to base level (remove ~x0)
1876        solver.pop();
1877        assert_eq!(solver.solve(), SolverResult::Sat);
1878        // Either x0 or x1 can be true now
1879
1880        // Push and add different clause: x0 /\ x2
1881        solver.push();
1882        solver.add_clause([Lit::pos(x0)]);
1883        solver.add_clause([Lit::pos(x2)]);
1884        assert_eq!(solver.solve(), SolverResult::Sat);
1885        assert!(solver.model_value(x0).is_true());
1886        assert!(solver.model_value(x2).is_true());
1887
1888        // Pop and verify clauses are removed
1889        solver.pop();
1890        assert_eq!(solver.solve(), SolverResult::Sat);
1891    }
1892
1893    #[test]
1894    fn test_push_pop_with_learned_clauses() {
1895        let mut solver = Solver::new();
1896
1897        let x0 = solver.new_var();
1898        let x1 = solver.new_var();
1899        let x2 = solver.new_var();
1900
1901        // Create a formula that will cause learning
1902        // (x0 \/ x1) /\ (~x0 \/ x2) /\ (~x1 \/ x2)
1903        solver.add_clause([Lit::pos(x0), Lit::pos(x1)]);
1904        solver.add_clause([Lit::neg(x0), Lit::pos(x2)]);
1905        solver.add_clause([Lit::neg(x1), Lit::pos(x2)]);
1906
1907        assert_eq!(solver.solve(), SolverResult::Sat);
1908
1909        // Push and add conflicting clause
1910        solver.push();
1911        solver.add_clause([Lit::neg(x2)]);
1912
1913        // This should be UNSAT and cause clause learning
1914        assert_eq!(solver.solve(), SolverResult::Unsat);
1915
1916        // Pop - learned clauses from this level should be removed
1917        solver.pop();
1918
1919        // Should be SAT again
1920        assert_eq!(solver.solve(), SolverResult::Sat);
1921    }
1922}