use std::collections::HashMap;
use std::convert::TryInto;
use crate::ast::Sym;
use crate::search::{Resolved, Resolver, SolutionState};
use crate::term_arena::{AppTerm, ArgRange, Term, TermId};
use crate::universe::SymbolStorage;
#[derive(Clone)]
pub struct ArithmeticResolver {
exp_map: HashMap<Sym, Exp>,
pred_map: HashMap<Sym, Pred>,
}
impl ArithmeticResolver {
pub fn new<T: SymbolStorage>(symbols: &mut T) -> Self {
let exps = [
("add", Exp::Add),
("sub", Exp::Sub),
("mul", Exp::Mul),
("div", Exp::Div),
("rem", Exp::Rem),
("pow", Exp::Pow),
];
let preds = [("is", Pred::Is)];
Self {
exp_map: symbols.build_sym_map(exps),
pred_map: symbols.build_sym_map(preds),
}
}
fn eval_exp(&self, solution: &SolutionState, exp: TermId) -> Option<i64> {
match solution.follow_vars(exp).1 {
Term::Var(_) => None,
Term::App(AppTerm(sym, arg_range)) => {
let op = self.exp_map.get(&sym)?;
let [a1, a2] = solution.terms().get_args_fixed(arg_range)?;
let v1 = self.eval_exp(solution, a1)?;
let v2 = self.eval_exp(solution, a2)?;
let ret = match op {
Exp::Add => v1.checked_add(v2)?,
Exp::Sub => v1.checked_sub(v2)?,
Exp::Mul => v1.checked_mul(v2)?,
Exp::Div => v1.checked_div(v2)?,
Exp::Rem => v1.checked_rem(v2)?,
Exp::Pow => v1.checked_pow(v2.try_into().ok()?)?,
};
Some(ret)
}
Term::Int(i) => Some(i),
_ => None,
}
}
fn resolve_is(
&mut self,
args: ArgRange,
context: &mut crate::search::ResolveContext,
) -> Option<Resolved<()>> {
let [left, right] = context.solution().terms().get_args_fixed(args)?;
let right_val = self.eval_exp(context.solution(), right)?;
let (_left_id, left_term) = context.solution().follow_vars(left);
match left_term {
Term::Var(var) => {
let result_term = context.solution_mut().terms_mut().int(right_val);
context
.solution_mut()
.set_var(var, result_term)
.then_some(Resolved::Success)
}
Term::Int(left_val) => (left_val == right_val).then_some(Resolved::Success),
_ => None,
}
}
}
#[derive(Clone)]
enum Exp {
Add,
Sub,
Mul,
Div,
Rem,
Pow,
}
#[derive(Clone)]
enum Pred {
Is,
}
impl Resolver for ArithmeticResolver {
type Choice = ();
fn resolve(
&mut self,
_goal_id: crate::term_arena::TermId,
AppTerm(sym, args): crate::term_arena::AppTerm,
context: &mut crate::search::ResolveContext,
) -> Option<Resolved<Self::Choice>> {
let pred = self.pred_map.get(&sym)?;
match pred {
Pred::Is => self.resolve_is(args, context),
}
}
fn resume(
&mut self,
_choice: &mut Self::Choice,
_goal_id: crate::term_arena::TermId,
_context: &mut crate::search::ResolveContext,
) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use crate::ast::Term;
use crate::query_dfs;
use crate::resolve::ResolverExt;
use crate::search::Solution;
use crate::textual::TextualUniverse;
use super::ArithmeticResolver;
#[test]
fn simple() {
let tu = TextualUniverse::new();
let mut query = tu
.prepare_query("is(X, add(3, mul(3, sub(6, div(10, rem(10, pow(2,3))))))).")
.unwrap();
let resolver = ArithmeticResolver::new(&mut query.symbols_mut());
let mut results = query_dfs(resolver.or_else(tu.resolver()), query.query());
assert_eq!(results.next(), Some(Solution(vec![Some(Term::Int(6))])));
assert!(results.next().is_none());
}
#[test]
fn complex() {
let mut tu = TextualUniverse::new();
let mut arith = ArithmeticResolver::new(&mut tu.symbols);
tu.load_str(
r"
eq(Exp1, Exp2) :- is(X, Exp1), is(X, Exp2), !.
eq(Exp1, Exp2) :- is(Exp1, Exp2), !.
eq(Exp1, Exp2) :- is(Exp2, Exp1), !.
",
)
.unwrap();
{
let query = tu.prepare_query("eq(add(2, 2), pow(2, 2)).").unwrap();
let mut results = query_dfs(arith.by_ref().or_else(tu.resolver()), query.query());
assert_eq!(results.next(), Some(Solution(vec![])));
assert!(results.next().is_none());
}
{
let query = tu.prepare_query("eq(X, pow(2, 2)).").unwrap();
let mut results = query_dfs(arith.by_ref().or_else(tu.resolver()), query.query());
assert_eq!(results.next(), Some(Solution(vec![Some(Term::Int(4))])));
assert!(results.next().is_none());
}
{
let query = tu.prepare_query("eq(add(2, 2), X).").unwrap();
let mut results = query_dfs(arith.by_ref().or_else(tu.resolver()), query.query());
assert_eq!(results.next(), Some(Solution(vec![Some(Term::Int(4))])));
assert!(results.next().is_none());
}
{
let query = tu.prepare_query("eq(2, 2).").unwrap();
let mut results = query_dfs(arith.by_ref().or_else(tu.resolver()), query.query());
assert_eq!(results.next(), Some(Solution(vec![])));
assert!(results.next().is_none());
}
}
}