1#![allow(dead_code)] use oxiz_core::ast::{TermId, TermManager};
12use rustc_hash::{FxHashMap, FxHashSet};
13
14pub struct ClauseLearner {
16 impl_graph: ImplicationGraph,
18 learned_db: LearnedDatabase,
20 minimizer: ClauseMinimizer,
22 config: ClauseLearningConfig,
24 stats: ClauseLearningStats,
26}
27
28#[derive(Debug, Clone)]
30pub struct ImplicationGraph {
31 nodes: FxHashMap<TermId, ImplicationNode>,
33 predecessors: FxHashMap<TermId, Vec<TermId>>,
35 levels: FxHashMap<TermId, usize>,
37 current_level: usize,
39}
40
41#[derive(Debug, Clone)]
43pub struct ImplicationNode {
44 pub var: TermId,
46 pub value: bool,
48 pub level: usize,
50 pub reason: Option<ClauseId>,
52 pub is_decision: bool,
54}
55
56pub type ClauseId = usize;
58
59#[derive(Debug, Clone)]
61pub struct LearnedDatabase {
62 clauses: Vec<LearnedClause>,
64 activity: Vec<f64>,
66 clause_map: FxHashMap<Vec<TermId>, ClauseId>,
68 bump_increment: f64,
70 decay_factor: f64,
72}
73
74#[derive(Debug, Clone)]
76pub struct LearnedClause {
77 pub literals: Vec<TermId>,
79 pub asserting_lit: TermId,
81 pub backtrack_level: usize,
83 pub activity: f64,
85 pub locked: bool,
87 pub lbd: usize,
89}
90
91#[derive(Debug, Clone)]
93pub struct ClauseMinimizer {
94 seen: FxHashSet<TermId>,
96 analyze_stack: Vec<TermId>,
98 cache: FxHashMap<TermId, bool>,
100}
101
102#[derive(Debug, Clone)]
104pub struct ClauseLearningConfig {
105 pub enable_minimization: bool,
107 pub enable_recursive_minimization: bool,
109 pub enable_subsumption: bool,
111 pub enable_strengthening: bool,
113 pub max_learned_size: usize,
115 pub lbd_threshold: usize,
117 pub activity_decay: f64,
119}
120
121impl Default for ClauseLearningConfig {
122 fn default() -> Self {
123 Self {
124 enable_minimization: true,
125 enable_recursive_minimization: true,
126 enable_subsumption: true,
127 enable_strengthening: true,
128 max_learned_size: 1000,
129 lbd_threshold: 5,
130 activity_decay: 0.95,
131 }
132 }
133}
134
135#[derive(Debug, Clone, Default)]
137pub struct ClauseLearningStats {
138 pub conflicts_analyzed: usize,
140 pub clauses_learned: usize,
142 pub literals_before_minimization: usize,
144 pub literals_after_minimization: usize,
146 pub clauses_subsumed: usize,
148 pub clauses_strengthened: usize,
150 pub uip_computations: usize,
152 pub db_reductions: usize,
154}
155
156impl ClauseLearner {
157 pub fn new(config: ClauseLearningConfig) -> Self {
159 Self {
160 impl_graph: ImplicationGraph::new(),
161 learned_db: LearnedDatabase::new(config.activity_decay),
162 minimizer: ClauseMinimizer::new(),
163 config,
164 stats: ClauseLearningStats::default(),
165 }
166 }
167
168 pub fn analyze_conflict(
170 &mut self,
171 conflict_clause: ClauseId,
172 _tm: &TermManager,
173 ) -> Result<LearnedClause, String> {
174 self.stats.conflicts_analyzed += 1;
175
176 let conflict_lits = self.get_clause_literals(conflict_clause)?;
178
179 let (learned_lits, asserting_lit, backtrack_level) =
181 self.compute_first_uip(&conflict_lits)?;
182
183 self.stats.uip_computations += 1;
184 self.stats.literals_before_minimization += learned_lits.len();
185
186 let minimized_lits = if self.config.enable_minimization {
188 self.minimize_clause(&learned_lits)?
189 } else {
190 learned_lits
191 };
192
193 self.stats.literals_after_minimization += minimized_lits.len();
194
195 let lbd = self.compute_lbd(&minimized_lits);
197
198 let learned = LearnedClause {
200 literals: minimized_lits,
201 asserting_lit,
202 backtrack_level,
203 activity: 0.0,
204 locked: false,
205 lbd,
206 };
207
208 self.stats.clauses_learned += 1;
209
210 self.learned_db.add_clause(learned.clone());
212
213 Ok(learned)
214 }
215
216 fn compute_first_uip(
218 &mut self,
219 conflict_lits: &[TermId],
220 ) -> Result<(Vec<TermId>, TermId, usize), String> {
221 let current_level = self.impl_graph.current_level;
222
223 let mut clause = conflict_lits.to_vec();
225 let mut seen = FxHashSet::default();
226 let mut counter = 0;
227
228 for &lit in &clause {
230 if self.impl_graph.get_level(lit) == current_level {
231 counter += 1;
232 }
233 seen.insert(lit);
234 }
235
236 let mut asserting_lit = TermId::from(0);
238
239 while counter > 1 {
240 let resolve_lit = clause
242 .iter()
243 .copied()
244 .find(|&lit| {
245 self.impl_graph.get_level(lit) == current_level
246 && !self.impl_graph.is_decision(lit)
247 })
248 .ok_or("No literal to resolve on")?;
249
250 let reason = self
252 .impl_graph
253 .get_reason(resolve_lit)
254 .ok_or("No reason for propagated literal")?;
255
256 let reason_lits = self.get_clause_literals(reason)?;
257
258 clause.retain(|&lit| lit != resolve_lit);
260 counter -= 1;
261
262 for &reason_lit in &reason_lits {
263 if reason_lit != resolve_lit && !seen.contains(&reason_lit) {
264 clause.push(reason_lit);
265 seen.insert(reason_lit);
266
267 if self.impl_graph.get_level(reason_lit) == current_level {
268 counter += 1;
269 }
270 }
271 }
272 }
273
274 for &lit in &clause {
276 if self.impl_graph.get_level(lit) == current_level {
277 asserting_lit = lit;
278 break;
279 }
280 }
281
282 let mut levels: Vec<usize> = clause
284 .iter()
285 .map(|&lit| self.impl_graph.get_level(lit))
286 .collect();
287 levels.sort_unstable();
288 levels.dedup();
289
290 let backtrack_level = if levels.len() > 1 {
291 levels[levels.len() - 2]
292 } else {
293 0
294 };
295
296 Ok((clause, asserting_lit, backtrack_level))
297 }
298
299 fn minimize_clause(&mut self, clause: &[TermId]) -> Result<Vec<TermId>, String> {
301 if !self.config.enable_minimization {
302 return Ok(clause.to_vec());
303 }
304
305 let mut minimized = clause.to_vec();
306
307 minimized.retain(|&lit| !self.is_redundant(lit, clause));
309
310 if self.config.enable_recursive_minimization {
312 minimized = self.recursive_minimize(&minimized)?;
313 }
314
315 Ok(minimized)
316 }
317
318 fn is_redundant(&mut self, lit: TermId, clause: &[TermId]) -> bool {
320 if let Some(reason) = self.impl_graph.get_reason(lit)
322 && let Ok(reason_lits) = self.get_clause_literals(reason)
323 {
324 return reason_lits
325 .iter()
326 .all(|&r_lit| r_lit == lit || clause.contains(&r_lit));
327 }
328
329 false
330 }
331
332 fn recursive_minimize(&mut self, clause: &[TermId]) -> Result<Vec<TermId>, String> {
334 self.minimizer.seen.clear();
335 self.minimizer.analyze_stack.clear();
336
337 for &lit in clause {
339 self.minimizer.seen.insert(lit);
340 }
341
342 let mut minimized = Vec::new();
343
344 for &lit in clause {
345 if !self.minimizer.can_remove(lit, &self.impl_graph)? {
346 minimized.push(lit);
347 }
348 }
349
350 Ok(minimized)
351 }
352
353 fn compute_lbd(&self, clause: &[TermId]) -> usize {
355 let mut levels = FxHashSet::default();
356
357 for &lit in clause {
358 let level = self.impl_graph.get_level(lit);
359 levels.insert(level);
360 }
361
362 levels.len()
363 }
364
365 fn get_clause_literals(&self, _clause_id: ClauseId) -> Result<Vec<TermId>, String> {
367 Ok(vec![])
369 }
370
371 pub fn subsume_clauses(&mut self) -> Result<(), String> {
373 if !self.config.enable_subsumption {
374 return Ok(());
375 }
376
377 let mut to_remove = Vec::new();
378
379 for i in 0..self.learned_db.clauses.len() {
381 for j in (i + 1)..self.learned_db.clauses.len() {
382 if self.learned_db.clauses[i].locked || self.learned_db.clauses[j].locked {
383 continue;
384 }
385
386 let clause_i = &self.learned_db.clauses[i].literals;
387 let clause_j = &self.learned_db.clauses[j].literals;
388
389 if Self::subsumes(clause_i, clause_j) {
391 to_remove.push(j);
392 self.stats.clauses_subsumed += 1;
393 } else if Self::subsumes(clause_j, clause_i) {
394 to_remove.push(i);
395 self.stats.clauses_subsumed += 1;
396 break;
397 }
398 }
399 }
400
401 to_remove.sort_unstable();
403 to_remove.dedup();
404 for &idx in to_remove.iter().rev() {
405 self.learned_db.clauses.remove(idx);
406 self.learned_db.activity.remove(idx);
407 }
408
409 Ok(())
410 }
411
412 fn subsumes(a: &[TermId], b: &[TermId]) -> bool {
414 if a.len() > b.len() {
415 return false;
416 }
417
418 let b_set: FxHashSet<TermId> = b.iter().copied().collect();
419
420 a.iter().all(|lit| b_set.contains(lit))
421 }
422
423 pub fn strengthen_clauses(&mut self) -> Result<(), String> {
425 if !self.config.enable_strengthening {
426 return Ok(());
427 }
428
429 Ok(())
449 }
450
451 fn can_remove_literal(&self, _lit: TermId, _clause: &[TermId]) -> bool {
453 false
455 }
456
457 pub fn reduce_database(&mut self) -> Result<(), String> {
459 self.stats.db_reductions += 1;
460
461 self.learned_db.reduce();
463
464 Ok(())
465 }
466
467 pub fn bump_clause(&mut self, clause_id: ClauseId) {
469 self.learned_db.bump_activity(clause_id);
470 }
471
472 pub fn stats(&self) -> &ClauseLearningStats {
474 &self.stats
475 }
476}
477
478impl ImplicationGraph {
479 pub fn new() -> Self {
481 Self {
482 nodes: FxHashMap::default(),
483 predecessors: FxHashMap::default(),
484 levels: FxHashMap::default(),
485 current_level: 0,
486 }
487 }
488
489 pub fn add_node(
491 &mut self,
492 var: TermId,
493 value: bool,
494 level: usize,
495 reason: Option<ClauseId>,
496 is_decision: bool,
497 ) {
498 self.nodes.insert(
499 var,
500 ImplicationNode {
501 var,
502 value,
503 level,
504 reason,
505 is_decision,
506 },
507 );
508
509 self.levels.insert(var, level);
510 }
511
512 pub fn get_level(&self, var: TermId) -> usize {
514 self.levels.get(&var).copied().unwrap_or(0)
515 }
516
517 pub fn is_decision(&self, var: TermId) -> bool {
519 self.nodes.get(&var).is_some_and(|n| n.is_decision)
520 }
521
522 pub fn get_reason(&self, var: TermId) -> Option<ClauseId> {
524 self.nodes.get(&var).and_then(|n| n.reason)
525 }
526
527 pub fn set_level(&mut self, level: usize) {
529 self.current_level = level;
530 }
531}
532
533impl LearnedDatabase {
534 pub fn new(decay_factor: f64) -> Self {
536 Self {
537 clauses: Vec::new(),
538 activity: Vec::new(),
539 clause_map: FxHashMap::default(),
540 bump_increment: 1.0,
541 decay_factor,
542 }
543 }
544
545 pub fn add_clause(&mut self, clause: LearnedClause) {
547 let clause_id = self.clauses.len();
548
549 self.clause_map.insert(clause.literals.clone(), clause_id);
550 self.activity.push(clause.activity);
551 self.clauses.push(clause);
552 }
553
554 pub fn bump_activity(&mut self, clause_id: ClauseId) {
556 if clause_id < self.activity.len() {
557 self.activity[clause_id] += self.bump_increment;
558
559 if self.activity[clause_id] > 1e20 {
561 for act in &mut self.activity {
562 *act *= 1e-20;
563 }
564 self.bump_increment *= 1e-20;
565 }
566 }
567 }
568
569 pub fn decay(&mut self) {
571 self.bump_increment /= self.decay_factor;
572 }
573
574 pub fn reduce(&mut self) {
576 let mut sorted_indices: Vec<usize> = (0..self.clauses.len()).collect();
577
578 sorted_indices.sort_by(|&a, &b| {
580 self.activity[b]
581 .partial_cmp(&self.activity[a])
582 .unwrap_or(std::cmp::Ordering::Equal)
583 });
584
585 let keep_count = self.clauses.len() / 2;
587
588 let mut to_keep = FxHashSet::default();
589 for &idx in sorted_indices.iter().take(keep_count) {
590 to_keep.insert(idx);
591 }
592
593 for (idx, clause) in self.clauses.iter().enumerate() {
595 if clause.locked {
596 to_keep.insert(idx);
597 }
598 }
599
600 let mut new_clauses = Vec::new();
602 let mut new_activity = Vec::new();
603
604 for (idx, clause) in self.clauses.iter().enumerate() {
605 if to_keep.contains(&idx) {
606 new_clauses.push(clause.clone());
607 new_activity.push(self.activity[idx]);
608 }
609 }
610
611 self.clauses = new_clauses;
612 self.activity = new_activity;
613 self.clause_map.clear();
614
615 for (idx, clause) in self.clauses.iter().enumerate() {
617 self.clause_map.insert(clause.literals.clone(), idx);
618 }
619 }
620}
621
622impl ClauseMinimizer {
623 pub fn new() -> Self {
625 Self {
626 seen: FxHashSet::default(),
627 analyze_stack: Vec::new(),
628 cache: FxHashMap::default(),
629 }
630 }
631
632 fn can_remove(&mut self, _lit: TermId, _graph: &ImplicationGraph) -> Result<bool, String> {
634 Ok(false)
636 }
637}
638
639impl Default for ClauseLearner {
640 fn default() -> Self {
641 Self::new(ClauseLearningConfig::default())
642 }
643}
644
645impl Default for ImplicationGraph {
646 fn default() -> Self {
647 Self::new()
648 }
649}
650
651impl Default for ClauseMinimizer {
652 fn default() -> Self {
653 Self::new()
654 }
655}
656
657#[cfg(test)]
658mod tests {
659 use super::*;
660
661 #[test]
662 fn test_clause_learner() {
663 let learner = ClauseLearner::default();
664 assert_eq!(learner.stats.conflicts_analyzed, 0);
665 }
666
667 #[test]
668 fn test_implication_graph() {
669 let mut graph = ImplicationGraph::new();
670
671 let var = TermId::from(1);
672 graph.add_node(var, true, 1, None, true);
673
674 assert_eq!(graph.get_level(var), 1);
675 assert!(graph.is_decision(var));
676 }
677
678 #[test]
679 fn test_learned_database() {
680 let mut db = LearnedDatabase::new(0.95);
681
682 let clause = LearnedClause {
683 literals: vec![TermId::from(1), TermId::from(2)],
684 asserting_lit: TermId::from(1),
685 backtrack_level: 0,
686 activity: 0.0,
687 locked: false,
688 lbd: 2,
689 };
690
691 db.add_clause(clause);
692 assert_eq!(db.clauses.len(), 1);
693 }
694
695 #[test]
696 fn test_subsumption() {
697 let a = vec![TermId::from(1), TermId::from(2)];
698 let b = vec![TermId::from(1), TermId::from(2), TermId::from(3)];
699
700 assert!(ClauseLearner::subsumes(&a, &b));
701 assert!(!ClauseLearner::subsumes(&b, &a));
702 }
703
704 #[test]
705 fn test_lbd_computation() {
706 let learner = ClauseLearner::default();
707
708 let clause = vec![TermId::from(1), TermId::from(2), TermId::from(3)];
709 let lbd = learner.compute_lbd(&clause);
710
711 assert_eq!(lbd, 1);
713 }
714}