1mod arith;
27mod array;
28mod bv;
29mod proof;
30mod quant;
31
32pub use arith::{ArithCheckConfig, ArithChecker};
33pub use array::ArrayChecker;
34pub use bv::BvChecker;
35pub use proof::{ProofChecker, ProofStep, ProofStepKind};
36pub use quant::QuantChecker;
37
38use oxiz_core::ast::TermId;
39use std::collections::HashSet;
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum CheckResult {
44 Valid,
46 Invalid(String),
48 Unknown(String),
50}
51
52impl CheckResult {
53 pub fn is_valid(&self) -> bool {
55 matches!(self, CheckResult::Valid)
56 }
57
58 pub fn is_invalid(&self) -> bool {
60 matches!(self, CheckResult::Invalid(_))
61 }
62
63 pub fn error_message(&self) -> Option<&str> {
65 match self {
66 CheckResult::Invalid(msg) => Some(msg),
67 CheckResult::Unknown(msg) => Some(msg),
68 CheckResult::Valid => None,
69 }
70 }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub struct Literal {
76 pub term: TermId,
78 pub positive: bool,
80}
81
82impl Literal {
83 pub fn pos(term: TermId) -> Self {
85 Self {
86 term,
87 positive: true,
88 }
89 }
90
91 pub fn neg(term: TermId) -> Self {
93 Self {
94 term,
95 positive: false,
96 }
97 }
98
99 pub fn negate(self) -> Self {
101 Self {
102 term: self.term,
103 positive: !self.positive,
104 }
105 }
106}
107
108pub trait TheoryChecker: Send + Sync {
110 fn name(&self) -> &'static str;
112
113 fn check_conflict(&self, clause: &[Literal]) -> CheckResult;
118
119 fn check_propagation(&self, literal: Literal, explanation: &[Literal]) -> CheckResult;
123
124 fn check_model(&self, assignments: &[(TermId, bool)]) -> CheckResult;
126
127 fn check_lemma(&self, clause: &[Literal]) -> CheckResult {
129 self.check_conflict(clause)
131 }
132
133 fn stats(&self) -> CheckerStats;
135
136 fn reset_stats(&mut self);
138}
139
140#[derive(Debug, Clone, Default)]
142pub struct CheckerStats {
143 pub conflict_checks: u64,
145 pub valid_conflicts: u64,
147 pub invalid_conflicts: u64,
149 pub propagation_checks: u64,
151 pub valid_propagations: u64,
153 pub model_checks: u64,
155 pub check_time_us: u64,
157}
158
159impl CheckerStats {
160 pub fn merge(&mut self, other: &CheckerStats) {
162 self.conflict_checks += other.conflict_checks;
163 self.valid_conflicts += other.valid_conflicts;
164 self.invalid_conflicts += other.invalid_conflicts;
165 self.propagation_checks += other.propagation_checks;
166 self.valid_propagations += other.valid_propagations;
167 self.model_checks += other.model_checks;
168 self.check_time_us += other.check_time_us;
169 }
170
171 pub fn success_rate(&self) -> f64 {
173 let total = self.conflict_checks + self.propagation_checks;
174 if total == 0 {
175 1.0
176 } else {
177 let valid = self.valid_conflicts + self.valid_propagations;
178 valid as f64 / total as f64
179 }
180 }
181}
182
183#[derive(Debug)]
185pub struct CombinedChecker {
186 pub arith: ArithChecker,
188 pub array: ArrayChecker,
190 pub bv: BvChecker,
192 pub quant: QuantChecker,
194 theory_terms: std::collections::HashMap<TermId, TheoryKind>,
196}
197
198#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
200pub enum TheoryKind {
201 Bool,
203 Arith,
205 Array,
207 Bv,
209 Quant,
211 Uf,
213}
214
215impl CombinedChecker {
216 pub fn new() -> Self {
218 Self {
219 arith: ArithChecker::new(),
220 array: ArrayChecker::new(),
221 bv: BvChecker::new(),
222 quant: QuantChecker::new(),
223 theory_terms: std::collections::HashMap::new(),
224 }
225 }
226
227 pub fn register_term(&mut self, term: TermId, kind: TheoryKind) {
229 self.theory_terms.insert(term, kind);
230 }
231
232 pub fn get_theory(&self, term: TermId) -> Option<TheoryKind> {
234 self.theory_terms.get(&term).copied()
235 }
236
237 pub fn check_conflict(&self, clause: &[Literal]) -> CheckResult {
239 let theories: HashSet<_> = clause
241 .iter()
242 .filter_map(|lit| self.theory_terms.get(&lit.term))
243 .collect();
244
245 if theories.len() > 1 {
246 return CheckResult::Unknown("Multi-theory conflict".to_string());
248 }
249
250 match theories.iter().next() {
251 Some(TheoryKind::Arith) => self.arith.check_conflict(clause),
252 Some(TheoryKind::Array) => self.array.check_conflict(clause),
253 Some(TheoryKind::Bv) => self.bv.check_conflict(clause),
254 Some(TheoryKind::Quant) => self.quant.check_conflict(clause),
255 _ => CheckResult::Valid, }
257 }
258
259 pub fn stats(&self) -> CheckerStats {
261 let mut stats = CheckerStats::default();
262 stats.merge(&self.arith.stats());
263 stats.merge(&self.array.stats());
264 stats.merge(&self.bv.stats());
265 stats.merge(&self.quant.stats());
266 stats
267 }
268
269 pub fn reset_stats(&mut self) {
271 self.arith.reset_stats();
272 self.array.reset_stats();
273 self.bv.reset_stats();
274 self.quant.reset_stats();
275 }
276}
277
278impl Default for CombinedChecker {
279 fn default() -> Self {
280 Self::new()
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_check_result() {
290 let valid = CheckResult::Valid;
291 assert!(valid.is_valid());
292 assert!(!valid.is_invalid());
293 assert_eq!(valid.error_message(), None);
294
295 let invalid = CheckResult::Invalid("test error".to_string());
296 assert!(!invalid.is_valid());
297 assert!(invalid.is_invalid());
298 assert_eq!(invalid.error_message(), Some("test error"));
299 }
300
301 #[test]
302 fn test_literal() {
303 let t = TermId::from(1u32);
304 let pos = Literal::pos(t);
305 let neg = Literal::neg(t);
306
307 assert!(pos.positive);
308 assert!(!neg.positive);
309 assert_eq!(pos.negate(), neg);
310 assert_eq!(neg.negate(), pos);
311 }
312
313 #[test]
314 fn test_checker_stats() {
315 let mut stats1 = CheckerStats {
316 conflict_checks: 10,
317 valid_conflicts: 8,
318 ..Default::default()
319 };
320
321 let stats2 = CheckerStats {
322 conflict_checks: 5,
323 valid_conflicts: 5,
324 ..Default::default()
325 };
326
327 stats1.merge(&stats2);
328 assert_eq!(stats1.conflict_checks, 15);
329 assert_eq!(stats1.valid_conflicts, 13);
330 }
331
332 #[test]
333 fn test_success_rate() {
334 let mut stats = CheckerStats::default();
335 assert_eq!(stats.success_rate(), 1.0);
336
337 stats.conflict_checks = 10;
338 stats.valid_conflicts = 8;
339 assert!((stats.success_rate() - 0.8).abs() < 0.001);
340 }
341
342 #[test]
343 fn test_combined_checker() {
344 let mut checker = CombinedChecker::new();
345 let t = TermId::from(1u32);
346
347 checker.register_term(t, TheoryKind::Arith);
348 assert_eq!(checker.get_theory(t), Some(TheoryKind::Arith));
349 }
350}