1#![allow(dead_code)] #[allow(unused_imports)]
12use crate::prelude::*;
13use oxiz_core::ast::{TermId, TermManager};
14
15pub struct ClauseLearner {
17 impl_graph: ImplicationGraph,
19 learned_db: LearnedDatabase,
21 minimizer: ClauseMinimizer,
23 config: ClauseLearningConfig,
25 stats: ClauseLearningStats,
27}
28
29#[derive(Debug, Clone)]
31pub struct ImplicationGraph {
32 nodes: FxHashMap<TermId, ImplicationNode>,
34 predecessors: FxHashMap<TermId, Vec<TermId>>,
36 levels: FxHashMap<TermId, usize>,
38 current_level: usize,
40}
41
42#[derive(Debug, Clone)]
44pub struct ImplicationNode {
45 pub var: TermId,
47 pub value: bool,
49 pub level: usize,
51 pub reason: Option<ClauseId>,
53 pub is_decision: bool,
55}
56
57pub type ClauseId = usize;
59
60#[derive(Debug, Clone)]
62pub struct LearnedDatabase {
63 clauses: Vec<LearnedClause>,
65 activity: Vec<f64>,
67 clause_map: FxHashMap<Vec<TermId>, ClauseId>,
69 bump_increment: f64,
71 decay_factor: f64,
73}
74
75#[derive(Debug, Clone)]
77pub struct LearnedClause {
78 pub literals: Vec<TermId>,
80 pub asserting_lit: TermId,
82 pub backtrack_level: usize,
84 pub activity: f64,
86 pub locked: bool,
88 pub lbd: usize,
90}
91
92#[derive(Debug, Clone)]
94pub struct ClauseMinimizer {
95 seen: FxHashSet<TermId>,
97 analyze_stack: Vec<TermId>,
99 cache: FxHashMap<TermId, bool>,
101}
102
103#[derive(Debug, Clone)]
105pub struct ClauseLearningConfig {
106 pub enable_minimization: bool,
108 pub enable_recursive_minimization: bool,
110 pub enable_subsumption: bool,
112 pub enable_strengthening: bool,
114 pub max_learned_size: usize,
116 pub lbd_threshold: usize,
118 pub activity_decay: f64,
120}
121
122impl Default for ClauseLearningConfig {
123 fn default() -> Self {
124 Self {
125 enable_minimization: true,
126 enable_recursive_minimization: true,
127 enable_subsumption: true,
128 enable_strengthening: true,
129 max_learned_size: 1000,
130 lbd_threshold: 5,
131 activity_decay: 0.95,
132 }
133 }
134}
135
136#[derive(Debug, Clone, Default)]
138pub struct ClauseLearningStats {
139 pub conflicts_analyzed: usize,
141 pub clauses_learned: usize,
143 pub literals_before_minimization: usize,
145 pub literals_after_minimization: usize,
147 pub clauses_subsumed: usize,
149 pub clauses_strengthened: usize,
151 pub uip_computations: usize,
153 pub db_reductions: usize,
155}
156
157impl ClauseLearner {
158 pub fn new(config: ClauseLearningConfig) -> Self {
160 Self {
161 impl_graph: ImplicationGraph::new(),
162 learned_db: LearnedDatabase::new(config.activity_decay),
163 minimizer: ClauseMinimizer::new(),
164 config,
165 stats: ClauseLearningStats::default(),
166 }
167 }
168
169 pub fn analyze_conflict(
171 &mut self,
172 conflict_clause: ClauseId,
173 _tm: &TermManager,
174 ) -> Result<LearnedClause, String> {
175 self.stats.conflicts_analyzed += 1;
176
177 let conflict_lits = self.get_clause_literals(conflict_clause)?;
179
180 let (learned_lits, asserting_lit, backtrack_level) =
182 self.compute_first_uip(&conflict_lits)?;
183
184 self.stats.uip_computations += 1;
185 self.stats.literals_before_minimization += learned_lits.len();
186
187 let minimized_lits = if self.config.enable_minimization {
189 self.minimize_clause(&learned_lits)?
190 } else {
191 learned_lits
192 };
193
194 self.stats.literals_after_minimization += minimized_lits.len();
195
196 let lbd = self.compute_lbd(&minimized_lits);
198
199 let learned = LearnedClause {
201 literals: minimized_lits,
202 asserting_lit,
203 backtrack_level,
204 activity: 0.0,
205 locked: false,
206 lbd,
207 };
208
209 self.stats.clauses_learned += 1;
210
211 self.learned_db.add_clause(learned.clone());
213
214 Ok(learned)
215 }
216
217 fn compute_first_uip(
219 &mut self,
220 conflict_lits: &[TermId],
221 ) -> Result<(Vec<TermId>, TermId, usize), String> {
222 let current_level = self.impl_graph.current_level;
223
224 let mut clause = conflict_lits.to_vec();
226 let mut seen = FxHashSet::default();
227 let mut counter = 0;
228
229 for &lit in &clause {
231 if self.impl_graph.get_level(lit) == current_level {
232 counter += 1;
233 }
234 seen.insert(lit);
235 }
236
237 let mut asserting_lit = TermId::from(0);
239
240 while counter > 1 {
241 let resolve_lit = clause
243 .iter()
244 .copied()
245 .find(|&lit| {
246 self.impl_graph.get_level(lit) == current_level
247 && !self.impl_graph.is_decision(lit)
248 })
249 .ok_or("No literal to resolve on")?;
250
251 let reason = self
253 .impl_graph
254 .get_reason(resolve_lit)
255 .ok_or("No reason for propagated literal")?;
256
257 let reason_lits = self.get_clause_literals(reason)?;
258
259 clause.retain(|&lit| lit != resolve_lit);
261 counter -= 1;
262
263 for &reason_lit in &reason_lits {
264 if reason_lit != resolve_lit && !seen.contains(&reason_lit) {
265 clause.push(reason_lit);
266 seen.insert(reason_lit);
267
268 if self.impl_graph.get_level(reason_lit) == current_level {
269 counter += 1;
270 }
271 }
272 }
273 }
274
275 for &lit in &clause {
277 if self.impl_graph.get_level(lit) == current_level {
278 asserting_lit = lit;
279 break;
280 }
281 }
282
283 let mut levels: Vec<usize> = clause
285 .iter()
286 .map(|&lit| self.impl_graph.get_level(lit))
287 .collect();
288 levels.sort_unstable();
289 levels.dedup();
290
291 let backtrack_level = if levels.len() > 1 {
292 levels[levels.len() - 2]
293 } else {
294 0
295 };
296
297 Ok((clause, asserting_lit, backtrack_level))
298 }
299
300 fn minimize_clause(&mut self, clause: &[TermId]) -> Result<Vec<TermId>, String> {
302 if !self.config.enable_minimization {
303 return Ok(clause.to_vec());
304 }
305
306 let mut minimized = clause.to_vec();
307
308 minimized.retain(|&lit| !self.is_redundant(lit, clause));
310
311 if self.config.enable_recursive_minimization {
313 minimized = self.recursive_minimize(&minimized)?;
314 }
315
316 Ok(minimized)
317 }
318
319 fn is_redundant(&mut self, lit: TermId, clause: &[TermId]) -> bool {
321 if let Some(reason) = self.impl_graph.get_reason(lit)
323 && let Ok(reason_lits) = self.get_clause_literals(reason)
324 {
325 return reason_lits
326 .iter()
327 .all(|&r_lit| r_lit == lit || clause.contains(&r_lit));
328 }
329
330 false
331 }
332
333 fn recursive_minimize(&mut self, clause: &[TermId]) -> Result<Vec<TermId>, String> {
335 self.minimizer.seen.clear();
336 self.minimizer.analyze_stack.clear();
337
338 for &lit in clause {
340 self.minimizer.seen.insert(lit);
341 }
342
343 let mut minimized = Vec::new();
344
345 for &lit in clause {
346 if !self.minimizer.can_remove(lit, &self.impl_graph)? {
347 minimized.push(lit);
348 }
349 }
350
351 Ok(minimized)
352 }
353
354 fn compute_lbd(&self, clause: &[TermId]) -> usize {
356 let mut levels = FxHashSet::default();
357
358 for &lit in clause {
359 let level = self.impl_graph.get_level(lit);
360 levels.insert(level);
361 }
362
363 levels.len()
364 }
365
366 fn get_clause_literals(&self, _clause_id: ClauseId) -> Result<Vec<TermId>, String> {
368 Ok(vec![])
370 }
371
372 pub fn subsume_clauses(&mut self) -> Result<(), String> {
374 if !self.config.enable_subsumption {
375 return Ok(());
376 }
377
378 let mut to_remove = Vec::new();
379
380 for i in 0..self.learned_db.clauses.len() {
382 for j in (i + 1)..self.learned_db.clauses.len() {
383 if self.learned_db.clauses[i].locked || self.learned_db.clauses[j].locked {
384 continue;
385 }
386
387 let clause_i = &self.learned_db.clauses[i].literals;
388 let clause_j = &self.learned_db.clauses[j].literals;
389
390 if Self::subsumes(clause_i, clause_j) {
392 to_remove.push(j);
393 self.stats.clauses_subsumed += 1;
394 } else if Self::subsumes(clause_j, clause_i) {
395 to_remove.push(i);
396 self.stats.clauses_subsumed += 1;
397 break;
398 }
399 }
400 }
401
402 to_remove.sort_unstable();
404 to_remove.dedup();
405 for &idx in to_remove.iter().rev() {
406 self.learned_db.clauses.remove(idx);
407 self.learned_db.activity.remove(idx);
408 }
409
410 Ok(())
411 }
412
413 fn subsumes(a: &[TermId], b: &[TermId]) -> bool {
415 if a.len() > b.len() {
416 return false;
417 }
418
419 let b_set: FxHashSet<TermId> = b.iter().copied().collect();
420
421 a.iter().all(|lit| b_set.contains(lit))
422 }
423
424 pub fn strengthen_clauses(&mut self) -> Result<(), String> {
426 if !self.config.enable_strengthening {
427 return Ok(());
428 }
429
430 Ok(())
450 }
451
452 fn can_remove_literal(&self, _lit: TermId, _clause: &[TermId]) -> bool {
454 false
456 }
457
458 pub fn reduce_database(&mut self) -> Result<(), String> {
460 self.stats.db_reductions += 1;
461
462 self.learned_db.reduce();
464
465 Ok(())
466 }
467
468 pub fn bump_clause(&mut self, clause_id: ClauseId) {
470 self.learned_db.bump_activity(clause_id);
471 }
472
473 pub fn stats(&self) -> &ClauseLearningStats {
475 &self.stats
476 }
477}
478
479impl ImplicationGraph {
480 pub fn new() -> Self {
482 Self {
483 nodes: FxHashMap::default(),
484 predecessors: FxHashMap::default(),
485 levels: FxHashMap::default(),
486 current_level: 0,
487 }
488 }
489
490 pub fn add_node(
492 &mut self,
493 var: TermId,
494 value: bool,
495 level: usize,
496 reason: Option<ClauseId>,
497 is_decision: bool,
498 ) {
499 self.nodes.insert(
500 var,
501 ImplicationNode {
502 var,
503 value,
504 level,
505 reason,
506 is_decision,
507 },
508 );
509
510 self.levels.insert(var, level);
511 }
512
513 pub fn get_level(&self, var: TermId) -> usize {
515 self.levels.get(&var).copied().unwrap_or(0)
516 }
517
518 pub fn is_decision(&self, var: TermId) -> bool {
520 self.nodes.get(&var).is_some_and(|n| n.is_decision)
521 }
522
523 pub fn get_reason(&self, var: TermId) -> Option<ClauseId> {
525 self.nodes.get(&var).and_then(|n| n.reason)
526 }
527
528 pub fn set_level(&mut self, level: usize) {
530 self.current_level = level;
531 }
532}
533
534impl LearnedDatabase {
535 pub fn new(decay_factor: f64) -> Self {
537 Self {
538 clauses: Vec::new(),
539 activity: Vec::new(),
540 clause_map: FxHashMap::default(),
541 bump_increment: 1.0,
542 decay_factor,
543 }
544 }
545
546 pub fn add_clause(&mut self, clause: LearnedClause) {
548 let clause_id = self.clauses.len();
549
550 self.clause_map.insert(clause.literals.clone(), clause_id);
551 self.activity.push(clause.activity);
552 self.clauses.push(clause);
553 }
554
555 pub fn bump_activity(&mut self, clause_id: ClauseId) {
557 if clause_id < self.activity.len() {
558 self.activity[clause_id] += self.bump_increment;
559
560 if self.activity[clause_id] > 1e20 {
562 for act in &mut self.activity {
563 *act *= 1e-20;
564 }
565 self.bump_increment *= 1e-20;
566 }
567 }
568 }
569
570 pub fn decay(&mut self) {
572 self.bump_increment /= self.decay_factor;
573 }
574
575 pub fn reduce(&mut self) {
577 let mut sorted_indices: Vec<usize> = (0..self.clauses.len()).collect();
578
579 sorted_indices.sort_by(|&a, &b| {
581 self.activity[b]
582 .partial_cmp(&self.activity[a])
583 .unwrap_or(core::cmp::Ordering::Equal)
584 });
585
586 let keep_count = self.clauses.len() / 2;
588
589 let mut to_keep = FxHashSet::default();
590 for &idx in sorted_indices.iter().take(keep_count) {
591 to_keep.insert(idx);
592 }
593
594 for (idx, clause) in self.clauses.iter().enumerate() {
596 if clause.locked {
597 to_keep.insert(idx);
598 }
599 }
600
601 let mut new_clauses = Vec::new();
603 let mut new_activity = Vec::new();
604
605 for (idx, clause) in self.clauses.iter().enumerate() {
606 if to_keep.contains(&idx) {
607 new_clauses.push(clause.clone());
608 new_activity.push(self.activity[idx]);
609 }
610 }
611
612 self.clauses = new_clauses;
613 self.activity = new_activity;
614 self.clause_map.clear();
615
616 for (idx, clause) in self.clauses.iter().enumerate() {
618 self.clause_map.insert(clause.literals.clone(), idx);
619 }
620 }
621}
622
623impl ClauseMinimizer {
624 pub fn new() -> Self {
626 Self {
627 seen: FxHashSet::default(),
628 analyze_stack: Vec::new(),
629 cache: FxHashMap::default(),
630 }
631 }
632
633 fn can_remove(&mut self, _lit: TermId, _graph: &ImplicationGraph) -> Result<bool, String> {
635 Ok(false)
637 }
638}
639
640impl Default for ClauseLearner {
641 fn default() -> Self {
642 Self::new(ClauseLearningConfig::default())
643 }
644}
645
646impl Default for ImplicationGraph {
647 fn default() -> Self {
648 Self::new()
649 }
650}
651
652impl Default for ClauseMinimizer {
653 fn default() -> Self {
654 Self::new()
655 }
656}
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661
662 #[test]
663 fn test_clause_learner() {
664 let learner = ClauseLearner::default();
665 assert_eq!(learner.stats.conflicts_analyzed, 0);
666 }
667
668 #[test]
669 fn test_implication_graph() {
670 let mut graph = ImplicationGraph::new();
671
672 let var = TermId::from(1);
673 graph.add_node(var, true, 1, None, true);
674
675 assert_eq!(graph.get_level(var), 1);
676 assert!(graph.is_decision(var));
677 }
678
679 #[test]
680 fn test_learned_database() {
681 let mut db = LearnedDatabase::new(0.95);
682
683 let clause = LearnedClause {
684 literals: vec![TermId::from(1), TermId::from(2)],
685 asserting_lit: TermId::from(1),
686 backtrack_level: 0,
687 activity: 0.0,
688 locked: false,
689 lbd: 2,
690 };
691
692 db.add_clause(clause);
693 assert_eq!(db.clauses.len(), 1);
694 }
695
696 #[test]
697 fn test_subsumption() {
698 let a = vec![TermId::from(1), TermId::from(2)];
699 let b = vec![TermId::from(1), TermId::from(2), TermId::from(3)];
700
701 assert!(ClauseLearner::subsumes(&a, &b));
702 assert!(!ClauseLearner::subsumes(&b, &a));
703 }
704
705 #[test]
706 fn test_lbd_computation() {
707 let learner = ClauseLearner::default();
708
709 let clause = vec![TermId::from(1), TermId::from(2), TermId::from(3)];
710 let lbd = learner.compute_lbd(&clause);
711
712 assert_eq!(lbd, 1);
714 }
715}