Skip to main content

oxiz_theories/checking/
mod.rs

1//! Theory Checking Framework
2//!
3//! Provides verification infrastructure for theory solver correctness:
4//! - Conflict certification per theory
5//! - Proof step validation
6//! - Explanation checking
7//! - Integration with DRAT/Alethe proof generation
8//!
9//! # Design
10//!
11//! Each theory provides a `TheoryChecker` implementation that can verify:
12//! - Conflict clauses are valid (negation is T-unsatisfiable)
13//! - Propagation explanations are correct
14//! - Model assignments satisfy theory constraints
15//!
16//! # Example
17//!
18//! ```ignore
19//! use oxiz_theories::checking::{TheoryChecker, CheckResult};
20//!
21//! let checker = ArithChecker::new();
22//! let result = checker.check_conflict(&literals, &explanation);
23//! assert!(result.is_valid());
24//! ```
25
26mod arith;
27mod array;
28mod bv;
29mod proof;
30mod quant;
31
32pub use arith::{ArithCheckConfig, ArithChecker};
33pub use array::ArrayChecker;
34pub use bv::BvChecker;
35pub use proof::{ProofChecker, ProofStep, ProofStepKind};
36pub use quant::QuantChecker;
37
38use oxiz_core::ast::TermId;
39use std::collections::HashSet;
40
41/// Result of checking a theory inference
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum CheckResult {
44    /// The inference is valid
45    Valid,
46    /// The inference is invalid with reason
47    Invalid(String),
48    /// Check could not be performed (missing information)
49    Unknown(String),
50}
51
52impl CheckResult {
53    /// Check if the result is valid
54    pub fn is_valid(&self) -> bool {
55        matches!(self, CheckResult::Valid)
56    }
57
58    /// Check if the result is invalid
59    pub fn is_invalid(&self) -> bool {
60        matches!(self, CheckResult::Invalid(_))
61    }
62
63    /// Get error message if invalid
64    pub fn error_message(&self) -> Option<&str> {
65        match self {
66            CheckResult::Invalid(msg) => Some(msg),
67            CheckResult::Unknown(msg) => Some(msg),
68            CheckResult::Valid => None,
69        }
70    }
71}
72
73/// A literal (signed term)
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub struct Literal {
76    /// The term
77    pub term: TermId,
78    /// Is the literal positive (term = true) or negative (term = false)
79    pub positive: bool,
80}
81
82impl Literal {
83    /// Create a positive literal
84    pub fn pos(term: TermId) -> Self {
85        Self {
86            term,
87            positive: true,
88        }
89    }
90
91    /// Create a negative literal
92    pub fn neg(term: TermId) -> Self {
93        Self {
94            term,
95            positive: false,
96        }
97    }
98
99    /// Negate this literal
100    pub fn negate(self) -> Self {
101        Self {
102            term: self.term,
103            positive: !self.positive,
104        }
105    }
106}
107
108/// Trait for theory-specific checkers
109pub trait TheoryChecker: Send + Sync {
110    /// Name of the theory
111    fn name(&self) -> &'static str;
112
113    /// Check a conflict clause
114    ///
115    /// A conflict clause is valid if the conjunction of negated literals is
116    /// T-unsatisfiable.
117    fn check_conflict(&self, clause: &[Literal]) -> CheckResult;
118
119    /// Check a propagation explanation
120    ///
121    /// Given that `explanation => literal`, verify this is correct.
122    fn check_propagation(&self, literal: Literal, explanation: &[Literal]) -> CheckResult;
123
124    /// Check that a model satisfies theory constraints
125    fn check_model(&self, assignments: &[(TermId, bool)]) -> CheckResult;
126
127    /// Check a lemma (clause that should be T-valid)
128    fn check_lemma(&self, clause: &[Literal]) -> CheckResult {
129        // Default: check as conflict (all negated should be T-unsat)
130        self.check_conflict(clause)
131    }
132
133    /// Get statistics
134    fn stats(&self) -> CheckerStats;
135
136    /// Reset statistics
137    fn reset_stats(&mut self);
138}
139
140/// Statistics for theory checking
141#[derive(Debug, Clone, Default)]
142pub struct CheckerStats {
143    /// Number of conflict checks
144    pub conflict_checks: u64,
145    /// Number of valid conflicts
146    pub valid_conflicts: u64,
147    /// Number of invalid conflicts
148    pub invalid_conflicts: u64,
149    /// Number of propagation checks
150    pub propagation_checks: u64,
151    /// Number of valid propagations
152    pub valid_propagations: u64,
153    /// Number of model checks
154    pub model_checks: u64,
155    /// Total checking time in microseconds
156    pub check_time_us: u64,
157}
158
159impl CheckerStats {
160    /// Merge statistics from another checker
161    pub fn merge(&mut self, other: &CheckerStats) {
162        self.conflict_checks += other.conflict_checks;
163        self.valid_conflicts += other.valid_conflicts;
164        self.invalid_conflicts += other.invalid_conflicts;
165        self.propagation_checks += other.propagation_checks;
166        self.valid_propagations += other.valid_propagations;
167        self.model_checks += other.model_checks;
168        self.check_time_us += other.check_time_us;
169    }
170
171    /// Validation success rate
172    pub fn success_rate(&self) -> f64 {
173        let total = self.conflict_checks + self.propagation_checks;
174        if total == 0 {
175            1.0
176        } else {
177            let valid = self.valid_conflicts + self.valid_propagations;
178            valid as f64 / total as f64
179        }
180    }
181}
182
183/// Combined theory checker that dispatches to appropriate theory
184#[derive(Debug)]
185pub struct CombinedChecker {
186    /// Arithmetic checker
187    pub arith: ArithChecker,
188    /// Array checker
189    pub array: ArrayChecker,
190    /// Bitvector checker
191    pub bv: BvChecker,
192    /// Quantifier checker
193    pub quant: QuantChecker,
194    /// Terms that belong to each theory
195    theory_terms: std::collections::HashMap<TermId, TheoryKind>,
196}
197
198/// Kind of theory a term belongs to
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
200pub enum TheoryKind {
201    /// Boolean/propositional
202    Bool,
203    /// Arithmetic (LIA/LRA)
204    Arith,
205    /// Arrays
206    Array,
207    /// Bitvectors
208    Bv,
209    /// Quantifiers
210    Quant,
211    /// Uninterpreted functions
212    Uf,
213}
214
215impl CombinedChecker {
216    /// Create a new combined checker
217    pub fn new() -> Self {
218        Self {
219            arith: ArithChecker::new(),
220            array: ArrayChecker::new(),
221            bv: BvChecker::new(),
222            quant: QuantChecker::new(),
223            theory_terms: std::collections::HashMap::new(),
224        }
225    }
226
227    /// Register a term with its theory
228    pub fn register_term(&mut self, term: TermId, kind: TheoryKind) {
229        self.theory_terms.insert(term, kind);
230    }
231
232    /// Get the theory kind for a term
233    pub fn get_theory(&self, term: TermId) -> Option<TheoryKind> {
234        self.theory_terms.get(&term).copied()
235    }
236
237    /// Check a conflict, dispatching to appropriate theory
238    pub fn check_conflict(&self, clause: &[Literal]) -> CheckResult {
239        // Determine theory from literals
240        let theories: HashSet<_> = clause
241            .iter()
242            .filter_map(|lit| self.theory_terms.get(&lit.term))
243            .collect();
244
245        if theories.len() > 1 {
246            // Multi-theory conflict - need combined checking
247            return CheckResult::Unknown("Multi-theory conflict".to_string());
248        }
249
250        match theories.iter().next() {
251            Some(TheoryKind::Arith) => self.arith.check_conflict(clause),
252            Some(TheoryKind::Array) => self.array.check_conflict(clause),
253            Some(TheoryKind::Bv) => self.bv.check_conflict(clause),
254            Some(TheoryKind::Quant) => self.quant.check_conflict(clause),
255            _ => CheckResult::Valid, // Bool/UF - assume valid
256        }
257    }
258
259    /// Get combined statistics
260    pub fn stats(&self) -> CheckerStats {
261        let mut stats = CheckerStats::default();
262        stats.merge(&self.arith.stats());
263        stats.merge(&self.array.stats());
264        stats.merge(&self.bv.stats());
265        stats.merge(&self.quant.stats());
266        stats
267    }
268
269    /// Reset all statistics
270    pub fn reset_stats(&mut self) {
271        self.arith.reset_stats();
272        self.array.reset_stats();
273        self.bv.reset_stats();
274        self.quant.reset_stats();
275    }
276}
277
278impl Default for CombinedChecker {
279    fn default() -> Self {
280        Self::new()
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn test_check_result() {
290        let valid = CheckResult::Valid;
291        assert!(valid.is_valid());
292        assert!(!valid.is_invalid());
293        assert_eq!(valid.error_message(), None);
294
295        let invalid = CheckResult::Invalid("test error".to_string());
296        assert!(!invalid.is_valid());
297        assert!(invalid.is_invalid());
298        assert_eq!(invalid.error_message(), Some("test error"));
299    }
300
301    #[test]
302    fn test_literal() {
303        let t = TermId::from(1u32);
304        let pos = Literal::pos(t);
305        let neg = Literal::neg(t);
306
307        assert!(pos.positive);
308        assert!(!neg.positive);
309        assert_eq!(pos.negate(), neg);
310        assert_eq!(neg.negate(), pos);
311    }
312
313    #[test]
314    fn test_checker_stats() {
315        let mut stats1 = CheckerStats {
316            conflict_checks: 10,
317            valid_conflicts: 8,
318            ..Default::default()
319        };
320
321        let stats2 = CheckerStats {
322            conflict_checks: 5,
323            valid_conflicts: 5,
324            ..Default::default()
325        };
326
327        stats1.merge(&stats2);
328        assert_eq!(stats1.conflict_checks, 15);
329        assert_eq!(stats1.valid_conflicts, 13);
330    }
331
332    #[test]
333    fn test_success_rate() {
334        let mut stats = CheckerStats::default();
335        assert_eq!(stats.success_rate(), 1.0);
336
337        stats.conflict_checks = 10;
338        stats.valid_conflicts = 8;
339        assert!((stats.success_rate() - 0.8).abs() < 0.001);
340    }
341
342    #[test]
343    fn test_combined_checker() {
344        let mut checker = CombinedChecker::new();
345        let t = TermId::from(1u32);
346
347        checker.register_term(t, TheoryKind::Arith);
348        assert_eq!(checker.get_theory(t), Some(TheoryKind::Arith));
349    }
350}