use itertools::Itertools;
use super::mkmvmap::MKMVMap;
use super::constraints::Constraint;
use crate::{
core::{AnyVal, Fork, Unify, Value, VarId},
LVarList, ReadyState,
};
use std::rc::Rc;
#[derive(Clone)]
pub struct State {
pub(crate) values: im_rc::HashMap<VarId, AnyVal>,
pub(crate) forks: im_rc::Vector<Rc<dyn Fork>>,
constraints: MKMVMap<VarId, Rc<dyn Constraint>>,
}
impl State {
pub fn new() -> Self {
State {
values: im_rc::HashMap::new(),
forks: im_rc::Vector::new(),
constraints: MKMVMap::new(),
}
}
pub fn apply<F>(self, func: F) -> Option<Self>
where
F: Fn(Self) -> Option<Self>,
{
func(self)
}
pub fn resolve<T: Unify>(&self, val: &Value<T>) -> Value<T> {
resolve_any(&self.values, &val.to_anyval())
.to_value()
.expect("AnyVal resolved to unexpected Value<T>")
}
pub fn unify<T: Unify>(mut self, a: &Value<T>, b: &Value<T>) -> Option<Self> {
let a = self.resolve(a);
let b = self.resolve(b);
match (a, b) {
(Value::Resolved(a), Value::Resolved(b)) => Unify::unify(self, a, b),
(Value::Var(a), Value::Var(b)) if a == b => Some(self),
(Value::Var(key), value) | (value, Value::Var(key)) => {
self.values.insert(key.id, value.to_anyval());
if let Some(constraints) = self.constraints.extract(&key.id) {
constraints
.into_iter()
.try_fold(self, |state, func| state.constrain(func))
} else {
Some(self)
}
}
}
}
pub fn constrain(mut self, constraint: Rc<dyn Constraint>) -> Option<Self> {
match constraint.attempt(&self) {
Ok(resolve) => resolve(self),
Err(watch) => {
self.constraints.add(watch.0, constraint);
Some(self)
}
}
}
pub fn fork(mut self, fork: impl Fork) -> Option<Self> {
self.forks.push_back(Rc::new(fork));
Some(self)
}
pub fn vars(&self) -> LVarList {
let vars = self.values.keys();
let watched_ids = self.constraints.keys();
let ids: Vec<_> = vars.chain(watched_ids).unique().copied().collect();
LVarList(ids)
}
pub fn is_ready(&self) -> bool {
self.forks.is_empty() && self.constraints.is_empty()
}
pub fn ready(self) -> Option<ReadyState> {
if self.is_ready() {
Some(ReadyState::new(self.values))
} else {
None
}
}
}
pub(crate) fn resolve_any<'a>(
values: &'a im_rc::HashMap<VarId, AnyVal>,
val: &'a AnyVal,
) -> &'a AnyVal {
match val {
AnyVal::Var(unresolved) => {
let resolved = values.get(unresolved);
match resolved {
Some(AnyVal::Var(found_var)) if found_var == unresolved => val,
Some(found) => resolve_any(values, found),
None => val,
}
}
value @ AnyVal::Resolved(_) => value,
}
}
impl Default for State {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod test {
use crate::{
core::{LVar, Query, StateIter, StateIterator, Value},
goals::{assert_1, Goal},
};
use super::*;
#[test]
fn state_default() {
let state = State::default();
assert!(state.is_ready());
}
#[test]
fn basic_unify() {
let x = Value::var();
let state = State::new();
let state = state.unify(&x, &Value::new(1)).unwrap();
assert_eq!(state.resolve(&x), Value::new(1));
}
#[test]
fn basic_fork() {
let x = LVar::new();
let state: State = State::new();
let forked = state.fork(move |s: &State| -> StateIter {
let s1 = s.clone().unify(&x.into(), &Value::new(1));
let s2 = s.clone().unify(&x.into(), &Value::new(2));
Box::new(s1.into_iter().chain(s2.into_iter()))
});
assert!(forked.clone().unwrap().ready().is_none());
let results = forked
.into_states()
.map(|s| s.resolve(&x.into()))
.collect::<Vec<_>>();
assert_eq!(results, vec![Value::new(1), Value::new(2)]);
}
#[test]
fn basic_apply() {
let x = LVar::new();
let state: State = State::new();
let results: Vec<_> = state
.apply(move |s| s.unify(&x.into(), &1.into()))
.query(x)
.collect();
assert_eq!(results, vec![1]);
}
#[test]
fn unify_vars() {
let x = LVar::new();
let state: State = State::new();
let results = state
.apply(move |s| s.unify(&x.into(), &1.into()))
.unwrap()
.vars();
assert_eq!(results.0, vec![x.id]);
}
#[test]
fn constraint_vars() {
let x: LVar<usize> = LVar::new();
let state: State = State::new();
let results = assert_1(x, |x| *x == 1).apply(state).unwrap().vars();
assert_eq!(results.0, vec![x.id]);
}
}