use std::collections::{HashMap, HashSet};
use crate::error::{PolarResult, RuntimeError};
use crate::folder::{fold_list, fold_term, Folder};
use crate::terms::{has_rest_var, Operation, Operator, Symbol, Term, Value};
use crate::vm::Goal;
#[derive(Clone, Debug)]
pub struct Binding(pub Symbol, pub Term);
pub type BindingStack = Vec<Binding>;
pub type Bindings = HashMap<Symbol, Term>;
pub type Bsp = Bsps;
pub type FollowerId = usize;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct Bsps {
bindings_index: usize,
followers: HashMap<FollowerId, Bsps>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum VariableState {
Unbound,
Bound(Term),
Partial,
}
struct Derefer<'a> {
binding_manager: &'a BindingManager,
seen: HashSet<u64>,
}
impl<'a> Derefer<'a> {
fn new(binding_manager: &'a BindingManager) -> Self {
Self {
binding_manager,
seen: HashSet::new(),
}
}
}
impl<'a> Folder for Derefer<'a> {
fn fold_list(&mut self, list: Vec<Term>) -> Vec<Term> {
let has_rest = has_rest_var(&list);
let mut list = fold_list(list, self);
if has_rest {
let last = list.pop().unwrap();
if let Value::List(rest) = last.value() {
list.append(&mut rest.clone());
} else {
list.push(last);
}
}
list
}
fn fold_term(&mut self, t: Term) -> Term {
match t.value() {
Value::Expression(_) => t,
Value::Variable(v) | Value::RestVariable(v) => {
let hash = t.hash_value();
if self.seen.contains(&hash) {
t
} else {
self.seen.insert(hash);
let t = self.binding_manager.lookup(v).unwrap_or(t);
let t = fold_term(t, self);
self.seen.remove(&hash);
t
}
}
_ => fold_term(t, self),
}
}
}
fn cycle_constraints(cycle: Vec<Symbol>) -> Operation {
let mut constraints = op!(And);
for (x, y) in cycle.iter().zip(cycle.iter().skip(1)) {
constraints.add_constraint(op!(Unify, term!(x.clone()), term!(y.clone())));
}
constraints
}
impl From<BindingManagerVariableState<'_>> for VariableState {
fn from(other: BindingManagerVariableState) -> Self {
match other {
BindingManagerVariableState::Unbound => VariableState::Unbound,
BindingManagerVariableState::Bound(b) => VariableState::Bound(b),
BindingManagerVariableState::Cycle(_) => VariableState::Partial,
BindingManagerVariableState::Partial(_) => VariableState::Partial,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
enum BindingManagerVariableState<'a> {
Unbound,
Bound(Term),
Cycle(Vec<Symbol>),
Partial(&'a Operation),
}
#[derive(Clone, Debug, Default)]
pub struct BindingManager {
bindings: BindingStack,
followers: HashMap<FollowerId, BindingManager>,
next_follower_id: FollowerId,
}
impl BindingManager {
pub fn new() -> Self {
Self::default()
}
fn partial_bind(&mut self, partial: Operation, var: &Symbol, val: Term) -> PolarResult<Goal> {
match partial.ground(var, val.clone()) {
None => Err(RuntimeError::IncompatibleBindings {
msg: "Grounding failed A".into(),
}
.into()),
Some(grounded) => {
self.add_binding(var, val);
Ok(Goal::Query {
term: grounded.into(),
})
}
}
}
pub fn bind(&mut self, var: &Symbol, val: Term) -> PolarResult<Option<Goal>> {
let mut goal = None;
if let Ok(symbol) = val.value().as_symbol() {
goal = self.bind_variables(var, symbol)?;
} else {
match self._variable_state(var) {
BindingManagerVariableState::Partial(p) => {
let p = p.clone();
let val = val.clone();
goal = Some(self.partial_bind(p, var, val)?)
}
BindingManagerVariableState::Bound(_) => {
return Err(RuntimeError::IncompatibleBindings {
msg: format!("Cannot rebind {:?}", var),
}
.into())
}
_ => self.add_binding(var, val.clone()),
}
}
self.do_followers(|_, follower| {
follower.bind(var, val.clone())?;
Ok(())
})
.unwrap();
Ok(goal)
}
pub fn unsafe_rebind(&mut self, var: &Symbol, val: Term) {
assert!(matches!(
self._variable_state(var),
BindingManagerVariableState::Unbound | BindingManagerVariableState::Bound(_)
));
self.add_binding(var, val);
}
pub fn add_constraint(&mut self, term: &Term) -> PolarResult<()> {
self.do_followers(|_, follower| follower.add_constraint(term))?;
assert!(term.value().as_expression().is_ok());
let mut op = op!(And, term.clone());
for var in op.variables().iter().rev() {
match self._variable_state(var) {
BindingManagerVariableState::Cycle(c) => {
op = cycle_constraints(c).merge_constraints(op)
}
BindingManagerVariableState::Partial(e) => op = e.clone().merge_constraints(op),
_ => {}
}
}
let vars = op.variables();
let mut varset = vars.iter().collect::<HashSet<_>>();
for var in vars.iter() {
if let BindingManagerVariableState::Bound(val) = self._variable_state(var) {
varset.remove(var);
match op.ground(var, val) {
Some(o) => op = o,
None => {
return Err(RuntimeError::IncompatibleBindings {
msg: "Grounding failed B".into(),
}
.into())
}
}
}
}
for var in varset {
self.add_binding(var, op.clone().into())
}
Ok(())
}
pub fn backtrack(&mut self, to: &Bsp) {
self.do_followers(|follower_id, follower| {
if let Some(follower_to) = to.followers.get(&follower_id) {
follower.backtrack(follower_to);
} else {
follower.backtrack(&Bsp::default());
}
Ok(())
})
.unwrap();
self.bindings.truncate(to.bindings_index)
}
pub fn deep_deref(&self, term: &Term) -> Term {
Derefer::new(self).fold_term(term.clone())
}
pub fn get_constraints(&self, variable: &Symbol) -> Operation {
match self._variable_state(variable) {
BindingManagerVariableState::Unbound => op!(And),
BindingManagerVariableState::Bound(val) => {
op!(And, term!(op!(Unify, term!(variable.clone()), val)))
}
BindingManagerVariableState::Partial(expr) => expr.clone(),
BindingManagerVariableState::Cycle(c) => cycle_constraints(c),
}
}
pub fn variable_state(&self, variable: &Symbol) -> VariableState {
self.variable_state_at_point(variable, &self.bsp())
}
pub fn variable_state_at_point(&self, variable: &Symbol, bsp: &Bsp) -> VariableState {
let index = bsp.bindings_index;
let mut next = variable;
while let Some(value) = self.value(next, index) {
match value.value() {
Value::Expression(_) => return VariableState::Partial,
Value::Variable(v) | Value::RestVariable(v) => {
if v == variable {
return VariableState::Partial;
} else {
next = v;
}
}
_ => return VariableState::Bound(value.clone()),
}
}
VariableState::Unbound
}
pub fn variables(&self) -> HashSet<Symbol> {
self.bindings
.iter()
.map(|Binding(v, _)| v.clone())
.collect()
}
pub fn bsp(&self) -> Bsp {
let follower_bsps = self
.followers
.iter()
.map(|(id, f)| (*id, f.bsp()))
.collect::<HashMap<_, _>>();
Bsps {
bindings_index: self.bindings.len(),
followers: follower_bsps,
}
}
pub fn bindings(&self, include_temps: bool) -> Bindings {
self.bindings_after(include_temps, &Bsp::default())
}
pub fn bindings_after(&self, include_temps: bool, after: &Bsp) -> Bindings {
let mut bindings = HashMap::new();
for Binding(var, value) in &self.bindings[after.bindings_index..] {
if !include_temps && var.is_temporary_var() {
continue;
}
bindings.insert(var.clone(), self.deep_deref(value));
}
bindings
}
pub fn variable_bindings(&self, variables: &HashSet<Symbol>) -> Bindings {
let mut bindings = HashMap::new();
for var in variables.iter() {
let value = self.value(var, self.bsp().bindings_index);
if let Some(value) = value {
bindings.insert(var.clone(), self.deep_deref(value));
}
}
bindings
}
pub fn bindings_debug(&self) -> &BindingStack {
&self.bindings
}
pub fn add_follower(&mut self, follower: BindingManager) -> FollowerId {
let follower_id = self.next_follower_id;
self.followers.insert(follower_id, follower);
self.next_follower_id += 1;
follower_id
}
pub fn remove_follower(&mut self, follower_id: &FollowerId) -> Option<BindingManager> {
self.followers.remove(follower_id)
}
}
impl BindingManager {
fn bind_variables(&mut self, left: &Symbol, right: &Symbol) -> PolarResult<Option<Goal>> {
let mut goal = None;
match (self._variable_state(left), self._variable_state(right)) {
(
BindingManagerVariableState::Bound(left_value),
BindingManagerVariableState::Unbound,
) => {
self.add_binding(right, left_value);
}
(
BindingManagerVariableState::Unbound,
BindingManagerVariableState::Bound(right_value),
) => {
self.add_binding(left, right_value);
}
(BindingManagerVariableState::Unbound, BindingManagerVariableState::Unbound) => {
if left != right {
self.add_binding(left, term!(right.clone()));
self.add_binding(right, term!(left.clone()));
}
}
(BindingManagerVariableState::Cycle(cycle), BindingManagerVariableState::Unbound) => {
let last = cycle.last().unwrap();
assert_ne!(last, left);
self.add_binding(last, term!(right.clone()));
self.add_binding(right, term!(left.clone()));
}
(BindingManagerVariableState::Unbound, BindingManagerVariableState::Cycle(cycle)) => {
let last = cycle.last().unwrap();
assert_ne!(last, right);
self.add_binding(last, term!(left.clone()));
self.add_binding(left, term!(right.clone()));
}
(
BindingManagerVariableState::Cycle(left_cycle),
BindingManagerVariableState::Cycle(right_cycle),
) => {
let iter_left = left_cycle.iter().collect::<HashSet<&Symbol>>();
let iter_right = right_cycle.iter().collect::<HashSet<&Symbol>>();
if iter_left.intersection(&iter_right).next().is_some() {
assert_eq!(iter_left, iter_right);
} else {
let last_left = left_cycle.last().unwrap();
let last_right = right_cycle.last().unwrap();
assert_ne!(last_left, left);
assert_ne!(last_right, right);
self.add_binding(last_left, term!(right.clone()));
self.add_binding(last_right, term!(left.clone()));
}
}
(
BindingManagerVariableState::Cycle(_),
BindingManagerVariableState::Bound(right_value),
) => {
self.add_binding(left, right_value);
}
(
BindingManagerVariableState::Bound(left_value),
BindingManagerVariableState::Cycle(_),
) => {
self.add_binding(right, left_value);
}
(BindingManagerVariableState::Bound(_), BindingManagerVariableState::Bound(_)) => {
return Err(RuntimeError::IncompatibleBindings {
msg: format!("{} and {} are both bound", left, right),
}
.into());
}
(
BindingManagerVariableState::Bound(left_value),
BindingManagerVariableState::Partial(p),
) => {
let p = p.clone();
goal = Some(self.partial_bind(p, right, left_value)?);
}
(
BindingManagerVariableState::Partial(p),
BindingManagerVariableState::Bound(right_value),
) => {
let p = p.clone();
goal = Some(self.partial_bind(p, left, right_value)?);
}
(BindingManagerVariableState::Partial(_), _)
| (_, BindingManagerVariableState::Partial(_)) => {
self.add_constraint(&op!(Unify, term!(left.clone()), term!(right.clone())).into())?;
}
}
Ok(goal)
}
fn add_binding(&mut self, var: &Symbol, val: Term) {
self.bindings.push(Binding(var.clone(), val));
}
fn lookup(&self, var: &Symbol) -> Option<Term> {
match self.variable_state(var) {
VariableState::Bound(val) => Some(val),
_ => None,
}
}
fn value(&self, variable: &Symbol, bsp: usize) -> Option<&Term> {
self.bindings[..bsp]
.iter()
.rev()
.find(|Binding(var, _)| var == variable)
.map(|Binding(_, val)| val)
}
fn _variable_state(&self, variable: &Symbol) -> BindingManagerVariableState {
self._variable_state_at_point(variable, &self.bsp())
}
fn _variable_state_at_point(
&self,
variable: &Symbol,
bsp: &Bsp,
) -> BindingManagerVariableState {
let index = bsp.bindings_index;
let mut path = vec![variable];
while let Some(value) = self.value(path.last().unwrap(), index) {
match value.value() {
Value::Expression(e) => return BindingManagerVariableState::Partial(e),
Value::Variable(v) | Value::RestVariable(v) => {
if v == variable {
return BindingManagerVariableState::Cycle(
path.into_iter().cloned().collect(),
);
} else {
path.push(v);
}
}
_ => return BindingManagerVariableState::Bound(value.clone()),
}
}
BindingManagerVariableState::Unbound
}
fn do_followers<F>(&mut self, mut func: F) -> PolarResult<()>
where
F: FnMut(FollowerId, &mut BindingManager) -> PolarResult<()>,
{
for (id, follower) in self.followers.iter_mut() {
func(*id, follower)?
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::formatting::to_polar::ToPolarString;
#[test]
fn variable_state() {
let mut bindings = BindingManager::new();
let x = sym!("x");
let y = sym!("y");
let z = sym!("z");
assert_eq!(
bindings._variable_state(&x),
BindingManagerVariableState::Unbound
);
bindings.add_binding(&x, term!(1));
assert_eq!(
bindings._variable_state(&x),
BindingManagerVariableState::Bound(term!(1))
);
bindings.add_binding(&x, term!(x.clone()));
assert_eq!(
bindings._variable_state(&x),
BindingManagerVariableState::Cycle(vec![x.clone()])
);
bindings.add_binding(&x, term!(y.clone()));
bindings.add_binding(&y, term!(x.clone()));
assert_eq!(
bindings._variable_state(&x),
BindingManagerVariableState::Cycle(vec![x.clone(), y.clone()])
);
assert_eq!(
bindings._variable_state(&y),
BindingManagerVariableState::Cycle(vec![y.clone(), x.clone()])
);
bindings.add_binding(&x, term!(y.clone()));
bindings.add_binding(&y, term!(z.clone()));
bindings.add_binding(&z, term!(x.clone()));
assert_eq!(
bindings._variable_state(&x),
BindingManagerVariableState::Cycle(vec![x.clone(), y.clone(), z.clone()])
);
assert_eq!(
bindings._variable_state(&y),
BindingManagerVariableState::Cycle(vec![y.clone(), z.clone(), x.clone()])
);
assert_eq!(
bindings._variable_state(&z),
BindingManagerVariableState::Cycle(vec![z.clone(), x.clone(), y])
);
bindings.add_binding(&x, term!(op!(And)));
assert_eq!(
bindings._variable_state(&x),
BindingManagerVariableState::Partial(&op!(And))
);
}
#[test]
fn test_followers() {
let mut b1 = BindingManager::new();
b1.bind(&sym!("x"), term!(1)).unwrap();
b1.bind(&sym!("y"), term!(2)).unwrap();
assert_eq!(
b1._variable_state(&sym!("x")),
BindingManagerVariableState::Bound(term!(1))
);
assert_eq!(
b1._variable_state(&sym!("y")),
BindingManagerVariableState::Bound(term!(2))
);
let b2 = BindingManager::new();
let b2_id = b1.add_follower(b2);
b1.bind(&sym!("z"), term!(3)).unwrap();
assert_eq!(
b1._variable_state(&sym!("x")),
BindingManagerVariableState::Bound(term!(1))
);
assert_eq!(
b1._variable_state(&sym!("y")),
BindingManagerVariableState::Bound(term!(2))
);
assert_eq!(
b1._variable_state(&sym!("z")),
BindingManagerVariableState::Bound(term!(3))
);
let b2 = b1.remove_follower(&b2_id).unwrap();
assert_eq!(
b2._variable_state(&sym!("x")),
BindingManagerVariableState::Unbound
);
assert_eq!(
b2._variable_state(&sym!("y")),
BindingManagerVariableState::Unbound
);
assert_eq!(
b2._variable_state(&sym!("z")),
BindingManagerVariableState::Bound(term!(3))
);
let mut b1 = BindingManager::new();
b1.bind(&sym!("x"), term!(sym!("y"))).unwrap();
b1.bind(&sym!("x"), term!(sym!("z"))).unwrap();
let b2 = BindingManager::new();
let b2_id = b1.add_follower(b2);
assert!(matches!(
b1._variable_state(&sym!("x")),
BindingManagerVariableState::Cycle(_)
));
assert!(matches!(
b1._variable_state(&sym!("y")),
BindingManagerVariableState::Cycle(_)
));
assert!(matches!(
b1._variable_state(&sym!("z")),
BindingManagerVariableState::Cycle(_)
));
b1.bind(&sym!("x"), term!(sym!("a"))).unwrap();
if let BindingManagerVariableState::Cycle(c) = b1._variable_state(&sym!("a")) {
assert_eq!(
c,
vec![sym!("a"), sym!("x"), sym!("y"), sym!("z")],
"c was {:?}",
c
);
}
let b2 = b1.remove_follower(&b2_id).unwrap();
if let BindingManagerVariableState::Cycle(c) = b2._variable_state(&sym!("a")) {
assert_eq!(c, vec![sym!("a"), sym!("x")], "c was {:?}", c);
} else {
panic!("unexpected");
}
if let BindingManagerVariableState::Cycle(c) = b2._variable_state(&sym!("x")) {
assert_eq!(c, vec![sym!("x"), sym!("a")], "c was {:?}", c);
} else {
panic!("unexpected");
}
let mut b1 = BindingManager::new();
b1.bind(&sym!("x"), term!(sym!("y"))).unwrap();
b1.bind(&sym!("x"), term!(sym!("z"))).unwrap();
let b2 = BindingManager::new();
let b2_id = b1.add_follower(b2);
assert!(matches!(
b1._variable_state(&sym!("x")),
BindingManagerVariableState::Cycle(_)
));
assert!(matches!(
b1._variable_state(&sym!("y")),
BindingManagerVariableState::Cycle(_)
));
assert!(matches!(
b1._variable_state(&sym!("z")),
BindingManagerVariableState::Cycle(_)
));
b1.add_constraint(&term!(op!(Gt, term!(sym!("x")), term!(sym!("y")))))
.unwrap();
let b2 = b1.remove_follower(&b2_id).unwrap();
if let BindingManagerVariableState::Partial(p) = b1._variable_state(&sym!("x")) {
assert_eq!(p.to_polar(), "x = y and y = z and z = x and x > y");
} else {
panic!("unexpected");
}
if let BindingManagerVariableState::Partial(p) = b2._variable_state(&sym!("x")) {
assert_eq!(p.to_polar(), "x > y");
} else {
panic!("unexpected");
}
}
#[test]
fn old_deref() {
let mut bm = BindingManager::default();
let value = term!(1);
let x = sym!("x");
let y = sym!("y");
let term_x = term!(x.clone());
let term_y = term!(y.clone());
assert_eq!(bm.deep_deref(&term_x), term_x);
bm.bind(&x, term_y.clone()).unwrap();
assert_eq!(bm.deep_deref(&term_x), term_x);
assert_eq!(bm.deep_deref(&value), value.clone());
let mut bm = BindingManager::default();
bm.bind(&x, value.clone()).unwrap();
assert_eq!(bm.deep_deref(&term_x), value);
let mut bm = BindingManager::default();
bm.bind(&x, term_y).unwrap();
bm.bind(&y, value.clone()).unwrap();
assert_eq!(bm.deep_deref(&term_x), value);
}
#[test]
fn deep_deref() {
let mut bm = BindingManager::default();
let one = term!(1);
let two = term!(1);
let one_var = sym!("one");
let two_var = sym!("two");
bm.bind(&one_var, one.clone()).unwrap();
bm.bind(&two_var, two.clone()).unwrap();
let dict = btreemap! {
sym!("x") => term!(one_var),
sym!("y") => term!(two_var),
};
let list = term!([dict]);
assert_eq!(
bm.deep_deref(&list).value().clone(),
Value::List(vec![term!(btreemap! {
sym!("x") => one,
sym!("y") => two,
})])
);
}
#[test]
fn bind() {
let x = sym!("x");
let y = sym!("y");
let zero = term!(0);
let mut bm = BindingManager::default();
bm.bind(&x, zero.clone()).unwrap();
assert_eq!(bm.variable_state(&x), VariableState::Bound(zero));
assert_eq!(bm.variable_state(&y), VariableState::Unbound);
}
#[test]
fn test_backtrack_followers() {
let mut b1 = BindingManager::new();
b1.bind(&sym!("x"), term!(sym!("y"))).unwrap();
b1.bind(&sym!("z"), term!(sym!("x"))).unwrap();
let b2 = BindingManager::new();
let b2_id = b1.add_follower(b2);
b1.add_constraint(&term!(op!(Gt, term!(sym!("x")), term!(1))))
.unwrap();
let bsp = b1.bsp();
b1.bind(&sym!("a"), term!(sym!("x"))).unwrap();
assert!(matches!(
b1.variable_state(&sym!("a")),
VariableState::Partial
));
b1.backtrack(&bsp);
let b2 = b1.remove_follower(&b2_id).unwrap();
assert!(matches!(
b2.variable_state(&sym!("a")),
VariableState::Unbound
));
}
}