Skip to main content

oxiz_solver/conflict/
clause_learning.rs

1//! Clause Learning for CDCL(T) Solver.
2#![allow(dead_code)] // Under development
3//!
4//! Implements sophisticated conflict-driven clause learning including:
5//! - First UIP computation
6//! - Conflict clause minimization
7//! - Asserting clause generation
8//! - Learned clause database management
9//! - Clause subsumption and strengthening
10
11use oxiz_core::ast::{TermId, TermManager};
12use rustc_hash::{FxHashMap, FxHashSet};
13
14/// Clause learning engine for CDCL.
15pub struct ClauseLearner {
16    /// Implication graph
17    impl_graph: ImplicationGraph,
18    /// Learned clause database
19    learned_db: LearnedDatabase,
20    /// Clause minimization engine
21    minimizer: ClauseMinimizer,
22    /// Configuration
23    config: ClauseLearningConfig,
24    /// Statistics
25    stats: ClauseLearningStats,
26}
27
28/// Implication graph for conflict analysis.
29#[derive(Debug, Clone)]
30pub struct ImplicationGraph {
31    /// Nodes: variable → implication node
32    nodes: FxHashMap<TermId, ImplicationNode>,
33    /// Adjacency list: variable → predecessors
34    predecessors: FxHashMap<TermId, Vec<TermId>>,
35    /// Decision levels: variable → level
36    levels: FxHashMap<TermId, usize>,
37    /// Current decision level
38    current_level: usize,
39}
40
41/// Node in the implication graph.
42#[derive(Debug, Clone)]
43pub struct ImplicationNode {
44    /// Variable
45    pub var: TermId,
46    /// Assigned value
47    pub value: bool,
48    /// Decision level
49    pub level: usize,
50    /// Reason clause (None for decisions)
51    pub reason: Option<ClauseId>,
52    /// Is this a decision variable?
53    pub is_decision: bool,
54}
55
56/// Clause identifier.
57pub type ClauseId = usize;
58
59/// Learned clause database.
60#[derive(Debug, Clone)]
61pub struct LearnedDatabase {
62    /// Learned clauses
63    clauses: Vec<LearnedClause>,
64    /// Activity scores for LRU
65    activity: Vec<f64>,
66    /// Clause to ID mapping
67    clause_map: FxHashMap<Vec<TermId>, ClauseId>,
68    /// Bump increment for activity
69    bump_increment: f64,
70    /// Decay factor
71    decay_factor: f64,
72}
73
74/// A learned clause.
75#[derive(Debug, Clone)]
76pub struct LearnedClause {
77    /// Literals in the clause
78    pub literals: Vec<TermId>,
79    /// Asserting literal (first UIP)
80    pub asserting_lit: TermId,
81    /// Backtrack level
82    pub backtrack_level: usize,
83    /// Clause activity
84    pub activity: f64,
85    /// Is this clause locked (in conflict analysis)?
86    pub locked: bool,
87    /// Glue level (LBD)
88    pub lbd: usize,
89}
90
91/// Clause minimization engine.
92#[derive(Debug, Clone)]
93pub struct ClauseMinimizer {
94    /// Seen variables during minimization
95    seen: FxHashSet<TermId>,
96    /// Variables to analyze
97    analyze_stack: Vec<TermId>,
98    /// Minimization cache
99    cache: FxHashMap<TermId, bool>,
100}
101
102/// Configuration for clause learning.
103#[derive(Debug, Clone)]
104pub struct ClauseLearningConfig {
105    /// Enable clause minimization
106    pub enable_minimization: bool,
107    /// Enable recursive minimization
108    pub enable_recursive_minimization: bool,
109    /// Enable clause subsumption
110    pub enable_subsumption: bool,
111    /// Enable clause strengthening
112    pub enable_strengthening: bool,
113    /// Maximum clause size for learning
114    pub max_learned_size: usize,
115    /// LBD threshold for keeping clauses
116    pub lbd_threshold: usize,
117    /// Activity decay factor
118    pub activity_decay: f64,
119}
120
121impl Default for ClauseLearningConfig {
122    fn default() -> Self {
123        Self {
124            enable_minimization: true,
125            enable_recursive_minimization: true,
126            enable_subsumption: true,
127            enable_strengthening: true,
128            max_learned_size: 1000,
129            lbd_threshold: 5,
130            activity_decay: 0.95,
131        }
132    }
133}
134
135/// Clause learning statistics.
136#[derive(Debug, Clone, Default)]
137pub struct ClauseLearningStats {
138    /// Conflicts analyzed
139    pub conflicts_analyzed: usize,
140    /// Clauses learned
141    pub clauses_learned: usize,
142    /// Literals in learned clauses (before minimization)
143    pub literals_before_minimization: usize,
144    /// Literals after minimization
145    pub literals_after_minimization: usize,
146    /// Clauses subsumed
147    pub clauses_subsumed: usize,
148    /// Clauses strengthened
149    pub clauses_strengthened: usize,
150    /// UIP computations
151    pub uip_computations: usize,
152    /// Clause database reductions
153    pub db_reductions: usize,
154}
155
156impl ClauseLearner {
157    /// Create a new clause learner.
158    pub fn new(config: ClauseLearningConfig) -> Self {
159        Self {
160            impl_graph: ImplicationGraph::new(),
161            learned_db: LearnedDatabase::new(config.activity_decay),
162            minimizer: ClauseMinimizer::new(),
163            config,
164            stats: ClauseLearningStats::default(),
165        }
166    }
167
168    /// Analyze a conflict and learn a clause.
169    pub fn analyze_conflict(
170        &mut self,
171        conflict_clause: ClauseId,
172        _tm: &TermManager,
173    ) -> Result<LearnedClause, String> {
174        self.stats.conflicts_analyzed += 1;
175
176        // Build initial conflict clause
177        let conflict_lits = self.get_clause_literals(conflict_clause)?;
178
179        // Compute First UIP
180        let (learned_lits, asserting_lit, backtrack_level) =
181            self.compute_first_uip(&conflict_lits)?;
182
183        self.stats.uip_computations += 1;
184        self.stats.literals_before_minimization += learned_lits.len();
185
186        // Minimize clause
187        let minimized_lits = if self.config.enable_minimization {
188            self.minimize_clause(&learned_lits)?
189        } else {
190            learned_lits
191        };
192
193        self.stats.literals_after_minimization += minimized_lits.len();
194
195        // Compute LBD (Literal Block Distance)
196        let lbd = self.compute_lbd(&minimized_lits);
197
198        // Create learned clause
199        let learned = LearnedClause {
200            literals: minimized_lits,
201            asserting_lit,
202            backtrack_level,
203            activity: 0.0,
204            locked: false,
205            lbd,
206        };
207
208        self.stats.clauses_learned += 1;
209
210        // Add to database
211        self.learned_db.add_clause(learned.clone());
212
213        Ok(learned)
214    }
215
216    /// Compute First UIP (Unique Implication Point).
217    fn compute_first_uip(
218        &mut self,
219        conflict_lits: &[TermId],
220    ) -> Result<(Vec<TermId>, TermId, usize), String> {
221        let current_level = self.impl_graph.current_level;
222
223        // Initialize with conflict clause
224        let mut clause = conflict_lits.to_vec();
225        let mut seen = FxHashSet::default();
226        let mut counter = 0;
227
228        // Count literals at current level
229        for &lit in &clause {
230            if self.impl_graph.get_level(lit) == current_level {
231                counter += 1;
232            }
233            seen.insert(lit);
234        }
235
236        // Resolve until we have exactly one literal at current level
237        let mut asserting_lit = TermId::from(0);
238
239        while counter > 1 {
240            // Find a literal to resolve on
241            let resolve_lit = clause
242                .iter()
243                .copied()
244                .find(|&lit| {
245                    self.impl_graph.get_level(lit) == current_level
246                        && !self.impl_graph.is_decision(lit)
247                })
248                .ok_or("No literal to resolve on")?;
249
250            // Get reason clause
251            let reason = self
252                .impl_graph
253                .get_reason(resolve_lit)
254                .ok_or("No reason for propagated literal")?;
255
256            let reason_lits = self.get_clause_literals(reason)?;
257
258            // Resolve
259            clause.retain(|&lit| lit != resolve_lit);
260            counter -= 1;
261
262            for &reason_lit in &reason_lits {
263                if reason_lit != resolve_lit && !seen.contains(&reason_lit) {
264                    clause.push(reason_lit);
265                    seen.insert(reason_lit);
266
267                    if self.impl_graph.get_level(reason_lit) == current_level {
268                        counter += 1;
269                    }
270                }
271            }
272        }
273
274        // Find the asserting literal (the one at current level)
275        for &lit in &clause {
276            if self.impl_graph.get_level(lit) == current_level {
277                asserting_lit = lit;
278                break;
279            }
280        }
281
282        // Compute backtrack level (second highest level in clause)
283        let mut levels: Vec<usize> = clause
284            .iter()
285            .map(|&lit| self.impl_graph.get_level(lit))
286            .collect();
287        levels.sort_unstable();
288        levels.dedup();
289
290        let backtrack_level = if levels.len() > 1 {
291            levels[levels.len() - 2]
292        } else {
293            0
294        };
295
296        Ok((clause, asserting_lit, backtrack_level))
297    }
298
299    /// Minimize a learned clause.
300    fn minimize_clause(&mut self, clause: &[TermId]) -> Result<Vec<TermId>, String> {
301        if !self.config.enable_minimization {
302            return Ok(clause.to_vec());
303        }
304
305        let mut minimized = clause.to_vec();
306
307        // Remove redundant literals
308        minimized.retain(|&lit| !self.is_redundant(lit, clause));
309
310        // Recursive minimization
311        if self.config.enable_recursive_minimization {
312            minimized = self.recursive_minimize(&minimized)?;
313        }
314
315        Ok(minimized)
316    }
317
318    /// Check if a literal is redundant in a clause.
319    fn is_redundant(&mut self, lit: TermId, clause: &[TermId]) -> bool {
320        // Check if all literals in the reason of lit are in clause
321        if let Some(reason) = self.impl_graph.get_reason(lit)
322            && let Ok(reason_lits) = self.get_clause_literals(reason)
323        {
324            return reason_lits
325                .iter()
326                .all(|&r_lit| r_lit == lit || clause.contains(&r_lit));
327        }
328
329        false
330    }
331
332    /// Recursive clause minimization.
333    fn recursive_minimize(&mut self, clause: &[TermId]) -> Result<Vec<TermId>, String> {
334        self.minimizer.seen.clear();
335        self.minimizer.analyze_stack.clear();
336
337        // Mark all clause literals as seen
338        for &lit in clause {
339            self.minimizer.seen.insert(lit);
340        }
341
342        let mut minimized = Vec::new();
343
344        for &lit in clause {
345            if !self.minimizer.can_remove(lit, &self.impl_graph)? {
346                minimized.push(lit);
347            }
348        }
349
350        Ok(minimized)
351    }
352
353    /// Compute Literal Block Distance (LBD/Glue).
354    fn compute_lbd(&self, clause: &[TermId]) -> usize {
355        let mut levels = FxHashSet::default();
356
357        for &lit in clause {
358            let level = self.impl_graph.get_level(lit);
359            levels.insert(level);
360        }
361
362        levels.len()
363    }
364
365    /// Get literals of a clause.
366    fn get_clause_literals(&self, _clause_id: ClauseId) -> Result<Vec<TermId>, String> {
367        // Placeholder: would retrieve from clause database
368        Ok(vec![])
369    }
370
371    /// Subsume redundant clauses.
372    pub fn subsume_clauses(&mut self) -> Result<(), String> {
373        if !self.config.enable_subsumption {
374            return Ok(());
375        }
376
377        let mut to_remove = Vec::new();
378
379        // Check each pair of clauses
380        for i in 0..self.learned_db.clauses.len() {
381            for j in (i + 1)..self.learned_db.clauses.len() {
382                if self.learned_db.clauses[i].locked || self.learned_db.clauses[j].locked {
383                    continue;
384                }
385
386                let clause_i = &self.learned_db.clauses[i].literals;
387                let clause_j = &self.learned_db.clauses[j].literals;
388
389                // Check if clause_i subsumes clause_j
390                if Self::subsumes(clause_i, clause_j) {
391                    to_remove.push(j);
392                    self.stats.clauses_subsumed += 1;
393                } else if Self::subsumes(clause_j, clause_i) {
394                    to_remove.push(i);
395                    self.stats.clauses_subsumed += 1;
396                    break;
397                }
398            }
399        }
400
401        // Remove subsumed clauses
402        to_remove.sort_unstable();
403        to_remove.dedup();
404        for &idx in to_remove.iter().rev() {
405            self.learned_db.clauses.remove(idx);
406            self.learned_db.activity.remove(idx);
407        }
408
409        Ok(())
410    }
411
412    /// Check if clause A subsumes clause B.
413    fn subsumes(a: &[TermId], b: &[TermId]) -> bool {
414        if a.len() > b.len() {
415            return false;
416        }
417
418        let b_set: FxHashSet<TermId> = b.iter().copied().collect();
419
420        a.iter().all(|lit| b_set.contains(lit))
421    }
422
423    /// Strengthen clauses by removing literals.
424    pub fn strengthen_clauses(&mut self) -> Result<(), String> {
425        if !self.config.enable_strengthening {
426            return Ok(());
427        }
428
429        // TODO: Clause strengthening not implemented yet (can_remove_literal always returns false)
430        // for clause in &mut self.learned_db.clauses {
431        //     if clause.locked {
432        //         continue;
433        //     }
434        //
435        //     let original_len = clause.literals.len();
436        //
437        //     // Try to remove each literal (clone to avoid borrow checker issues)
438        //     let original_literals = clause.literals.clone();
439        //     clause.literals.retain(|&lit| {
440        //         !self.can_remove_literal(lit, &original_literals)
441        //     });
442        //
443        //     if clause.literals.len() < original_len {
444        //         self.stats.clauses_strengthened += 1;
445        //     }
446        // }
447
448        Ok(())
449    }
450
451    /// Check if a literal can be removed from a clause.
452    fn can_remove_literal(&self, _lit: TermId, _clause: &[TermId]) -> bool {
453        // Simplified: would check if clause remains asserting without this literal
454        false
455    }
456
457    /// Reduce clause database.
458    pub fn reduce_database(&mut self) -> Result<(), String> {
459        self.stats.db_reductions += 1;
460
461        // Remove low-activity clauses
462        self.learned_db.reduce();
463
464        Ok(())
465    }
466
467    /// Bump clause activity.
468    pub fn bump_clause(&mut self, clause_id: ClauseId) {
469        self.learned_db.bump_activity(clause_id);
470    }
471
472    /// Get statistics.
473    pub fn stats(&self) -> &ClauseLearningStats {
474        &self.stats
475    }
476}
477
478impl ImplicationGraph {
479    /// Create a new implication graph.
480    pub fn new() -> Self {
481        Self {
482            nodes: FxHashMap::default(),
483            predecessors: FxHashMap::default(),
484            levels: FxHashMap::default(),
485            current_level: 0,
486        }
487    }
488
489    /// Add a node to the graph.
490    pub fn add_node(
491        &mut self,
492        var: TermId,
493        value: bool,
494        level: usize,
495        reason: Option<ClauseId>,
496        is_decision: bool,
497    ) {
498        self.nodes.insert(
499            var,
500            ImplicationNode {
501                var,
502                value,
503                level,
504                reason,
505                is_decision,
506            },
507        );
508
509        self.levels.insert(var, level);
510    }
511
512    /// Get decision level of a variable.
513    pub fn get_level(&self, var: TermId) -> usize {
514        self.levels.get(&var).copied().unwrap_or(0)
515    }
516
517    /// Check if a variable is a decision.
518    pub fn is_decision(&self, var: TermId) -> bool {
519        self.nodes.get(&var).is_some_and(|n| n.is_decision)
520    }
521
522    /// Get reason clause for a variable.
523    pub fn get_reason(&self, var: TermId) -> Option<ClauseId> {
524        self.nodes.get(&var).and_then(|n| n.reason)
525    }
526
527    /// Set current decision level.
528    pub fn set_level(&mut self, level: usize) {
529        self.current_level = level;
530    }
531}
532
533impl LearnedDatabase {
534    /// Create a new learned database.
535    pub fn new(decay_factor: f64) -> Self {
536        Self {
537            clauses: Vec::new(),
538            activity: Vec::new(),
539            clause_map: FxHashMap::default(),
540            bump_increment: 1.0,
541            decay_factor,
542        }
543    }
544
545    /// Add a clause to the database.
546    pub fn add_clause(&mut self, clause: LearnedClause) {
547        let clause_id = self.clauses.len();
548
549        self.clause_map.insert(clause.literals.clone(), clause_id);
550        self.activity.push(clause.activity);
551        self.clauses.push(clause);
552    }
553
554    /// Bump clause activity.
555    pub fn bump_activity(&mut self, clause_id: ClauseId) {
556        if clause_id < self.activity.len() {
557            self.activity[clause_id] += self.bump_increment;
558
559            // Rescale if needed
560            if self.activity[clause_id] > 1e20 {
561                for act in &mut self.activity {
562                    *act *= 1e-20;
563                }
564                self.bump_increment *= 1e-20;
565            }
566        }
567    }
568
569    /// Decay all activities.
570    pub fn decay(&mut self) {
571        self.bump_increment /= self.decay_factor;
572    }
573
574    /// Reduce database by removing low-activity clauses.
575    pub fn reduce(&mut self) {
576        let mut sorted_indices: Vec<usize> = (0..self.clauses.len()).collect();
577
578        // Sort by activity (descending)
579        sorted_indices.sort_by(|&a, &b| {
580            self.activity[b]
581                .partial_cmp(&self.activity[a])
582                .unwrap_or(std::cmp::Ordering::Equal)
583        });
584
585        // Keep top 50%
586        let keep_count = self.clauses.len() / 2;
587
588        let mut to_keep = FxHashSet::default();
589        for &idx in sorted_indices.iter().take(keep_count) {
590            to_keep.insert(idx);
591        }
592
593        // Also keep locked clauses
594        for (idx, clause) in self.clauses.iter().enumerate() {
595            if clause.locked {
596                to_keep.insert(idx);
597            }
598        }
599
600        // Rebuild database
601        let mut new_clauses = Vec::new();
602        let mut new_activity = Vec::new();
603
604        for (idx, clause) in self.clauses.iter().enumerate() {
605            if to_keep.contains(&idx) {
606                new_clauses.push(clause.clone());
607                new_activity.push(self.activity[idx]);
608            }
609        }
610
611        self.clauses = new_clauses;
612        self.activity = new_activity;
613        self.clause_map.clear();
614
615        // Rebuild map
616        for (idx, clause) in self.clauses.iter().enumerate() {
617            self.clause_map.insert(clause.literals.clone(), idx);
618        }
619    }
620}
621
622impl ClauseMinimizer {
623    /// Create a new clause minimizer.
624    pub fn new() -> Self {
625        Self {
626            seen: FxHashSet::default(),
627            analyze_stack: Vec::new(),
628            cache: FxHashMap::default(),
629        }
630    }
631
632    /// Check if a literal can be removed.
633    fn can_remove(&mut self, _lit: TermId, _graph: &ImplicationGraph) -> Result<bool, String> {
634        // Simplified: would do recursive analysis
635        Ok(false)
636    }
637}
638
639impl Default for ClauseLearner {
640    fn default() -> Self {
641        Self::new(ClauseLearningConfig::default())
642    }
643}
644
645impl Default for ImplicationGraph {
646    fn default() -> Self {
647        Self::new()
648    }
649}
650
651impl Default for ClauseMinimizer {
652    fn default() -> Self {
653        Self::new()
654    }
655}
656
657#[cfg(test)]
658mod tests {
659    use super::*;
660
661    #[test]
662    fn test_clause_learner() {
663        let learner = ClauseLearner::default();
664        assert_eq!(learner.stats.conflicts_analyzed, 0);
665    }
666
667    #[test]
668    fn test_implication_graph() {
669        let mut graph = ImplicationGraph::new();
670
671        let var = TermId::from(1);
672        graph.add_node(var, true, 1, None, true);
673
674        assert_eq!(graph.get_level(var), 1);
675        assert!(graph.is_decision(var));
676    }
677
678    #[test]
679    fn test_learned_database() {
680        let mut db = LearnedDatabase::new(0.95);
681
682        let clause = LearnedClause {
683            literals: vec![TermId::from(1), TermId::from(2)],
684            asserting_lit: TermId::from(1),
685            backtrack_level: 0,
686            activity: 0.0,
687            locked: false,
688            lbd: 2,
689        };
690
691        db.add_clause(clause);
692        assert_eq!(db.clauses.len(), 1);
693    }
694
695    #[test]
696    fn test_subsumption() {
697        let a = vec![TermId::from(1), TermId::from(2)];
698        let b = vec![TermId::from(1), TermId::from(2), TermId::from(3)];
699
700        assert!(ClauseLearner::subsumes(&a, &b));
701        assert!(!ClauseLearner::subsumes(&b, &a));
702    }
703
704    #[test]
705    fn test_lbd_computation() {
706        let learner = ClauseLearner::default();
707
708        let clause = vec![TermId::from(1), TermId::from(2), TermId::from(3)];
709        let lbd = learner.compute_lbd(&clause);
710
711        // LBD depends on decision levels, which are 0 by default
712        assert_eq!(lbd, 1);
713    }
714}