use indexmap::{map::Entry, IndexMap, IndexSet};
use itertools::Itertools;
use smtlib_lowlevel::{
ast::{self, Identifier, QualIdentifier},
backend,
lexicon::{Numeral, Symbol},
Driver, Logger, Storage,
};
use crate::{
funs, sorts,
terms::{qual_ident, Dynamic},
Bool, Error, Logic, Model, SatResult, SatResultWithModel, Sorted,
};
#[derive(Debug)]
pub struct Solver<'st, B> {
driver: Driver<'st, B>,
push_pop_stack: Vec<StackSizes>,
decls: IndexMap<Identifier<'st>, ast::Sort<'st>>,
declared_sorts: IndexSet<ast::Sort<'st>>,
}
#[derive(Debug)]
struct StackSizes {
decls: usize,
declared_sorts: usize,
}
impl<'st, B> Solver<'st, B>
where
B: backend::Backend,
{
pub fn new(st: &'st Storage, backend: B) -> Result<Self, Error> {
Ok(Self {
driver: Driver::new(st, backend)?,
push_pop_stack: Vec::new(),
decls: Default::default(),
declared_sorts: Default::default(),
})
}
pub fn st(&self) -> &'st Storage {
self.driver.st()
}
pub fn set_logger(&mut self, logger: impl Logger) {
self.driver.set_logger(logger)
}
pub fn set_timeout(&mut self, ms: usize) -> Result<(), Error> {
let cmd = ast::Command::SetOption(ast::Option::Attribute(ast::Attribute::WithValue(
smtlib_lowlevel::lexicon::Keyword(":timeout"),
ast::AttributeValue::SpecConstant(ast::SpecConstant::Numeral(Numeral::from_usize(ms))),
)));
match self.driver.exec(cmd)? {
ast::GeneralResponse::Success => Ok(()),
ast::GeneralResponse::Error(e) => Err(Error::Smt(e.to_string(), cmd.to_string())),
_ => todo!(),
}
}
pub fn set_logic(&mut self, logic: Logic) -> Result<(), Error> {
let cmd = ast::Command::SetLogic(Symbol(self.st().alloc_str(&logic.name())));
match self.driver.exec(cmd)? {
ast::GeneralResponse::Success => Ok(()),
ast::GeneralResponse::SpecificSuccessResponse(_) => todo!(),
ast::GeneralResponse::Unsupported => todo!(),
ast::GeneralResponse::Error(_) => todo!(),
}
}
pub fn run_command(
&mut self,
cmd: ast::Command<'st>,
) -> Result<ast::GeneralResponse<'st>, Error> {
Ok(self.driver.exec(cmd)?)
}
pub fn assert(&mut self, b: Bool<'st>) -> Result<(), Error> {
let term = b.term();
self.declare_all_consts(term)?;
let cmd = ast::Command::Assert(term);
match self.driver.exec(cmd)? {
ast::GeneralResponse::Success => Ok(()),
ast::GeneralResponse::Error(e) => Err(Error::Smt(e.to_string(), cmd.to_string())),
_ => todo!(),
}
}
pub fn check_sat(&mut self) -> Result<SatResult, Error> {
let cmd = ast::Command::CheckSat;
match self.driver.exec(cmd)? {
ast::GeneralResponse::SpecificSuccessResponse(
ast::SpecificSuccessResponse::CheckSatResponse(res),
) => Ok(match res {
ast::CheckSatResponse::Sat => SatResult::Sat,
ast::CheckSatResponse::Unsat => SatResult::Unsat,
ast::CheckSatResponse::Unknown => SatResult::Unknown,
}),
ast::GeneralResponse::Error(msg) => Err(Error::Smt(msg.to_string(), format!("{cmd}"))),
res => todo!("{res:?}"),
}
}
pub fn check_sat_with_model(&mut self) -> Result<SatResultWithModel<'st>, Error> {
match self.check_sat()? {
SatResult::Unsat => Ok(SatResultWithModel::Unsat),
SatResult::Sat => Ok(SatResultWithModel::Sat(self.get_model()?)),
SatResult::Unknown => Ok(SatResultWithModel::Unknown),
}
}
pub fn get_model(&mut self) -> Result<Model<'st>, Error> {
match self.driver.exec(ast::Command::GetModel)? {
ast::GeneralResponse::SpecificSuccessResponse(
ast::SpecificSuccessResponse::GetModelResponse(model),
) => Ok(Model::new(self.st(), model)),
res => todo!("{res:?}"),
}
}
pub fn declare_fun(&mut self, fun: &funs::Fun<'st>) -> Result<(), Error> {
for var in fun.vars {
self.declare_sort(&var.ast())?;
}
self.declare_sort(&fun.return_sort.ast())?;
if fun.vars.is_empty() {
return self.declare_const(&qual_ident(fun.name, Some(fun.return_sort.ast())));
}
let cmd = ast::Command::DeclareFun(
Symbol(fun.name),
self.st()
.alloc_slice(&fun.vars.iter().map(|s| s.ast()).collect_vec()),
fun.return_sort.ast(),
);
match self.driver.exec(cmd)? {
ast::GeneralResponse::Success => Ok(()),
ast::GeneralResponse::Error(e) => Err(Error::Smt(e.to_string(), cmd.to_string())),
_ => todo!(),
}
}
pub fn simplify(
&mut self,
t: Dynamic<'st>,
) -> Result<&'st smtlib_lowlevel::ast::Term<'st>, Error> {
self.declare_all_consts(t.term())?;
let cmd = ast::Command::Simplify(t.term());
match self.driver.exec(cmd)? {
ast::GeneralResponse::SpecificSuccessResponse(
ast::SpecificSuccessResponse::SimplifyResponse(t),
) => Ok(t.0),
res => todo!("{res:?}"),
}
}
pub fn scope<T>(
&mut self,
f: impl FnOnce(&mut Solver<'st, B>) -> Result<T, Error>,
) -> Result<T, Error> {
self.push(1)?;
let res = f(self)?;
self.pop(1)?;
Ok(res)
}
fn push(&mut self, levels: usize) -> Result<(), Error> {
self.push_pop_stack.push(StackSizes {
decls: self.decls.len(),
declared_sorts: self.declared_sorts.len(),
});
let cmd = ast::Command::Push(Numeral::from_usize(levels));
match self.driver.exec(cmd)? {
ast::GeneralResponse::Success => {}
ast::GeneralResponse::Error(e) => {
return Err(Error::Smt(e.to_string(), cmd.to_string()))
}
_ => todo!(),
};
Ok(())
}
fn pop(&mut self, levels: usize) -> Result<(), Error> {
if let Some(sizes) = self.push_pop_stack.pop() {
self.decls.truncate(sizes.decls);
self.declared_sorts.truncate(sizes.declared_sorts);
}
let cmd = ast::Command::Pop(Numeral::from_usize(levels));
match self.driver.exec(cmd)? {
ast::GeneralResponse::Success => {}
ast::GeneralResponse::Error(e) => {
return Err(Error::Smt(e.to_string(), cmd.to_string()))
}
_ => todo!(),
};
Ok(())
}
fn declare_all_consts(&mut self, t: &'st ast::Term<'st>) -> Result<(), Error> {
for q in t.all_consts() {
self.declare_const(q)?;
}
Ok(())
}
fn declare_const(&mut self, q: &QualIdentifier<'st>) -> Result<(), Error> {
match q {
QualIdentifier::Identifier(_) => {}
QualIdentifier::Sorted(i, s) => {
self.declare_sort(s)?;
match self.decls.entry(*i) {
Entry::Occupied(stored) => assert_eq!(s, stored.get()),
Entry::Vacant(v) => {
v.insert(*s);
match i {
Identifier::Simple(sym) => {
self.driver.exec(ast::Command::DeclareConst(*sym, *s))?;
}
Identifier::Indexed(_, _) => todo!(),
}
}
}
}
};
Ok(())
}
fn declare_sort(&mut self, s: &ast::Sort<'st>) -> Result<(), Error> {
if self.declared_sorts.contains(s) {
return Ok(());
}
self.declared_sorts.insert(*s);
let cmd = match s {
ast::Sort::Sort(ident) => {
let sym = match ident {
Identifier::Simple(sym) => sym,
Identifier::Indexed(_, _) => {
return Ok(());
}
};
if sorts::is_built_in_sort(sym.0) {
return Ok(());
}
ast::Command::DeclareSort(*sym, Numeral::from_usize(0))
}
ast::Sort::Parametric(ident, params) => {
let sym = match ident {
Identifier::Simple(sym) => sym,
Identifier::Indexed(_, _) => {
return Ok(());
}
};
if sorts::is_built_in_sort(sym.0) {
return Ok(());
}
ast::Command::DeclareSort(*sym, Numeral::from_usize(params.len()))
}
};
match self.driver.exec(cmd)? {
ast::GeneralResponse::Success => Ok(()),
ast::GeneralResponse::Error(e) => Err(Error::Smt(e.to_string(), cmd.to_string())),
_ => todo!(),
}
}
}
#[cfg(test)]
mod tests {
use smtlib_lowlevel::{backend::z3_binary::Z3Binary, Storage};
use super::Solver;
use crate::{terms::StaticSorted, Int, SatResult, Sorted};
#[test]
fn scope() -> Result<(), crate::Error> {
let st = Storage::new();
let mut solver = Solver::new(&st, Z3Binary::new("z3").unwrap())?;
let x = Int::new_const(&st, "x");
solver.assert(x._eq(10))?;
solver.scope(|solver| {
solver.assert(x._eq(20))?;
assert_eq!(solver.check_sat()?, SatResult::Unsat);
Ok(())
})?;
assert_eq!(solver.check_sat()?, SatResult::Sat);
Ok(())
}
}