#[cfg(test)]
mod test;
use crate::{
ast::{self, Query, Var},
term_arena::{self, TermArena, TermId},
};
pub fn query_dfs<R: Resolver>(resolver: R, query: &Query) -> SolutionIter<R> {
let max_var = query.count_var_slots();
let mut solution = SolutionState::new(max_var);
let mut scratch = Vec::new();
SolutionIter {
resolver,
unresolved_goals: query
.goals
.iter()
.rev() .map(|app| solution.terms.insert_ast_term(&mut scratch, app))
.map(|goal| GoalFrame { goal, cut_level: 0 })
.collect(),
checkpoints: vec![],
solution,
}
}
pub trait Resolver {
type Choice: std::fmt::Debug;
fn resolve(
&mut self,
goal_id: term_arena::TermId,
goal_term: term_arena::AppTerm,
context: &mut ResolveContext,
) -> Option<Resolved<Self::Choice>>;
fn resume(
&mut self,
choice: &mut Self::Choice,
goal_id: term_arena::TermId,
context: &mut ResolveContext,
) -> bool;
}
pub struct ResolveContext<'c> {
solution: &'c mut SolutionState,
goal_stack: &'c mut Vec<GoalFrame>,
checkpoint: &'c SolutionCheckpoint,
goal_len: usize,
cut_level: usize,
}
impl<'c> ResolveContext<'c> {
#[inline(always)]
pub fn solution_mut(&mut self) -> &mut SolutionState {
self.solution
}
#[inline(always)]
pub fn solution(&self) -> &SolutionState {
self.solution
}
#[inline(always)]
pub fn push_goal(&mut self, goal: term_arena::TermId) {
self.goal_stack.push(GoalFrame {
goal,
cut_level: self.cut_level,
});
}
#[inline(always)]
pub fn extend_goals(&mut self, new_goals: impl Iterator<Item = term_arena::TermId>) {
let level = self.cut_level;
self.goal_stack.extend(new_goals.map(|goal| GoalFrame {
goal,
cut_level: level,
}));
}
#[inline(always)]
pub fn reset(&mut self) {
self.goal_stack.truncate(self.goal_len);
self.solution.restore(self.checkpoint);
}
}
#[derive(Debug)]
pub enum Resolved<C> {
Success,
SuccessRetry(C),
}
impl<C> Resolved<C> {
pub fn map_choice<C2>(self, f: impl FnOnce(C) -> C2) -> Resolved<C2> {
match self {
Resolved::Success => Resolved::Success,
Resolved::SuccessRetry(c) => Resolved::SuccessRetry(f(c)),
}
}
}
impl<R: Resolver> Resolver for &mut R {
type Choice = R::Choice;
#[inline(always)]
fn resolve(
&mut self,
goal_id: term_arena::TermId,
goal_term: term_arena::AppTerm,
context: &mut ResolveContext,
) -> Option<Resolved<Self::Choice>> {
(*self).resolve(goal_id, goal_term, context)
}
#[inline(always)]
fn resume(
&mut self,
choice: &mut Self::Choice,
goal_id: term_arena::TermId,
context: &mut ResolveContext,
) -> bool {
(*self).resume(choice, goal_id, context)
}
}
#[derive(Debug)]
pub struct SolutionIter<R: Resolver> {
resolver: R,
unresolved_goals: Vec<GoalFrame>,
checkpoints: Vec<Checkpoint<R>>,
solution: SolutionState,
}
#[derive(Debug)]
struct Checkpoint<R: Resolver> {
goal: term_arena::TermId,
choice: Option<R::Choice>,
goals_checkpoint: usize,
solution_checkpoint: SolutionCheckpoint,
cut_level: usize,
}
impl<R: Resolver> Checkpoint<R> {
fn restore_goal_frame(self) -> GoalFrame {
GoalFrame {
goal: self.goal,
cut_level: self.cut_level,
}
}
}
#[derive(Debug)]
struct GoalFrame {
goal: term_arena::TermId,
cut_level: usize,
}
#[derive(PartialEq, Debug, Clone)]
pub struct Solution(pub Vec<Option<ast::Term>>);
impl Solution {
pub fn vars(&self) -> &Vec<Option<ast::Term>> {
&self.0
}
pub fn iter_vars(&self) -> impl Iterator<Item = (Var, Option<&ast::Term>)> {
self.0
.iter()
.enumerate()
.map(|(i, term)| (Var::from_ord(i), term.as_ref()))
}
pub fn get(&self, var: Var) -> Option<&ast::Term> {
self.0[var.ord()].as_ref()
}
}
pub enum Step {
Yield,
Continue,
Done,
}
impl<R: Resolver> SolutionIter<R> {
pub fn step(&mut self) -> Step {
if let Some(goal_frame) = self.unresolved_goals.pop() {
let goal_term = self.solution.terms.get_term(goal_frame.goal);
let solution_checkpoint = self.solution.checkpoint();
let goals_checkpoint = self.unresolved_goals.len();
let mut context = ResolveContext {
solution: &mut self.solution,
goal_stack: &mut self.unresolved_goals,
checkpoint: &solution_checkpoint,
goal_len: goals_checkpoint,
cut_level: self.checkpoints.len(),
};
let resolved = match goal_term {
term_arena::Term::Var(v) => {
if let Some(new_goal) = context.solution.get_var(v) {
context.push_goal(new_goal);
}
Some(Resolved::Success)
}
term_arena::Term::App(app) => {
self.resolver.resolve(goal_frame.goal, app, &mut context)
}
term_arena::Term::Cut => {
for checkpoint in self.checkpoints[goal_frame.cut_level..].iter_mut() {
checkpoint.choice = None;
}
Some(Resolved::Success)
}
_ => {
None
}
};
let choice = match resolved {
None => {
self.unresolved_goals.push(goal_frame);
return self.resume_or_backtrack();
}
Some(Resolved::Success) => None,
Some(Resolved::SuccessRetry(choice)) => Some(choice),
};
self.checkpoints.push(Checkpoint {
goal: goal_frame.goal,
choice,
solution_checkpoint,
goals_checkpoint,
cut_level: goal_frame.cut_level,
});
self.yield_or_continue()
} else {
self.resume_or_backtrack()
}
}
pub fn get_solution(&self) -> Solution {
Solution(self.solution.get_solution())
}
fn resume_checkpoint(&mut self) -> bool {
let checkpoint = self
.checkpoints
.last_mut()
.expect("invariant: there is always a checkpoint when this is called");
let success = match &mut checkpoint.choice {
None => false,
Some(choice) => {
let mut context = ResolveContext {
solution: &mut self.solution,
goal_stack: &mut self.unresolved_goals,
checkpoint: &checkpoint.solution_checkpoint,
goal_len: checkpoint.goals_checkpoint,
cut_level: checkpoint.cut_level,
};
self.resolver.resume(choice, checkpoint.goal, &mut context)
}
};
if success {
true
} else {
let discarded = self.checkpoints.pop().expect("we know there is one here");
self.unresolved_goals.push(discarded.restore_goal_frame());
false
}
}
fn resume_or_backtrack(&mut self) -> Step {
while let Some(checkpoint) = self.checkpoints.last() {
self.solution.restore(&checkpoint.solution_checkpoint);
self.unresolved_goals.truncate(checkpoint.goals_checkpoint);
if self.resume_checkpoint() {
return self.yield_or_continue();
}
}
Step::Done
}
fn yield_or_continue(&self) -> Step {
if self.unresolved_goals.is_empty() {
Step::Yield
} else {
Step::Continue
}
}
}
impl<R: Resolver> Iterator for SolutionIter<R> {
type Item = Solution;
fn next(&mut self) -> Option<Self::Item> {
loop {
match self.step() {
Step::Yield => break Some(self.get_solution()),
Step::Continue => continue,
Step::Done => break None,
}
}
}
}
#[derive(Debug)]
pub struct SolutionState {
variables: Vec<Option<term_arena::TermId>>,
assignments: Vec<Var>,
goal_vars: usize,
terms: TermArena,
occurs_stack: Vec<term_arena::TermId>,
}
#[derive(Debug)]
pub struct SolutionCheckpoint {
operations_checkpoint: usize,
variables_checkpoint: usize,
terms_checkpoint: term_arena::Checkpoint,
}
impl SolutionState {
fn new(goal_vars: usize) -> Self {
Self {
assignments: vec![],
variables: vec![None; goal_vars],
goal_vars,
terms: TermArena::new(),
occurs_stack: Vec::new(),
}
}
#[inline(always)]
pub fn allocate_vars(&mut self, num_vars: usize) -> Var {
let start = self.variables.len();
self.variables.resize(self.variables.len() + num_vars, None);
Var::from_ord(start)
}
#[inline(always)]
pub fn allocate_var(&mut self) -> Var {
self.allocate_vars(1)
}
pub fn set_var(&mut self, var: Var, value: term_arena::TermId) -> bool {
debug_assert!(self.variables[var.ord()].is_none());
if self.occurs(var, value) {
return false;
}
self.variables[var.ord()] = Some(value);
self.assignments.push(var);
true
}
pub fn get_var(&self, mut var: Var) -> Option<TermId> {
while let Some(term) = self.variables[var.ord()] {
match self.terms.get_term(term) {
term_arena::Term::Var(next) => var = next,
_ => return Some(term),
}
}
None
}
fn occurs(&mut self, var: Var, mut term: term_arena::TermId) -> bool {
loop {
match self.terms.get_term(term) {
term_arena::Term::Var(v) => {
if v == var {
self.occurs_stack.clear();
return true;
} else if let Some(value) = self.variables[v.ord()] {
term = value;
continue;
}
}
term_arena::Term::App(term_arena::AppTerm(_, args)) => {
let terms = &self.terms;
self.occurs_stack.extend(terms.get_args(args))
}
term_arena::Term::Int(_) => {}
term_arena::Term::Cut => {}
}
match self.occurs_stack.pop() {
Some(next) => term = next,
None => return false,
}
}
}
pub fn checkpoint(&self) -> SolutionCheckpoint {
SolutionCheckpoint {
operations_checkpoint: self.assignments.len(),
variables_checkpoint: self.variables.len(),
terms_checkpoint: self.terms.checkpoint(),
}
}
pub fn restore(&mut self, checkpoint: &SolutionCheckpoint) {
for var in self.assignments.drain(checkpoint.operations_checkpoint..) {
self.variables[var.ord()] = None;
}
self.variables.truncate(checkpoint.variables_checkpoint);
self.terms.release(&checkpoint.terms_checkpoint);
}
pub fn extract_term(&self, term: term_arena::TermId) -> ast::Term {
match self.terms.get_term(term) {
term_arena::Term::Var(v) => {
if let Some(value) = &self.variables[v.ord()] {
self.extract_term(*value)
} else {
ast::Term::Var(v)
}
}
term_arena::Term::App(app) => ast::Term::App(self.extract_app_term(app)),
term_arena::Term::Int(i) => ast::Term::Int(i),
term_arena::Term::Cut => ast::Term::Cut,
}
}
pub fn extract_app_term(&self, term: term_arena::AppTerm) -> ast::AppTerm {
ast::AppTerm {
functor: term.0,
args: self
.terms
.get_args(term.1)
.map(|arg| self.extract_term(arg))
.collect(),
}
}
fn get_solution(&self) -> Vec<Option<ast::Term>> {
self.variables
.iter()
.take(self.goal_vars)
.map(|val| val.as_ref().map(|t| self.extract_term(*t)))
.collect()
}
pub fn follow_vars(
&self,
mut term: term_arena::TermId,
) -> (term_arena::TermId, term_arena::Term) {
loop {
match self.terms.get_term(term) {
term_arena::Term::Var(var) => {
if let Some(value) = self.variables[var.ord()] {
term = value;
} else {
return (term, term_arena::Term::Var(var));
}
}
other => return (term, other),
}
}
}
pub fn unify(&mut self, goal_term: term_arena::TermId, rule_term: term_arena::TermId) -> bool {
let (goal_term_id, goal_term) = self.follow_vars(goal_term);
let (rule_term_id, rule_term) = self.follow_vars(rule_term);
match (goal_term, rule_term) {
(term_arena::Term::Var(goal_var), term_arena::Term::Var(rule_var)) => {
if goal_var != rule_var {
self.set_var(rule_var, goal_term_id)
} else {
true
}
}
(term_arena::Term::Var(goal_var), _) => self.set_var(goal_var, rule_term_id),
(_, term_arena::Term::Var(rule_var)) => self.set_var(rule_var, goal_term_id),
(term_arena::Term::App(goal_app), term_arena::Term::App(rule_app)) => {
self.unify_app(goal_app, rule_app)
}
(term_arena::Term::Int(goal_int), term_arena::Term::Int(rule_int)) => {
goal_int == rule_int
}
(_, _) => false,
}
}
#[inline(always)]
pub fn unify_app(
&mut self,
goal_term: term_arena::AppTerm,
rule_term: term_arena::AppTerm,
) -> bool {
if goal_term.0 == rule_term.0 {
self.unify_args(goal_term.1, rule_term.1)
} else {
false
}
}
#[inline(always)]
pub fn unify_args(
&mut self,
goal_args: term_arena::ArgRange,
rule_args: term_arena::ArgRange,
) -> bool {
if goal_args.len() != rule_args.len() {
return false;
}
goal_args.zip(rule_args).all(|(goal_arg, rule_arg)| {
self.unify(self.terms.get_arg(goal_arg), self.terms.get_arg(rule_arg))
})
}
#[inline(always)]
pub fn terms(&self) -> &TermArena {
&self.terms
}
#[inline(always)]
pub fn terms_mut(&mut self) -> &mut TermArena {
&mut self.terms
}
}
#[cfg(test)]
mod tests {
use crate::ast::{Term, Var};
use crate::search::{Resolver, Solution, SolutionIter};
use crate::textual::TextualUniverse;
#[test]
fn occurs_check_issue_15() {
let mut tu = TextualUniverse::new();
tu.load_str("refl(f(X), g(X)).").unwrap();
let solver = tu.query_dfs("refl(A, f(A)).").unwrap();
assert_no_solution(solver);
let solver = tu.query_dfs("refl(f(A), A).").unwrap();
assert_no_solution(solver);
#[track_caller]
fn assert_no_solution<R: Resolver + std::fmt::Debug>(mut solver: SolutionIter<R>) {
loop {
match solver.step() {
super::Step::Yield => {
panic!("occurs check should prevent solution: {:#?}", solver)
}
super::Step::Continue => continue,
super::Step::Done => break,
}
}
}
}
#[test]
fn solution_get() {
let solution = Solution(vec![None, Some(Term::Int(0))]);
assert_eq!(solution.get(Var::from_ord(1)), Some(&Term::Int(0)));
assert_eq!(
solution.iter_vars().collect::<Vec<_>>(),
vec![
(Var::from_ord(0), None),
(Var::from_ord(1), Some(&Term::Int(0))),
],
);
}
}