use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::rc::Rc;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::bindings::{BindingManager, Bsp, FollowerId, VariableState};
use crate::counter::Counter;
use crate::error::{PolarError, PolarResult};
use crate::events::QueryEvent;
use crate::formatting::ToPolarString;
use crate::kb::Bindings;
use crate::partial::simplify_bindings;
use crate::runnable::Runnable;
use crate::terms::{Operation, Operator, Term, Value};
use crate::vm::{Goals, PolarVirtualMachine};
#[derive(Clone)]
pub struct Inverter {
vm: PolarVirtualMachine,
bsp: Bsp,
results: Vec<BindingManager>,
add_constraints: Rc<RefCell<Bindings>>,
follower: Option<FollowerId>,
_debug_id: u64,
}
static ID: AtomicU64 = AtomicU64::new(0);
impl Inverter {
pub fn new(
vm: &PolarVirtualMachine,
goals: Goals,
add_constraints: Rc<RefCell<Bindings>>,
bsp: Bsp,
) -> Self {
let mut vm = vm.clone_with_goals(goals);
vm.inverting = true;
Self {
vm,
bsp,
add_constraints,
results: vec![],
follower: None,
_debug_id: ID.fetch_add(1, Ordering::AcqRel),
}
}
}
fn results_to_constraints(results: Vec<BindingManager>) -> Bindings {
let inverted = results.into_iter().map(invert_partials).collect();
let reduced = reduce_constraints(inverted);
let simplified = simplify_bindings(reduced).unwrap_or_else(Bindings::new);
simplified
.into_iter()
.map(|(k, v)| match v.value() {
Value::Expression(_) => (k, v),
_ => (
k.clone(),
v.clone_with_value(Value::Expression(op!(Unify, term!(k), v.clone()))),
),
})
.collect()
}
fn invert_partials(bindings: BindingManager) -> Bindings {
let mut new_bindings = Bindings::new();
for var in bindings.variables() {
let constraint = bindings.get_constraints(&var);
new_bindings.insert(var.clone(), term!(constraint));
}
let simplified = simplify_bindings(new_bindings).unwrap_or_else(Bindings::new);
simplified
.into_iter()
.map(|(k, v)| match v.value() {
Value::Expression(e) => (k, e.invert().into()),
_ => (
k.clone(),
term!(op!(And, term!(op!(Neq, term!(k), v.clone())))),
),
})
.collect::<Bindings>()
}
fn reduce_constraints(bindings: Vec<Bindings>) -> Bindings {
bindings
.into_iter()
.fold(Bindings::new(), |mut acc, bindings| {
bindings
.into_iter()
.for_each(|(var, value)| match acc.entry(var.clone()) {
Entry::Occupied(mut o) => match (o.get().value(), value.value()) {
(Value::Expression(x), Value::Expression(y)) => {
let x = x.clone().merge_constraints(y.clone());
o.insert(value.clone_with_value(value!(x)));
}
(existing, new) => panic!(
"Illegal state reached while reducing constraints for {}: {} → {}",
var,
existing.to_polar(),
new.to_polar()
),
},
Entry::Vacant(v) => {
v.insert(value);
}
});
acc
})
}
fn filter_inverted_constraints(
constraints: Bindings,
vm: &PolarVirtualMachine,
bsp: Bsp,
) -> Bindings {
constraints
.into_iter()
.filter(|(k, _)| {
!(matches!(
vm.variable_state_at_point(k, &bsp),
VariableState::Unbound | VariableState::Bound(_)
))
})
.collect::<Bindings>()
}
impl Runnable for Inverter {
fn run(&mut self, _: Option<&mut Counter>) -> PolarResult<QueryEvent> {
if self.follower.is_none() {
self.follower = Some(self.vm.add_binding_follower());
}
loop {
match self.vm.run(None)? {
QueryEvent::Done { .. } => {
let result = self.results.is_empty();
if !result {
let constraints =
results_to_constraints(self.results.drain(..).collect::<Vec<_>>());
let mut bsp = Bsp::default();
std::mem::swap(&mut self.bsp, &mut bsp);
let constraints = filter_inverted_constraints(constraints, &self.vm, bsp);
if !constraints.is_empty() {
self.add_constraints.borrow_mut().extend(constraints);
return Ok(QueryEvent::Done { result: true });
}
}
return Ok(QueryEvent::Done { result });
}
QueryEvent::Result { .. } => {
let binding_follower = self
.vm
.remove_binding_follower(&self.follower.unwrap())
.unwrap();
self.results.push(binding_follower);
self.follower = Some(self.vm.add_binding_follower());
}
event => return Ok(event),
}
}
}
fn external_question_result(&mut self, call_id: u64, answer: bool) -> PolarResult<()> {
self.vm.external_question_result(call_id, answer)
}
fn external_call_result(&mut self, call_id: u64, term: Option<Term>) -> PolarResult<()> {
self.vm.external_call_result(call_id, term)
}
fn debug_command(&mut self, command: &str) -> PolarResult<()> {
self.vm.debug_command(command)
}
fn clone_runnable(&self) -> Box<dyn Runnable> {
Box::new(self.clone())
}
fn handle_error(&mut self, error: PolarError) -> PolarResult<QueryEvent> {
self.vm.handle_error(error)
}
}