1use crate::kernel::expr::PredicateKind;
9use crate::kernel::{ExprData, ExprId, ExprPool};
10use std::collections::HashMap;
11use std::fmt;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum LogicError {
20 UnsupportedExpr(&'static str),
22}
23
24impl fmt::Display for LogicError {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 match self {
27 LogicError::UnsupportedExpr(s) => write!(f, "{s}"),
28 }
29 }
30}
31
32impl std::error::Error for LogicError {}
33
34impl crate::errors::AlkahestError for LogicError {
35 fn code(&self) -> &'static str {
36 "E-LOGIC-001"
37 }
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub enum Formula {
47 Atom {
48 kind: PredicateKind,
49 args: Vec<ExprId>,
50 },
51 And(Box<Formula>, Box<Formula>),
52 Or(Box<Formula>, Box<Formula>),
53 Not(Box<Formula>),
54 True,
55 False,
56 Forall {
57 var: ExprId,
58 body: Box<Formula>,
59 },
60 Exists {
61 var: ExprId,
62 body: Box<Formula>,
63 },
64}
65
66impl Formula {
67 pub fn and(a: Formula, b: Formula) -> Self {
68 Formula::And(Box::new(a), Box::new(b))
69 }
70
71 pub fn or(a: Formula, b: Formula) -> Self {
72 Formula::Or(Box::new(a), Box::new(b))
73 }
74
75 #[allow(clippy::should_implement_trait)] pub fn not(a: Formula) -> Self {
77 Formula::Not(Box::new(a))
78 }
79
80 pub fn to_expr(&self, pool: &ExprPool) -> ExprId {
82 match self {
83 Formula::True => pool.pred_true(),
84 Formula::False => pool.pred_false(),
85 Formula::Atom { kind, args } => pool.predicate(kind.clone(), args.clone()),
86 Formula::And(l, r) => pool.pred_and(vec![l.to_expr(pool), r.to_expr(pool)]),
87 Formula::Or(l, r) => pool.pred_or(vec![l.to_expr(pool), r.to_expr(pool)]),
88 Formula::Not(x) => pool.pred_not(x.to_expr(pool)),
89 Formula::Forall { var, body } => pool.forall(*var, body.to_expr(pool)),
90 Formula::Exists { var, body } => pool.exists(*var, body.to_expr(pool)),
91 }
92 }
93}
94
95pub fn formula_from_expr(expr: ExprId, pool: &ExprPool) -> Result<Formula, LogicError> {
97 match pool.get(expr) {
98 ExprData::Predicate { kind, args } => match kind {
99 PredicateKind::True => Ok(Formula::True),
100 PredicateKind::False => Ok(Formula::False),
101 PredicateKind::And => {
102 if args.is_empty() {
103 Ok(Formula::True)
104 } else {
105 let mut it = args.into_iter();
106 let first = formula_from_expr(it.next().unwrap(), pool)?;
107 it.try_fold(first, |acc, e| {
108 Ok(Formula::and(acc, formula_from_expr(e, pool)?))
109 })
110 }
111 }
112 PredicateKind::Or => {
113 if args.is_empty() {
114 Ok(Formula::False)
115 } else {
116 let mut it = args.into_iter();
117 let first = formula_from_expr(it.next().unwrap(), pool)?;
118 it.try_fold(first, |acc, e| {
119 Ok(Formula::or(acc, formula_from_expr(e, pool)?))
120 })
121 }
122 }
123 PredicateKind::Not => {
124 if args.len() != 1 {
125 return Err(LogicError::UnsupportedExpr("Not predicate arity must be 1"));
126 }
127 Ok(Formula::not(formula_from_expr(args[0], pool)?))
128 }
129 PredicateKind::Lt
130 | PredicateKind::Le
131 | PredicateKind::Gt
132 | PredicateKind::Ge
133 | PredicateKind::Eq
134 | PredicateKind::Ne => {
135 if args.len() != 2 {
136 return Err(LogicError::UnsupportedExpr("relation arity must be 2"));
137 }
138 Ok(Formula::Atom { kind, args })
139 }
140 },
141 ExprData::Forall { var, body } => Ok(Formula::Forall {
142 var,
143 body: Box::new(formula_from_expr(body, pool)?),
144 }),
145 ExprData::Exists { var, body } => Ok(Formula::Exists {
146 var,
147 body: Box::new(formula_from_expr(body, pool)?),
148 }),
149 _ => Err(LogicError::UnsupportedExpr(
150 "expression is not a predicate or quantified formula",
151 )),
152 }
153}
154
155#[derive(Clone, Debug)]
160enum Bound {
161 Lower { val: rug::Rational, strict: bool },
162 Upper { val: rug::Rational, strict: bool },
163}
164
165#[derive(Clone, Debug, Default)]
166struct VarInterval {
167 lower: Option<Bound>,
168 upper: Option<Bound>,
169}
170
171impl VarInterval {
172 fn is_empty(&self) -> bool {
173 match (&self.lower, &self.upper) {
174 (
175 Some(Bound::Lower {
176 val: lo,
177 strict: ls,
178 }),
179 Some(Bound::Upper {
180 val: hi,
181 strict: us,
182 }),
183 ) => {
184 if lo > hi {
185 return true;
186 }
187 if lo < hi {
188 return false;
189 }
190 *ls || *us
191 }
192 _ => false,
193 }
194 }
195
196 fn intersect(&self, other: &VarInterval) -> Option<VarInterval> {
197 let lower = match (&self.lower, &other.lower) {
198 (None, b) => b.clone(),
199 (a, None) => a.clone(),
200 (
201 Some(Bound::Lower { val: a, strict: sa }),
202 Some(Bound::Lower { val: b, strict: sb }),
203 ) => {
204 if a > b {
205 Some(Bound::Lower {
206 val: a.clone(),
207 strict: *sa,
208 })
209 } else if b > a {
210 Some(Bound::Lower {
211 val: b.clone(),
212 strict: *sb,
213 })
214 } else {
215 Some(Bound::Lower {
216 val: a.clone(),
217 strict: *sa || *sb,
218 })
219 }
220 }
221 _ => return None,
222 };
223 let upper = match (&self.upper, &other.upper) {
224 (None, b) => b.clone(),
225 (a, None) => a.clone(),
226 (
227 Some(Bound::Upper { val: a, strict: sa }),
228 Some(Bound::Upper { val: b, strict: sb }),
229 ) => {
230 if a < b {
231 Some(Bound::Upper {
232 val: a.clone(),
233 strict: *sa,
234 })
235 } else if b < a {
236 Some(Bound::Upper {
237 val: b.clone(),
238 strict: *sb,
239 })
240 } else {
241 Some(Bound::Upper {
242 val: a.clone(),
243 strict: *sa || *sb,
244 })
245 }
246 }
247 _ => return None,
248 };
249 let r = VarInterval { lower, upper };
250 if r.is_empty() {
251 None
252 } else {
253 Some(r)
254 }
255 }
256}
257
258fn rat_atom(pool: &ExprPool, id: ExprId) -> Option<rug::Rational> {
259 match pool.get(id) {
260 ExprData::Integer(n) => Some(rug::Rational::from(n.0)),
261 ExprData::Rational(r) => Some(r.0.clone()),
262 _ => None,
263 }
264}
265
266fn symbol_key(pool: &ExprPool, id: ExprId) -> Option<String> {
267 pool.with(id, |d| match d {
268 ExprData::Symbol { name, .. } => Some(name.clone()),
269 _ => None,
270 })
271}
272
273fn atom_to_interval(
274 pool: &ExprPool,
275 kind: PredicateKind,
276 args: &[ExprId],
277) -> Option<(ExprId, VarInterval)> {
278 if args.len() != 2 {
279 return None;
280 }
281 let (a, b) = (args[0], args[1]);
282 let (var, c_id, swapped) = if symbol_key(pool, a).is_some() && rat_atom(pool, b).is_some() {
283 (a, b, false)
284 } else if rat_atom(pool, a).is_some() && symbol_key(pool, b).is_some() {
285 (b, a, true)
286 } else {
287 return None;
288 };
289 let c = rat_atom(pool, c_id)?;
290 let iv = match (kind, swapped) {
291 (PredicateKind::Lt, false) => VarInterval {
292 lower: None,
293 upper: Some(Bound::Upper {
294 val: c,
295 strict: true,
296 }),
297 },
298 (PredicateKind::Le, false) => VarInterval {
299 lower: None,
300 upper: Some(Bound::Upper {
301 val: c,
302 strict: false,
303 }),
304 },
305 (PredicateKind::Gt, false) => VarInterval {
306 lower: Some(Bound::Lower {
307 val: c,
308 strict: true,
309 }),
310 upper: None,
311 },
312 (PredicateKind::Ge, false) => VarInterval {
313 lower: Some(Bound::Lower {
314 val: c,
315 strict: false,
316 }),
317 upper: None,
318 },
319 (PredicateKind::Eq, false) => VarInterval {
320 lower: Some(Bound::Lower {
321 val: c.clone(),
322 strict: false,
323 }),
324 upper: Some(Bound::Upper {
325 val: c,
326 strict: false,
327 }),
328 },
329 (PredicateKind::Lt, true) => VarInterval {
330 lower: Some(Bound::Lower {
331 val: c,
332 strict: true,
333 }),
334 upper: None,
335 },
336 (PredicateKind::Le, true) => VarInterval {
337 lower: Some(Bound::Lower {
338 val: c,
339 strict: false,
340 }),
341 upper: None,
342 },
343 (PredicateKind::Gt, true) => VarInterval {
344 lower: None,
345 upper: Some(Bound::Upper {
346 val: c,
347 strict: true,
348 }),
349 },
350 (PredicateKind::Ge, true) => VarInterval {
351 lower: None,
352 upper: Some(Bound::Upper {
353 val: c,
354 strict: false,
355 }),
356 },
357 _ => return None,
358 };
359 Some((var, iv))
360}
361
362fn is_rel(k: &PredicateKind) -> bool {
363 matches!(
364 k,
365 PredicateKind::Lt
366 | PredicateKind::Le
367 | PredicateKind::Gt
368 | PredicateKind::Ge
369 | PredicateKind::Eq
370 | PredicateKind::Ne
371 )
372}
373
374fn dual_kind(kind: PredicateKind) -> PredicateKind {
375 use PredicateKind::*;
376 match kind {
377 Lt => Ge,
378 Le => Gt,
379 Gt => Le,
380 Ge => Lt,
381 Eq => Ne,
382 Ne => Eq,
383 other => other,
384 }
385}
386
387fn nnf(f: Formula) -> Formula {
388 match f {
389 Formula::Not(inner) => match *inner {
390 Formula::True => Formula::False,
391 Formula::False => Formula::True,
392 Formula::Not(g) => nnf(*g),
393 Formula::And(a, b) => nnf(Formula::or(Formula::not(*a), Formula::not(*b))),
394 Formula::Or(a, b) => nnf(Formula::and(Formula::not(*a), Formula::not(*b))),
395 Formula::Forall { var, body } => nnf(Formula::Exists {
396 var,
397 body: Box::new(Formula::not(*body)),
398 }),
399 Formula::Exists { var, body } => nnf(Formula::Forall {
400 var,
401 body: Box::new(Formula::not(*body)),
402 }),
403 Formula::Atom {
404 kind: PredicateKind::True,
405 ..
406 } => Formula::False,
407 Formula::Atom {
408 kind: PredicateKind::False,
409 ..
410 } => Formula::True,
411 Formula::Atom { kind, args } if is_rel(&kind) => Formula::Atom {
412 kind: dual_kind(kind),
413 args,
414 },
415 inner => Formula::Not(Box::new(inner)),
416 },
417 Formula::And(a, b) => Formula::and(nnf(*a), nnf(*b)),
418 Formula::Or(a, b) => Formula::or(nnf(*a), nnf(*b)),
419 Formula::Forall { var, body } => Formula::Forall {
420 var,
421 body: Box::new(nnf(*body)),
422 },
423 Formula::Exists { var, body } => Formula::Exists {
424 var,
425 body: Box::new(nnf(*body)),
426 },
427 other => other,
428 }
429}
430
431fn witness_rational(iv: &VarInterval) -> Option<rug::Rational> {
432 let eps = || rug::Rational::from((1, 10_000));
433 match (&iv.lower, &iv.upper) {
434 (None, None) => Some(rug::Rational::from(0)),
435 (Some(Bound::Lower { val: lo, strict: s }), None) => {
436 let e = eps();
437 Some(if *s { lo.clone() + &e } else { lo.clone() })
438 }
439 (None, Some(Bound::Upper { val: hi, strict: s })) => {
440 let e = eps();
441 Some(if *s { hi.clone() - &e } else { hi.clone() })
442 }
443 (
444 Some(Bound::Lower {
445 val: lo,
446 strict: sl,
447 }),
448 Some(Bound::Upper {
449 val: hi,
450 strict: su,
451 }),
452 ) => {
453 if lo > hi {
454 return None;
455 }
456 if lo < hi {
457 return Some((lo.clone() + hi.clone()) / rug::Rational::from(2));
458 }
459 if *sl || *su {
461 None
462 } else {
463 Some(lo.clone())
464 }
465 }
466 _ => None,
467 }
468}
469
470fn map_to_witness(
471 m: &HashMap<ExprId, VarInterval>,
472 pool: &ExprPool,
473) -> Result<HashMap<String, String>, SatFail> {
474 let mut out = HashMap::new();
475 for (&id, iv) in m {
476 let name = symbol_key(pool, id).ok_or(SatFail::Unknown)?;
477 let w = witness_rational(iv).ok_or(SatFail::Unknown)?;
478 out.insert(name, w.to_string());
479 }
480 Ok(out)
481}
482
483#[derive(Debug, Clone, PartialEq, Eq)]
484pub enum Satisfiability {
485 Sat(HashMap<String, String>),
486 Unsat,
487 Unknown,
488}
489
490enum SatFail {
491 Unsat,
492 Unknown,
493}
494
495fn merge_maps(
496 mut a: HashMap<ExprId, VarInterval>,
497 b: HashMap<ExprId, VarInterval>,
498) -> Result<HashMap<ExprId, VarInterval>, SatFail> {
499 for (k, vb) in b {
500 match a.remove(&k) {
501 None => {
502 a.insert(k, vb);
503 }
504 Some(va) => {
505 let m = va.intersect(&vb).ok_or(SatFail::Unsat)?;
506 a.insert(k, m);
507 }
508 }
509 }
510 Ok(a)
511}
512
513fn sat_intervals(f: &Formula, pool: &ExprPool) -> Result<HashMap<ExprId, VarInterval>, SatFail> {
514 match f {
515 Formula::True => Ok(HashMap::new()),
516 Formula::False => Err(SatFail::Unsat),
517 Formula::Forall { .. } => Err(SatFail::Unknown),
518 Formula::Exists { body, .. } => sat_intervals(body, pool),
519 Formula::And(a, b) => {
520 let ma = sat_intervals(a, pool)?;
521 let mb = sat_intervals(b, pool)?;
522 merge_maps(ma, mb)
523 }
524 Formula::Or(a, b) => match sat_intervals(a, pool) {
525 Ok(m) => Ok(m),
526 Err(SatFail::Unsat) => sat_intervals(b, pool),
527 Err(SatFail::Unknown) => match sat_intervals(b, pool) {
528 Ok(m) => Ok(m),
529 Err(SatFail::Unsat) => Err(SatFail::Unknown),
530 Err(SatFail::Unknown) => Err(SatFail::Unknown),
531 },
532 },
533 Formula::Not(inner) => {
534 if let Formula::Atom { kind, args } = inner.as_ref() {
535 if is_rel(kind) {
536 let dual = Formula::Atom {
537 kind: dual_kind(kind.clone()),
538 args: args.clone(),
539 };
540 return sat_intervals(&dual, pool);
541 }
542 }
543 Err(SatFail::Unknown)
544 }
545 Formula::Atom { kind, args } => {
546 if matches!(
547 kind,
548 PredicateKind::And | PredicateKind::Or | PredicateKind::Not
549 ) {
550 return Err(SatFail::Unknown);
551 }
552 if matches!(kind, PredicateKind::True) {
553 return Ok(HashMap::new());
554 }
555 if matches!(kind, PredicateKind::False) {
556 return Err(SatFail::Unsat);
557 }
558 let (v, iv) = atom_to_interval(pool, kind.clone(), args).ok_or(SatFail::Unknown)?;
559 if iv.is_empty() {
560 return Err(SatFail::Unsat);
561 }
562 let mut m = HashMap::new();
563 m.insert(v, iv);
564 Ok(m)
565 }
566 }
567}
568
569fn simplify_formula_constants(f: Formula) -> Formula {
570 match f {
571 Formula::And(a, b) => {
572 let la = simplify_formula_constants(*a);
573 let lb = simplify_formula_constants(*b);
574 match (&la, &lb) {
575 (Formula::False, _) | (_, Formula::False) => Formula::False,
576 (Formula::True, x) => x.clone(),
577 (x, Formula::True) => x.clone(),
578 _ => Formula::and(la, lb),
579 }
580 }
581 Formula::Or(a, b) => {
582 let la = simplify_formula_constants(*a);
583 let lb = simplify_formula_constants(*b);
584 match (&la, &lb) {
585 (Formula::True, _) | (_, Formula::True) => Formula::True,
586 (Formula::False, x) => x.clone(),
587 (x, Formula::False) => x.clone(),
588 _ => Formula::or(la, lb),
589 }
590 }
591 Formula::Not(x) => Formula::not(simplify_formula_constants(*x)),
592 Formula::Forall { var, body } => Formula::Forall {
593 var,
594 body: Box::new(simplify_formula_constants(*body)),
595 },
596 Formula::Exists { var, body } => Formula::Exists {
597 var,
598 body: Box::new(simplify_formula_constants(*body)),
599 },
600 other => other,
601 }
602}
603
604pub fn satisfiable(expr: ExprId, pool: &ExprPool) -> Satisfiability {
606 let f = match formula_from_expr(expr, pool) {
607 Ok(f) => f,
608 Err(_) => return Satisfiability::Unknown,
609 };
610 let f = nnf(simplify_formula_constants(f));
611 match sat_intervals(&f, pool).and_then(|m| map_to_witness(&m, pool)) {
612 Ok(w) => Satisfiability::Sat(w),
613 Err(SatFail::Unsat) => Satisfiability::Unsat,
614 Err(SatFail::Unknown) => Satisfiability::Unknown,
615 }
616}
617
618pub type BoolLit = i32;
624
625pub type BoolClause = Vec<BoolLit>;
627
628pub fn dpll_sat(clauses: Vec<BoolClause>, n_vars: u32) -> Option<Vec<bool>> {
630 fn is_conflict(c: &BoolClause, a: &[Option<bool>]) -> bool {
631 c.iter().all(|&lit| {
632 let v = lit.unsigned_abs() as usize - 1;
633 let sign = lit > 0;
634 match a[v] {
635 Some(t) => t != sign,
636 None => false,
637 }
638 })
639 }
640
641 fn unit_prop(clauses: &[BoolClause], a: &mut [Option<bool>]) -> Result<(), ()> {
642 loop {
643 let mut progressed = false;
644 for cl in clauses {
645 let mut unassigned: Vec<(usize, bool)> = vec![];
646 let mut satisfied = false;
647 for &lit in cl {
648 let v = lit.unsigned_abs() as usize - 1;
649 let sign = lit > 0;
650 match a[v] {
651 None => unassigned.push((v, sign)),
652 Some(t) if t == sign => satisfied = true,
653 _ => {}
654 }
655 }
656 if satisfied {
657 continue;
658 }
659 if unassigned.is_empty() {
660 return Err(());
661 }
662 if unassigned.len() == 1 {
663 let (v, s) = unassigned[0];
664 if a[v].is_none() {
665 a[v] = Some(s);
666 progressed = true;
667 }
668 }
669 }
670 if !progressed {
671 break;
672 }
673 }
674 Ok(())
675 }
676
677 fn dfs(clauses: &[BoolClause], a: &mut [Option<bool>]) -> Result<(), ()> {
678 unit_prop(clauses, a)?;
679 for cl in clauses {
680 if is_conflict(cl, a) {
681 return Err(());
682 }
683 }
684 if let Some((i, _)) = a.iter().enumerate().find(|(_, x)| x.is_none()) {
685 a[i] = Some(false);
686 if dfs(clauses, a).is_ok() {
687 return Ok(());
688 }
689 a[i] = Some(true);
690 if dfs(clauses, a).is_ok() {
691 return Ok(());
692 }
693 a[i] = None;
694 Err(())
695 } else {
696 Ok(())
697 }
698 }
699
700 let n = n_vars as usize;
701 let mut assign = vec![None; n];
702 if dfs(&clauses, &mut assign).is_ok() {
703 Some(
704 assign
705 .into_iter()
706 .map(|x| x.unwrap_or(false))
707 .collect::<Vec<_>>(),
708 )
709 } else {
710 None
711 }
712}
713
714#[cfg(test)]
715mod tests {
716 use super::*;
717 use crate::kernel::Domain;
718
719 #[test]
720 fn and_contradiction_unsat() {
721 let p = ExprPool::new();
722 let x = p.symbol("x", Domain::Real);
723 let z = p.integer(0_i32);
724 let f = p.pred_and(vec![p.pred_gt(x, z), p.pred_lt(x, z)]);
725 assert_eq!(satisfiable(f, &p), Satisfiability::Unsat);
726 }
727
728 #[test]
729 fn or_cover_sat() {
730 let p = ExprPool::new();
731 let x = p.symbol("x", Domain::Real);
732 let z = p.integer(0_i32);
733 let f = p.pred_or(vec![p.pred_gt(x, z), p.pred_le(x, z)]);
734 match satisfiable(f, &p) {
735 Satisfiability::Sat(m) => assert!(m.contains_key("x")),
736 other => panic!("expected Sat, got {other:?}"),
737 }
738 }
739
740 #[test]
741 fn forall_unknown() {
742 let p = ExprPool::new();
743 let x = p.symbol("x", Domain::Real);
744 let f = p.forall(x, p.pred_gt(x, p.integer(0_i32)));
745 assert_eq!(satisfiable(f, &p), Satisfiability::Unknown);
746 }
747
748 #[test]
749 fn formula_quant_round_trip() {
750 let p = ExprPool::new();
751 let x = p.symbol("x", Domain::Real);
752 let body = p.pred_gt(x, p.integer(0_i32));
753 let q = Formula::Exists {
754 var: x,
755 body: Box::new(formula_from_expr(body, &p).unwrap()),
756 };
757 let e = q.to_expr(&p);
758 let back = formula_from_expr(e, &p).unwrap();
759 assert_eq!(back, q);
760 }
761
762 #[test]
763 fn dpll_tiny_sat() {
764 let r = dpll_sat(vec![vec![1, 2], vec![-1, 2]], 2);
766 assert!(r.is_some());
767 }
768}