use std::collections::{HashMap, HashSet};
use crate::ast::{ArithExpr, Clause, Goal, Term};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VarOccurrence {
First,
Subsequent,
Void,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VarClass {
Permanent(u8),
Temporary(u8),
}
#[derive(Debug)]
pub struct ClauseAnalysis {
pub var_classes: HashMap<String, VarClass>,
pub permanent_count: u8,
pub temporary_count: u8,
pub head_vars: Vec<String>,
pub goal_vars: Vec<HashSet<String>>,
}
impl ClauseAnalysis {
pub fn get_class(&self, name: &str) -> Option<VarClass> {
self.var_classes.get(name).copied()
}
pub fn is_first_occurrence(&self, name: &str, seen: &HashSet<String>) -> bool {
!seen.contains(name)
}
}
pub fn analyze_clause(clause: &Clause) -> ClauseAnalysis {
let mut all_vars: HashMap<String, Vec<usize>> = HashMap::new();
let mut head_vars = Vec::new();
let mut goal_vars = Vec::new();
collect_term_vars(&clause.head, &mut head_vars);
for var in &head_vars {
all_vars.entry(var.clone()).or_default().push(0);
}
for (i, goal) in clause.body.iter().enumerate() {
let mut vars = HashSet::new();
collect_goal_vars(goal, &mut vars);
for var in &vars {
all_vars.entry(var.clone()).or_default().push(i + 1);
}
goal_vars.push(vars);
}
let mut var_classes = HashMap::new();
let mut permanent_count = 0u8;
let mut temporary_count = 0u8;
for (name, occurrences) in &all_vars {
let unique_goals: HashSet<_> = occurrences.iter().collect();
if unique_goals.len() > 1 {
var_classes.insert(name.clone(), VarClass::Permanent(permanent_count));
permanent_count += 1;
} else {
var_classes.insert(name.clone(), VarClass::Temporary(temporary_count));
temporary_count += 1;
}
}
ClauseAnalysis {
var_classes,
permanent_count,
temporary_count,
head_vars,
goal_vars,
}
}
fn collect_term_vars(term: &Term, vars: &mut Vec<String>) {
match term {
Term::Variable(name) => {
if !vars.contains(name) {
vars.push(name.clone());
}
}
Term::Compound { args, .. } => {
for arg in args {
collect_term_vars(arg, vars);
}
}
Term::Cons(head, tail) => {
collect_term_vars(head, vars);
collect_term_vars(tail, vars);
}
_ => {}
}
}
fn collect_goal_vars(goal: &Goal, vars: &mut HashSet<String>) {
match goal {
Goal::Call(term) => {
let mut v = Vec::new();
collect_term_vars(term, &mut v);
vars.extend(v);
}
Goal::Is(name, expr) => {
vars.insert(name.clone());
collect_arith_vars(expr, vars);
}
Goal::Compare(_, left, right) => {
collect_arith_vars(left, vars);
collect_arith_vars(right, vars);
}
Goal::Unify(t1, t2) | Goal::NotUnify(t1, t2) => {
let mut v = Vec::new();
collect_term_vars(t1, &mut v);
collect_term_vars(t2, &mut v);
vars.extend(v);
}
Goal::Write(term) => {
let mut v = Vec::new();
collect_term_vars(term, &mut v);
vars.extend(v);
}
Goal::Cut | Goal::Nl | Goal::Fail | Goal::True => {}
}
}
fn collect_arith_vars(expr: &ArithExpr, vars: &mut HashSet<String>) {
match expr {
ArithExpr::Variable(name) => {
vars.insert(name.clone());
}
ArithExpr::BinOp(_, left, right) => {
collect_arith_vars(left, vars);
collect_arith_vars(right, vars);
}
ArithExpr::Neg(inner) => {
collect_arith_vars(inner, vars);
}
ArithExpr::Integer(_) => {}
}
}
pub fn is_void_var(name: &str) -> bool {
name == "_" || name.starts_with('_')
}
pub fn vars_to_save(
analysis: &ClauseAnalysis,
goal_index: usize,
seen: &HashSet<String>,
) -> HashSet<String> {
let mut to_save = HashSet::new();
for (name, class) in &analysis.var_classes {
if let VarClass::Permanent(_) = class {
if seen.contains(name) {
for (i, goal_vars) in analysis.goal_vars.iter().enumerate() {
if i > goal_index && goal_vars.contains(name) {
to_save.insert(name.clone());
break;
}
}
}
}
}
to_save
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parse;
#[test]
fn test_simple_fact_analysis() {
let program = parse("parent(tom, bob).").unwrap();
let analysis = analyze_clause(&program.clauses[0]);
assert_eq!(analysis.permanent_count, 0);
assert_eq!(analysis.temporary_count, 0);
}
#[test]
fn test_rule_with_shared_vars() {
let program = parse("grandparent(X, Z) :- parent(X, Y), parent(Y, Z).").unwrap();
let analysis = analyze_clause(&program.clauses[0]);
assert_eq!(analysis.permanent_count, 3);
assert!(matches!(analysis.get_class("X"), Some(VarClass::Permanent(_))));
assert!(matches!(analysis.get_class("Y"), Some(VarClass::Permanent(_))));
assert!(matches!(analysis.get_class("Z"), Some(VarClass::Permanent(_))));
}
#[test]
fn test_temporary_var() {
let program = parse("foo(X) :- bar(X, Y), baz(Y).").unwrap();
let analysis = analyze_clause(&program.clauses[0]);
assert!(matches!(analysis.get_class("X"), Some(VarClass::Permanent(_))));
assert!(matches!(analysis.get_class("Y"), Some(VarClass::Permanent(_))));
}
#[test]
fn test_arithmetic_vars() {
let program = parse("double(X, Y) :- Y is X * 2.").unwrap();
let analysis = analyze_clause(&program.clauses[0]);
assert!(matches!(analysis.get_class("X"), Some(VarClass::Permanent(_))));
assert!(matches!(analysis.get_class("Y"), Some(VarClass::Permanent(_))));
}
#[test]
fn test_void_detection() {
assert!(is_void_var("_"));
assert!(is_void_var("_X"));
assert!(is_void_var("_Unused"));
assert!(!is_void_var("X"));
assert!(!is_void_var("Var"));
}
}