1use crate::clause::{ClauseDatabase, ClauseId};
15use crate::literal::{LBool, Lit};
16#[allow(unused_imports)]
17use crate::prelude::*;
18use crate::trail::Trail;
19use crate::watched::WatchLists;
20
21#[derive(Debug, Default, Clone)]
23pub struct DistillationStats {
24 pub clauses_distilled: u64,
26 pub literals_removed: u64,
28 pub clauses_deleted: u64,
30}
31
32pub struct Distillation {
34 #[allow(dead_code)]
36 max_depth: u32,
37 stats: DistillationStats,
39}
40
41impl Distillation {
42 pub fn new(max_depth: u32) -> Self {
44 Self {
45 max_depth,
46 stats: DistillationStats::default(),
47 }
48 }
49
50 #[must_use]
52 pub fn stats(&self) -> &DistillationStats {
53 &self.stats
54 }
55
56 #[allow(clippy::too_many_arguments)]
60 pub fn distill_clause(
61 &mut self,
62 clause_id: ClauseId,
63 clauses: &mut ClauseDatabase,
64 trail: &mut Trail,
65 watches: &mut WatchLists,
66 assignment: &[LBool],
67 ) -> bool {
68 let Some(clause) = clauses.get(clause_id) else {
69 return false;
70 };
71
72 if clause.deleted || clause.len() <= 2 {
73 return false;
75 }
76
77 let original_lits: Vec<Lit> = clause.lits.iter().copied().collect();
78 let mut strengthened = false;
79
80 for &lit in &original_lits {
82 if assignment[lit.var().index()] != LBool::Undef {
84 continue;
85 }
86
87 if self.try_remove_literal(lit, clause_id, clauses, trail, watches, assignment) {
89 strengthened = true;
90 self.stats.literals_removed += 1;
91 }
92 }
93
94 if strengthened {
95 self.stats.clauses_distilled += 1;
96
97 if let Some(clause) = clauses.get(clause_id)
99 && clause.len() <= 1
100 {
101 self.stats.clauses_deleted += 1;
102 return true;
103 }
104 }
105
106 strengthened
107 }
108
109 #[allow(clippy::too_many_arguments)]
111 fn try_remove_literal(
112 &self,
113 lit: Lit,
114 clause_id: ClauseId,
115 clauses: &mut ClauseDatabase,
116 trail: &mut Trail,
117 _watches: &mut WatchLists,
118 assignment: &[LBool],
119 ) -> bool {
120 let saved_level = trail.decision_level();
122
123 trail.new_decision_level();
125
126 let _test_lit = lit.negate();
128
129 let Some(clause) = clauses.get(clause_id) else {
131 trail.backtrack_to(saved_level);
132 return false;
133 };
134
135 for &other_lit in &clause.lits {
136 if other_lit == lit {
137 continue;
138 }
139
140 if assignment[other_lit.var().index()] == LBool::from(!other_lit.sign()) {
142 trail.backtrack_to(saved_level);
143 return false;
144 }
145
146 if assignment[other_lit.var().index()] == LBool::from(other_lit.sign()) {
149 if let Some(clause) = clauses.get_mut(clause_id) {
151 clause.lits.retain(|l| *l != lit);
152 }
153 trail.backtrack_to(saved_level);
154 return true;
155 }
156 }
157
158 trail.backtrack_to(saved_level);
161
162 false
164 }
165
166 pub fn distill_all(
168 &mut self,
169 clauses: &mut ClauseDatabase,
170 trail: &mut Trail,
171 watches: &mut WatchLists,
172 assignment: &[LBool],
173 ) -> u64 {
174 let mut total_strengthened = 0;
175
176 let clause_ids: Vec<ClauseId> = clauses
178 .iter_ids()
179 .filter(|&id| {
180 if let Some(clause) = clauses.get(id) {
181 clause.learned && !clause.deleted
182 } else {
183 false
184 }
185 })
186 .collect();
187
188 for clause_id in clause_ids {
189 if self.distill_clause(clause_id, clauses, trail, watches, assignment) {
190 total_strengthened += 1;
191 }
192 }
193
194 total_strengthened
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use crate::literal::Var;
202
203 #[test]
204 fn test_distillation_stats() {
205 let distill = Distillation::new(10);
206 let stats = distill.stats();
207 assert_eq!(stats.clauses_distilled, 0);
208 assert_eq!(stats.literals_removed, 0);
209 }
210
211 #[test]
212 fn test_distillation_creation() {
213 let distill = Distillation::new(5);
214 assert_eq!(distill.max_depth, 5);
215 }
216
217 #[test]
218 fn test_distillation_binary_clause() {
219 let mut distill = Distillation::new(10);
220 let mut db = ClauseDatabase::new();
221 let mut trail = Trail::new(10);
222 let mut watches = WatchLists::new(10);
223 let assignment = vec![LBool::Undef; 10];
224
225 let clause_id = db.add_learned([Lit::pos(Var::new(0)), Lit::neg(Var::new(1))]);
227
228 let result =
229 distill.distill_clause(clause_id, &mut db, &mut trail, &mut watches, &assignment);
230 assert!(!result); }
232}