Skip to main content

oxiz_sat/solver/
mod.rs

1//! CDCL SAT Solver
2
3mod conflict;
4mod decide;
5mod learn;
6mod propagate;
7
8use crate::chb::CHB;
9use crate::chrono::ChronoBacktrack;
10use crate::clause::{ClauseDatabase, ClauseId};
11use crate::literal::{LBool, Lit, Var};
12use crate::lrb::LRB;
13use crate::memory_opt::{MemoryAction, MemoryOptimizer};
14#[allow(unused_imports)]
15use crate::prelude::*;
16use crate::trail::{Reason, Trail};
17use crate::vsids::VSIDS;
18use crate::watched::{WatchLists, Watcher};
19use smallvec::SmallVec;
20
21/// Binary implication graph for efficient binary clause propagation
22/// For each literal L, stores the list of literals that are implied when L is false
23/// (i.e., for binary clause (~L v M), when L is assigned false, M must be true)
24#[derive(Debug, Clone)]
25pub(super) struct BinaryImplicationGraph {
26    /// implications[lit] = list of (implied_lit, clause_id) pairs
27    implications: Vec<Vec<(Lit, ClauseId)>>,
28}
29
30impl BinaryImplicationGraph {
31    fn new(num_vars: usize) -> Self {
32        Self {
33            implications: vec![Vec::new(); num_vars * 2],
34        }
35    }
36
37    fn resize(&mut self, num_vars: usize) {
38        self.implications.resize(num_vars * 2, Vec::new());
39    }
40
41    fn add(&mut self, lit: Lit, implied: Lit, clause_id: ClauseId) {
42        self.implications[lit.code() as usize].push((implied, clause_id));
43    }
44
45    fn get(&self, lit: Lit) -> &[(Lit, ClauseId)] {
46        &self.implications[lit.code() as usize]
47    }
48
49    fn clear(&mut self) {
50        for implications in &mut self.implications {
51            implications.clear();
52        }
53    }
54}
55
56/// Result from a theory check
57#[derive(Debug, Clone)]
58pub enum TheoryCheckResult {
59    /// Theory is satisfied under current assignment
60    Sat,
61    /// Theory detected a conflict, returns conflict clause literals
62    Conflict(SmallVec<[Lit; 8]>),
63    /// Theory propagated new literals (lit, reason clause)
64    Propagated(Vec<(Lit, SmallVec<[Lit; 8]>)>),
65}
66
67/// Callback trait for theory solvers
68/// The CDCL(T) solver implements this to receive theory callbacks
69pub trait TheoryCallback {
70    /// Called when a literal is assigned
71    /// Returns a theory check result
72    fn on_assignment(&mut self, lit: Lit) -> TheoryCheckResult;
73
74    /// Called after propagation is complete to do a full theory check
75    fn final_check(&mut self) -> TheoryCheckResult;
76
77    /// Called when the decision level increases
78    fn on_new_level(&mut self, _level: u32) {}
79
80    /// Called when backtracking
81    fn on_backtrack(&mut self, level: u32);
82}
83
84/// Result of SAT solving
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum SolverResult {
87    /// Satisfiable
88    Sat,
89    /// Unsatisfiable
90    Unsat,
91    /// Unknown (e.g., timeout, resource limit)
92    Unknown,
93}
94
95/// Solver configuration
96#[derive(Debug, Clone)]
97pub struct SolverConfig {
98    /// Restart interval (number of conflicts)
99    pub restart_interval: u64,
100    /// Restart multiplier for geometric restarts
101    pub restart_multiplier: f64,
102    /// Clause deletion threshold
103    pub clause_deletion_threshold: usize,
104    /// Variable decay factor
105    pub var_decay: f64,
106    /// Clause decay factor
107    pub clause_decay: f64,
108    /// Random polarity probability (0.0 to 1.0)
109    pub random_polarity_prob: f64,
110    /// Restart strategy: "luby" or "geometric"
111    pub restart_strategy: RestartStrategy,
112    /// Enable lazy hyper-binary resolution
113    pub enable_lazy_hyper_binary: bool,
114    /// Use CHB instead of VSIDS for branching
115    pub use_chb_branching: bool,
116    /// Use LRB (Learning Rate Branching) for branching
117    pub use_lrb_branching: bool,
118    /// Enable inprocessing (periodic preprocessing during search)
119    pub enable_inprocessing: bool,
120    /// Inprocessing interval (number of conflicts between inprocessing)
121    pub inprocessing_interval: u64,
122    /// Enable chronological backtracking
123    pub enable_chronological_backtrack: bool,
124    /// Chronological backtracking threshold (max distance from assertion level)
125    pub chrono_backtrack_threshold: u32,
126}
127
128/// Restart strategy
129#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum RestartStrategy {
131    /// Luby sequence restarts
132    Luby,
133    /// Geometric restarts
134    Geometric,
135    /// Glucose-style dynamic restarts based on LBD
136    Glucose,
137    /// Local restarts based on LBD trail
138    LocalLbd,
139}
140
141impl Default for SolverConfig {
142    fn default() -> Self {
143        Self {
144            restart_interval: 100,
145            restart_multiplier: 1.5,
146            clause_deletion_threshold: 10000,
147            var_decay: 0.95,
148            clause_decay: 0.999,
149            random_polarity_prob: 0.02,
150            restart_strategy: RestartStrategy::Luby,
151            enable_lazy_hyper_binary: true,
152            use_chb_branching: false,
153            use_lrb_branching: false,
154            enable_inprocessing: false,
155            inprocessing_interval: 5000,
156            enable_chronological_backtrack: true,
157            chrono_backtrack_threshold: 100,
158        }
159    }
160}
161
162/// Statistics for the solver
163#[derive(Debug, Default, Clone)]
164pub struct SolverStats {
165    /// Number of decisions made
166    pub decisions: u64,
167    /// Number of propagations
168    pub propagations: u64,
169    /// Number of conflicts
170    pub conflicts: u64,
171    /// Number of restarts
172    pub restarts: u64,
173    /// Number of learned clauses
174    pub learned_clauses: u64,
175    /// Number of deleted clauses
176    pub deleted_clauses: u64,
177    /// Number of binary clauses learned
178    pub binary_clauses: u64,
179    /// Number of unit clauses learned
180    pub unit_clauses: u64,
181    /// Total LBD of learned clauses
182    pub total_lbd: u64,
183    /// Number of clause minimizations
184    pub minimizations: u64,
185    /// Literals removed by minimization
186    pub literals_removed: u64,
187    /// Number of chronological backtracks
188    pub chrono_backtracks: u64,
189    /// Number of non-chronological backtracks
190    pub non_chrono_backtracks: u64,
191}
192
193impl SolverStats {
194    /// Get average LBD of learned clauses
195    #[must_use]
196    pub fn avg_lbd(&self) -> f64 {
197        if self.learned_clauses == 0 {
198            0.0
199        } else {
200            self.total_lbd as f64 / self.learned_clauses as f64
201        }
202    }
203
204    /// Get average decisions per conflict
205    #[must_use]
206    pub fn avg_decisions_per_conflict(&self) -> f64 {
207        if self.conflicts == 0 {
208            0.0
209        } else {
210            self.decisions as f64 / self.conflicts as f64
211        }
212    }
213
214    /// Get propagations per conflict
215    #[must_use]
216    pub fn propagations_per_conflict(&self) -> f64 {
217        if self.conflicts == 0 {
218            0.0
219        } else {
220            self.propagations as f64 / self.conflicts as f64
221        }
222    }
223
224    /// Get clause deletion ratio
225    #[must_use]
226    pub fn deletion_ratio(&self) -> f64 {
227        if self.learned_clauses == 0 {
228            0.0
229        } else {
230            self.deleted_clauses as f64 / self.learned_clauses as f64
231        }
232    }
233
234    /// Get chronological backtrack ratio
235    #[must_use]
236    pub fn chrono_backtrack_ratio(&self) -> f64 {
237        let total = self.chrono_backtracks + self.non_chrono_backtracks;
238        if total == 0 {
239            0.0
240        } else {
241            self.chrono_backtracks as f64 / total as f64
242        }
243    }
244
245    /// Display formatted statistics
246    pub fn display(&self) {
247        println!("========== Solver Statistics ==========");
248        println!("Decisions:              {:>12}", self.decisions);
249        println!("Propagations:           {:>12}", self.propagations);
250        println!("Conflicts:              {:>12}", self.conflicts);
251        println!("Restarts:               {:>12}", self.restarts);
252        println!("Learned clauses:        {:>12}", self.learned_clauses);
253        println!("  - Unit clauses:       {:>12}", self.unit_clauses);
254        println!("  - Binary clauses:     {:>12}", self.binary_clauses);
255        println!("Deleted clauses:        {:>12}", self.deleted_clauses);
256        println!("Minimizations:          {:>12}", self.minimizations);
257        println!("Literals removed:       {:>12}", self.literals_removed);
258        println!("Chrono backtracks:      {:>12}", self.chrono_backtracks);
259        println!("Non-chrono backtracks:  {:>12}", self.non_chrono_backtracks);
260        println!("---------------------------------------");
261        println!("Avg LBD:                {:>12.2}", self.avg_lbd());
262        println!(
263            "Avg decisions/conflict: {:>12.2}",
264            self.avg_decisions_per_conflict()
265        );
266        println!(
267            "Propagations/conflict:  {:>12.2}",
268            self.propagations_per_conflict()
269        );
270        println!(
271            "Deletion ratio:         {:>12.2}%",
272            self.deletion_ratio() * 100.0
273        );
274        println!(
275            "Chrono backtrack ratio: {:>12.2}%",
276            self.chrono_backtrack_ratio() * 100.0
277        );
278        println!("=======================================");
279    }
280}
281
282/// CDCL SAT Solver
283#[derive(Debug)]
284pub struct Solver {
285    /// Configuration
286    pub(super) config: SolverConfig,
287    /// Number of variables
288    pub(super) num_vars: usize,
289    /// Clause database
290    pub(super) clauses: ClauseDatabase,
291    /// Assignment trail
292    pub(super) trail: Trail,
293    /// Watch lists
294    pub(super) watches: WatchLists,
295    /// VSIDS branching heuristic
296    pub(super) vsids: VSIDS,
297    /// CHB branching heuristic
298    pub(super) chb: CHB,
299    /// LRB branching heuristic
300    pub(super) lrb: LRB,
301    /// Statistics
302    pub(super) stats: SolverStats,
303    /// Learnt clause for conflict analysis
304    pub(super) learnt: SmallVec<[Lit; 16]>,
305    /// Seen flags for conflict analysis
306    pub(super) seen: Vec<bool>,
307    /// Analyze stack
308    pub(super) analyze_stack: Vec<Lit>,
309    /// Current restart threshold
310    pub(super) restart_threshold: u64,
311    /// Assertions stack for incremental solving (number of original clauses)
312    pub(super) assertion_levels: Vec<usize>,
313    /// Trail sizes at each assertion level (for proper pop backtracking)
314    pub(super) assertion_trail_sizes: Vec<usize>,
315    /// Clause IDs added at each assertion level (for proper pop)
316    pub(super) assertion_clause_ids: Vec<Vec<ClauseId>>,
317    /// Model (if sat)
318    pub(super) model: Vec<LBool>,
319    /// Whether formula is trivially unsatisfiable
320    pub(super) trivially_unsat: bool,
321    /// Phase saving: last polarity assigned to each variable
322    pub(super) phase: Vec<bool>,
323    /// Luby sequence index for restarts
324    pub(super) luby_index: u64,
325    /// Level marks for LBD computation
326    pub(super) level_marks: Vec<u32>,
327    /// Current mark counter for LBD computation
328    pub(super) lbd_mark: u32,
329    /// Learned clause IDs for deletion
330    pub(super) learned_clause_ids: Vec<ClauseId>,
331    /// Number of conflicts since last clause deletion
332    pub(super) conflicts_since_deletion: u64,
333    /// PRNG state (xorshift64)
334    pub(super) rng_state: u64,
335    /// For Glucose-style restarts: average LBD of recent conflicts
336    pub(super) recent_lbd_sum: u64,
337    /// Number of conflicts contributing to recent_lbd_sum
338    pub(super) recent_lbd_count: u64,
339    /// Binary implication graph for fast binary clause propagation
340    pub(super) binary_graph: BinaryImplicationGraph,
341    /// Global average LBD for local restarts
342    pub(super) global_lbd_sum: u64,
343    /// Number of conflicts contributing to global LBD
344    pub(super) global_lbd_count: u64,
345    /// Conflicts since last local restart
346    pub(super) conflicts_since_local_restart: u64,
347    /// Conflicts since last inprocessing
348    pub(super) conflicts_since_inprocessing: u64,
349    /// Chronological backtracking helper
350    pub(super) chrono_backtrack: ChronoBacktrack,
351    /// Clause activity bump increment (for MapleSAT-style clause bumping)
352    pub(super) clause_bump_increment: f64,
353    /// Memory optimizer with size-class pools for clause allocation
354    pub(super) memory_optimizer: MemoryOptimizer,
355}
356
357impl Default for Solver {
358    fn default() -> Self {
359        Self::new()
360    }
361}
362
363impl Solver {
364    /// Create a new solver
365    #[must_use]
366    pub fn new() -> Self {
367        Self::with_config(SolverConfig::default())
368    }
369
370    /// Create a new solver with configuration
371    #[must_use]
372    pub fn with_config(config: SolverConfig) -> Self {
373        let chrono_enabled = config.enable_chronological_backtrack;
374        let chrono_threshold = config.chrono_backtrack_threshold;
375
376        Self {
377            restart_threshold: config.restart_interval,
378            config,
379            num_vars: 0,
380            clauses: ClauseDatabase::new(),
381            trail: Trail::new(0),
382            watches: WatchLists::new(0),
383            vsids: VSIDS::new(0),
384            chb: CHB::new(0),
385            lrb: LRB::new(0),
386            stats: SolverStats::default(),
387            learnt: SmallVec::new(),
388            seen: Vec::new(),
389            analyze_stack: Vec::new(),
390            assertion_levels: vec![0],
391            assertion_trail_sizes: vec![0],
392            assertion_clause_ids: vec![Vec::new()],
393            model: Vec::new(),
394            trivially_unsat: false,
395            phase: Vec::new(),
396            luby_index: 0,
397            level_marks: Vec::new(),
398            lbd_mark: 0,
399            learned_clause_ids: Vec::new(),
400            conflicts_since_deletion: 0,
401            rng_state: 0x853c_49e6_748f_ea9b, // Random seed
402            recent_lbd_sum: 0,
403            recent_lbd_count: 0,
404            binary_graph: BinaryImplicationGraph::new(0),
405            global_lbd_sum: 0,
406            global_lbd_count: 0,
407            conflicts_since_local_restart: 0,
408            conflicts_since_inprocessing: 0,
409            chrono_backtrack: ChronoBacktrack::new(chrono_enabled, chrono_threshold),
410            clause_bump_increment: 1.0,
411            memory_optimizer: MemoryOptimizer::new(),
412        }
413    }
414
415    /// Create a new variable
416    pub fn new_var(&mut self) -> Var {
417        let var = Var::new(self.num_vars as u32);
418        self.num_vars += 1;
419        self.trail.resize(self.num_vars);
420        self.watches.resize(self.num_vars);
421        self.binary_graph.resize(self.num_vars);
422        self.vsids.insert(var);
423        self.chb.insert(var);
424        self.lrb.resize(self.num_vars);
425        self.seen.resize(self.num_vars, false);
426        self.model.resize(self.num_vars, LBool::Undef);
427        self.phase.resize(self.num_vars, false); // Default phase: negative
428        // Resize level_marks to at least num_vars (enough for decision levels)
429        if self.level_marks.len() < self.num_vars {
430            self.level_marks.resize(self.num_vars, 0);
431        }
432        var
433    }
434
435    /// Ensure we have at least n variables
436    pub fn ensure_vars(&mut self, n: usize) {
437        while self.num_vars < n {
438            self.new_var();
439        }
440    }
441
442    /// Add a clause
443    pub fn add_clause(&mut self, lits: impl IntoIterator<Item = Lit>) -> bool {
444        let mut clause_lits: SmallVec<[Lit; 8]> = lits.into_iter().collect();
445
446        // Ensure we have all variables
447        for lit in &clause_lits {
448            let var_idx = lit.var().index();
449            if var_idx >= self.num_vars {
450                self.ensure_vars(var_idx + 1);
451            }
452        }
453
454        // Remove duplicates and check for tautology
455        clause_lits.sort_by_key(|l| l.code());
456        clause_lits.dedup();
457
458        // Check for tautology (x and ~x in same clause)
459        for i in 0..clause_lits.len() {
460            for j in (i + 1)..clause_lits.len() {
461                if clause_lits[i] == clause_lits[j].negate() {
462                    return true; // Tautology - always satisfied
463                }
464            }
465        }
466
467        // Handle special cases
468        match clause_lits.len() {
469            0 => {
470                self.trivially_unsat = true;
471                return false; // Empty clause - unsat
472            }
473            1 => {
474                // Unit clause - enqueue at decision level 0
475                // Unit clauses must be assigned at level 0 to survive backtracking.
476                // After solve(), current_level may be > 0, so we must backtrack first.
477                let lit = clause_lits[0];
478
479                if self.trail.lit_value(lit).is_false() {
480                    // The literal conflicts with the current trail.
481                    // Check if the conflict is at decision level 0 (permanent constraint)
482                    // or from a previous solve (can be retried after backtrack).
483                    let var = lit.var();
484                    let level = self.trail.level(var);
485                    if level == 0 {
486                        // Conflict with a level-0 assignment - truly UNSAT
487                        self.trivially_unsat = true;
488                        return false;
489                    } else {
490                        // Conflict with higher-level assignment from previous solve.
491                        // Backtrack to root and assign the new unit literal at level 0.
492                        self.backtrack_to_root();
493                        self.trail.assign_decision(lit);
494                        return true;
495                    }
496                }
497
498                if self.trail.lit_value(lit).is_true() {
499                    // Already satisfied - check if at level 0
500                    let var = lit.var();
501                    let level = self.trail.level(var);
502                    if level == 0 {
503                        // Already assigned at level 0, nothing to do
504                        return true;
505                    }
506                    // Assigned at higher level - backtrack and reassign at level 0
507                    self.backtrack_to_root();
508                    self.trail.assign_decision(lit);
509                    return true;
510                }
511
512                // Variable is unassigned - backtrack to level 0 first to ensure
513                // the assignment is at level 0 (survives future backtracks)
514                if self.trail.decision_level() > 0 {
515                    self.backtrack_to_root();
516                }
517                self.trail.assign_decision(lit);
518                return true;
519            }
520            2 => {
521                // Binary clause - check if it conflicts with current assignment
522                let lit0 = clause_lits[0];
523                let lit1 = clause_lits[1];
524                let val0 = self.trail.lit_value(lit0);
525                let val1 = self.trail.lit_value(lit1);
526
527                // If clause is satisfied, just add it
528                if val0.is_true() || val1.is_true() {
529                    // Clause already satisfied by current assignment
530                    let clause_id = self.clauses.add_original(clause_lits.iter().copied());
531                    if let Some(current_level_clauses) = self.assertion_clause_ids.last_mut() {
532                        current_level_clauses.push(clause_id);
533                    }
534                    self.binary_graph.add(lit0.negate(), lit1, clause_id);
535                    self.binary_graph.add(lit1.negate(), lit0, clause_id);
536                    self.watches
537                        .add(lit0.negate(), Watcher::new(clause_id, lit1));
538                    self.watches
539                        .add(lit1.negate(), Watcher::new(clause_id, lit0));
540                    return true;
541                }
542
543                // If both literals are false, we have a conflict
544                if val0.is_false() && val1.is_false() {
545                    // Check if both are at level 0
546                    let level0 = self.trail.level(lit0.var());
547                    let level1 = self.trail.level(lit1.var());
548
549                    if level0 == 0 && level1 == 0 {
550                        // Conflict at level 0 - UNSAT
551                        self.trivially_unsat = true;
552                        return false;
553                    }
554
555                    // Backtrack to level 0 and add clause
556                    // The clause will be propagated on next solve()
557                    self.backtrack_to_root();
558                }
559
560                // If one literal is false and one undefined, propagate
561                // after adding the clause (via next solve())
562
563                let clause_id = self.clauses.add_original(clause_lits.iter().copied());
564                if let Some(current_level_clauses) = self.assertion_clause_ids.last_mut() {
565                    current_level_clauses.push(clause_id);
566                }
567                self.binary_graph.add(lit0.negate(), lit1, clause_id);
568                self.binary_graph.add(lit1.negate(), lit0, clause_id);
569                self.watches
570                    .add(lit0.negate(), Watcher::new(clause_id, lit1));
571                self.watches
572                    .add(lit1.negate(), Watcher::new(clause_id, lit0));
573                return true;
574            }
575            _ => {}
576        }
577
578        // Add clause (3+ literals)
579        // Check if clause is satisfied or conflicts with current assignment
580        let num_false = clause_lits
581            .iter()
582            .filter(|&l| self.trail.lit_value(*l).is_false())
583            .count();
584        let has_true = clause_lits
585            .iter()
586            .any(|l| self.trail.lit_value(*l).is_true());
587
588        if !has_true && num_false == clause_lits.len() {
589            // All literals are false - conflict
590            // Check if all at level 0
591            let all_at_zero = clause_lits.iter().all(|l| self.trail.level(l.var()) == 0);
592            if all_at_zero {
593                self.trivially_unsat = true;
594                return false;
595            }
596            // Backtrack to level 0
597            self.backtrack_to_root();
598        }
599
600        let clause_id = self.clauses.add_original(clause_lits.iter().copied());
601
602        // Track clause for incremental solving
603        if let Some(current_level_clauses) = self.assertion_clause_ids.last_mut() {
604            current_level_clauses.push(clause_id);
605        }
606
607        // Set up watches - prefer non-false literals for watching
608        let lit0 = clause_lits[0];
609        let lit1 = clause_lits[1];
610
611        self.watches
612            .add(lit0.negate(), Watcher::new(clause_id, lit1));
613        self.watches
614            .add(lit1.negate(), Watcher::new(clause_id, lit0));
615
616        true
617    }
618
619    /// Add a clause from DIMACS literals
620    pub fn add_clause_dimacs(&mut self, lits: &[i32]) -> bool {
621        self.add_clause(lits.iter().map(|&l| Lit::from_dimacs(l)))
622    }
623
624    /// Solve the SAT problem
625    pub fn solve(&mut self) -> SolverResult {
626        // Check if trivially unsatisfiable
627        if self.trivially_unsat {
628            return SolverResult::Unsat;
629        }
630
631        // Initial propagation
632        if self.propagate().is_some() {
633            return SolverResult::Unsat;
634        }
635
636        loop {
637            // Propagate
638            if let Some(conflict) = self.propagate() {
639                self.stats.conflicts += 1;
640                self.conflicts_since_inprocessing += 1;
641
642                if self.trail.decision_level() == 0 {
643                    return SolverResult::Unsat;
644                }
645
646                // Analyze conflict
647                let (backtrack_level, learnt_clause) = self.analyze(conflict);
648
649                // Backtrack with phase saving
650                self.backtrack_with_phase_saving(backtrack_level);
651
652                // Learn clause
653                if learnt_clause.len() == 1 {
654                    // Store unit learned clause in database for persistence
655                    let clause_id = self.clauses.add_learned(learnt_clause.iter().copied());
656                    self.stats.learned_clauses += 1;
657                    self.stats.unit_clauses += 1;
658                    self.learned_clause_ids.push(clause_id);
659
660                    // Track for incremental solving
661                    if let Some(current_level_clauses) = self.assertion_clause_ids.last_mut() {
662                        current_level_clauses.push(clause_id);
663                    }
664
665                    self.trail.assign_decision(learnt_clause[0]);
666                } else {
667                    // Compute LBD for the learned clause
668                    let lbd = self.compute_lbd(&learnt_clause);
669
670                    // Track recent LBD for Glucose-style and local restarts
671                    self.recent_lbd_sum += u64::from(lbd);
672                    self.recent_lbd_count += 1;
673                    self.global_lbd_sum += u64::from(lbd);
674                    self.global_lbd_count += 1;
675
676                    // Reset recent LBD tracking periodically
677                    if self.recent_lbd_count >= 5000 {
678                        self.recent_lbd_sum /= 2;
679                        self.recent_lbd_count /= 2;
680                    }
681
682                    let clause_id = self.clauses.add_learned(learnt_clause.iter().copied());
683                    self.stats.learned_clauses += 1;
684
685                    // Set LBD score for the clause
686                    if let Some(clause) = self.clauses.get_mut(clause_id) {
687                        clause.lbd = lbd;
688                    }
689
690                    // Track learned clause for potential deletion
691                    self.learned_clause_ids.push(clause_id);
692
693                    // Track clause for incremental solving
694                    if let Some(current_level_clauses) = self.assertion_clause_ids.last_mut() {
695                        current_level_clauses.push(clause_id);
696                    }
697
698                    // Watch first two literals
699                    let lit0 = learnt_clause[0];
700                    let lit1 = learnt_clause[1];
701                    self.watches
702                        .add(lit0.negate(), Watcher::new(clause_id, lit1));
703                    self.watches
704                        .add(lit1.negate(), Watcher::new(clause_id, lit0));
705
706                    // Propagate the asserting literal
707                    self.trail.assign_propagation(learnt_clause[0], clause_id);
708                }
709
710                // Decay activities
711                self.vsids.decay();
712                self.chb.decay();
713                self.lrb.decay();
714                self.lrb.on_conflict();
715                self.clauses.decay_activity(self.config.clause_decay);
716                // Increase clause bump increment (inverse of decay)
717                self.clause_bump_increment /= self.config.clause_decay;
718
719                // Track conflicts for clause deletion
720                self.conflicts_since_deletion += 1;
721
722                // Periodic clause database reduction
723                if self.conflicts_since_deletion >= self.config.clause_deletion_threshold as u64 {
724                    self.reduce_clause_database();
725                    self.conflicts_since_deletion = 0;
726
727                    // Vivification after clause database reduction (at level 0 after restart)
728                    if self.stats.restarts.is_multiple_of(10) {
729                        let saved_level = self.trail.decision_level();
730                        if saved_level == 0 {
731                            self.vivify_clauses();
732                        }
733                    }
734                }
735
736                // Check for restart
737                if self.stats.conflicts >= self.restart_threshold {
738                    self.restart();
739                }
740
741                // Periodic inprocessing
742                if self.config.enable_inprocessing
743                    && self.conflicts_since_inprocessing >= self.config.inprocessing_interval
744                {
745                    self.inprocess();
746                    self.conflicts_since_inprocessing = 0;
747                }
748            } else {
749                // No conflict - try to decide
750                if let Some(var) = self.pick_branch_var() {
751                    self.stats.decisions += 1;
752                    self.trail.new_decision_level();
753
754                    // Use phase saving with random polarity
755                    let polarity = if self.rand_bool(self.config.random_polarity_prob) {
756                        // Random polarity
757                        self.rand_bool(0.5)
758                    } else {
759                        // Saved phase
760                        self.phase[var.index()]
761                    };
762                    let lit = if polarity {
763                        Lit::pos(var)
764                    } else {
765                        Lit::neg(var)
766                    };
767                    self.trail.assign_decision(lit);
768                } else {
769                    // All variables assigned - SAT
770                    self.save_model();
771                    return SolverResult::Sat;
772                }
773            }
774        }
775    }
776
777    /// Solve with assumptions and return unsat core if UNSAT
778    ///
779    /// This is the key method for MaxSAT: it solves under assumptions and
780    /// if the result is UNSAT, returns the subset of assumptions in the core.
781    ///
782    /// # Arguments
783    /// * `assumptions` - Literals that must be true
784    ///
785    /// # Returns
786    /// * `(SolverResult, Option<Vec<Lit>>)` - Result and unsat core (if UNSAT)
787    pub fn solve_with_assumptions(
788        &mut self,
789        assumptions: &[Lit],
790    ) -> (SolverResult, Option<Vec<Lit>>) {
791        if self.trivially_unsat {
792            return (SolverResult::Unsat, Some(Vec::new()));
793        }
794
795        // Ensure all assumption variables exist
796        for &lit in assumptions {
797            while self.num_vars <= lit.var().index() {
798                self.new_var();
799            }
800        }
801
802        // Initial propagation at level 0
803        if self.propagate().is_some() {
804            return (SolverResult::Unsat, Some(Vec::new()));
805        }
806
807        // Create a new decision level for assumptions
808        let assumption_level_start = self.trail.decision_level();
809
810        // Assign assumptions as decisions
811        for (i, &lit) in assumptions.iter().enumerate() {
812            // Check if already assigned
813            let value = self.trail.lit_value(lit);
814            if value.is_true() {
815                continue; // Already satisfied
816            }
817            if value.is_false() {
818                // Conflict with assumption - extract core from conflicting assumptions
819                let core = self.extract_assumption_core(assumptions, i);
820                self.backtrack(assumption_level_start);
821                return (SolverResult::Unsat, Some(core));
822            }
823
824            // Make decision for assumption
825            self.trail.new_decision_level();
826            self.trail.assign_decision(lit);
827
828            // Propagate after each assumption
829            if let Some(_conflict) = self.propagate() {
830                // Conflict during assumption propagation
831                let core = self.analyze_assumption_conflict(assumptions);
832                self.backtrack(assumption_level_start);
833                return (SolverResult::Unsat, Some(core));
834            }
835        }
836
837        // Now solve normally
838        loop {
839            if let Some(conflict) = self.propagate() {
840                self.stats.conflicts += 1;
841
842                // Check if conflict involves assumptions
843                let backtrack_level = self.analyze_conflict_level(conflict);
844
845                if backtrack_level <= assumption_level_start {
846                    // Conflict forces backtracking past assumptions - UNSAT
847                    let core = self.analyze_assumption_conflict(assumptions);
848                    self.backtrack(assumption_level_start);
849                    return (SolverResult::Unsat, Some(core));
850                }
851
852                let (bt_level, learnt_clause) = self.analyze(conflict);
853                self.backtrack_with_phase_saving(bt_level.max(assumption_level_start + 1));
854                self.learn_clause(learnt_clause);
855
856                self.vsids.decay();
857                self.clauses.decay_activity(self.config.clause_decay);
858                self.handle_clause_deletion_and_restart_limited(assumption_level_start);
859            } else {
860                // No conflict - try to decide
861                if let Some(var) = self.pick_branch_var() {
862                    self.stats.decisions += 1;
863                    self.trail.new_decision_level();
864
865                    let polarity = if self.rand_bool(self.config.random_polarity_prob) {
866                        self.rand_bool(0.5)
867                    } else {
868                        self.phase.get(var.index()).copied().unwrap_or(false)
869                    };
870                    let lit = if polarity {
871                        Lit::pos(var)
872                    } else {
873                        Lit::neg(var)
874                    };
875                    self.trail.assign_decision(lit);
876                } else {
877                    // All variables assigned - SAT
878                    self.save_model();
879                    self.backtrack(assumption_level_start);
880                    return (SolverResult::Sat, None);
881                }
882            }
883        }
884    }
885
886    /// Solve with theory integration via callbacks
887    ///
888    /// This implements the CDCL(T) loop:
889    /// 1. BCP (Boolean Constraint Propagation)
890    /// 2. Theory propagation (via callback)
891    /// 3. On conflict: analyze and learn
892    /// 4. Decision
893    /// 5. Final theory check when all vars assigned
894    pub fn solve_with_theory<T: TheoryCallback>(&mut self, theory: &mut T) -> SolverResult {
895        if self.trivially_unsat {
896            return SolverResult::Unsat;
897        }
898
899        // Initial propagation
900        if self.propagate().is_some() {
901            return SolverResult::Unsat;
902        }
903
904        // Track how many assignments have been sent to the theory.
905        // We only send NEW assignments (not previously processed ones) to avoid
906        // duplicate theory constraints that would cause spurious UNSAT.
907        let mut theory_processed: usize = 0;
908
909        loop {
910            // Boolean propagation
911            if let Some(conflict) = self.propagate() {
912                self.stats.conflicts += 1;
913
914                if self.trail.decision_level() == 0 {
915                    return SolverResult::Unsat;
916                }
917
918                let (backtrack_level, learnt_clause) = self.analyze(conflict);
919                theory.on_backtrack(backtrack_level);
920                self.backtrack_with_phase_saving(backtrack_level);
921                // After backtrack, the trail may be shorter; update processed count
922                theory_processed = theory_processed.min(self.trail.assignments().len());
923                self.learn_clause(learnt_clause);
924
925                self.vsids.decay();
926                self.clauses.decay_activity(self.config.clause_decay);
927                self.handle_clause_deletion_and_restart();
928                continue;
929            }
930
931            // Theory propagation check after each assignment
932            loop {
933                // Get only NEW (unprocessed) assignments and notify theory
934                let assignments = self.trail.assignments().to_vec();
935                let mut theory_conflict = None;
936                let mut theory_propagations = Vec::new();
937
938                // Check only NEW assignments with theory (skip already-processed ones).
939                // Guard against stale theory_processed after backtracks/restarts.
940                let safe_start = theory_processed.min(assignments.len());
941                for &lit in &assignments[safe_start..] {
942                    match theory.on_assignment(lit) {
943                        TheoryCheckResult::Sat => {}
944                        TheoryCheckResult::Conflict(conflict_lits) => {
945                            theory_conflict = Some(conflict_lits);
946                            break;
947                        }
948                        TheoryCheckResult::Propagated(props) => {
949                            theory_propagations.extend(props);
950                        }
951                    }
952                }
953                // Update processed count
954                theory_processed = assignments.len();
955
956                // Handle theory conflict
957                if let Some(conflict_lits) = theory_conflict {
958                    self.stats.conflicts += 1;
959
960                    if self.trail.decision_level() == 0 {
961                        return SolverResult::Unsat;
962                    }
963
964                    let (backtrack_level, learnt_clause) =
965                        self.analyze_theory_conflict(&conflict_lits);
966
967                    // Empty learned clause signals all-level-0 conflict = fundamental UNSAT
968                    if learnt_clause.is_empty() {
969                        self.trivially_unsat = true;
970                        return SolverResult::Unsat;
971                    }
972
973                    theory.on_backtrack(backtrack_level);
974                    self.backtrack_with_phase_saving(backtrack_level);
975                    // After backtrack, update theory_processed to trail length
976                    theory_processed = theory_processed.min(self.trail.assignments().len());
977                    self.learn_clause(learnt_clause);
978
979                    self.vsids.decay();
980                    self.clauses.decay_activity(self.config.clause_decay);
981                    self.handle_clause_deletion_and_restart();
982                    continue;
983                }
984
985                // Handle theory propagations
986                let mut made_propagation = false;
987                for (lit, reason_lits) in theory_propagations {
988                    if !self.trail.is_assigned(lit.var()) {
989                        // Add reason clause and propagate
990                        let clause_id = self.add_theory_reason_clause(&reason_lits, lit);
991                        self.trail.assign_propagation(lit, clause_id);
992                        made_propagation = true;
993                    }
994                }
995
996                if made_propagation {
997                    // Re-run Boolean propagation
998                    if let Some(conflict) = self.propagate() {
999                        self.stats.conflicts += 1;
1000
1001                        if self.trail.decision_level() == 0 {
1002                            return SolverResult::Unsat;
1003                        }
1004
1005                        let (backtrack_level, learnt_clause) = self.analyze(conflict);
1006                        theory.on_backtrack(backtrack_level);
1007                        self.backtrack_with_phase_saving(backtrack_level);
1008                        // After backtrack, the trail is shorter; update processed count
1009                        theory_processed = theory_processed.min(self.trail.assignments().len());
1010                        self.learn_clause(learnt_clause);
1011
1012                        self.vsids.decay();
1013                        self.clauses.decay_activity(self.config.clause_decay);
1014                        self.handle_clause_deletion_and_restart();
1015                    }
1016                    continue;
1017                }
1018
1019                break;
1020            }
1021
1022            // Try to decide
1023            if let Some(var) = self.pick_branch_var() {
1024                self.stats.decisions += 1;
1025                self.trail.new_decision_level();
1026                let new_level = self.trail.decision_level();
1027                theory.on_new_level(new_level);
1028
1029                let polarity = if self.rand_bool(self.config.random_polarity_prob) {
1030                    self.rand_bool(0.5)
1031                } else {
1032                    self.phase[var.index()]
1033                };
1034                let lit = if polarity {
1035                    Lit::pos(var)
1036                } else {
1037                    Lit::neg(var)
1038                };
1039                self.trail.assign_decision(lit);
1040            } else {
1041                // All variables assigned - do final theory check
1042                match theory.final_check() {
1043                    TheoryCheckResult::Sat => {
1044                        self.save_model();
1045                        return SolverResult::Sat;
1046                    }
1047                    TheoryCheckResult::Conflict(conflict_lits) => {
1048                        self.stats.conflicts += 1;
1049
1050                        if self.trail.decision_level() == 0 {
1051                            return SolverResult::Unsat;
1052                        }
1053
1054                        let (backtrack_level, learnt_clause) =
1055                            self.analyze_theory_conflict(&conflict_lits);
1056
1057                        // If all conflict literals are at level 0, analyze_theory_conflict
1058                        // returns an empty learned clause as a signal of fundamental UNSAT.
1059                        if learnt_clause.is_empty() {
1060                            self.trivially_unsat = true;
1061                            return SolverResult::Unsat;
1062                        }
1063
1064                        theory.on_backtrack(backtrack_level);
1065                        self.backtrack_with_phase_saving(backtrack_level);
1066                        // After backtrack, update theory_processed
1067                        theory_processed = theory_processed.min(self.trail.assignments().len());
1068                        self.learn_clause(learnt_clause);
1069
1070                        self.vsids.decay();
1071                        self.clauses.decay_activity(self.config.clause_decay);
1072                        self.handle_clause_deletion_and_restart();
1073                    }
1074                    TheoryCheckResult::Propagated(props) => {
1075                        // Handle late propagations
1076                        for (lit, reason_lits) in props {
1077                            if !self.trail.is_assigned(lit.var()) {
1078                                let clause_id = self.add_theory_reason_clause(&reason_lits, lit);
1079                                self.trail.assign_propagation(lit, clause_id);
1080                            }
1081                        }
1082                    }
1083                }
1084            }
1085        }
1086    }
1087
1088    /// Get the model (if sat)
1089    #[must_use]
1090    pub fn model(&self) -> &[LBool] {
1091        &self.model
1092    }
1093
1094    /// Get the value of a variable in the model
1095    #[must_use]
1096    pub fn model_value(&self, var: Var) -> LBool {
1097        self.model.get(var.index()).copied().unwrap_or(LBool::Undef)
1098    }
1099
1100    /// Get statistics
1101    #[must_use]
1102    pub fn stats(&self) -> &SolverStats {
1103        &self.stats
1104    }
1105
1106    /// Get memory optimizer statistics
1107    #[must_use]
1108    pub fn memory_opt_stats(&self) -> &crate::memory_opt::MemoryOptStats {
1109        self.memory_optimizer.stats()
1110    }
1111
1112    /// Get number of variables
1113    #[must_use]
1114    pub fn num_vars(&self) -> usize {
1115        self.num_vars
1116    }
1117
1118    /// Get number of clauses
1119    #[must_use]
1120    pub fn num_clauses(&self) -> usize {
1121        self.clauses.len()
1122    }
1123
1124    /// Push a new assertion level (for incremental solving)
1125    ///
1126    /// This saves the current state so that clauses added after this point
1127    /// can be removed with pop(). Automatically backtracks to decision level 0
1128    /// to ensure a clean state for adding new constraints.
1129    pub fn push(&mut self) {
1130        // Backtrack to level 0 to ensure clean state
1131        // This is necessary because solve() may leave assignments on the trail
1132        // Use phase-saving backtrack to properly re-insert variables into decision heaps
1133        self.backtrack_with_phase_saving(0);
1134
1135        self.assertion_levels.push(self.clauses.num_original());
1136        self.assertion_trail_sizes.push(self.trail.size());
1137        self.assertion_clause_ids.push(Vec::new());
1138    }
1139
1140    /// Pop to previous assertion level
1141    pub fn pop(&mut self) {
1142        if self.assertion_levels.len() > 1 {
1143            self.assertion_levels.pop();
1144
1145            // Get the trail size to backtrack to
1146            let trail_size = self.assertion_trail_sizes.pop().unwrap_or(0);
1147
1148            // Remove all clauses added at this assertion level
1149            if let Some(clause_ids_to_remove) = self.assertion_clause_ids.pop() {
1150                for clause_id in clause_ids_to_remove {
1151                    // Remove from clause database
1152                    self.clauses.remove(clause_id);
1153
1154                    // Remove from learned clause tracking if it's a learned clause
1155                    self.learned_clause_ids.retain(|&id| id != clause_id);
1156
1157                    // Note: Watch lists will be cleaned up naturally during propagation
1158                    // as they check if clauses are deleted before using them
1159                }
1160            }
1161
1162            // Backtrack trail to the exact size it was at push()
1163            // This properly handles unit clauses that were added after push
1164            // Note: backtrack_to_size clears values but doesn't re-insert into heaps,
1165            // so we need to manually re-insert unassigned variables.
1166            let current_size = self.trail.size();
1167            if current_size > trail_size {
1168                // Collect variables that will be unassigned
1169                let mut unassigned_vars = Vec::new();
1170                for i in trail_size..current_size {
1171                    let lit = self.trail.assignments()[i];
1172                    unassigned_vars.push(lit.var());
1173                }
1174
1175                self.trail.backtrack_to_size(trail_size);
1176
1177                // Re-insert unassigned variables into decision heaps
1178                for var in unassigned_vars {
1179                    if !self.vsids.contains(var) {
1180                        self.vsids.insert(var);
1181                    }
1182                    if !self.chb.contains(var) {
1183                        self.chb.insert(var);
1184                    }
1185                    self.lrb.unassign(var);
1186                }
1187            }
1188
1189            // Ensure we're at decision level 0 with proper heap re-insertion
1190            self.backtrack_with_phase_saving(0);
1191
1192            // Clear the trivially_unsat flag as we've removed problematic clauses
1193            self.trivially_unsat = false;
1194        }
1195    }
1196
1197    /// Backtrack to decision level 0 (for AllSAT enumeration)
1198    ///
1199    /// This is necessary after a SAT result before adding blocking clauses
1200    /// to ensure the new clauses can trigger propagation correctly.
1201    /// Uses phase-saving backtrack to properly re-insert unassigned variables
1202    /// into the decision heaps (VSIDS, CHB, LRB).
1203    pub fn backtrack_to_root(&mut self) {
1204        self.backtrack_with_phase_saving(0);
1205    }
1206
1207    /// Reset the solver
1208    pub fn reset(&mut self) {
1209        self.clauses = ClauseDatabase::new();
1210        self.trail.clear();
1211        self.watches.clear();
1212        self.vsids.clear();
1213        self.chb.clear();
1214        self.stats = SolverStats::default();
1215        self.learnt.clear();
1216        self.seen.clear();
1217        self.analyze_stack.clear();
1218        self.assertion_levels.clear();
1219        self.assertion_levels.push(0);
1220        self.assertion_trail_sizes.clear();
1221        self.assertion_trail_sizes.push(0);
1222        self.assertion_clause_ids.clear();
1223        self.assertion_clause_ids.push(Vec::new());
1224        self.model.clear();
1225        self.num_vars = 0;
1226        self.restart_threshold = self.config.restart_interval;
1227        self.trivially_unsat = false;
1228        self.phase.clear();
1229        self.luby_index = 0;
1230        self.level_marks.clear();
1231        self.lbd_mark = 0;
1232        self.learned_clause_ids.clear();
1233        self.conflicts_since_deletion = 0;
1234        self.rng_state = 0x853c_49e6_748f_ea9b;
1235        self.recent_lbd_sum = 0;
1236        self.recent_lbd_count = 0;
1237        self.binary_graph.clear();
1238        self.global_lbd_sum = 0;
1239        self.global_lbd_count = 0;
1240        self.conflicts_since_local_restart = 0;
1241    }
1242
1243    /// Get the current trail (for theory solvers)
1244    #[must_use]
1245    pub fn trail(&self) -> &Trail {
1246        &self.trail
1247    }
1248
1249    /// Get the current decision level
1250    #[must_use]
1251    pub fn decision_level(&self) -> u32 {
1252        self.trail.decision_level()
1253    }
1254
1255    /// Debug method: print all learned clauses
1256    pub fn debug_print_learned_clauses(&self) {
1257        println!(
1258            "=== Learned Clauses ({}) ===",
1259            self.learned_clause_ids.len()
1260        );
1261        for (i, &cid) in self.learned_clause_ids.iter().enumerate() {
1262            if let Some(clause) = self.clauses.get(cid)
1263                && !clause.deleted
1264            {
1265                let lits: Vec<String> = clause
1266                    .lits
1267                    .iter()
1268                    .map(|lit| {
1269                        let var = lit.var().index();
1270                        if lit.is_pos() {
1271                            format!("v{}", var)
1272                        } else {
1273                            format!("~v{}", var)
1274                        }
1275                    })
1276                    .collect();
1277                println!(
1278                    "  Learned {}: ({}), LBD={}",
1279                    i,
1280                    lits.join(" | "),
1281                    clause.lbd
1282                );
1283            }
1284        }
1285    }
1286
1287    /// Debug method: print binary implication graph entries
1288    pub fn debug_print_binary_graph(&self) {
1289        println!("=== Binary Implication Graph ===");
1290        for lit_code in 0..(self.num_vars * 2) {
1291            let lit = Lit::from_code(lit_code as u32);
1292            let implications = self.binary_graph.get(lit);
1293            if !implications.is_empty() {
1294                let lit_str = if lit.is_pos() {
1295                    format!("v{}", lit.var().index())
1296                } else {
1297                    format!("~v{}", lit.var().index())
1298                };
1299                for &(implied, _cid) in implications {
1300                    let impl_str = if implied.is_pos() {
1301                        format!("v{}", implied.var().index())
1302                    } else {
1303                        format!("~v{}", implied.var().index())
1304                    };
1305                    println!("  {} -> {}", lit_str, impl_str);
1306                }
1307            }
1308        }
1309    }
1310}
1311
1312#[cfg(test)]
1313mod tests {
1314    use super::*;
1315
1316    #[test]
1317    fn test_empty_sat() {
1318        let mut solver = Solver::new();
1319        assert_eq!(solver.solve(), SolverResult::Sat);
1320    }
1321
1322    #[test]
1323    fn test_simple_sat() {
1324        let mut solver = Solver::new();
1325        let _x = solver.new_var();
1326        let _y = solver.new_var();
1327
1328        // x or y
1329        solver.add_clause_dimacs(&[1, 2]);
1330        // not x or y
1331        solver.add_clause_dimacs(&[-1, 2]);
1332
1333        assert_eq!(solver.solve(), SolverResult::Sat);
1334        assert!(solver.model_value(Var::new(1)).is_true()); // y must be true
1335    }
1336
1337    #[test]
1338    fn test_simple_unsat() {
1339        let mut solver = Solver::new();
1340        let _x = solver.new_var();
1341
1342        // x
1343        solver.add_clause_dimacs(&[1]);
1344        // not x
1345        solver.add_clause_dimacs(&[-1]);
1346
1347        assert_eq!(solver.solve(), SolverResult::Unsat);
1348    }
1349
1350    #[test]
1351    fn test_pigeonhole_2_1() {
1352        // 2 pigeons, 1 hole - UNSAT
1353        let mut solver = Solver::new();
1354        let _p1h1 = solver.new_var(); // pigeon 1 in hole 1
1355        let _p2h1 = solver.new_var(); // pigeon 2 in hole 1
1356
1357        // Each pigeon must be in some hole
1358        solver.add_clause_dimacs(&[1]); // p1 in h1
1359        solver.add_clause_dimacs(&[2]); // p2 in h1
1360
1361        // No hole can have two pigeons
1362        solver.add_clause_dimacs(&[-1, -2]); // not (p1h1 and p2h1)
1363
1364        assert_eq!(solver.solve(), SolverResult::Unsat);
1365    }
1366
1367    #[test]
1368    fn test_3sat_random() {
1369        let mut solver = Solver::new();
1370        for _ in 0..10 {
1371            solver.new_var();
1372        }
1373
1374        // Random 3-SAT instance (likely SAT)
1375        solver.add_clause_dimacs(&[1, 2, 3]);
1376        solver.add_clause_dimacs(&[-1, 4, 5]);
1377        solver.add_clause_dimacs(&[2, -3, 6]);
1378        solver.add_clause_dimacs(&[-4, 7, 8]);
1379        solver.add_clause_dimacs(&[5, -6, 9]);
1380        solver.add_clause_dimacs(&[-7, 8, 10]);
1381        solver.add_clause_dimacs(&[1, -8, -9]);
1382        solver.add_clause_dimacs(&[-2, 3, -10]);
1383
1384        let result = solver.solve();
1385        assert_eq!(result, SolverResult::Sat);
1386    }
1387
1388    #[test]
1389    fn test_luby_sequence() {
1390        // Luby sequence: 1, 1, 2, 1, 1, 2, 4, 1, 1, 2, 1, 1, 2, 4, 8, ...
1391        assert_eq!(Solver::luby(0), 1);
1392        assert_eq!(Solver::luby(1), 1);
1393        assert_eq!(Solver::luby(2), 2);
1394        assert_eq!(Solver::luby(3), 1);
1395        assert_eq!(Solver::luby(4), 1);
1396        assert_eq!(Solver::luby(5), 2);
1397        assert_eq!(Solver::luby(6), 4);
1398        assert_eq!(Solver::luby(7), 1);
1399    }
1400
1401    #[test]
1402    fn test_phase_saving() {
1403        let mut solver = Solver::new();
1404        for _ in 0..5 {
1405            solver.new_var();
1406        }
1407
1408        // Set up a problem where phase saving helps
1409        solver.add_clause_dimacs(&[1, 2]);
1410        solver.add_clause_dimacs(&[-1, 3]);
1411        solver.add_clause_dimacs(&[-2, 4]);
1412        solver.add_clause_dimacs(&[-3, -4, 5]);
1413        solver.add_clause_dimacs(&[-5, 1]);
1414
1415        let result = solver.solve();
1416        assert_eq!(result, SolverResult::Sat);
1417    }
1418
1419    #[test]
1420    fn test_lbd_computation() {
1421        // Test that clause deletion can handle a problem that generates learned clauses
1422        let mut solver = Solver::with_config(SolverConfig {
1423            clause_deletion_threshold: 5, // Trigger deletion quickly
1424            ..SolverConfig::default()
1425        });
1426
1427        for _ in 0..20 {
1428            solver.new_var();
1429        }
1430
1431        // A harder problem to generate more conflicts and learned clauses
1432        // PHP(3,2): 3 pigeons, 2 holes - UNSAT
1433        // Variables: p_i_h (pigeon i in hole h)
1434        // p11=1, p12=2, p21=3, p22=4, p31=5, p32=6
1435
1436        // Each pigeon must be in some hole
1437        solver.add_clause_dimacs(&[1, 2]); // p1 in h1 or h2
1438        solver.add_clause_dimacs(&[3, 4]); // p2 in h1 or h2
1439        solver.add_clause_dimacs(&[5, 6]); // p3 in h1 or h2
1440
1441        // No hole can have two pigeons
1442        solver.add_clause_dimacs(&[-1, -3]); // not (p1h1 and p2h1)
1443        solver.add_clause_dimacs(&[-1, -5]); // not (p1h1 and p3h1)
1444        solver.add_clause_dimacs(&[-3, -5]); // not (p2h1 and p3h1)
1445        solver.add_clause_dimacs(&[-2, -4]); // not (p1h2 and p2h2)
1446        solver.add_clause_dimacs(&[-2, -6]); // not (p1h2 and p3h2)
1447        solver.add_clause_dimacs(&[-4, -6]); // not (p2h2 and p3h2)
1448
1449        let result = solver.solve();
1450        assert_eq!(result, SolverResult::Unsat);
1451        // Verify we had some conflicts (and thus learned clauses)
1452        assert!(solver.stats().conflicts > 0);
1453    }
1454
1455    #[test]
1456    fn test_clause_activity_decay() {
1457        let mut solver = Solver::new();
1458        for _ in 0..10 {
1459            solver.new_var();
1460        }
1461
1462        // Add some clauses
1463        solver.add_clause_dimacs(&[1, 2, 3]);
1464        solver.add_clause_dimacs(&[-1, 4, 5]);
1465        solver.add_clause_dimacs(&[-2, -3, 6]);
1466
1467        // Solve (should be SAT)
1468        let result = solver.solve();
1469        assert_eq!(result, SolverResult::Sat);
1470    }
1471
1472    #[test]
1473    fn test_clause_minimization() {
1474        // Test that clause minimization works correctly on a problem
1475        // that will generate learned clauses
1476        let mut solver = Solver::new();
1477
1478        for _ in 0..15 {
1479            solver.new_var();
1480        }
1481
1482        // A problem structure that generates conflicts and learned clauses
1483        // Graph coloring with 3 colors on 5 vertices
1484        // Vertices: 1-5, Colors: R(0-4), G(5-9), B(10-14)
1485
1486        // Each vertex has at least one color
1487        solver.add_clause_dimacs(&[1, 6, 11]); // v1: R or G or B
1488        solver.add_clause_dimacs(&[2, 7, 12]); // v2
1489        solver.add_clause_dimacs(&[3, 8, 13]); // v3
1490        solver.add_clause_dimacs(&[4, 9, 14]); // v4
1491        solver.add_clause_dimacs(&[5, 10, 15]); // v5
1492
1493        // At most one color per vertex (pairwise exclusion)
1494        solver.add_clause_dimacs(&[-1, -6]); // v1: not (R and G)
1495        solver.add_clause_dimacs(&[-1, -11]); // v1: not (R and B)
1496        solver.add_clause_dimacs(&[-6, -11]); // v1: not (G and B)
1497
1498        solver.add_clause_dimacs(&[-2, -7]);
1499        solver.add_clause_dimacs(&[-2, -12]);
1500        solver.add_clause_dimacs(&[-7, -12]);
1501
1502        solver.add_clause_dimacs(&[-3, -8]);
1503        solver.add_clause_dimacs(&[-3, -13]);
1504        solver.add_clause_dimacs(&[-8, -13]);
1505
1506        // Adjacent vertices have different colors (edges: 1-2, 2-3, 3-4, 4-5)
1507        solver.add_clause_dimacs(&[-1, -2]); // edge 1-2: not both R
1508        solver.add_clause_dimacs(&[-6, -7]); // edge 1-2: not both G
1509        solver.add_clause_dimacs(&[-11, -12]); // edge 1-2: not both B
1510
1511        solver.add_clause_dimacs(&[-2, -3]); // edge 2-3
1512        solver.add_clause_dimacs(&[-7, -8]);
1513        solver.add_clause_dimacs(&[-12, -13]);
1514
1515        let result = solver.solve();
1516        assert_eq!(result, SolverResult::Sat);
1517
1518        // The solver may or may not have conflicts/learned clauses depending on
1519        // the decision heuristic. The key is that the result is correct.
1520        // If there are learned clauses, minimization would have been applied.
1521    }
1522
1523    /// A simple theory callback that does nothing (pure SAT)
1524    struct NullTheory;
1525
1526    impl TheoryCallback for NullTheory {
1527        fn on_assignment(&mut self, _lit: Lit) -> TheoryCheckResult {
1528            TheoryCheckResult::Sat
1529        }
1530
1531        fn final_check(&mut self) -> TheoryCheckResult {
1532            TheoryCheckResult::Sat
1533        }
1534
1535        fn on_backtrack(&mut self, _level: u32) {}
1536    }
1537
1538    #[test]
1539    fn test_solve_with_theory_sat() {
1540        let mut solver = Solver::new();
1541        let mut theory = NullTheory;
1542
1543        let _x = solver.new_var();
1544        let _y = solver.new_var();
1545
1546        // x or y
1547        solver.add_clause_dimacs(&[1, 2]);
1548        // not x or y
1549        solver.add_clause_dimacs(&[-1, 2]);
1550
1551        assert_eq!(solver.solve_with_theory(&mut theory), SolverResult::Sat);
1552        assert!(solver.model_value(Var::new(1)).is_true()); // y must be true
1553    }
1554
1555    #[test]
1556    fn test_solve_with_theory_unsat() {
1557        let mut solver = Solver::new();
1558        let mut theory = NullTheory;
1559
1560        let _x = solver.new_var();
1561
1562        // x
1563        solver.add_clause_dimacs(&[1]);
1564        // not x
1565        solver.add_clause_dimacs(&[-1]);
1566
1567        assert_eq!(solver.solve_with_theory(&mut theory), SolverResult::Unsat);
1568    }
1569
1570    /// A theory that forces x0 => x1 (if x0 is true, x1 must be true)
1571    struct ImplicationTheory {
1572        /// Track if x0 is assigned true
1573        x0_true: bool,
1574    }
1575
1576    impl ImplicationTheory {
1577        fn new() -> Self {
1578            Self { x0_true: false }
1579        }
1580    }
1581
1582    impl TheoryCallback for ImplicationTheory {
1583        fn on_assignment(&mut self, lit: Lit) -> TheoryCheckResult {
1584            // If x0 becomes true, propagate x1
1585            if lit.var().index() == 0 && lit.is_pos() {
1586                self.x0_true = true;
1587                // Propagate: x1 must be true because x0 is true
1588                // The reason is: ~x0 (if x0 were false, we wouldn't need x1)
1589                let reason: SmallVec<[Lit; 8]> = smallvec::smallvec![Lit::pos(Var::new(0))];
1590                return TheoryCheckResult::Propagated(vec![(Lit::pos(Var::new(1)), reason)]);
1591            }
1592            TheoryCheckResult::Sat
1593        }
1594
1595        fn final_check(&mut self) -> TheoryCheckResult {
1596            TheoryCheckResult::Sat
1597        }
1598
1599        fn on_backtrack(&mut self, _level: u32) {
1600            self.x0_true = false;
1601        }
1602    }
1603
1604    #[test]
1605    fn test_theory_propagation() {
1606        let mut solver = Solver::new();
1607        let mut theory = ImplicationTheory::new();
1608
1609        let _x0 = solver.new_var();
1610        let _x1 = solver.new_var();
1611
1612        // Force x0 to be true
1613        solver.add_clause_dimacs(&[1]);
1614
1615        let result = solver.solve_with_theory(&mut theory);
1616        assert_eq!(result, SolverResult::Sat);
1617
1618        // x0 should be true (forced by clause)
1619        assert!(solver.model_value(Var::new(0)).is_true());
1620        // x1 should also be true (propagated by theory)
1621        assert!(solver.model_value(Var::new(1)).is_true());
1622    }
1623
1624    /// Theory that says x0 and x1 can't both be true
1625    struct MutexTheory {
1626        x0_true: Option<Lit>,
1627        x1_true: Option<Lit>,
1628    }
1629
1630    impl MutexTheory {
1631        fn new() -> Self {
1632            Self {
1633                x0_true: None,
1634                x1_true: None,
1635            }
1636        }
1637    }
1638
1639    impl TheoryCallback for MutexTheory {
1640        fn on_assignment(&mut self, lit: Lit) -> TheoryCheckResult {
1641            if lit.var().index() == 0 && lit.is_pos() {
1642                self.x0_true = Some(lit);
1643            }
1644            if lit.var().index() == 1 && lit.is_pos() {
1645                self.x1_true = Some(lit);
1646            }
1647
1648            // If both are true, conflict
1649            if self.x0_true.is_some() && self.x1_true.is_some() {
1650                // Conflict clause: ~x0 or ~x1 (at least one must be false)
1651                let conflict: SmallVec<[Lit; 8]> = smallvec::smallvec![
1652                    Lit::pos(Var::new(0)), // x0 is true (we negate in conflict)
1653                    Lit::pos(Var::new(1))  // x1 is true
1654                ];
1655                return TheoryCheckResult::Conflict(conflict);
1656            }
1657            TheoryCheckResult::Sat
1658        }
1659
1660        fn final_check(&mut self) -> TheoryCheckResult {
1661            if self.x0_true.is_some() && self.x1_true.is_some() {
1662                let conflict: SmallVec<[Lit; 8]> =
1663                    smallvec::smallvec![Lit::pos(Var::new(0)), Lit::pos(Var::new(1))];
1664                return TheoryCheckResult::Conflict(conflict);
1665            }
1666            TheoryCheckResult::Sat
1667        }
1668
1669        fn on_backtrack(&mut self, _level: u32) {
1670            self.x0_true = None;
1671            self.x1_true = None;
1672        }
1673    }
1674
1675    #[test]
1676    fn test_theory_conflict() {
1677        let mut solver = Solver::new();
1678        let mut theory = MutexTheory::new();
1679
1680        let _x0 = solver.new_var();
1681        let _x1 = solver.new_var();
1682
1683        // Force both x0 and x1 to be true (should cause theory conflict)
1684        solver.add_clause_dimacs(&[1]);
1685        solver.add_clause_dimacs(&[2]);
1686
1687        let result = solver.solve_with_theory(&mut theory);
1688        assert_eq!(result, SolverResult::Unsat);
1689    }
1690
1691    #[test]
1692    fn test_solve_with_assumptions_sat() {
1693        let mut solver = Solver::new();
1694
1695        let x0 = solver.new_var();
1696        let x1 = solver.new_var();
1697
1698        // x0 \/ x1
1699        solver.add_clause([Lit::pos(x0), Lit::pos(x1)]);
1700
1701        // Assume x0 = true
1702        let assumptions = [Lit::pos(x0)];
1703        let (result, core) = solver.solve_with_assumptions(&assumptions);
1704
1705        assert_eq!(result, SolverResult::Sat);
1706        assert!(core.is_none());
1707    }
1708
1709    #[test]
1710    fn test_solve_with_assumptions_unsat() {
1711        let mut solver = Solver::new();
1712
1713        let x0 = solver.new_var();
1714        let x1 = solver.new_var();
1715
1716        // x0 -> ~x1 (encoded as ~x0 \/ ~x1)
1717        solver.add_clause([Lit::neg(x0), Lit::neg(x1)]);
1718
1719        // Assume both x0 = true and x1 = true (should be UNSAT)
1720        let assumptions = [Lit::pos(x0), Lit::pos(x1)];
1721        let (result, core) = solver.solve_with_assumptions(&assumptions);
1722
1723        assert_eq!(result, SolverResult::Unsat);
1724        assert!(core.is_some());
1725        let core = core.expect("UNSAT result must have conflict core");
1726        // Core should contain at least one of the conflicting assumptions
1727        assert!(!core.is_empty());
1728    }
1729
1730    #[test]
1731    fn test_solve_with_assumptions_core_extraction() {
1732        let mut solver = Solver::new();
1733
1734        let x0 = solver.new_var();
1735        let x1 = solver.new_var();
1736        let x2 = solver.new_var();
1737
1738        // ~x0 (x0 must be false)
1739        solver.add_clause([Lit::neg(x0)]);
1740
1741        // Assume x0 = true, x1 = true, x2 = true
1742        // Only x0 should be in the core
1743        let assumptions = [Lit::pos(x0), Lit::pos(x1), Lit::pos(x2)];
1744        let (result, core) = solver.solve_with_assumptions(&assumptions);
1745
1746        assert_eq!(result, SolverResult::Unsat);
1747        assert!(core.is_some());
1748        let core = core.expect("UNSAT result must have conflict core");
1749        // x0 should be in the core
1750        assert!(core.contains(&Lit::pos(x0)));
1751    }
1752
1753    #[test]
1754    fn test_solve_with_assumptions_incremental() {
1755        let mut solver = Solver::new();
1756
1757        let x0 = solver.new_var();
1758        let x1 = solver.new_var();
1759
1760        // x0 \/ x1
1761        solver.add_clause([Lit::pos(x0), Lit::pos(x1)]);
1762
1763        // First: assume ~x0 (should be SAT with x1 = true)
1764        let (result1, _) = solver.solve_with_assumptions(&[Lit::neg(x0)]);
1765        assert_eq!(result1, SolverResult::Sat);
1766
1767        // Second: assume ~x0 and ~x1 (should be UNSAT)
1768        let (result2, core2) = solver.solve_with_assumptions(&[Lit::neg(x0), Lit::neg(x1)]);
1769        assert_eq!(result2, SolverResult::Unsat);
1770        assert!(core2.is_some());
1771
1772        // Third: assume x0 (should be SAT again)
1773        let (result3, _) = solver.solve_with_assumptions(&[Lit::pos(x0)]);
1774        assert_eq!(result3, SolverResult::Sat);
1775    }
1776
1777    #[test]
1778    fn test_push_pop_simple() {
1779        let mut solver = Solver::new();
1780
1781        let x0 = solver.new_var();
1782
1783        // Should be SAT (x0 can be true or false)
1784        assert_eq!(solver.solve(), SolverResult::Sat);
1785
1786        // Push and add unit clause: x0
1787        solver.push();
1788        solver.add_clause([Lit::pos(x0)]);
1789        assert_eq!(solver.solve(), SolverResult::Sat);
1790        assert!(solver.model_value(x0).is_true());
1791
1792        // Pop - should be SAT again
1793        solver.pop();
1794        let result = solver.solve();
1795        assert_eq!(
1796            result,
1797            SolverResult::Sat,
1798            "After pop, expected SAT but got {:?}. trivially_unsat={}",
1799            result,
1800            solver.trivially_unsat
1801        );
1802    }
1803
1804    #[test]
1805    fn test_push_pop_incremental() {
1806        let mut solver = Solver::new();
1807
1808        let x0 = solver.new_var();
1809        let x1 = solver.new_var();
1810        let x2 = solver.new_var();
1811
1812        // Base level: x0 \/ x1
1813        solver.add_clause([Lit::pos(x0), Lit::pos(x1)]);
1814        assert_eq!(solver.solve(), SolverResult::Sat);
1815
1816        // Push and add: ~x0
1817        solver.push();
1818        solver.add_clause([Lit::neg(x0)]);
1819        assert_eq!(solver.solve(), SolverResult::Sat);
1820        // x1 must be true
1821        assert!(solver.model_value(x1).is_true());
1822
1823        // Push again and add: ~x1 (should be UNSAT)
1824        solver.push();
1825        solver.add_clause([Lit::neg(x1)]);
1826        assert_eq!(solver.solve(), SolverResult::Unsat);
1827
1828        // Pop back one level (remove ~x1, keep ~x0)
1829        solver.pop();
1830        assert_eq!(solver.solve(), SolverResult::Sat);
1831        assert!(solver.model_value(x1).is_true());
1832
1833        // Pop back to base level (remove ~x0)
1834        solver.pop();
1835        assert_eq!(solver.solve(), SolverResult::Sat);
1836        // Either x0 or x1 can be true now
1837
1838        // Push and add different clause: x0 /\ x2
1839        solver.push();
1840        solver.add_clause([Lit::pos(x0)]);
1841        solver.add_clause([Lit::pos(x2)]);
1842        assert_eq!(solver.solve(), SolverResult::Sat);
1843        assert!(solver.model_value(x0).is_true());
1844        assert!(solver.model_value(x2).is_true());
1845
1846        // Pop and verify clauses are removed
1847        solver.pop();
1848        assert_eq!(solver.solve(), SolverResult::Sat);
1849    }
1850
1851    #[test]
1852    fn test_push_pop_with_learned_clauses() {
1853        let mut solver = Solver::new();
1854
1855        let x0 = solver.new_var();
1856        let x1 = solver.new_var();
1857        let x2 = solver.new_var();
1858
1859        // Create a formula that will cause learning
1860        // (x0 \/ x1) /\ (~x0 \/ x2) /\ (~x1 \/ x2)
1861        solver.add_clause([Lit::pos(x0), Lit::pos(x1)]);
1862        solver.add_clause([Lit::neg(x0), Lit::pos(x2)]);
1863        solver.add_clause([Lit::neg(x1), Lit::pos(x2)]);
1864
1865        assert_eq!(solver.solve(), SolverResult::Sat);
1866
1867        // Push and add conflicting clause
1868        solver.push();
1869        solver.add_clause([Lit::neg(x2)]);
1870
1871        // This should be UNSAT and cause clause learning
1872        assert_eq!(solver.solve(), SolverResult::Unsat);
1873
1874        // Pop - learned clauses from this level should be removed
1875        solver.pop();
1876
1877        // Should be SAT again
1878        assert_eq!(solver.solve(), SolverResult::Sat);
1879    }
1880}