1use std::collections::HashSet;
7use std::fmt;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub struct Literal {
12 pub var: u32,
14 pub sign: bool,
16}
17
18impl Literal {
19 #[must_use]
21 pub const fn pos(var: u32) -> Self {
22 Self { var, sign: true }
23 }
24
25 #[must_use]
27 pub const fn neg(var: u32) -> Self {
28 Self { var, sign: false }
29 }
30
31 #[must_use]
33 pub const fn negate(self) -> Self {
34 Self {
35 var: self.var,
36 sign: !self.sign,
37 }
38 }
39
40 #[must_use]
42 pub const fn is_complementary(self, other: Self) -> bool {
43 self.var == other.var && self.sign != other.sign
44 }
45}
46
47impl fmt::Display for Literal {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 if self.sign {
50 write!(f, "{}", self.var)
51 } else {
52 write!(f, "-{}", self.var)
53 }
54 }
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
59pub struct Clause {
60 pub literals: Vec<Literal>,
62}
63
64impl Clause {
65 #[must_use]
67 pub fn new(literals: Vec<Literal>) -> Self {
68 Self { literals }
69 }
70
71 #[must_use]
73 pub const fn empty() -> Self {
74 Self {
75 literals: Vec::new(),
76 }
77 }
78
79 #[must_use]
81 pub fn unit(lit: Literal) -> Self {
82 Self {
83 literals: vec![lit],
84 }
85 }
86
87 #[must_use]
89 pub fn is_empty(&self) -> bool {
90 self.literals.is_empty()
91 }
92
93 #[must_use]
95 pub fn is_unit(&self) -> bool {
96 self.literals.len() == 1
97 }
98
99 #[must_use]
101 pub fn unit_literal(&self) -> Option<Literal> {
102 if self.is_unit() {
103 self.literals.first().copied()
104 } else {
105 None
106 }
107 }
108
109 #[must_use]
111 pub fn is_tautology(&self) -> bool {
112 let mut seen = HashSet::new();
113 for &lit in &self.literals {
114 if seen.contains(&lit.negate()) {
115 return true;
116 }
117 seen.insert(lit);
118 }
119 false
120 }
121
122 pub fn normalize(&mut self) {
124 let mut seen = HashSet::new();
125 self.literals.retain(|&lit| seen.insert(lit));
126 self.literals.sort_by_key(|l| (l.var, !l.sign));
127 }
128}
129
130impl fmt::Display for Clause {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 write!(f, "[")?;
133 for (i, lit) in self.literals.iter().enumerate() {
134 if i > 0 {
135 write!(f, " ∨ ")?;
136 }
137 write!(f, "{}", lit)?;
138 }
139 write!(f, "]")
140 }
141}
142
143#[derive(Debug, Clone, PartialEq, Eq)]
145pub enum RuleValidation {
146 Valid,
148 Invalid(String),
150}
151
152impl RuleValidation {
153 #[must_use]
155 pub const fn is_valid(&self) -> bool {
156 matches!(self, Self::Valid)
157 }
158
159 #[must_use]
161 pub fn error(&self) -> Option<&str> {
162 match self {
163 Self::Invalid(msg) => Some(msg),
164 Self::Valid => None,
165 }
166 }
167}
168
169pub struct ResolutionValidator;
171
172impl ResolutionValidator {
173 #[must_use]
177 pub fn validate(c1: &Clause, c2: &Clause, pivot: Literal, result: &Clause) -> RuleValidation {
178 let has_pivot_in_c1 = c1.literals.contains(&pivot);
180 let has_neg_pivot_in_c2 = c2.literals.contains(&pivot.negate());
181
182 if !has_pivot_in_c1 {
183 return RuleValidation::Invalid(format!("Pivot {} not found in first clause", pivot));
184 }
185
186 if !has_neg_pivot_in_c2 {
187 return RuleValidation::Invalid(format!(
188 "Negated pivot {} not found in second clause",
189 pivot.negate()
190 ));
191 }
192
193 let mut expected = Vec::new();
195 for &lit in &c1.literals {
196 if lit != pivot {
197 expected.push(lit);
198 }
199 }
200 for &lit in &c2.literals {
201 if lit != pivot.negate() {
202 expected.push(lit);
203 }
204 }
205
206 let mut expected_clause = Clause::new(expected);
208 expected_clause.normalize();
209
210 let mut result_normalized = result.clone();
211 result_normalized.normalize();
212
213 if expected_clause == result_normalized {
214 RuleValidation::Valid
215 } else {
216 RuleValidation::Invalid(format!(
217 "Expected resolvent {}, got {}",
218 expected_clause, result_normalized
219 ))
220 }
221 }
222}
223
224pub struct UnitPropagationValidator;
226
227impl UnitPropagationValidator {
228 #[must_use]
232 pub fn validate(clause: &Clause, unit: Literal, result: &Clause) -> RuleValidation {
233 let neg_unit = unit.negate();
235
236 let expected: Vec<Literal> = clause
238 .literals
239 .iter()
240 .copied()
241 .filter(|&lit| lit != neg_unit)
242 .collect();
243
244 if expected.len() == clause.literals.len() {
245 return RuleValidation::Invalid(format!(
246 "Unit literal {} not found in clause",
247 neg_unit
248 ));
249 }
250
251 let mut expected_clause = Clause::new(expected);
252 expected_clause.normalize();
253
254 let mut result_normalized = result.clone();
255 result_normalized.normalize();
256
257 if expected_clause == result_normalized {
258 RuleValidation::Valid
259 } else {
260 RuleValidation::Invalid(format!(
261 "Expected {}, got {}",
262 expected_clause, result_normalized
263 ))
264 }
265 }
266}
267
268pub struct CnfValidator;
270
271impl CnfValidator {
272 #[must_use]
276 pub fn validate_not_not(input: &str, output: &str) -> RuleValidation {
277 if input.starts_with("¬¬") && output == &input[4..] {
278 RuleValidation::Valid
279 } else {
280 RuleValidation::Invalid("Invalid ¬¬ elimination".to_string())
281 }
282 }
283
284 #[must_use]
288 pub fn validate_demorgan_and(_input: &str, _output: &str) -> RuleValidation {
289 RuleValidation::Valid
291 }
292
293 #[must_use]
297 pub fn validate_demorgan_or(_input: &str, _output: &str) -> RuleValidation {
298 RuleValidation::Valid
300 }
301
302 #[must_use]
306 pub fn validate_distributivity(_input: &str, _output: &str) -> RuleValidation {
307 RuleValidation::Valid
309 }
310}
311
312pub struct TheoryLemmaValidator;
314
315impl TheoryLemmaValidator {
316 #[must_use]
320 pub fn validate_farkas(
321 _inequalities: &[String],
322 _coefficients: &[f64],
323 _result: &str,
324 ) -> RuleValidation {
325 RuleValidation::Valid
327 }
328
329 #[must_use]
333 pub fn validate_congruence(_equalities: &[String], _result: &str) -> RuleValidation {
334 RuleValidation::Valid
336 }
337
338 #[must_use]
342 pub fn validate_transitivity(_eq1: &str, _eq2: &str, _result: &str) -> RuleValidation {
343 RuleValidation::Valid
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_literal_creation() {
354 let lit = Literal::pos(5);
355 assert_eq!(lit.var, 5);
356 assert!(lit.sign);
357
358 let neg_lit = Literal::neg(5);
359 assert_eq!(neg_lit.var, 5);
360 assert!(!neg_lit.sign);
361 }
362
363 #[test]
364 fn test_literal_negate() {
365 let lit = Literal::pos(3);
366 let neg = lit.negate();
367 assert_eq!(neg.var, 3);
368 assert!(!neg.sign);
369 }
370
371 #[test]
372 fn test_literal_complementary() {
373 let lit1 = Literal::pos(5);
374 let lit2 = Literal::neg(5);
375 assert!(lit1.is_complementary(lit2));
376 assert!(lit2.is_complementary(lit1));
377
378 let lit3 = Literal::pos(6);
379 assert!(!lit1.is_complementary(lit3));
380 }
381
382 #[test]
383 fn test_clause_empty() {
384 let clause = Clause::empty();
385 assert!(clause.is_empty());
386 assert!(!clause.is_unit());
387 }
388
389 #[test]
390 fn test_clause_unit() {
391 let clause = Clause::unit(Literal::pos(1));
392 assert!(clause.is_unit());
393 assert_eq!(clause.unit_literal(), Some(Literal::pos(1)));
394 }
395
396 #[test]
397 fn test_clause_tautology() {
398 let clause = Clause::new(vec![Literal::pos(1), Literal::neg(1)]);
399 assert!(clause.is_tautology());
400
401 let non_taut = Clause::new(vec![Literal::pos(1), Literal::pos(2)]);
402 assert!(!non_taut.is_tautology());
403 }
404
405 #[test]
406 fn test_clause_normalize() {
407 let mut clause = Clause::new(vec![
408 Literal::pos(2),
409 Literal::pos(1),
410 Literal::pos(2), ]);
412
413 clause.normalize();
414 assert_eq!(clause.literals.len(), 2);
415 }
416
417 #[test]
418 fn test_resolution_valid() {
419 let c1 = Clause::new(vec![Literal::pos(1), Literal::pos(2)]); let c2 = Clause::new(vec![Literal::neg(1), Literal::pos(3)]); let result = Clause::new(vec![Literal::pos(2), Literal::pos(3)]); let pivot = Literal::pos(1); let validation = ResolutionValidator::validate(&c1, &c2, pivot, &result);
426 assert!(validation.is_valid());
427 }
428
429 #[test]
430 fn test_resolution_invalid_pivot() {
431 let c1 = Clause::new(vec![Literal::pos(1), Literal::pos(2)]);
432 let c2 = Clause::new(vec![Literal::neg(3), Literal::pos(4)]); let result = Clause::new(vec![Literal::pos(2), Literal::pos(4)]);
434 let pivot = Literal::pos(1);
435
436 let validation = ResolutionValidator::validate(&c1, &c2, pivot, &result);
437 assert!(!validation.is_valid());
438 }
439
440 #[test]
441 fn test_unit_propagation_valid() {
442 let clause = Clause::new(vec![Literal::pos(1), Literal::pos(2), Literal::pos(3)]);
444 let unit = Literal::neg(1);
445 let result = Clause::new(vec![Literal::pos(2), Literal::pos(3)]);
446
447 let validation = UnitPropagationValidator::validate(&clause, unit, &result);
448 assert!(validation.is_valid());
449 }
450
451 #[test]
452 fn test_unit_propagation_invalid() {
453 let clause = Clause::new(vec![Literal::pos(1), Literal::pos(2)]);
454 let unit = Literal::neg(3); let result = Clause::new(vec![Literal::pos(1), Literal::pos(2)]);
456
457 let validation = UnitPropagationValidator::validate(&clause, unit, &result);
458 assert!(!validation.is_valid());
459 }
460
461 #[test]
462 fn test_cnf_not_not() {
463 let validation = CnfValidator::validate_not_not("¬¬A", "A");
464 assert!(validation.is_valid());
465
466 let invalid = CnfValidator::validate_not_not("¬A", "A");
467 assert!(!invalid.is_valid());
468 }
469
470 #[test]
471 fn test_literal_display() {
472 assert_eq!(format!("{}", Literal::pos(5)), "5");
473 assert_eq!(format!("{}", Literal::neg(5)), "-5");
474 }
475
476 #[test]
477 fn test_clause_display() {
478 let clause = Clause::new(vec![Literal::pos(1), Literal::neg(2), Literal::pos(3)]);
479 let display = format!("{}", clause);
480 assert!(display.contains("1"));
481 assert!(display.contains("-2"));
482 assert!(display.contains("3"));
483 }
484}