Skip to main content

oxiz_proof/
rules.rs

1//! Proof rule definitions and validators.
2//!
3//! This module provides validation logic for standard proof rules used in SMT solving,
4//! including resolution, unit propagation, CNF transformation, and theory-specific rules.
5
6use std::collections::HashSet;
7use std::fmt;
8
9/// A literal in a clause (variable index with sign)
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub struct Literal {
12    /// Variable index
13    pub var: u32,
14    /// True if positive, false if negated
15    pub sign: bool,
16}
17
18impl Literal {
19    /// Create a positive literal
20    #[must_use]
21    pub const fn pos(var: u32) -> Self {
22        Self { var, sign: true }
23    }
24
25    /// Create a negative literal
26    #[must_use]
27    pub const fn neg(var: u32) -> Self {
28        Self { var, sign: false }
29    }
30
31    /// Negate this literal
32    #[must_use]
33    pub const fn negate(self) -> Self {
34        Self {
35            var: self.var,
36            sign: !self.sign,
37        }
38    }
39
40    /// Check if two literals are complementary
41    #[must_use]
42    pub const fn is_complementary(self, other: Self) -> bool {
43        self.var == other.var && self.sign != other.sign
44    }
45}
46
47impl fmt::Display for Literal {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        if self.sign {
50            write!(f, "{}", self.var)
51        } else {
52            write!(f, "-{}", self.var)
53        }
54    }
55}
56
57/// A clause (disjunction of literals)
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub struct Clause {
60    /// Literals in the clause
61    pub literals: Vec<Literal>,
62}
63
64impl Clause {
65    /// Create a new clause
66    #[must_use]
67    pub fn new(literals: Vec<Literal>) -> Self {
68        Self { literals }
69    }
70
71    /// Create an empty clause (false)
72    #[must_use]
73    pub const fn empty() -> Self {
74        Self {
75            literals: Vec::new(),
76        }
77    }
78
79    /// Create a unit clause
80    #[must_use]
81    pub fn unit(lit: Literal) -> Self {
82        Self {
83            literals: vec![lit],
84        }
85    }
86
87    /// Check if the clause is empty
88    #[must_use]
89    pub fn is_empty(&self) -> bool {
90        self.literals.is_empty()
91    }
92
93    /// Check if the clause is a unit clause
94    #[must_use]
95    pub fn is_unit(&self) -> bool {
96        self.literals.len() == 1
97    }
98
99    /// Get the unit literal (if this is a unit clause)
100    #[must_use]
101    pub fn unit_literal(&self) -> Option<Literal> {
102        if self.is_unit() {
103            self.literals.first().copied()
104        } else {
105            None
106        }
107    }
108
109    /// Check if the clause is a tautology
110    #[must_use]
111    pub fn is_tautology(&self) -> bool {
112        let mut seen = HashSet::new();
113        for &lit in &self.literals {
114            if seen.contains(&lit.negate()) {
115                return true;
116            }
117            seen.insert(lit);
118        }
119        false
120    }
121
122    /// Remove duplicate literals
123    pub fn normalize(&mut self) {
124        let mut seen = HashSet::new();
125        self.literals.retain(|&lit| seen.insert(lit));
126        self.literals.sort_by_key(|l| (l.var, !l.sign));
127    }
128}
129
130impl fmt::Display for Clause {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        write!(f, "[")?;
133        for (i, lit) in self.literals.iter().enumerate() {
134            if i > 0 {
135                write!(f, " ∨ ")?;
136            }
137            write!(f, "{}", lit)?;
138        }
139        write!(f, "]")
140    }
141}
142
143/// Result of rule validation
144#[derive(Debug, Clone, PartialEq, Eq)]
145pub enum RuleValidation {
146    /// Rule application is valid
147    Valid,
148    /// Rule application is invalid
149    Invalid(String),
150}
151
152impl RuleValidation {
153    /// Check if the validation is successful
154    #[must_use]
155    pub const fn is_valid(&self) -> bool {
156        matches!(self, Self::Valid)
157    }
158
159    /// Get the error message (if invalid)
160    #[must_use]
161    pub fn error(&self) -> Option<&str> {
162        match self {
163            Self::Invalid(msg) => Some(msg),
164            Self::Valid => None,
165        }
166    }
167}
168
169/// Resolution rule validator
170pub struct ResolutionValidator;
171
172impl ResolutionValidator {
173    /// Validate a resolution step
174    ///
175    /// Resolution: C1 ∨ x, C2 ∨ ¬x ⊢ C1 ∨ C2
176    #[must_use]
177    pub fn validate(c1: &Clause, c2: &Clause, pivot: Literal, result: &Clause) -> RuleValidation {
178        // Find the pivot literal in c1 and its negation in c2
179        let has_pivot_in_c1 = c1.literals.contains(&pivot);
180        let has_neg_pivot_in_c2 = c2.literals.contains(&pivot.negate());
181
182        if !has_pivot_in_c1 {
183            return RuleValidation::Invalid(format!("Pivot {} not found in first clause", pivot));
184        }
185
186        if !has_neg_pivot_in_c2 {
187            return RuleValidation::Invalid(format!(
188                "Negated pivot {} not found in second clause",
189                pivot.negate()
190            ));
191        }
192
193        // Build expected resolvent
194        let mut expected = Vec::new();
195        for &lit in &c1.literals {
196            if lit != pivot {
197                expected.push(lit);
198            }
199        }
200        for &lit in &c2.literals {
201            if lit != pivot.negate() {
202                expected.push(lit);
203            }
204        }
205
206        // Normalize and compare
207        let mut expected_clause = Clause::new(expected);
208        expected_clause.normalize();
209
210        let mut result_normalized = result.clone();
211        result_normalized.normalize();
212
213        if expected_clause == result_normalized {
214            RuleValidation::Valid
215        } else {
216            RuleValidation::Invalid(format!(
217                "Expected resolvent {}, got {}",
218                expected_clause, result_normalized
219            ))
220        }
221    }
222}
223
224/// Unit propagation validator
225pub struct UnitPropagationValidator;
226
227impl UnitPropagationValidator {
228    /// Validate a unit propagation step
229    ///
230    /// Unit propagation: C ∨ x, ¬x ⊢ C
231    #[must_use]
232    pub fn validate(clause: &Clause, unit: Literal, result: &Clause) -> RuleValidation {
233        // Check that unit is indeed a literal
234        let neg_unit = unit.negate();
235
236        // Build expected result (clause with neg_unit removed)
237        let expected: Vec<Literal> = clause
238            .literals
239            .iter()
240            .copied()
241            .filter(|&lit| lit != neg_unit)
242            .collect();
243
244        if expected.len() == clause.literals.len() {
245            return RuleValidation::Invalid(format!(
246                "Unit literal {} not found in clause",
247                neg_unit
248            ));
249        }
250
251        let mut expected_clause = Clause::new(expected);
252        expected_clause.normalize();
253
254        let mut result_normalized = result.clone();
255        result_normalized.normalize();
256
257        if expected_clause == result_normalized {
258            RuleValidation::Valid
259        } else {
260            RuleValidation::Invalid(format!(
261                "Expected {}, got {}",
262                expected_clause, result_normalized
263            ))
264        }
265    }
266}
267
268/// CNF transformation validator
269pub struct CnfValidator;
270
271impl CnfValidator {
272    /// Validate negation normal form transformation
273    ///
274    /// ¬(¬A) ⟺ A
275    #[must_use]
276    pub fn validate_not_not(input: &str, output: &str) -> RuleValidation {
277        if input.starts_with("¬¬") && output == &input[4..] {
278            RuleValidation::Valid
279        } else {
280            RuleValidation::Invalid("Invalid ¬¬ elimination".to_string())
281        }
282    }
283
284    /// Validate De Morgan's law (AND)
285    ///
286    /// ¬(A ∧ B) ⟺ ¬A ∨ ¬B
287    #[must_use]
288    pub fn validate_demorgan_and(_input: &str, _output: &str) -> RuleValidation {
289        // Simplified validation - in practice would parse formulas
290        RuleValidation::Valid
291    }
292
293    /// Validate De Morgan's law (OR)
294    ///
295    /// ¬(A ∨ B) ⟺ ¬A ∧ ¬B
296    #[must_use]
297    pub fn validate_demorgan_or(_input: &str, _output: &str) -> RuleValidation {
298        // Simplified validation - in practice would parse formulas
299        RuleValidation::Valid
300    }
301
302    /// Validate distributivity
303    ///
304    /// A ∨ (B ∧ C) ⟺ (A ∨ B) ∧ (A ∨ C)
305    #[must_use]
306    pub fn validate_distributivity(_input: &str, _output: &str) -> RuleValidation {
307        // Simplified validation - in practice would parse formulas
308        RuleValidation::Valid
309    }
310}
311
312/// Theory lemma validator
313pub struct TheoryLemmaValidator;
314
315impl TheoryLemmaValidator {
316    /// Validate an arithmetic Farkas lemma
317    ///
318    /// Given inequalities and coefficients, check that the combination is valid
319    #[must_use]
320    pub fn validate_farkas(
321        _inequalities: &[String],
322        _coefficients: &[f64],
323        _result: &str,
324    ) -> RuleValidation {
325        // Simplified - in practice would check arithmetic
326        RuleValidation::Valid
327    }
328
329    /// Validate congruence closure
330    ///
331    /// a = b, f(a) ⊢ f(a) = f(b)
332    #[must_use]
333    pub fn validate_congruence(_equalities: &[String], _result: &str) -> RuleValidation {
334        // Simplified - in practice would check congruence
335        RuleValidation::Valid
336    }
337
338    /// Validate transitivity of equality
339    ///
340    /// a = b, b = c ⊢ a = c
341    #[must_use]
342    pub fn validate_transitivity(_eq1: &str, _eq2: &str, _result: &str) -> RuleValidation {
343        // Simplified - in practice would parse and check
344        RuleValidation::Valid
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_literal_creation() {
354        let lit = Literal::pos(5);
355        assert_eq!(lit.var, 5);
356        assert!(lit.sign);
357
358        let neg_lit = Literal::neg(5);
359        assert_eq!(neg_lit.var, 5);
360        assert!(!neg_lit.sign);
361    }
362
363    #[test]
364    fn test_literal_negate() {
365        let lit = Literal::pos(3);
366        let neg = lit.negate();
367        assert_eq!(neg.var, 3);
368        assert!(!neg.sign);
369    }
370
371    #[test]
372    fn test_literal_complementary() {
373        let lit1 = Literal::pos(5);
374        let lit2 = Literal::neg(5);
375        assert!(lit1.is_complementary(lit2));
376        assert!(lit2.is_complementary(lit1));
377
378        let lit3 = Literal::pos(6);
379        assert!(!lit1.is_complementary(lit3));
380    }
381
382    #[test]
383    fn test_clause_empty() {
384        let clause = Clause::empty();
385        assert!(clause.is_empty());
386        assert!(!clause.is_unit());
387    }
388
389    #[test]
390    fn test_clause_unit() {
391        let clause = Clause::unit(Literal::pos(1));
392        assert!(clause.is_unit());
393        assert_eq!(clause.unit_literal(), Some(Literal::pos(1)));
394    }
395
396    #[test]
397    fn test_clause_tautology() {
398        let clause = Clause::new(vec![Literal::pos(1), Literal::neg(1)]);
399        assert!(clause.is_tautology());
400
401        let non_taut = Clause::new(vec![Literal::pos(1), Literal::pos(2)]);
402        assert!(!non_taut.is_tautology());
403    }
404
405    #[test]
406    fn test_clause_normalize() {
407        let mut clause = Clause::new(vec![
408            Literal::pos(2),
409            Literal::pos(1),
410            Literal::pos(2), // duplicate
411        ]);
412
413        clause.normalize();
414        assert_eq!(clause.literals.len(), 2);
415    }
416
417    #[test]
418    fn test_resolution_valid() {
419        // (p ∨ q) ∧ (¬p ∨ r) ⊢ (q ∨ r)
420        let c1 = Clause::new(vec![Literal::pos(1), Literal::pos(2)]); // p ∨ q
421        let c2 = Clause::new(vec![Literal::neg(1), Literal::pos(3)]); // ¬p ∨ r
422        let result = Clause::new(vec![Literal::pos(2), Literal::pos(3)]); // q ∨ r
423        let pivot = Literal::pos(1); // p
424
425        let validation = ResolutionValidator::validate(&c1, &c2, pivot, &result);
426        assert!(validation.is_valid());
427    }
428
429    #[test]
430    fn test_resolution_invalid_pivot() {
431        let c1 = Clause::new(vec![Literal::pos(1), Literal::pos(2)]);
432        let c2 = Clause::new(vec![Literal::neg(3), Literal::pos(4)]); // Wrong pivot
433        let result = Clause::new(vec![Literal::pos(2), Literal::pos(4)]);
434        let pivot = Literal::pos(1);
435
436        let validation = ResolutionValidator::validate(&c1, &c2, pivot, &result);
437        assert!(!validation.is_valid());
438    }
439
440    #[test]
441    fn test_unit_propagation_valid() {
442        // (p ∨ q ∨ r) with unit ¬p ⊢ (q ∨ r)
443        let clause = Clause::new(vec![Literal::pos(1), Literal::pos(2), Literal::pos(3)]);
444        let unit = Literal::neg(1);
445        let result = Clause::new(vec![Literal::pos(2), Literal::pos(3)]);
446
447        let validation = UnitPropagationValidator::validate(&clause, unit, &result);
448        assert!(validation.is_valid());
449    }
450
451    #[test]
452    fn test_unit_propagation_invalid() {
453        let clause = Clause::new(vec![Literal::pos(1), Literal::pos(2)]);
454        let unit = Literal::neg(3); // Not in clause
455        let result = Clause::new(vec![Literal::pos(1), Literal::pos(2)]);
456
457        let validation = UnitPropagationValidator::validate(&clause, unit, &result);
458        assert!(!validation.is_valid());
459    }
460
461    #[test]
462    fn test_cnf_not_not() {
463        let validation = CnfValidator::validate_not_not("¬¬A", "A");
464        assert!(validation.is_valid());
465
466        let invalid = CnfValidator::validate_not_not("¬A", "A");
467        assert!(!invalid.is_valid());
468    }
469
470    #[test]
471    fn test_literal_display() {
472        assert_eq!(format!("{}", Literal::pos(5)), "5");
473        assert_eq!(format!("{}", Literal::neg(5)), "-5");
474    }
475
476    #[test]
477    fn test_clause_display() {
478        let clause = Clause::new(vec![Literal::pos(1), Literal::neg(2), Literal::pos(3)]);
479        let display = format!("{}", clause);
480        assert!(display.contains("1"));
481        assert!(display.contains("-2"));
482        assert!(display.contains("3"));
483    }
484}