Skip to main content

oxiz_solver/conflict/
clause_learning.rs

1//! Clause Learning for CDCL(T) Solver.
2#![allow(dead_code)] // Under development
3//!
4//! Implements sophisticated conflict-driven clause learning including:
5//! - First UIP computation
6//! - Conflict clause minimization
7//! - Asserting clause generation
8//! - Learned clause database management
9//! - Clause subsumption and strengthening
10
11#[allow(unused_imports)]
12use crate::prelude::*;
13use oxiz_core::ast::{TermId, TermManager};
14
15/// Clause learning engine for CDCL.
16pub struct ClauseLearner {
17    /// Implication graph
18    impl_graph: ImplicationGraph,
19    /// Learned clause database
20    learned_db: LearnedDatabase,
21    /// Clause minimization engine
22    minimizer: ClauseMinimizer,
23    /// Configuration
24    config: ClauseLearningConfig,
25    /// Statistics
26    stats: ClauseLearningStats,
27}
28
29/// Implication graph for conflict analysis.
30#[derive(Debug, Clone)]
31pub struct ImplicationGraph {
32    /// Nodes: variable → implication node
33    nodes: FxHashMap<TermId, ImplicationNode>,
34    /// Adjacency list: variable → predecessors
35    predecessors: FxHashMap<TermId, Vec<TermId>>,
36    /// Decision levels: variable → level
37    levels: FxHashMap<TermId, usize>,
38    /// Current decision level
39    current_level: usize,
40}
41
42/// Node in the implication graph.
43#[derive(Debug, Clone)]
44pub struct ImplicationNode {
45    /// Variable
46    pub var: TermId,
47    /// Assigned value
48    pub value: bool,
49    /// Decision level
50    pub level: usize,
51    /// Reason clause (None for decisions)
52    pub reason: Option<ClauseId>,
53    /// Is this a decision variable?
54    pub is_decision: bool,
55}
56
57/// Clause identifier.
58pub type ClauseId = usize;
59
60/// Learned clause database.
61#[derive(Debug, Clone)]
62pub struct LearnedDatabase {
63    /// Learned clauses
64    clauses: Vec<LearnedClause>,
65    /// Activity scores for LRU
66    activity: Vec<f64>,
67    /// Clause to ID mapping
68    clause_map: FxHashMap<Vec<TermId>, ClauseId>,
69    /// Bump increment for activity
70    bump_increment: f64,
71    /// Decay factor
72    decay_factor: f64,
73}
74
75/// A learned clause.
76#[derive(Debug, Clone)]
77pub struct LearnedClause {
78    /// Literals in the clause
79    pub literals: Vec<TermId>,
80    /// Asserting literal (first UIP)
81    pub asserting_lit: TermId,
82    /// Backtrack level
83    pub backtrack_level: usize,
84    /// Clause activity
85    pub activity: f64,
86    /// Is this clause locked (in conflict analysis)?
87    pub locked: bool,
88    /// Glue level (LBD)
89    pub lbd: usize,
90}
91
92/// Clause minimization engine.
93#[derive(Debug, Clone)]
94pub struct ClauseMinimizer {
95    /// Seen variables during minimization
96    seen: FxHashSet<TermId>,
97    /// Variables to analyze
98    analyze_stack: Vec<TermId>,
99    /// Minimization cache
100    cache: FxHashMap<TermId, bool>,
101}
102
103/// Configuration for clause learning.
104#[derive(Debug, Clone)]
105pub struct ClauseLearningConfig {
106    /// Enable clause minimization
107    pub enable_minimization: bool,
108    /// Enable recursive minimization
109    pub enable_recursive_minimization: bool,
110    /// Enable clause subsumption
111    pub enable_subsumption: bool,
112    /// Enable clause strengthening
113    pub enable_strengthening: bool,
114    /// Maximum clause size for learning
115    pub max_learned_size: usize,
116    /// LBD threshold for keeping clauses
117    pub lbd_threshold: usize,
118    /// Activity decay factor
119    pub activity_decay: f64,
120}
121
122impl Default for ClauseLearningConfig {
123    fn default() -> Self {
124        Self {
125            enable_minimization: true,
126            enable_recursive_minimization: true,
127            enable_subsumption: true,
128            enable_strengthening: true,
129            max_learned_size: 1000,
130            lbd_threshold: 5,
131            activity_decay: 0.95,
132        }
133    }
134}
135
136/// Clause learning statistics.
137#[derive(Debug, Clone, Default)]
138pub struct ClauseLearningStats {
139    /// Conflicts analyzed
140    pub conflicts_analyzed: usize,
141    /// Clauses learned
142    pub clauses_learned: usize,
143    /// Literals in learned clauses (before minimization)
144    pub literals_before_minimization: usize,
145    /// Literals after minimization
146    pub literals_after_minimization: usize,
147    /// Clauses subsumed
148    pub clauses_subsumed: usize,
149    /// Clauses strengthened
150    pub clauses_strengthened: usize,
151    /// UIP computations
152    pub uip_computations: usize,
153    /// Clause database reductions
154    pub db_reductions: usize,
155}
156
157impl ClauseLearner {
158    /// Create a new clause learner.
159    pub fn new(config: ClauseLearningConfig) -> Self {
160        Self {
161            impl_graph: ImplicationGraph::new(),
162            learned_db: LearnedDatabase::new(config.activity_decay),
163            minimizer: ClauseMinimizer::new(),
164            config,
165            stats: ClauseLearningStats::default(),
166        }
167    }
168
169    /// Analyze a conflict and learn a clause.
170    pub fn analyze_conflict(
171        &mut self,
172        conflict_clause: ClauseId,
173        _tm: &TermManager,
174    ) -> Result<LearnedClause, String> {
175        self.stats.conflicts_analyzed += 1;
176
177        // Build initial conflict clause
178        let conflict_lits = self.get_clause_literals(conflict_clause)?;
179
180        // Compute First UIP
181        let (learned_lits, asserting_lit, backtrack_level) =
182            self.compute_first_uip(&conflict_lits)?;
183
184        self.stats.uip_computations += 1;
185        self.stats.literals_before_minimization += learned_lits.len();
186
187        // Minimize clause
188        let minimized_lits = if self.config.enable_minimization {
189            self.minimize_clause(&learned_lits)?
190        } else {
191            learned_lits
192        };
193
194        self.stats.literals_after_minimization += minimized_lits.len();
195
196        // Compute LBD (Literal Block Distance)
197        let lbd = self.compute_lbd(&minimized_lits);
198
199        // Create learned clause
200        let learned = LearnedClause {
201            literals: minimized_lits,
202            asserting_lit,
203            backtrack_level,
204            activity: 0.0,
205            locked: false,
206            lbd,
207        };
208
209        self.stats.clauses_learned += 1;
210
211        // Add to database
212        self.learned_db.add_clause(learned.clone());
213
214        Ok(learned)
215    }
216
217    /// Compute First UIP (Unique Implication Point).
218    fn compute_first_uip(
219        &mut self,
220        conflict_lits: &[TermId],
221    ) -> Result<(Vec<TermId>, TermId, usize), String> {
222        let current_level = self.impl_graph.current_level;
223
224        // Initialize with conflict clause
225        let mut clause = conflict_lits.to_vec();
226        let mut seen = FxHashSet::default();
227        let mut counter = 0;
228
229        // Count literals at current level
230        for &lit in &clause {
231            if self.impl_graph.get_level(lit) == current_level {
232                counter += 1;
233            }
234            seen.insert(lit);
235        }
236
237        // Resolve until we have exactly one literal at current level
238        let mut asserting_lit = TermId::from(0);
239
240        while counter > 1 {
241            // Find a literal to resolve on
242            let resolve_lit = clause
243                .iter()
244                .copied()
245                .find(|&lit| {
246                    self.impl_graph.get_level(lit) == current_level
247                        && !self.impl_graph.is_decision(lit)
248                })
249                .ok_or("No literal to resolve on")?;
250
251            // Get reason clause
252            let reason = self
253                .impl_graph
254                .get_reason(resolve_lit)
255                .ok_or("No reason for propagated literal")?;
256
257            let reason_lits = self.get_clause_literals(reason)?;
258
259            // Resolve
260            clause.retain(|&lit| lit != resolve_lit);
261            counter -= 1;
262
263            for &reason_lit in &reason_lits {
264                if reason_lit != resolve_lit && !seen.contains(&reason_lit) {
265                    clause.push(reason_lit);
266                    seen.insert(reason_lit);
267
268                    if self.impl_graph.get_level(reason_lit) == current_level {
269                        counter += 1;
270                    }
271                }
272            }
273        }
274
275        // Find the asserting literal (the one at current level)
276        for &lit in &clause {
277            if self.impl_graph.get_level(lit) == current_level {
278                asserting_lit = lit;
279                break;
280            }
281        }
282
283        // Compute backtrack level (second highest level in clause)
284        let mut levels: Vec<usize> = clause
285            .iter()
286            .map(|&lit| self.impl_graph.get_level(lit))
287            .collect();
288        levels.sort_unstable();
289        levels.dedup();
290
291        let backtrack_level = if levels.len() > 1 {
292            levels[levels.len() - 2]
293        } else {
294            0
295        };
296
297        Ok((clause, asserting_lit, backtrack_level))
298    }
299
300    /// Minimize a learned clause.
301    fn minimize_clause(&mut self, clause: &[TermId]) -> Result<Vec<TermId>, String> {
302        if !self.config.enable_minimization {
303            return Ok(clause.to_vec());
304        }
305
306        let mut minimized = clause.to_vec();
307
308        // Remove redundant literals
309        minimized.retain(|&lit| !self.is_redundant(lit, clause));
310
311        // Recursive minimization
312        if self.config.enable_recursive_minimization {
313            minimized = self.recursive_minimize(&minimized)?;
314        }
315
316        Ok(minimized)
317    }
318
319    /// Check if a literal is redundant in a clause.
320    fn is_redundant(&mut self, lit: TermId, clause: &[TermId]) -> bool {
321        // Check if all literals in the reason of lit are in clause
322        if let Some(reason) = self.impl_graph.get_reason(lit)
323            && let Ok(reason_lits) = self.get_clause_literals(reason)
324        {
325            return reason_lits
326                .iter()
327                .all(|&r_lit| r_lit == lit || clause.contains(&r_lit));
328        }
329
330        false
331    }
332
333    /// Recursive clause minimization.
334    fn recursive_minimize(&mut self, clause: &[TermId]) -> Result<Vec<TermId>, String> {
335        self.minimizer.seen.clear();
336        self.minimizer.analyze_stack.clear();
337
338        // Mark all clause literals as seen
339        for &lit in clause {
340            self.minimizer.seen.insert(lit);
341        }
342
343        let mut minimized = Vec::new();
344
345        for &lit in clause {
346            if !self.minimizer.can_remove(lit, &self.impl_graph)? {
347                minimized.push(lit);
348            }
349        }
350
351        Ok(minimized)
352    }
353
354    /// Compute Literal Block Distance (LBD/Glue).
355    fn compute_lbd(&self, clause: &[TermId]) -> usize {
356        let mut levels = FxHashSet::default();
357
358        for &lit in clause {
359            let level = self.impl_graph.get_level(lit);
360            levels.insert(level);
361        }
362
363        levels.len()
364    }
365
366    /// Get literals of a clause.
367    fn get_clause_literals(&self, _clause_id: ClauseId) -> Result<Vec<TermId>, String> {
368        // Placeholder: would retrieve from clause database
369        Ok(vec![])
370    }
371
372    /// Subsume redundant clauses.
373    pub fn subsume_clauses(&mut self) -> Result<(), String> {
374        if !self.config.enable_subsumption {
375            return Ok(());
376        }
377
378        let mut to_remove = Vec::new();
379
380        // Check each pair of clauses
381        for i in 0..self.learned_db.clauses.len() {
382            for j in (i + 1)..self.learned_db.clauses.len() {
383                if self.learned_db.clauses[i].locked || self.learned_db.clauses[j].locked {
384                    continue;
385                }
386
387                let clause_i = &self.learned_db.clauses[i].literals;
388                let clause_j = &self.learned_db.clauses[j].literals;
389
390                // Check if clause_i subsumes clause_j
391                if Self::subsumes(clause_i, clause_j) {
392                    to_remove.push(j);
393                    self.stats.clauses_subsumed += 1;
394                } else if Self::subsumes(clause_j, clause_i) {
395                    to_remove.push(i);
396                    self.stats.clauses_subsumed += 1;
397                    break;
398                }
399            }
400        }
401
402        // Remove subsumed clauses
403        to_remove.sort_unstable();
404        to_remove.dedup();
405        for &idx in to_remove.iter().rev() {
406            self.learned_db.clauses.remove(idx);
407            self.learned_db.activity.remove(idx);
408        }
409
410        Ok(())
411    }
412
413    /// Check if clause A subsumes clause B.
414    fn subsumes(a: &[TermId], b: &[TermId]) -> bool {
415        if a.len() > b.len() {
416            return false;
417        }
418
419        let b_set: FxHashSet<TermId> = b.iter().copied().collect();
420
421        a.iter().all(|lit| b_set.contains(lit))
422    }
423
424    /// Strengthen clauses by removing literals.
425    pub fn strengthen_clauses(&mut self) -> Result<(), String> {
426        if !self.config.enable_strengthening {
427            return Ok(());
428        }
429
430        // TODO: Clause strengthening not implemented yet (can_remove_literal always returns false)
431        // for clause in &mut self.learned_db.clauses {
432        //     if clause.locked {
433        //         continue;
434        //     }
435        //
436        //     let original_len = clause.literals.len();
437        //
438        //     // Try to remove each literal (clone to avoid borrow checker issues)
439        //     let original_literals = clause.literals.clone();
440        //     clause.literals.retain(|&lit| {
441        //         !self.can_remove_literal(lit, &original_literals)
442        //     });
443        //
444        //     if clause.literals.len() < original_len {
445        //         self.stats.clauses_strengthened += 1;
446        //     }
447        // }
448
449        Ok(())
450    }
451
452    /// Check if a literal can be removed from a clause.
453    fn can_remove_literal(&self, _lit: TermId, _clause: &[TermId]) -> bool {
454        // Simplified: would check if clause remains asserting without this literal
455        false
456    }
457
458    /// Reduce clause database.
459    pub fn reduce_database(&mut self) -> Result<(), String> {
460        self.stats.db_reductions += 1;
461
462        // Remove low-activity clauses
463        self.learned_db.reduce();
464
465        Ok(())
466    }
467
468    /// Bump clause activity.
469    pub fn bump_clause(&mut self, clause_id: ClauseId) {
470        self.learned_db.bump_activity(clause_id);
471    }
472
473    /// Get statistics.
474    pub fn stats(&self) -> &ClauseLearningStats {
475        &self.stats
476    }
477}
478
479impl ImplicationGraph {
480    /// Create a new implication graph.
481    pub fn new() -> Self {
482        Self {
483            nodes: FxHashMap::default(),
484            predecessors: FxHashMap::default(),
485            levels: FxHashMap::default(),
486            current_level: 0,
487        }
488    }
489
490    /// Add a node to the graph.
491    pub fn add_node(
492        &mut self,
493        var: TermId,
494        value: bool,
495        level: usize,
496        reason: Option<ClauseId>,
497        is_decision: bool,
498    ) {
499        self.nodes.insert(
500            var,
501            ImplicationNode {
502                var,
503                value,
504                level,
505                reason,
506                is_decision,
507            },
508        );
509
510        self.levels.insert(var, level);
511    }
512
513    /// Get decision level of a variable.
514    pub fn get_level(&self, var: TermId) -> usize {
515        self.levels.get(&var).copied().unwrap_or(0)
516    }
517
518    /// Check if a variable is a decision.
519    pub fn is_decision(&self, var: TermId) -> bool {
520        self.nodes.get(&var).is_some_and(|n| n.is_decision)
521    }
522
523    /// Get reason clause for a variable.
524    pub fn get_reason(&self, var: TermId) -> Option<ClauseId> {
525        self.nodes.get(&var).and_then(|n| n.reason)
526    }
527
528    /// Set current decision level.
529    pub fn set_level(&mut self, level: usize) {
530        self.current_level = level;
531    }
532}
533
534impl LearnedDatabase {
535    /// Create a new learned database.
536    pub fn new(decay_factor: f64) -> Self {
537        Self {
538            clauses: Vec::new(),
539            activity: Vec::new(),
540            clause_map: FxHashMap::default(),
541            bump_increment: 1.0,
542            decay_factor,
543        }
544    }
545
546    /// Add a clause to the database.
547    pub fn add_clause(&mut self, clause: LearnedClause) {
548        let clause_id = self.clauses.len();
549
550        self.clause_map.insert(clause.literals.clone(), clause_id);
551        self.activity.push(clause.activity);
552        self.clauses.push(clause);
553    }
554
555    /// Bump clause activity.
556    pub fn bump_activity(&mut self, clause_id: ClauseId) {
557        if clause_id < self.activity.len() {
558            self.activity[clause_id] += self.bump_increment;
559
560            // Rescale if needed
561            if self.activity[clause_id] > 1e20 {
562                for act in &mut self.activity {
563                    *act *= 1e-20;
564                }
565                self.bump_increment *= 1e-20;
566            }
567        }
568    }
569
570    /// Decay all activities.
571    pub fn decay(&mut self) {
572        self.bump_increment /= self.decay_factor;
573    }
574
575    /// Reduce database by removing low-activity clauses.
576    pub fn reduce(&mut self) {
577        let mut sorted_indices: Vec<usize> = (0..self.clauses.len()).collect();
578
579        // Sort by activity (descending)
580        sorted_indices.sort_by(|&a, &b| {
581            self.activity[b]
582                .partial_cmp(&self.activity[a])
583                .unwrap_or(core::cmp::Ordering::Equal)
584        });
585
586        // Keep top 50%
587        let keep_count = self.clauses.len() / 2;
588
589        let mut to_keep = FxHashSet::default();
590        for &idx in sorted_indices.iter().take(keep_count) {
591            to_keep.insert(idx);
592        }
593
594        // Also keep locked clauses
595        for (idx, clause) in self.clauses.iter().enumerate() {
596            if clause.locked {
597                to_keep.insert(idx);
598            }
599        }
600
601        // Rebuild database
602        let mut new_clauses = Vec::new();
603        let mut new_activity = Vec::new();
604
605        for (idx, clause) in self.clauses.iter().enumerate() {
606            if to_keep.contains(&idx) {
607                new_clauses.push(clause.clone());
608                new_activity.push(self.activity[idx]);
609            }
610        }
611
612        self.clauses = new_clauses;
613        self.activity = new_activity;
614        self.clause_map.clear();
615
616        // Rebuild map
617        for (idx, clause) in self.clauses.iter().enumerate() {
618            self.clause_map.insert(clause.literals.clone(), idx);
619        }
620    }
621}
622
623impl ClauseMinimizer {
624    /// Create a new clause minimizer.
625    pub fn new() -> Self {
626        Self {
627            seen: FxHashSet::default(),
628            analyze_stack: Vec::new(),
629            cache: FxHashMap::default(),
630        }
631    }
632
633    /// Check if a literal can be removed.
634    fn can_remove(&mut self, _lit: TermId, _graph: &ImplicationGraph) -> Result<bool, String> {
635        // Simplified: would do recursive analysis
636        Ok(false)
637    }
638}
639
640impl Default for ClauseLearner {
641    fn default() -> Self {
642        Self::new(ClauseLearningConfig::default())
643    }
644}
645
646impl Default for ImplicationGraph {
647    fn default() -> Self {
648        Self::new()
649    }
650}
651
652impl Default for ClauseMinimizer {
653    fn default() -> Self {
654        Self::new()
655    }
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661
662    #[test]
663    fn test_clause_learner() {
664        let learner = ClauseLearner::default();
665        assert_eq!(learner.stats.conflicts_analyzed, 0);
666    }
667
668    #[test]
669    fn test_implication_graph() {
670        let mut graph = ImplicationGraph::new();
671
672        let var = TermId::from(1);
673        graph.add_node(var, true, 1, None, true);
674
675        assert_eq!(graph.get_level(var), 1);
676        assert!(graph.is_decision(var));
677    }
678
679    #[test]
680    fn test_learned_database() {
681        let mut db = LearnedDatabase::new(0.95);
682
683        let clause = LearnedClause {
684            literals: vec![TermId::from(1), TermId::from(2)],
685            asserting_lit: TermId::from(1),
686            backtrack_level: 0,
687            activity: 0.0,
688            locked: false,
689            lbd: 2,
690        };
691
692        db.add_clause(clause);
693        assert_eq!(db.clauses.len(), 1);
694    }
695
696    #[test]
697    fn test_subsumption() {
698        let a = vec![TermId::from(1), TermId::from(2)];
699        let b = vec![TermId::from(1), TermId::from(2), TermId::from(3)];
700
701        assert!(ClauseLearner::subsumes(&a, &b));
702        assert!(!ClauseLearner::subsumes(&b, &a));
703    }
704
705    #[test]
706    fn test_lbd_computation() {
707        let learner = ClauseLearner::default();
708
709        let clause = vec![TermId::from(1), TermId::from(2), TermId::from(3)];
710        let lbd = learner.compute_lbd(&clause);
711
712        // LBD depends on decision levels, which are 0 by default
713        assert_eq!(lbd, 1);
714    }
715}