use std::{
collections::{HashMap, HashSet},
fmt,
};
use crate::{
bindings::Bindings,
error::RuntimeError,
filter::singleton,
folder::{fold_term, Folder},
terms::*,
};
use super::partial::{invert_operation, FALSE, TRUE};
type Result<T> = core::result::Result<T, RuntimeError>;
const TRACK_PERF: bool = false;
const SIMPLIFY_DEBUG: bool = false;
macro_rules! if_debug {
($($e:tt)*) => {
if SIMPLIFY_DEBUG {
$($e)*
}
}
}
macro_rules! simplify_debug {
($($e:tt)*) => {
if_debug!(eprintln!($($e)*))
}
}
enum MaybeDrop {
Keep,
Drop,
Bind(Symbol, Term),
Check(Symbol, Term),
}
struct VariableSubber {
this_var: Symbol,
}
impl VariableSubber {
pub fn new(this_var: Symbol) -> Self {
Self { this_var }
}
}
impl Folder for VariableSubber {
fn fold_variable(&mut self, v: Symbol) -> Symbol {
if v == self.this_var {
sym!("_this")
} else {
v
}
}
fn fold_rest_variable(&mut self, v: Symbol) -> Symbol {
if v == self.this_var {
sym!("_this")
} else {
v
}
}
}
pub fn sub_this(this: Symbol, term: Term) -> Term {
if term.as_symbol().map(|s| s == &this).unwrap_or(false) {
return term;
}
fold_term(term, &mut VariableSubber::new(this))
}
fn simplify_trivial_constraint(this: Symbol, term: Term) -> Term {
use {Operator::*, Value::*};
match term.value() {
Expression(o) if o.operator == Unify => {
let left = &o.args[0];
let right = &o.args[1];
match (left.value(), right.value()) {
(Variable(v), Variable(w))
| (Variable(v), RestVariable(w))
| (RestVariable(v), Variable(w))
| (RestVariable(v), RestVariable(w))
if v == w =>
{
TRUE.into()
}
(Variable(l), _) | (RestVariable(l), _) if l == &this && right.is_ground() => {
right.clone()
}
(_, Variable(r)) | (_, RestVariable(r)) if r == &this && left.is_ground() => {
left.clone()
}
_ => term,
}
}
_ => term,
}
}
pub fn simplify_partial(
var: &Symbol,
mut term: Term,
output_vars: HashSet<Symbol>,
track_performance: bool,
) -> (Term, Option<PerfCounters>) {
let mut simplifier = Simplifier::new(output_vars, track_performance);
simplify_debug!("*** simplify partial {:?}", var);
simplifier.simplify_partial(&mut term);
term = simplify_trivial_constraint(var.clone(), term);
simplify_debug!("simplify partial done {:?}, {}", var, term);
if matches!(term.value(), Value::Expression(e) if e.operator != Operator::And) {
(op!(And, term).into(), simplifier.perf_counters())
} else {
(term, simplifier.perf_counters())
}
}
pub fn simplify_bindings(bindings: Bindings) -> Option<Bindings> {
simplify_bindings_opt(bindings, true)
.expect("unexpected error thrown by the simplifier when simplifying all bindings")
}
pub fn simplify_bindings_opt(bindings: Bindings, all: bool) -> Result<Option<Bindings>> {
let mut perf = PerfCounters::new(TRACK_PERF);
simplify_debug!("simplify bindings");
if_debug! {
eprintln!("before simplified");
for (k, v) in bindings.iter() {
eprintln!("{:?} {}", k, v);
}
}
let mut unsatisfiable = false;
let mut simplify_var = |bindings: &Bindings, var: &Symbol, value: &Term| match value.value() {
Value::Expression(o) => {
assert_eq!(o.operator, Operator::And);
let output_vars = if all {
singleton(var.clone())
} else {
bindings
.keys()
.filter(|v| !v.is_temporary_var())
.cloned()
.collect::<HashSet<_>>()
};
let (simplified, p) = simplify_partial(var, value.clone(), output_vars, TRACK_PERF);
if let Some(p) = p {
perf.merge(p);
}
match simplified.as_expression() {
Ok(o) if o == &FALSE => unsatisfiable = true,
_ => (),
}
simplified
}
Value::Variable(v) | Value::RestVariable(v)
if v.is_temporary_var()
&& bindings.contains_key(v)
&& matches!(
bindings[v].value(),
Value::Variable(_) | Value::RestVariable(_)
) =>
{
bindings[v].clone()
}
_ => value.clone(),
};
simplify_debug!("simplify bindings {}", if all { "all" } else { "not all" });
let mut simplified_bindings = HashMap::new();
for (var, value) in &bindings {
if !var.is_temporary_var() || all {
let simplified = simplify_var(&bindings, var, value);
simplified_bindings.insert(var.clone(), simplified);
} else if let Value::Expression(e) = value.value() {
if e.variables().iter().all(|v| v.is_temporary_var()) {
return Err(RuntimeError::UnhandledPartial {
var: var.clone(),
term: value.clone(),
});
}
}
}
if unsatisfiable {
Ok(None)
} else {
if_debug! {
eprintln!("after simplified");
for (k, v) in simplified_bindings.iter() {
eprintln!("{:?} {}", k, v);
}
}
Ok(Some(simplified_bindings))
}
}
#[derive(Clone, Default)]
pub struct PerfCounters {
enabled: bool,
simplify_term: HashMap<Term, u64>,
preprocess_and: HashMap<Term, u64>,
acc_simplify_term: u64,
acc_preprocess_and: u64,
}
impl fmt::Display for PerfCounters {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "perf {{")?;
writeln!(f, "simplify term")?;
for (term, ncalls) in self.simplify_term.iter() {
writeln!(f, "\t{}: {}", term, ncalls)?;
}
writeln!(f, "preprocess and")?;
for (term, ncalls) in self.preprocess_and.iter() {
writeln!(f, "\t{}: {}", term, ncalls)?;
}
writeln!(f, "}}")
}
}
impl PerfCounters {
fn new(enabled: bool) -> Self {
Self {
enabled,
..Default::default()
}
}
fn preprocess_and(&mut self) {
if !self.enabled {
return;
}
self.acc_preprocess_and += 1;
}
fn simplify_term(&mut self) {
if !self.enabled {
return;
}
self.acc_simplify_term += 1;
}
fn finish_acc(&mut self, term: Term) {
if !self.enabled {
return;
}
self.simplify_term
.insert(term.clone(), self.acc_simplify_term);
self.preprocess_and.insert(term, self.acc_preprocess_and);
self.acc_preprocess_and = 0;
self.acc_simplify_term = 0;
}
fn merge(&mut self, other: PerfCounters) {
if !self.enabled {
return;
}
self.simplify_term.extend(other.simplify_term.into_iter());
self.preprocess_and.extend(other.preprocess_and.into_iter());
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
}
#[derive(Clone)]
pub struct Simplifier {
bindings: Bindings,
output_vars: HashSet<Symbol>,
seen: HashSet<Term>,
counters: PerfCounters,
}
type TermSimplifier = dyn Fn(&mut Simplifier, &mut Term);
impl Simplifier {
pub fn new(output_vars: HashSet<Symbol>, track_performance: bool) -> Self {
Self {
bindings: Bindings::new(),
output_vars,
seen: HashSet::new(),
counters: PerfCounters::new(track_performance),
}
}
fn perf_counters(&mut self) -> Option<PerfCounters> {
if !self.counters.is_enabled() {
return None;
}
let mut counter = PerfCounters::new(true);
std::mem::swap(&mut self.counters, &mut counter);
Some(counter)
}
pub fn bind(&mut self, var: Symbol, value: Term) {
if !self.is_bound(&var) {
self.bindings.insert(var, self.deref(&value));
}
}
pub fn deref(&self, term: &Term) -> Term {
match term.value() {
Value::Variable(var) | Value::RestVariable(var) => {
self.bindings.get(var).unwrap_or(term).clone()
}
_ => term.clone(),
}
}
fn is_bound(&self, var: &Symbol) -> bool {
self.bindings.contains_key(var)
}
fn is_output(&self, t: &Term) -> bool {
match t.value() {
Value::Variable(v) | Value::RestVariable(v) => self.output_vars.contains(v),
_ => false,
}
}
fn maybe_bind_constraint(&mut self, constraint: &Operation) -> MaybeDrop {
match constraint.operator {
Operator::And if constraint.args.is_empty() => MaybeDrop::Drop,
Operator::Unify | Operator::Eq => {
let left = &constraint.args[0];
let right = &constraint.args[1];
if left == right {
MaybeDrop::Drop
} else {
match (left.value(), right.value()) {
(Value::Variable(_), Value::Variable(_))
if self.is_output(left) && self.is_output(right) =>
{
MaybeDrop::Keep
}
(Value::Variable(l), _) if !self.is_bound(l) && !self.is_output(left) => {
simplify_debug!("*** 1");
MaybeDrop::Bind(l.clone(), right.clone())
}
(_, Value::Variable(r)) if !self.is_bound(r) && !self.is_output(right) => {
simplify_debug!("*** 2");
MaybeDrop::Bind(r.clone(), left.clone())
}
(Value::Variable(var), val) if val.is_ground() && !self.is_bound(var) => {
simplify_debug!("*** 3");
MaybeDrop::Check(var.clone(), right.clone())
}
(val, Value::Variable(var)) if val.is_ground() && !self.is_bound(var) => {
simplify_debug!("*** 4");
MaybeDrop::Check(var.clone(), left.clone())
}
_ => MaybeDrop::Keep,
}
}
}
_ => MaybeDrop::Keep,
}
}
pub fn simplify_operation_variables(
&mut self,
o: &mut Operation,
simplify_term: &TermSimplifier,
) {
fn toss_trivial_unifies(args: &mut TermList) {
args.retain(|c| {
let o = c.as_expression().unwrap();
match o.operator {
Operator::Unify | Operator::Eq => {
assert_eq!(o.args.len(), 2);
let left = &o.args[0];
let right = &o.args[1];
left != right
}
_ => true,
}
});
}
if o.operator == Operator::And || o.operator == Operator::Or {
toss_trivial_unifies(&mut o.args);
}
match o.operator {
Operator::And | Operator::Or if o.args.is_empty() => (),
Operator::And | Operator::Or if o.args.len() == 1 => {
if let Value::Expression(operation) = o.args[0].value() {
*o = operation.clone();
self.simplify_operation_variables(o, simplify_term);
}
}
Operator::And if o.args.len() > 1 => {
let mut keep = o.args.iter().map(|_| true).collect::<Vec<bool>>();
let mut references = o.args.iter().map(|_| false).collect::<Vec<bool>>();
for (i, arg) in o.args.iter().enumerate() {
match self.maybe_bind_constraint(arg.as_expression().unwrap()) {
MaybeDrop::Keep => (),
MaybeDrop::Drop => keep[i] = false,
MaybeDrop::Bind(var, value) => {
keep[i] = false;
simplify_debug!("bind {:?}, {}", var, value);
self.bind(var, value);
}
MaybeDrop::Check(var, value) => {
simplify_debug!("check {}, {}", var, value);
for (j, arg) in o.args.iter().enumerate() {
if j != i && arg.contains_variable(&var) {
simplify_debug!("check bind {}, {} ref: {}", var, value, j);
self.bind(var, value);
keep[i] = false;
references[j] = true;
break;
}
}
}
}
}
let mut i = 0;
o.args.retain(|_| {
i += 1;
keep[i - 1] || references[i - 1]
});
for arg in &mut o.args {
simplify_term(self, arg);
}
}
Operator::Not => {
assert_eq!(o.args.len(), 1);
let mut simplified = o.args[0].clone();
let mut simplifier = self.clone();
simplifier.simplify_partial(&mut simplified);
*o = invert_operation(
simplified
.as_expression()
.expect("a simplified expression")
.clone(),
)
}
_ => {
for arg in &mut o.args {
simplify_term(self, arg);
}
}
}
}
pub fn deduplicate_operation(&mut self, o: &mut Operation, simplify_term: &TermSimplifier) {
fn preprocess_and(args: &mut TermList) {
let mut seen: HashSet<u64> = HashSet::with_capacity(args.len());
args.retain(|a| {
let o = a.as_expression().unwrap();
o != &TRUE && !seen.contains(&Term::from(o.mirror()).hash_value()) && seen.insert(a.hash_value()) });
}
if o.operator == Operator::And {
self.counters.preprocess_and();
preprocess_and(&mut o.args);
}
match o.operator {
Operator::And | Operator::Or if o.args.is_empty() => (),
Operator::And | Operator::Or if o.args.len() == 1 => {
if let Value::Expression(operation) = o.args[0].value() {
*o = operation.clone();
self.deduplicate_operation(o, simplify_term);
}
}
_ => {
for arg in &mut o.args {
simplify_term(self, arg);
}
}
}
}
pub fn simplify_term<F>(&mut self, term: &mut Term, simplify_operation: F)
where
F: Fn(&mut Self, &mut Operation, &TermSimplifier) + 'static + Clone,
{
if self.seen.contains(term) {
return;
}
let orig = term.clone();
self.seen.insert(term.clone());
let de = self.deref(term);
*term = de;
match term.mut_value() {
Value::Dictionary(dict) => {
for (_, v) in dict.fields.iter_mut() {
self.simplify_term(v, simplify_operation.clone());
}
}
Value::Call(call) => {
for arg in call.args.iter_mut() {
self.simplify_term(arg, simplify_operation.clone());
}
if let Some(kwargs) = &mut call.kwargs {
for (_, v) in kwargs.iter_mut() {
self.simplify_term(v, simplify_operation.clone());
}
}
}
Value::List(list) => {
for elem in list.iter_mut() {
self.simplify_term(elem, simplify_operation.clone());
}
}
Value::Expression(operation) => {
let so = simplify_operation.clone();
let cont = move |s: &mut Self, term: &mut Term| {
s.simplify_term(term, simplify_operation.clone())
};
so(self, operation, &cont);
}
_ => (),
}
if let Ok(sym) = orig.as_symbol() {
if term.contains_variable(sym) {
*term = orig.clone()
}
}
self.seen.remove(&orig);
}
pub fn simplify_partial(&mut self, term: &mut Term) {
let mut last = term.hash_value();
let mut nbindings = self.bindings.len();
loop {
simplify_debug!("simplify loop {}", term);
self.counters.simplify_term();
self.simplify_term(term, Simplifier::simplify_operation_variables);
let now = term.hash_value();
if last == now && self.bindings.len() == nbindings {
break;
}
last = now;
nbindings = self.bindings.len();
}
self.simplify_term(term, Simplifier::deduplicate_operation);
self.counters.finish_acc(term.clone());
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[allow(clippy::bool_assert_comparison)]
fn test_debug_off() {
assert_eq!(SIMPLIFY_DEBUG, false);
assert_eq!(TRACK_PERF, false);
}
#[test]
fn test_simplify_circular_dot_with_isa() {
let op = term!(op!(Dot, var!("x"), str!("x")));
let op = term!(op!(Unify, var!("x"), op));
let op = term!(op!(
And,
op,
term!(op!(Isa, var!("x"), term!(pattern!(instance!("X")))))
));
let mut vs: HashSet<Symbol> = HashSet::new();
vs.insert(sym!("x"));
let (x, _) = simplify_partial(&sym!("x"), op, vs, false);
assert_eq!(
x,
term!(op!(
And,
term!(op!(Unify, var!("x"), term!(op!(Dot, var!("x"), str!("x"))))),
term!(op!(Isa, var!("x"), term!(pattern!(instance!("X")))))
))
);
}
}