1#[allow(unused_imports)]
11use crate::prelude::*;
12use oxiz_core::ast::{TermId, TermManager};
13
14pub struct ClauseLearner {
16 impl_graph: ImplicationGraph,
18 learned_db: LearnedDatabase,
20 minimizer: ClauseMinimizer,
22 config: ClauseLearningConfig,
24 stats: ClauseLearningStats,
26}
27
28#[derive(Debug, Clone)]
30pub struct ImplicationGraph {
31 nodes: FxHashMap<TermId, ImplicationNode>,
33 #[allow(dead_code)]
35 predecessors: FxHashMap<TermId, Vec<TermId>>,
36 levels: FxHashMap<TermId, usize>,
38 current_level: usize,
40}
41
42#[derive(Debug, Clone)]
44pub struct ImplicationNode {
45 pub var: TermId,
47 pub value: bool,
49 pub level: usize,
51 pub reason: Option<ClauseId>,
53 pub is_decision: bool,
55}
56
57pub type ClauseId = usize;
59
60#[derive(Debug, Clone)]
62pub struct LearnedDatabase {
63 clauses: Vec<LearnedClause>,
65 activity: Vec<f64>,
67 clause_map: FxHashMap<Vec<TermId>, ClauseId>,
69 bump_increment: f64,
71 decay_factor: f64,
73}
74
75#[derive(Debug, Clone)]
77pub struct LearnedClause {
78 pub literals: Vec<TermId>,
80 pub asserting_lit: TermId,
82 pub backtrack_level: usize,
84 pub activity: f64,
86 pub locked: bool,
88 pub lbd: usize,
90}
91
92#[derive(Debug, Clone)]
94pub struct ClauseMinimizer {
95 seen: FxHashSet<TermId>,
97 analyze_stack: Vec<TermId>,
99 #[allow(dead_code)]
101 cache: FxHashMap<TermId, bool>,
102}
103
104#[derive(Debug, Clone)]
106pub struct ClauseLearningConfig {
107 pub enable_minimization: bool,
109 pub enable_recursive_minimization: bool,
111 pub enable_subsumption: bool,
113 pub enable_strengthening: bool,
115 pub max_learned_size: usize,
117 pub lbd_threshold: usize,
119 pub activity_decay: f64,
121}
122
123impl Default for ClauseLearningConfig {
124 fn default() -> Self {
125 Self {
126 enable_minimization: true,
127 enable_recursive_minimization: true,
128 enable_subsumption: true,
129 enable_strengthening: true,
130 max_learned_size: 1000,
131 lbd_threshold: 5,
132 activity_decay: 0.95,
133 }
134 }
135}
136
137#[derive(Debug, Clone, Default)]
139pub struct ClauseLearningStats {
140 pub conflicts_analyzed: usize,
142 pub clauses_learned: usize,
144 pub literals_before_minimization: usize,
146 pub literals_after_minimization: usize,
148 pub clauses_subsumed: usize,
150 pub clauses_strengthened: usize,
152 pub uip_computations: usize,
154 pub db_reductions: usize,
156}
157
158impl ClauseLearner {
159 pub fn new(config: ClauseLearningConfig) -> Self {
161 Self {
162 impl_graph: ImplicationGraph::new(),
163 learned_db: LearnedDatabase::new(config.activity_decay),
164 minimizer: ClauseMinimizer::new(),
165 config,
166 stats: ClauseLearningStats::default(),
167 }
168 }
169
170 pub fn analyze_conflict(
172 &mut self,
173 conflict_clause: ClauseId,
174 _tm: &TermManager,
175 ) -> Result<LearnedClause, String> {
176 self.stats.conflicts_analyzed += 1;
177
178 let conflict_lits = self.get_clause_literals(conflict_clause)?;
180
181 let (learned_lits, asserting_lit, backtrack_level) =
183 self.compute_first_uip(&conflict_lits)?;
184
185 self.stats.uip_computations += 1;
186 self.stats.literals_before_minimization += learned_lits.len();
187
188 let minimized_lits = if self.config.enable_minimization {
190 self.minimize_clause(&learned_lits)?
191 } else {
192 learned_lits
193 };
194
195 self.stats.literals_after_minimization += minimized_lits.len();
196
197 let lbd = self.compute_lbd(&minimized_lits);
199
200 let learned = LearnedClause {
202 literals: minimized_lits,
203 asserting_lit,
204 backtrack_level,
205 activity: 0.0,
206 locked: false,
207 lbd,
208 };
209
210 self.stats.clauses_learned += 1;
211
212 self.learned_db.add_clause(learned.clone());
214
215 Ok(learned)
216 }
217
218 fn compute_first_uip(
220 &mut self,
221 conflict_lits: &[TermId],
222 ) -> Result<(Vec<TermId>, TermId, usize), String> {
223 let current_level = self.impl_graph.current_level;
224
225 let mut clause = conflict_lits.to_vec();
227 let mut seen = FxHashSet::default();
228 let mut counter = 0;
229
230 for &lit in &clause {
232 if self.impl_graph.get_level(lit) == current_level {
233 counter += 1;
234 }
235 seen.insert(lit);
236 }
237
238 let mut asserting_lit = TermId::from(0);
240
241 while counter > 1 {
242 let resolve_lit = clause
244 .iter()
245 .copied()
246 .find(|&lit| {
247 self.impl_graph.get_level(lit) == current_level
248 && !self.impl_graph.is_decision(lit)
249 })
250 .ok_or("No literal to resolve on")?;
251
252 let reason = self
254 .impl_graph
255 .get_reason(resolve_lit)
256 .ok_or("No reason for propagated literal")?;
257
258 let reason_lits = self.get_clause_literals(reason)?;
259
260 clause.retain(|&lit| lit != resolve_lit);
262 counter -= 1;
263
264 for &reason_lit in &reason_lits {
265 if reason_lit != resolve_lit && !seen.contains(&reason_lit) {
266 clause.push(reason_lit);
267 seen.insert(reason_lit);
268
269 if self.impl_graph.get_level(reason_lit) == current_level {
270 counter += 1;
271 }
272 }
273 }
274 }
275
276 for &lit in &clause {
278 if self.impl_graph.get_level(lit) == current_level {
279 asserting_lit = lit;
280 break;
281 }
282 }
283
284 let mut levels: Vec<usize> = clause
286 .iter()
287 .map(|&lit| self.impl_graph.get_level(lit))
288 .collect();
289 levels.sort_unstable();
290 levels.dedup();
291
292 let backtrack_level = if levels.len() > 1 {
293 levels[levels.len() - 2]
294 } else {
295 0
296 };
297
298 Ok((clause, asserting_lit, backtrack_level))
299 }
300
301 fn minimize_clause(&mut self, clause: &[TermId]) -> Result<Vec<TermId>, String> {
303 if !self.config.enable_minimization {
304 return Ok(clause.to_vec());
305 }
306
307 let mut minimized = clause.to_vec();
308
309 minimized.retain(|&lit| !self.is_redundant(lit, clause));
311
312 if self.config.enable_recursive_minimization {
314 minimized = self.recursive_minimize(&minimized)?;
315 }
316
317 Ok(minimized)
318 }
319
320 fn is_redundant(&mut self, lit: TermId, clause: &[TermId]) -> bool {
322 if let Some(reason) = self.impl_graph.get_reason(lit)
324 && let Ok(reason_lits) = self.get_clause_literals(reason)
325 {
326 return reason_lits
327 .iter()
328 .all(|&r_lit| r_lit == lit || clause.contains(&r_lit));
329 }
330
331 false
332 }
333
334 fn recursive_minimize(&mut self, clause: &[TermId]) -> Result<Vec<TermId>, String> {
336 self.minimizer.seen.clear();
337 self.minimizer.analyze_stack.clear();
338
339 for &lit in clause {
341 self.minimizer.seen.insert(lit);
342 }
343
344 let mut minimized = Vec::new();
345
346 for &lit in clause {
347 if !self.minimizer.can_remove(lit, &self.impl_graph)? {
348 minimized.push(lit);
349 }
350 }
351
352 Ok(minimized)
353 }
354
355 fn compute_lbd(&self, clause: &[TermId]) -> usize {
357 let mut levels = FxHashSet::default();
358
359 for &lit in clause {
360 let level = self.impl_graph.get_level(lit);
361 levels.insert(level);
362 }
363
364 levels.len()
365 }
366
367 fn get_clause_literals(&self, _clause_id: ClauseId) -> Result<Vec<TermId>, String> {
369 Ok(vec![])
371 }
372
373 pub fn subsume_clauses(&mut self) -> Result<(), String> {
375 if !self.config.enable_subsumption {
376 return Ok(());
377 }
378
379 let mut to_remove = Vec::new();
380
381 for i in 0..self.learned_db.clauses.len() {
383 for j in (i + 1)..self.learned_db.clauses.len() {
384 if self.learned_db.clauses[i].locked || self.learned_db.clauses[j].locked {
385 continue;
386 }
387
388 let clause_i = &self.learned_db.clauses[i].literals;
389 let clause_j = &self.learned_db.clauses[j].literals;
390
391 if Self::subsumes(clause_i, clause_j) {
393 to_remove.push(j);
394 self.stats.clauses_subsumed += 1;
395 } else if Self::subsumes(clause_j, clause_i) {
396 to_remove.push(i);
397 self.stats.clauses_subsumed += 1;
398 break;
399 }
400 }
401 }
402
403 to_remove.sort_unstable();
405 to_remove.dedup();
406 for &idx in to_remove.iter().rev() {
407 self.learned_db.clauses.remove(idx);
408 self.learned_db.activity.remove(idx);
409 }
410
411 Ok(())
412 }
413
414 fn subsumes(a: &[TermId], b: &[TermId]) -> bool {
416 if a.len() > b.len() {
417 return false;
418 }
419
420 let b_set: FxHashSet<TermId> = b.iter().copied().collect();
421
422 a.iter().all(|lit| b_set.contains(lit))
423 }
424
425 pub fn strengthen_clauses(&mut self) -> Result<(), String> {
427 if !self.config.enable_strengthening {
428 return Ok(());
429 }
430
431 for idx in 0..self.learned_db.clauses.len() {
432 if self.learned_db.clauses[idx].locked {
433 continue;
434 }
435
436 let original_literals = self.learned_db.clauses[idx].literals.clone();
439 let original_len = original_literals.len();
440
441 let mut new_literals = Vec::with_capacity(original_len);
442 for &lit in &original_literals {
443 if !self.can_remove_literal(lit, &original_literals) {
444 new_literals.push(lit);
445 }
446 }
447
448 if new_literals.len() < original_len {
449 self.learned_db.clauses[idx].literals = new_literals;
450 self.stats.clauses_strengthened += 1;
451 }
452 }
453
454 Ok(())
455 }
456
457 fn can_remove_literal(&self, lit: TermId, clause: &[TermId]) -> bool {
467 let mut visited = FxHashSet::default();
468 self.can_remove_literal_rec(lit, clause, &mut visited)
469 }
470
471 fn can_remove_literal_rec(
472 &self,
473 lit: TermId,
474 clause: &[TermId],
475 visited: &mut FxHashSet<TermId>,
476 ) -> bool {
477 if !visited.insert(lit) {
478 return true;
480 }
481
482 let reason_id = match self.impl_graph.get_reason(lit) {
484 Some(r) => r,
485 None => return false,
486 };
487
488 let reason_lits = match self.get_clause_literals(reason_id) {
493 Ok(lits) => lits,
494 Err(_) => return false,
495 };
496
497 for other_lit in &reason_lits {
500 if *other_lit == lit {
501 continue;
502 }
503 if !clause.contains(other_lit)
504 && !self.can_remove_literal_rec(*other_lit, clause, visited)
505 {
506 return false;
507 }
508 }
509
510 true
511 }
512
513 pub fn reduce_database(&mut self) -> Result<(), String> {
515 self.stats.db_reductions += 1;
516
517 self.learned_db.reduce();
519
520 Ok(())
521 }
522
523 pub fn bump_clause(&mut self, clause_id: ClauseId) {
525 self.learned_db.bump_activity(clause_id);
526 }
527
528 pub fn stats(&self) -> &ClauseLearningStats {
530 &self.stats
531 }
532}
533
534impl ImplicationGraph {
535 pub fn new() -> Self {
537 Self {
538 nodes: FxHashMap::default(),
539 predecessors: FxHashMap::default(),
540 levels: FxHashMap::default(),
541 current_level: 0,
542 }
543 }
544
545 pub fn add_node(
547 &mut self,
548 var: TermId,
549 value: bool,
550 level: usize,
551 reason: Option<ClauseId>,
552 is_decision: bool,
553 ) {
554 self.nodes.insert(
555 var,
556 ImplicationNode {
557 var,
558 value,
559 level,
560 reason,
561 is_decision,
562 },
563 );
564
565 self.levels.insert(var, level);
566 }
567
568 pub fn get_level(&self, var: TermId) -> usize {
570 self.levels.get(&var).copied().unwrap_or(0)
571 }
572
573 pub fn is_decision(&self, var: TermId) -> bool {
575 self.nodes.get(&var).is_some_and(|n| n.is_decision)
576 }
577
578 pub fn get_reason(&self, var: TermId) -> Option<ClauseId> {
580 self.nodes.get(&var).and_then(|n| n.reason)
581 }
582
583 pub fn set_level(&mut self, level: usize) {
585 self.current_level = level;
586 }
587}
588
589impl LearnedDatabase {
590 pub fn new(decay_factor: f64) -> Self {
592 Self {
593 clauses: Vec::new(),
594 activity: Vec::new(),
595 clause_map: FxHashMap::default(),
596 bump_increment: 1.0,
597 decay_factor,
598 }
599 }
600
601 pub fn add_clause(&mut self, clause: LearnedClause) {
603 let clause_id = self.clauses.len();
604
605 self.clause_map.insert(clause.literals.clone(), clause_id);
606 self.activity.push(clause.activity);
607 self.clauses.push(clause);
608 }
609
610 pub fn bump_activity(&mut self, clause_id: ClauseId) {
612 if clause_id < self.activity.len() {
613 self.activity[clause_id] += self.bump_increment;
614
615 if self.activity[clause_id] > 1e20 {
617 for act in &mut self.activity {
618 *act *= 1e-20;
619 }
620 self.bump_increment *= 1e-20;
621 }
622 }
623 }
624
625 pub fn decay(&mut self) {
627 self.bump_increment /= self.decay_factor;
628 }
629
630 pub fn reduce(&mut self) {
632 let mut sorted_indices: Vec<usize> = (0..self.clauses.len()).collect();
633
634 sorted_indices.sort_by(|&a, &b| {
636 self.activity[b]
637 .partial_cmp(&self.activity[a])
638 .unwrap_or(core::cmp::Ordering::Equal)
639 });
640
641 let keep_count = self.clauses.len() / 2;
643
644 let mut to_keep = FxHashSet::default();
645 for &idx in sorted_indices.iter().take(keep_count) {
646 to_keep.insert(idx);
647 }
648
649 for (idx, clause) in self.clauses.iter().enumerate() {
651 if clause.locked {
652 to_keep.insert(idx);
653 }
654 }
655
656 let mut new_clauses = Vec::new();
658 let mut new_activity = Vec::new();
659
660 for (idx, clause) in self.clauses.iter().enumerate() {
661 if to_keep.contains(&idx) {
662 new_clauses.push(clause.clone());
663 new_activity.push(self.activity[idx]);
664 }
665 }
666
667 self.clauses = new_clauses;
668 self.activity = new_activity;
669 self.clause_map.clear();
670
671 for (idx, clause) in self.clauses.iter().enumerate() {
673 self.clause_map.insert(clause.literals.clone(), idx);
674 }
675 }
676}
677
678impl ClauseMinimizer {
679 pub fn new() -> Self {
681 Self {
682 seen: FxHashSet::default(),
683 analyze_stack: Vec::new(),
684 cache: FxHashMap::default(),
685 }
686 }
687
688 fn can_remove(&mut self, _lit: TermId, _graph: &ImplicationGraph) -> Result<bool, String> {
690 Ok(false)
692 }
693}
694
695impl Default for ClauseLearner {
696 fn default() -> Self {
697 Self::new(ClauseLearningConfig::default())
698 }
699}
700
701impl Default for ImplicationGraph {
702 fn default() -> Self {
703 Self::new()
704 }
705}
706
707impl Default for ClauseMinimizer {
708 fn default() -> Self {
709 Self::new()
710 }
711}
712
713#[cfg(test)]
714mod tests {
715 use super::*;
716
717 #[test]
718 fn test_clause_learner() {
719 let learner = ClauseLearner::default();
720 assert_eq!(learner.stats.conflicts_analyzed, 0);
721 }
722
723 #[test]
724 fn test_implication_graph() {
725 let mut graph = ImplicationGraph::new();
726
727 let var = TermId::from(1);
728 graph.add_node(var, true, 1, None, true);
729
730 assert_eq!(graph.get_level(var), 1);
731 assert!(graph.is_decision(var));
732 }
733
734 #[test]
735 fn test_learned_database() {
736 let mut db = LearnedDatabase::new(0.95);
737
738 let clause = LearnedClause {
739 literals: vec![TermId::from(1), TermId::from(2)],
740 asserting_lit: TermId::from(1),
741 backtrack_level: 0,
742 activity: 0.0,
743 locked: false,
744 lbd: 2,
745 };
746
747 db.add_clause(clause);
748 assert_eq!(db.clauses.len(), 1);
749 }
750
751 #[test]
752 fn test_subsumption() {
753 let a = vec![TermId::from(1), TermId::from(2)];
754 let b = vec![TermId::from(1), TermId::from(2), TermId::from(3)];
755
756 assert!(ClauseLearner::subsumes(&a, &b));
757 assert!(!ClauseLearner::subsumes(&b, &a));
758 }
759
760 #[test]
761 fn test_lbd_computation() {
762 let learner = ClauseLearner::default();
763
764 let clause = vec![TermId::from(1), TermId::from(2), TermId::from(3)];
765 let lbd = learner.compute_lbd(&clause);
766
767 assert_eq!(lbd, 1);
769 }
770}