use std::time::Duration;
use thiserror::Error;
use smallvec::SmallVec;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Sort {
Bool,
Int,
}
impl std::fmt::Display for Sort {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Sort::Bool => write!(f, "Bool"),
Sort::Int => write!(f, "Int"),
}
}
}
#[derive(Debug, Error)]
pub enum LogicError {
#[error("parse error at {line}:{col}: {msg}")]
Parse {
line: usize,
col: usize,
msg: String
},
#[error("solver error: {0}")]
Solver(String),
#[error("timeout after {0:?}")]
Timeout(Duration),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("sort mismatch: expected {expected}, got {got}")]
SortMismatch {
expected: Sort,
got: Sort
},
#[error("invalid term: {0}")]
InvalidTerm(String),
}
#[derive(Debug, Clone, PartialEq)]
pub struct And2(
)
pub Box<Term>,
pub Box<Term>
);
#[derive(Debug, Clone, PartialEq)]
pub struct Or2(
)
pub Box<Term>,
pub Box<Term>
);
#[derive(Debug, Clone, PartialEq)]
pub enum Term {
Bool(bool),
Int(i64),
Var(String, Sort),
Not(Box<Term>),
And(And2, SmallVec<[Box<Term>; 2]>),
Or(Or2, SmallVec<[Box<Term>; 2]>),
Eq(Box<Term>, Box<Term>),
Ite(Box<Term>, Box<Term>, Box<Term>),
}
impl Term {
pub fn sort(&self) -> Sort {
match self {
Term::Bool(_) => Sort::Bool,
Term::Int(_) => Sort::Int,
Term::Var(_, sort) => *sort,
Term::Not(_) => Sort::Bool,
Term::And(_, _) => Sort::Bool,
Term::Or(_, _) => Sort::Bool,
Term::Eq(_, _) => Sort::Bool,
Term::Ite(_, then_branch, _) => then_branch.sort(),
}
}
}
impl Term {
#[allow(clippy::should_implement_trait)] pub fn not(self) -> Term {
crate::assert_invariant!(
self.sort() == Sort::Bool,
"not requires Bool sort",
"term_not_sort"
);
Term::Not(Box::new(self))
}
pub fn and(self, other: Term) -> Term {
crate::assert_invariant!(
self.sort() == Sort::Bool,
"and requires Bool sort for self",
"term_and_sort_self"
);
crate::assert_invariant!(
other.sort() == Sort::Bool,
"and requires Bool sort for other",
"term_and_sort_other"
);
Term::And(And2(Box::new(self), Box::new(other)), SmallVec::new())
}
pub fn or(self, other: Term) -> Term {
crate::assert_invariant!(
self.sort() == Sort::Bool,
"or requires Bool sort for self",
"term_or_sort_self"
);
crate::assert_invariant!(
other.sort() == Sort::Bool,
"or requires Bool sort for other",
"term_or_sort_other"
);
Term::Or(Or2(Box::new(self), Box::new(other)), SmallVec::new())
}
pub fn eq(self, other: Term) -> Term {
crate::assert_invariant!(
self.sort() == other.sort(),
"eq requires matching sorts",
"term_eq_sort"
);
Term::Eq(Box::new(self), Box::new(other))
}
pub fn implies(self, other: Term) -> Term {
crate::assert_invariant!(
self.sort() == Sort::Bool,
"implies requires Bool sort for self",
"term_implies_sort_self"
);
crate::assert_invariant!(
other.sort() == Sort::Bool,
"implies requires Bool sort for other",
"term_implies_sort_other"
);
self.not().or(other)
}
pub fn and_many(self, others: Vec<Term>) -> Term {
crate::assert_invariant!(
self.sort() == Sort::Bool,
"and_many requires Bool sort for self",
"term_and_many_sort_self"
);
for (i, t) in others.iter().enumerate() {
crate::assert_invariant!(
t.sort() == Sort::Bool,
&format!("and_many requires Bool sort for term {}", i),
"term_and_many_sort_other"
);
}
if others.is_empty() {
Term::And(And2(Box::new(self), Box::new(Term::Bool(true))), SmallVec::new())
} else {
let mut iter = others.into_iter();
let second = iter.next().unwrap();
let rest: SmallVec<[Box<Term>; 2]> = iter.map(Box::new).collect();
Term::And(And2(Box::new(self), Box::new(second)), rest)
}
}
pub fn or_many(self, others: Vec<Term>) -> Term {
crate::assert_invariant!(
self.sort() == Sort::Bool,
"or_many requires Bool sort for self",
"term_or_many_sort_self"
);
for (i, t) in others.iter().enumerate() {
crate::assert_invariant!(
t.sort() == Sort::Bool,
&format!("or_many requires Bool sort for term {}", i),
"term_or_many_sort_other"
);
}
if others.is_empty() {
Term::Or(Or2(Box::new(self), Box::new(Term::Bool(false))), SmallVec::new())
} else {
let mut iter = others.into_iter();
let second = iter.next().unwrap();
let rest: SmallVec<[Box<Term>; 2]> = iter.map(Box::new).collect();
Term::Or(Or2(Box::new(self), Box::new(second)), rest)
}
}
}
impl std::fmt::Display for Term {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Term::Bool(true) => write!(f, "true"),
Term::Bool(false) => write!(f, "false"),
Term::Int(i) => write!(f, "{}", i),
Term::Var(name, _) => write!(f, "{}", name),
Term::Not(inner) => write!(f, "(not {})", inner),
Term::And(And2(a, b), rest) => {
write!(f, "(and {} {}", a, b)?;
for t in rest {
write!(f, " {}", t)?;
}
write!(f, ")")
}
Term::Or(Or2(a, b), rest) => {
write!(f, "(or {} {}", a, b)?;
for t in rest {
write!(f, " {}", t)?;
}
write!(f, ")")
}
Term::Eq(a, b) => write!(f, "(= {} {})", a, b),
Term::Ite(cond, then_b, else_b) => {
write!(f, "(ite {} {} {})", cond, then_b, else_b)
}
}
}
}