use std::collections::{hash_map::Entry, HashMap};
use smtlib_lowlevel::{
ast::{self, Identifier, QualIdentifier},
backend,
lexicon::Symbol,
tokio::TokioDriver,
Storage,
};
use crate::{Bool, Error, Logic, Model, SatResult, SatResultWithModel, Sorted};
#[derive(Debug)]
pub struct TokioSolver<'st, B> {
driver: TokioDriver<'st, B>,
decls: HashMap<Identifier<'st>, ast::Sort<'st>>,
}
impl<'st, B> TokioSolver<'st, B>
where
B: backend::tokio::TokioBackend,
{
pub async fn new(st: &'st Storage, backend: B) -> Result<Self, Error> {
Ok(Self {
driver: TokioDriver::new(st, backend).await?,
decls: Default::default(),
})
}
pub fn st(&self) -> &'st Storage {
self.driver.st()
}
pub async fn set_logic(&mut self, logic: Logic) -> Result<(), Error> {
let cmd = ast::Command::SetLogic(Symbol(self.st().alloc_str(&logic.name())));
match self.driver.exec(cmd).await? {
ast::GeneralResponse::Success => Ok(()),
ast::GeneralResponse::SpecificSuccessResponse(_) => todo!(),
ast::GeneralResponse::Unsupported => todo!(),
ast::GeneralResponse::Error(_) => todo!(),
}
}
pub async fn run_command(
&mut self,
cmd: ast::Command<'st>,
) -> Result<ast::GeneralResponse<'st>, Error> {
Ok(self.driver.exec(cmd).await?)
}
pub async fn assert(&mut self, b: Bool<'st>) -> Result<(), Error> {
let term = b.term();
for q in term.all_consts() {
match q {
QualIdentifier::Identifier(_) => {}
QualIdentifier::Sorted(i, s) => match self.decls.entry(*i) {
Entry::Occupied(stored) => assert_eq!(s, stored.get()),
Entry::Vacant(v) => {
v.insert(*s);
match i {
Identifier::Simple(sym) => {
self.driver
.exec(ast::Command::DeclareConst(*sym, *s))
.await?;
}
Identifier::Indexed(_, _) => todo!(),
}
}
},
}
}
let cmd = ast::Command::Assert(term);
match self.driver.exec(cmd).await? {
ast::GeneralResponse::Success => Ok(()),
ast::GeneralResponse::Error(e) => Err(Error::Smt(e.to_string(), cmd.to_string())),
_ => todo!(),
}
}
pub async fn check_sat(&mut self) -> Result<SatResult, Error> {
let cmd = ast::Command::CheckSat;
match self.driver.exec(cmd).await? {
ast::GeneralResponse::SpecificSuccessResponse(
ast::SpecificSuccessResponse::CheckSatResponse(res),
) => Ok(match res {
ast::CheckSatResponse::Sat => SatResult::Sat,
ast::CheckSatResponse::Unsat => SatResult::Unsat,
ast::CheckSatResponse::Unknown => SatResult::Unknown,
}),
ast::GeneralResponse::Error(msg) => Err(Error::Smt(msg.to_string(), format!("{cmd}"))),
res => todo!("{res:?}"),
}
}
pub async fn check_sat_with_model(&mut self) -> Result<SatResultWithModel, Error> {
match self.check_sat().await? {
SatResult::Unsat => Ok(SatResultWithModel::Unsat),
SatResult::Sat => Ok(SatResultWithModel::Sat(self.get_model().await?)),
SatResult::Unknown => Ok(SatResultWithModel::Unknown),
}
}
pub async fn get_model(&mut self) -> Result<Model, Error> {
match self.driver.exec(ast::Command::GetModel).await? {
ast::GeneralResponse::SpecificSuccessResponse(
ast::SpecificSuccessResponse::GetModelResponse(model),
) => Ok(Model::new(self.st(), model)),
res => todo!("{res:?}"),
}
}
}
#[cfg(test)]
mod tests {
use smtlib_lowlevel::backend::z3_binary::tokio::Z3BinaryTokio;
use super::TokioSolver;
use crate::{terms::StaticSorted, Int, Sorted};
type Result<T, E = crate::Error> = std::result::Result<T, E>;
#[tokio::test]
async fn basic() -> Result<(), Box<dyn std::error::Error>> {
let st = smtlib_lowlevel::Storage::new();
let mut solver = TokioSolver::new(&st, Z3BinaryTokio::new("z3").await?).await?;
let x = Int::new_const(&st, "x");
let y = Int::new_const(&st, "y");
solver.assert(x._eq(10)).await?;
solver.assert(y._eq(x + 2)).await?;
let model = solver.check_sat_with_model().await?.expect_sat()?;
insta::assert_display_snapshot!(model, @"{ x: 10, y: 12 }");
Ok(())
}
}