1use crate::literal::{Lit, Var};
8#[allow(unused_imports)]
9use crate::prelude::*;
10
11#[derive(Debug, Clone)]
13pub enum ExtensionType {
14 And(Lit, Lit),
16 Or(Lit, Lit),
18 Xor(Lit, Lit),
20 Ite {
22 cond: Lit,
24 then_lit: Lit,
26 else_lit: Lit,
28 },
29 Equiv(Lit, Lit),
31 General {
33 positive: Vec<Lit>,
35 negative: Vec<Lit>,
37 },
38}
39
40#[derive(Debug, Clone)]
42pub struct Extension {
43 pub var: Var,
45 pub def: ExtensionType,
47}
48
49impl Extension {
50 pub fn new(var: Var, def: ExtensionType) -> Self {
52 Self { var, def }
53 }
54
55 pub fn to_cnf(&self) -> Vec<Vec<Lit>> {
57 let z = Lit::pos(self.var);
58 let nz = Lit::neg(self.var);
59
60 match &self.def {
61 ExtensionType::And(x, y) => {
62 vec![vec![nz, *x], vec![nz, *y], vec![x.negate(), y.negate(), z]]
65 }
66 ExtensionType::Or(x, y) => {
67 vec![vec![nz, *x, *y], vec![x.negate(), z], vec![y.negate(), z]]
70 }
71 ExtensionType::Xor(x, y) => {
72 vec![
76 vec![nz, x.negate(), y.negate()],
77 vec![nz, *x, *y],
78 vec![*x, y.negate(), z],
79 vec![x.negate(), *y, z],
80 ]
81 }
82 ExtensionType::Ite {
83 cond,
84 then_lit,
85 else_lit,
86 } => {
87 vec![
93 vec![cond.negate(), then_lit.negate(), z],
94 vec![*cond, else_lit.negate(), z],
95 vec![nz, cond.negate(), *then_lit],
96 vec![nz, *cond, *else_lit],
97 ]
98 }
99 ExtensionType::Equiv(x, y) => {
100 vec![
105 vec![nz, x.negate(), *y],
106 vec![nz, *x, y.negate()],
107 vec![*x, *y, z],
108 vec![x.negate(), y.negate(), z],
109 ]
110 }
111 ExtensionType::General { positive, negative } => {
112 let mut clauses = Vec::new();
115
116 let mut forward = vec![nz];
118 forward.extend(positive.iter().copied());
119 for &lit in negative {
120 forward.push(lit.negate());
121 }
122 clauses.push(forward);
123
124 for &lit in positive {
126 clauses.push(vec![lit.negate(), z]);
127 }
128 for &lit in negative {
129 clauses.push(vec![lit, z]);
130 }
131
132 clauses
133 }
134 }
135 }
136}
137
138pub struct ExtendedResolution {
140 extensions: HashMap<Var, Extension>,
142 next_var: u32,
144 base_num_vars: u32,
146}
147
148impl ExtendedResolution {
149 pub fn new(num_vars: u32) -> Self {
151 Self {
152 extensions: HashMap::new(),
153 next_var: num_vars,
154 base_num_vars: num_vars,
155 }
156 }
157
158 pub fn add_extension(&mut self, def: ExtensionType) -> Var {
160 let var = Var(self.next_var);
161 self.next_var += 1;
162 self.extensions.insert(var, Extension::new(var, def));
163 var
164 }
165
166 pub fn add_and(&mut self, x: Lit, y: Lit) -> Var {
168 self.add_extension(ExtensionType::And(x, y))
169 }
170
171 pub fn add_or(&mut self, x: Lit, y: Lit) -> Var {
173 self.add_extension(ExtensionType::Or(x, y))
174 }
175
176 pub fn add_xor(&mut self, x: Lit, y: Lit) -> Var {
178 self.add_extension(ExtensionType::Xor(x, y))
179 }
180
181 pub fn add_ite(&mut self, cond: Lit, then_lit: Lit, else_lit: Lit) -> Var {
183 self.add_extension(ExtensionType::Ite {
184 cond,
185 then_lit,
186 else_lit,
187 })
188 }
189
190 pub fn add_equiv(&mut self, x: Lit, y: Lit) -> Var {
192 self.add_extension(ExtensionType::Equiv(x, y))
193 }
194
195 pub fn get_all_cnf(&self) -> Vec<Vec<Lit>> {
197 let mut clauses = Vec::new();
198 for ext in self.extensions.values() {
199 clauses.extend(ext.to_cnf());
200 }
201 clauses
202 }
203
204 pub fn get_extension(&self, var: Var) -> Option<&Extension> {
206 self.extensions.get(&var)
207 }
208
209 pub fn is_extension(&self, var: Var) -> bool {
211 var.0 >= self.base_num_vars
212 }
213
214 pub fn num_vars(&self) -> u32 {
216 self.next_var
217 }
218
219 pub fn num_extensions(&self) -> usize {
221 self.extensions.len()
222 }
223
224 pub fn get_extensions(&self) -> Vec<Var> {
226 let mut vars: Vec<Var> = self.extensions.keys().copied().collect();
227 vars.sort_by_key(|v| v.0);
228 vars
229 }
230
231 pub fn tseitin_and(&mut self, lits: &[Lit]) -> Var {
234 if lits.is_empty() {
235 return self.add_extension(ExtensionType::General {
238 positive: vec![],
239 negative: vec![],
240 });
241 }
242 if lits.len() == 1 {
243 return lits[0].var();
244 }
245
246 let mid = lits.len() / 2;
248 let left = self.tseitin_and(&lits[..mid]);
249 let right = self.tseitin_and(&lits[mid..]);
250 self.add_and(Lit::pos(left), Lit::pos(right))
251 }
252
253 pub fn tseitin_or(&mut self, lits: &[Lit]) -> Var {
255 if lits.is_empty() {
256 return self.add_extension(ExtensionType::General {
258 positive: vec![],
259 negative: vec![],
260 });
261 }
262 if lits.len() == 1 {
263 return lits[0].var();
264 }
265
266 let mid = lits.len() / 2;
268 let left = self.tseitin_or(&lits[..mid]);
269 let right = self.tseitin_or(&lits[mid..]);
270 self.add_or(Lit::pos(left), Lit::pos(right))
271 }
272}
273
274pub struct ClauseSubstitution {
276 substitutions: HashMap<(Lit, Lit), Var>,
278}
279
280impl ClauseSubstitution {
281 pub fn new() -> Self {
283 Self {
284 substitutions: HashMap::new(),
285 }
286 }
287
288 pub fn add(&mut self, x: Lit, y: Lit, z: Var) {
290 self.substitutions.insert((x, y), z);
291 self.substitutions.insert((y, x), z);
292 }
293
294 pub fn get(&self, x: Lit, y: Lit) -> Option<Var> {
296 self.substitutions.get(&(x, y)).copied()
297 }
298}
299
300impl Default for ClauseSubstitution {
301 fn default() -> Self {
302 Self::new()
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_and_extension() {
312 let x = Lit::pos(Var(0));
313 let y = Lit::pos(Var(1));
314 let z = Var(2);
315
316 let ext = Extension::new(z, ExtensionType::And(x, y));
317 let cnf = ext.to_cnf();
318
319 assert_eq!(cnf.len(), 3);
320 }
321
322 #[test]
323 fn test_or_extension() {
324 let x = Lit::pos(Var(0));
325 let y = Lit::pos(Var(1));
326 let z = Var(2);
327
328 let ext = Extension::new(z, ExtensionType::Or(x, y));
329 let cnf = ext.to_cnf();
330
331 assert_eq!(cnf.len(), 3);
332 }
333
334 #[test]
335 fn test_xor_extension() {
336 let x = Lit::pos(Var(0));
337 let y = Lit::pos(Var(1));
338 let z = Var(2);
339
340 let ext = Extension::new(z, ExtensionType::Xor(x, y));
341 let cnf = ext.to_cnf();
342
343 assert_eq!(cnf.len(), 4);
344 }
345
346 #[test]
347 fn test_ite_extension() {
348 let c = Lit::pos(Var(0));
349 let t = Lit::pos(Var(1));
350 let e = Lit::pos(Var(2));
351 let z = Var(3);
352
353 let ext = Extension::new(
354 z,
355 ExtensionType::Ite {
356 cond: c,
357 then_lit: t,
358 else_lit: e,
359 },
360 );
361 let cnf = ext.to_cnf();
362
363 assert_eq!(cnf.len(), 4);
364 }
365
366 #[test]
367 fn test_extended_resolution_manager() {
368 let mut er = ExtendedResolution::new(10);
369
370 let x = Lit::pos(Var(0));
371 let y = Lit::pos(Var(1));
372
373 let z = er.add_and(x, y);
374 assert!(er.is_extension(z));
375 assert_eq!(er.num_extensions(), 1);
376
377 let w = er.add_or(x, y);
378 assert!(er.is_extension(w));
379 assert_eq!(er.num_extensions(), 2);
380 }
381
382 #[test]
383 fn test_tseitin_and() {
384 let mut er = ExtendedResolution::new(10);
385
386 let lits = vec![
387 Lit::pos(Var(0)),
388 Lit::pos(Var(1)),
389 Lit::pos(Var(2)),
390 Lit::pos(Var(3)),
391 ];
392
393 let top = er.tseitin_and(&lits);
394 assert!(er.is_extension(top));
395
396 assert!(er.num_extensions() >= 1);
398 }
399
400 #[test]
401 fn test_clause_substitution() {
402 let mut subst = ClauseSubstitution::new();
403
404 let x = Lit::pos(Var(0));
405 let y = Lit::pos(Var(1));
406 let z = Var(2);
407
408 subst.add(x, y, z);
409
410 assert_eq!(subst.get(x, y), Some(z));
411 assert_eq!(subst.get(y, x), Some(z));
412 }
413}