Skip to main content

oxiz_solver/conflict/
clause_learning.rs

1//! Clause Learning for CDCL(T) Solver.
2//!
3//! Implements sophisticated conflict-driven clause learning including:
4//! - First UIP computation
5//! - Conflict clause minimization
6//! - Asserting clause generation
7//! - Learned clause database management
8//! - Clause subsumption and strengthening
9
10#[allow(unused_imports)]
11use crate::prelude::*;
12use oxiz_core::ast::{TermId, TermManager};
13
14/// Clause learning engine for CDCL.
15pub struct ClauseLearner {
16    /// Implication graph
17    impl_graph: ImplicationGraph,
18    /// Learned clause database
19    learned_db: LearnedDatabase,
20    /// Clause minimization engine
21    minimizer: ClauseMinimizer,
22    /// Configuration
23    config: ClauseLearningConfig,
24    /// Statistics
25    stats: ClauseLearningStats,
26}
27
28/// Implication graph for conflict analysis.
29#[derive(Debug, Clone)]
30pub struct ImplicationGraph {
31    /// Nodes: variable → implication node
32    nodes: FxHashMap<TermId, ImplicationNode>,
33    /// Adjacency list: variable → predecessors
34    #[allow(dead_code)]
35    predecessors: FxHashMap<TermId, Vec<TermId>>,
36    /// Decision levels: variable → level
37    levels: FxHashMap<TermId, usize>,
38    /// Current decision level
39    current_level: usize,
40}
41
42/// Node in the implication graph.
43#[derive(Debug, Clone)]
44pub struct ImplicationNode {
45    /// Variable
46    pub var: TermId,
47    /// Assigned value
48    pub value: bool,
49    /// Decision level
50    pub level: usize,
51    /// Reason clause (None for decisions)
52    pub reason: Option<ClauseId>,
53    /// Is this a decision variable?
54    pub is_decision: bool,
55}
56
57/// Clause identifier.
58pub type ClauseId = usize;
59
60/// Learned clause database.
61#[derive(Debug, Clone)]
62pub struct LearnedDatabase {
63    /// Learned clauses
64    clauses: Vec<LearnedClause>,
65    /// Activity scores for LRU
66    activity: Vec<f64>,
67    /// Clause to ID mapping
68    clause_map: FxHashMap<Vec<TermId>, ClauseId>,
69    /// Bump increment for activity
70    bump_increment: f64,
71    /// Decay factor
72    decay_factor: f64,
73}
74
75/// A learned clause.
76#[derive(Debug, Clone)]
77pub struct LearnedClause {
78    /// Literals in the clause
79    pub literals: Vec<TermId>,
80    /// Asserting literal (first UIP)
81    pub asserting_lit: TermId,
82    /// Backtrack level
83    pub backtrack_level: usize,
84    /// Clause activity
85    pub activity: f64,
86    /// Is this clause locked (in conflict analysis)?
87    pub locked: bool,
88    /// Glue level (LBD)
89    pub lbd: usize,
90}
91
92/// Clause minimization engine.
93#[derive(Debug, Clone)]
94pub struct ClauseMinimizer {
95    /// Seen variables during minimization
96    seen: FxHashSet<TermId>,
97    /// Variables to analyze
98    analyze_stack: Vec<TermId>,
99    /// Minimization cache
100    #[allow(dead_code)]
101    cache: FxHashMap<TermId, bool>,
102}
103
104/// Configuration for clause learning.
105#[derive(Debug, Clone)]
106pub struct ClauseLearningConfig {
107    /// Enable clause minimization
108    pub enable_minimization: bool,
109    /// Enable recursive minimization
110    pub enable_recursive_minimization: bool,
111    /// Enable clause subsumption
112    pub enable_subsumption: bool,
113    /// Enable clause strengthening
114    pub enable_strengthening: bool,
115    /// Maximum clause size for learning
116    pub max_learned_size: usize,
117    /// LBD threshold for keeping clauses
118    pub lbd_threshold: usize,
119    /// Activity decay factor
120    pub activity_decay: f64,
121}
122
123impl Default for ClauseLearningConfig {
124    fn default() -> Self {
125        Self {
126            enable_minimization: true,
127            enable_recursive_minimization: true,
128            enable_subsumption: true,
129            enable_strengthening: true,
130            max_learned_size: 1000,
131            lbd_threshold: 5,
132            activity_decay: 0.95,
133        }
134    }
135}
136
137/// Clause learning statistics.
138#[derive(Debug, Clone, Default)]
139pub struct ClauseLearningStats {
140    /// Conflicts analyzed
141    pub conflicts_analyzed: usize,
142    /// Clauses learned
143    pub clauses_learned: usize,
144    /// Literals in learned clauses (before minimization)
145    pub literals_before_minimization: usize,
146    /// Literals after minimization
147    pub literals_after_minimization: usize,
148    /// Clauses subsumed
149    pub clauses_subsumed: usize,
150    /// Clauses strengthened
151    pub clauses_strengthened: usize,
152    /// UIP computations
153    pub uip_computations: usize,
154    /// Clause database reductions
155    pub db_reductions: usize,
156}
157
158impl ClauseLearner {
159    /// Create a new clause learner.
160    pub fn new(config: ClauseLearningConfig) -> Self {
161        Self {
162            impl_graph: ImplicationGraph::new(),
163            learned_db: LearnedDatabase::new(config.activity_decay),
164            minimizer: ClauseMinimizer::new(),
165            config,
166            stats: ClauseLearningStats::default(),
167        }
168    }
169
170    /// Analyze a conflict and learn a clause.
171    pub fn analyze_conflict(
172        &mut self,
173        conflict_clause: ClauseId,
174        _tm: &TermManager,
175    ) -> Result<LearnedClause, String> {
176        self.stats.conflicts_analyzed += 1;
177
178        // Build initial conflict clause
179        let conflict_lits = self.get_clause_literals(conflict_clause)?;
180
181        // Compute First UIP
182        let (learned_lits, asserting_lit, backtrack_level) =
183            self.compute_first_uip(&conflict_lits)?;
184
185        self.stats.uip_computations += 1;
186        self.stats.literals_before_minimization += learned_lits.len();
187
188        // Minimize clause
189        let minimized_lits = if self.config.enable_minimization {
190            self.minimize_clause(&learned_lits)?
191        } else {
192            learned_lits
193        };
194
195        self.stats.literals_after_minimization += minimized_lits.len();
196
197        // Compute LBD (Literal Block Distance)
198        let lbd = self.compute_lbd(&minimized_lits);
199
200        // Create learned clause
201        let learned = LearnedClause {
202            literals: minimized_lits,
203            asserting_lit,
204            backtrack_level,
205            activity: 0.0,
206            locked: false,
207            lbd,
208        };
209
210        self.stats.clauses_learned += 1;
211
212        // Add to database
213        self.learned_db.add_clause(learned.clone());
214
215        Ok(learned)
216    }
217
218    /// Compute First UIP (Unique Implication Point).
219    fn compute_first_uip(
220        &mut self,
221        conflict_lits: &[TermId],
222    ) -> Result<(Vec<TermId>, TermId, usize), String> {
223        let current_level = self.impl_graph.current_level;
224
225        // Initialize with conflict clause
226        let mut clause = conflict_lits.to_vec();
227        let mut seen = FxHashSet::default();
228        let mut counter = 0;
229
230        // Count literals at current level
231        for &lit in &clause {
232            if self.impl_graph.get_level(lit) == current_level {
233                counter += 1;
234            }
235            seen.insert(lit);
236        }
237
238        // Resolve until we have exactly one literal at current level
239        let mut asserting_lit = TermId::from(0);
240
241        while counter > 1 {
242            // Find a literal to resolve on
243            let resolve_lit = clause
244                .iter()
245                .copied()
246                .find(|&lit| {
247                    self.impl_graph.get_level(lit) == current_level
248                        && !self.impl_graph.is_decision(lit)
249                })
250                .ok_or("No literal to resolve on")?;
251
252            // Get reason clause
253            let reason = self
254                .impl_graph
255                .get_reason(resolve_lit)
256                .ok_or("No reason for propagated literal")?;
257
258            let reason_lits = self.get_clause_literals(reason)?;
259
260            // Resolve
261            clause.retain(|&lit| lit != resolve_lit);
262            counter -= 1;
263
264            for &reason_lit in &reason_lits {
265                if reason_lit != resolve_lit && !seen.contains(&reason_lit) {
266                    clause.push(reason_lit);
267                    seen.insert(reason_lit);
268
269                    if self.impl_graph.get_level(reason_lit) == current_level {
270                        counter += 1;
271                    }
272                }
273            }
274        }
275
276        // Find the asserting literal (the one at current level)
277        for &lit in &clause {
278            if self.impl_graph.get_level(lit) == current_level {
279                asserting_lit = lit;
280                break;
281            }
282        }
283
284        // Compute backtrack level (second highest level in clause)
285        let mut levels: Vec<usize> = clause
286            .iter()
287            .map(|&lit| self.impl_graph.get_level(lit))
288            .collect();
289        levels.sort_unstable();
290        levels.dedup();
291
292        let backtrack_level = if levels.len() > 1 {
293            levels[levels.len() - 2]
294        } else {
295            0
296        };
297
298        Ok((clause, asserting_lit, backtrack_level))
299    }
300
301    /// Minimize a learned clause.
302    fn minimize_clause(&mut self, clause: &[TermId]) -> Result<Vec<TermId>, String> {
303        if !self.config.enable_minimization {
304            return Ok(clause.to_vec());
305        }
306
307        let mut minimized = clause.to_vec();
308
309        // Remove redundant literals
310        minimized.retain(|&lit| !self.is_redundant(lit, clause));
311
312        // Recursive minimization
313        if self.config.enable_recursive_minimization {
314            minimized = self.recursive_minimize(&minimized)?;
315        }
316
317        Ok(minimized)
318    }
319
320    /// Check if a literal is redundant in a clause.
321    fn is_redundant(&mut self, lit: TermId, clause: &[TermId]) -> bool {
322        // Check if all literals in the reason of lit are in clause
323        if let Some(reason) = self.impl_graph.get_reason(lit)
324            && let Ok(reason_lits) = self.get_clause_literals(reason)
325        {
326            return reason_lits
327                .iter()
328                .all(|&r_lit| r_lit == lit || clause.contains(&r_lit));
329        }
330
331        false
332    }
333
334    /// Recursive clause minimization.
335    fn recursive_minimize(&mut self, clause: &[TermId]) -> Result<Vec<TermId>, String> {
336        self.minimizer.seen.clear();
337        self.minimizer.analyze_stack.clear();
338
339        // Mark all clause literals as seen
340        for &lit in clause {
341            self.minimizer.seen.insert(lit);
342        }
343
344        let mut minimized = Vec::new();
345
346        for &lit in clause {
347            if !self.minimizer.can_remove(lit, &self.impl_graph)? {
348                minimized.push(lit);
349            }
350        }
351
352        Ok(minimized)
353    }
354
355    /// Compute Literal Block Distance (LBD/Glue).
356    fn compute_lbd(&self, clause: &[TermId]) -> usize {
357        let mut levels = FxHashSet::default();
358
359        for &lit in clause {
360            let level = self.impl_graph.get_level(lit);
361            levels.insert(level);
362        }
363
364        levels.len()
365    }
366
367    /// Get literals of a clause.
368    fn get_clause_literals(&self, _clause_id: ClauseId) -> Result<Vec<TermId>, String> {
369        // Placeholder: would retrieve from clause database
370        Ok(vec![])
371    }
372
373    /// Subsume redundant clauses.
374    pub fn subsume_clauses(&mut self) -> Result<(), String> {
375        if !self.config.enable_subsumption {
376            return Ok(());
377        }
378
379        let mut to_remove = Vec::new();
380
381        // Check each pair of clauses
382        for i in 0..self.learned_db.clauses.len() {
383            for j in (i + 1)..self.learned_db.clauses.len() {
384                if self.learned_db.clauses[i].locked || self.learned_db.clauses[j].locked {
385                    continue;
386                }
387
388                let clause_i = &self.learned_db.clauses[i].literals;
389                let clause_j = &self.learned_db.clauses[j].literals;
390
391                // Check if clause_i subsumes clause_j
392                if Self::subsumes(clause_i, clause_j) {
393                    to_remove.push(j);
394                    self.stats.clauses_subsumed += 1;
395                } else if Self::subsumes(clause_j, clause_i) {
396                    to_remove.push(i);
397                    self.stats.clauses_subsumed += 1;
398                    break;
399                }
400            }
401        }
402
403        // Remove subsumed clauses
404        to_remove.sort_unstable();
405        to_remove.dedup();
406        for &idx in to_remove.iter().rev() {
407            self.learned_db.clauses.remove(idx);
408            self.learned_db.activity.remove(idx);
409        }
410
411        Ok(())
412    }
413
414    /// Check if clause A subsumes clause B.
415    fn subsumes(a: &[TermId], b: &[TermId]) -> bool {
416        if a.len() > b.len() {
417            return false;
418        }
419
420        let b_set: FxHashSet<TermId> = b.iter().copied().collect();
421
422        a.iter().all(|lit| b_set.contains(lit))
423    }
424
425    /// Strengthen clauses by removing literals.
426    pub fn strengthen_clauses(&mut self) -> Result<(), String> {
427        if !self.config.enable_strengthening {
428            return Ok(());
429        }
430
431        for idx in 0..self.learned_db.clauses.len() {
432            if self.learned_db.clauses[idx].locked {
433                continue;
434            }
435
436            // Clone the literals to avoid simultaneous borrow of self (needed because
437            // can_remove_literal borrows self.impl_graph immutably while we modify the clause).
438            let original_literals = self.learned_db.clauses[idx].literals.clone();
439            let original_len = original_literals.len();
440
441            let mut new_literals = Vec::with_capacity(original_len);
442            for &lit in &original_literals {
443                if !self.can_remove_literal(lit, &original_literals) {
444                    new_literals.push(lit);
445                }
446            }
447
448            if new_literals.len() < original_len {
449                self.learned_db.clauses[idx].literals = new_literals;
450                self.stats.clauses_strengthened += 1;
451            }
452        }
453
454        Ok(())
455    }
456
457    /// Check if a literal can be removed from a clause.
458    ///
459    /// A literal `lit` is removable when every literal in its reason clause
460    /// (other than `lit` itself) is already in `clause`.  The check is
461    /// recursive: a peer literal that is not directly in `clause` may itself
462    /// be removable, allowing deeper subsumption.
463    ///
464    /// A `visited` set guards against revisiting nodes; implication graphs in
465    /// standard CDCL are acyclic, but the guard is kept for robustness.
466    fn can_remove_literal(&self, lit: TermId, clause: &[TermId]) -> bool {
467        let mut visited = FxHashSet::default();
468        self.can_remove_literal_rec(lit, clause, &mut visited)
469    }
470
471    fn can_remove_literal_rec(
472        &self,
473        lit: TermId,
474        clause: &[TermId],
475        visited: &mut FxHashSet<TermId>,
476    ) -> bool {
477        if !visited.insert(lit) {
478            // Already visited this node — avoid infinite looping.
479            return true;
480        }
481
482        // Decision literals have no reason clause and cannot be removed.
483        let reason_id = match self.impl_graph.get_reason(lit) {
484            Some(r) => r,
485            None => return false,
486        };
487
488        // Retrieve the reason clause's literals.  If the underlying clause
489        // database is not yet wired up, `get_clause_literals` returns an empty
490        // vec, which makes the vacuous `all` below return `true` — harmless
491        // because no real literal will be removed from an empty reason clause.
492        let reason_lits = match self.get_clause_literals(reason_id) {
493            Ok(lits) => lits,
494            Err(_) => return false,
495        };
496
497        // Every literal in the reason clause (other than `lit` itself) must
498        // either already be present in `clause` or be recursively removable.
499        for other_lit in &reason_lits {
500            if *other_lit == lit {
501                continue;
502            }
503            if !clause.contains(other_lit)
504                && !self.can_remove_literal_rec(*other_lit, clause, visited)
505            {
506                return false;
507            }
508        }
509
510        true
511    }
512
513    /// Reduce clause database.
514    pub fn reduce_database(&mut self) -> Result<(), String> {
515        self.stats.db_reductions += 1;
516
517        // Remove low-activity clauses
518        self.learned_db.reduce();
519
520        Ok(())
521    }
522
523    /// Bump clause activity.
524    pub fn bump_clause(&mut self, clause_id: ClauseId) {
525        self.learned_db.bump_activity(clause_id);
526    }
527
528    /// Get statistics.
529    pub fn stats(&self) -> &ClauseLearningStats {
530        &self.stats
531    }
532}
533
534impl ImplicationGraph {
535    /// Create a new implication graph.
536    pub fn new() -> Self {
537        Self {
538            nodes: FxHashMap::default(),
539            predecessors: FxHashMap::default(),
540            levels: FxHashMap::default(),
541            current_level: 0,
542        }
543    }
544
545    /// Add a node to the graph.
546    pub fn add_node(
547        &mut self,
548        var: TermId,
549        value: bool,
550        level: usize,
551        reason: Option<ClauseId>,
552        is_decision: bool,
553    ) {
554        self.nodes.insert(
555            var,
556            ImplicationNode {
557                var,
558                value,
559                level,
560                reason,
561                is_decision,
562            },
563        );
564
565        self.levels.insert(var, level);
566    }
567
568    /// Get decision level of a variable.
569    pub fn get_level(&self, var: TermId) -> usize {
570        self.levels.get(&var).copied().unwrap_or(0)
571    }
572
573    /// Check if a variable is a decision.
574    pub fn is_decision(&self, var: TermId) -> bool {
575        self.nodes.get(&var).is_some_and(|n| n.is_decision)
576    }
577
578    /// Get reason clause for a variable.
579    pub fn get_reason(&self, var: TermId) -> Option<ClauseId> {
580        self.nodes.get(&var).and_then(|n| n.reason)
581    }
582
583    /// Set current decision level.
584    pub fn set_level(&mut self, level: usize) {
585        self.current_level = level;
586    }
587}
588
589impl LearnedDatabase {
590    /// Create a new learned database.
591    pub fn new(decay_factor: f64) -> Self {
592        Self {
593            clauses: Vec::new(),
594            activity: Vec::new(),
595            clause_map: FxHashMap::default(),
596            bump_increment: 1.0,
597            decay_factor,
598        }
599    }
600
601    /// Add a clause to the database.
602    pub fn add_clause(&mut self, clause: LearnedClause) {
603        let clause_id = self.clauses.len();
604
605        self.clause_map.insert(clause.literals.clone(), clause_id);
606        self.activity.push(clause.activity);
607        self.clauses.push(clause);
608    }
609
610    /// Bump clause activity.
611    pub fn bump_activity(&mut self, clause_id: ClauseId) {
612        if clause_id < self.activity.len() {
613            self.activity[clause_id] += self.bump_increment;
614
615            // Rescale if needed
616            if self.activity[clause_id] > 1e20 {
617                for act in &mut self.activity {
618                    *act *= 1e-20;
619                }
620                self.bump_increment *= 1e-20;
621            }
622        }
623    }
624
625    /// Decay all activities.
626    pub fn decay(&mut self) {
627        self.bump_increment /= self.decay_factor;
628    }
629
630    /// Reduce database by removing low-activity clauses.
631    pub fn reduce(&mut self) {
632        let mut sorted_indices: Vec<usize> = (0..self.clauses.len()).collect();
633
634        // Sort by activity (descending)
635        sorted_indices.sort_by(|&a, &b| {
636            self.activity[b]
637                .partial_cmp(&self.activity[a])
638                .unwrap_or(core::cmp::Ordering::Equal)
639        });
640
641        // Keep top 50%
642        let keep_count = self.clauses.len() / 2;
643
644        let mut to_keep = FxHashSet::default();
645        for &idx in sorted_indices.iter().take(keep_count) {
646            to_keep.insert(idx);
647        }
648
649        // Also keep locked clauses
650        for (idx, clause) in self.clauses.iter().enumerate() {
651            if clause.locked {
652                to_keep.insert(idx);
653            }
654        }
655
656        // Rebuild database
657        let mut new_clauses = Vec::new();
658        let mut new_activity = Vec::new();
659
660        for (idx, clause) in self.clauses.iter().enumerate() {
661            if to_keep.contains(&idx) {
662                new_clauses.push(clause.clone());
663                new_activity.push(self.activity[idx]);
664            }
665        }
666
667        self.clauses = new_clauses;
668        self.activity = new_activity;
669        self.clause_map.clear();
670
671        // Rebuild map
672        for (idx, clause) in self.clauses.iter().enumerate() {
673            self.clause_map.insert(clause.literals.clone(), idx);
674        }
675    }
676}
677
678impl ClauseMinimizer {
679    /// Create a new clause minimizer.
680    pub fn new() -> Self {
681        Self {
682            seen: FxHashSet::default(),
683            analyze_stack: Vec::new(),
684            cache: FxHashMap::default(),
685        }
686    }
687
688    /// Check if a literal can be removed.
689    fn can_remove(&mut self, _lit: TermId, _graph: &ImplicationGraph) -> Result<bool, String> {
690        // Simplified: would do recursive analysis
691        Ok(false)
692    }
693}
694
695impl Default for ClauseLearner {
696    fn default() -> Self {
697        Self::new(ClauseLearningConfig::default())
698    }
699}
700
701impl Default for ImplicationGraph {
702    fn default() -> Self {
703        Self::new()
704    }
705}
706
707impl Default for ClauseMinimizer {
708    fn default() -> Self {
709        Self::new()
710    }
711}
712
713#[cfg(test)]
714mod tests {
715    use super::*;
716
717    #[test]
718    fn test_clause_learner() {
719        let learner = ClauseLearner::default();
720        assert_eq!(learner.stats.conflicts_analyzed, 0);
721    }
722
723    #[test]
724    fn test_implication_graph() {
725        let mut graph = ImplicationGraph::new();
726
727        let var = TermId::from(1);
728        graph.add_node(var, true, 1, None, true);
729
730        assert_eq!(graph.get_level(var), 1);
731        assert!(graph.is_decision(var));
732    }
733
734    #[test]
735    fn test_learned_database() {
736        let mut db = LearnedDatabase::new(0.95);
737
738        let clause = LearnedClause {
739            literals: vec![TermId::from(1), TermId::from(2)],
740            asserting_lit: TermId::from(1),
741            backtrack_level: 0,
742            activity: 0.0,
743            locked: false,
744            lbd: 2,
745        };
746
747        db.add_clause(clause);
748        assert_eq!(db.clauses.len(), 1);
749    }
750
751    #[test]
752    fn test_subsumption() {
753        let a = vec![TermId::from(1), TermId::from(2)];
754        let b = vec![TermId::from(1), TermId::from(2), TermId::from(3)];
755
756        assert!(ClauseLearner::subsumes(&a, &b));
757        assert!(!ClauseLearner::subsumes(&b, &a));
758    }
759
760    #[test]
761    fn test_lbd_computation() {
762        let learner = ClauseLearner::default();
763
764        let clause = vec![TermId::from(1), TermId::from(2), TermId::from(3)];
765        let lbd = learner.compute_lbd(&clause);
766
767        // LBD depends on decision levels, which are 0 by default
768        assert_eq!(lbd, 1);
769    }
770}