use crate::expr::TLExpr;
use crate::term::Term;
use crate::unification::{unify_term_list, Substitution};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Literal {
pub atom: TLExpr,
pub polarity: bool,
}
impl Literal {
pub fn positive(atom: TLExpr) -> Self {
Literal {
atom,
polarity: true,
}
}
pub fn negative(atom: TLExpr) -> Self {
Literal {
atom,
polarity: false,
}
}
pub fn negate(&self) -> Self {
Literal {
atom: self.atom.clone(),
polarity: !self.polarity,
}
}
pub fn is_complementary(&self, other: &Literal) -> bool {
self.atom == other.atom && self.polarity != other.polarity
}
pub fn try_unify(&self, other: &Literal) -> Option<Substitution> {
if self.polarity == other.polarity {
return None;
}
self.try_unify_atoms(&other.atom)
}
fn try_unify_atoms(&self, other_atom: &TLExpr) -> Option<Substitution> {
match (&self.atom, other_atom) {
(
TLExpr::Pred {
name: n1,
args: args1,
},
TLExpr::Pred {
name: n2,
args: args2,
},
) => {
if n1 != n2 {
return None;
}
if args1.len() != args2.len() {
return None;
}
let pairs: Vec<(Term, Term)> = args1
.iter()
.zip(args2.iter())
.map(|(t1, t2)| (t1.clone(), t2.clone()))
.collect();
unify_term_list(&pairs).ok()
}
_ => None,
}
}
pub fn apply_substitution(&self, subst: &Substitution) -> Literal {
let new_atom = self.apply_subst_to_expr(&self.atom, subst);
Literal {
atom: new_atom,
polarity: self.polarity,
}
}
fn apply_subst_to_expr(&self, expr: &TLExpr, subst: &Substitution) -> TLExpr {
match expr {
TLExpr::Pred { name, args } => {
let new_args = args.iter().map(|term| subst.apply(term)).collect();
TLExpr::Pred {
name: name.clone(),
args: new_args,
}
}
_ => expr.clone(),
}
}
pub fn is_positive(&self) -> bool {
self.polarity
}
pub fn is_negative(&self) -> bool {
!self.polarity
}
pub fn free_vars(&self) -> HashSet<String> {
self.atom.free_vars()
}
pub(super) fn try_one_way_match(
&self,
other_atom: &TLExpr,
allowed_vars: &HashSet<String>,
) -> Option<Substitution> {
match (&self.atom, other_atom) {
(
TLExpr::Pred {
name: n1,
args: args1,
},
TLExpr::Pred {
name: n2,
args: args2,
},
) => {
if n1 != n2 {
return None;
}
if args1.len() != args2.len() {
return None;
}
let mut subst = Substitution::empty();
for (t1, t2) in args1.iter().zip(args2.iter()) {
if !try_one_way_match_terms(t1, t2, allowed_vars, &mut subst) {
return None;
}
}
Some(subst)
}
_ => None,
}
}
}
pub(super) fn try_one_way_match_terms(
t1: &Term,
t2: &Term,
allowed_vars: &HashSet<String>,
subst: &mut Substitution,
) -> bool {
let t1_subst = subst.apply(t1);
match (&t1_subst, t2) {
(Term::Const(c1), Term::Const(c2)) => c1 == c2,
(Term::Var(v1), Term::Var(v2)) => v1 == v2,
(Term::Var(v1), _) if allowed_vars.contains(v1) => {
let after_subst = subst.apply(&t1_subst);
if after_subst != t1_subst {
&after_subst == t2
} else {
subst.bind(v1.clone(), t2.clone());
true
}
}
(Term::Var(_), _) => false,
(_, Term::Var(_)) => false,
(
Term::Typed {
value: inner1,
type_annotation: ty1,
},
Term::Typed {
value: inner2,
type_annotation: ty2,
},
) => {
if ty1 != ty2 {
return false;
}
try_one_way_match_terms(inner1, inner2, allowed_vars, subst)
}
(Term::Typed { value, .. }, other) | (other, Term::Typed { value, .. }) => {
try_one_way_match_terms(value, other, allowed_vars, subst)
}
}
}