use crate::expr::TLExpr;
use crate::term::Term;
use crate::unification::Substitution;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
fn substitute_in_expr(expr: &TLExpr, subst: &Substitution) -> TLExpr {
match expr {
TLExpr::Pred { name, args } => {
let new_args = args.iter().map(|term| subst.apply(term)).collect();
TLExpr::Pred {
name: name.clone(),
args: new_args,
}
}
TLExpr::And(left, right) => TLExpr::And(
Box::new(substitute_in_expr(left, subst)),
Box::new(substitute_in_expr(right, subst)),
),
TLExpr::Or(left, right) => TLExpr::Or(
Box::new(substitute_in_expr(left, subst)),
Box::new(substitute_in_expr(right, subst)),
),
TLExpr::Not(inner) => TLExpr::Not(Box::new(substitute_in_expr(inner, subst))),
TLExpr::Imply(left, right) => TLExpr::Imply(
Box::new(substitute_in_expr(left, subst)),
Box::new(substitute_in_expr(right, subst)),
),
TLExpr::Exists { var, domain, body } => {
if subst.domain().contains(var) {
expr.clone()
} else {
TLExpr::Exists {
var: var.clone(),
domain: domain.clone(),
body: Box::new(substitute_in_expr(body, subst)),
}
}
}
TLExpr::ForAll { var, domain, body } => {
if subst.domain().contains(var) {
expr.clone()
} else {
TLExpr::ForAll {
var: var.clone(),
domain: domain.clone(),
body: Box::new(substitute_in_expr(body, subst)),
}
}
}
_ => expr.clone(),
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Sequent {
pub antecedents: Vec<TLExpr>,
pub consequents: Vec<TLExpr>,
}
impl Sequent {
pub fn new(antecedents: Vec<TLExpr>, consequents: Vec<TLExpr>) -> Self {
Sequent {
antecedents,
consequents,
}
}
pub fn identity(formula: TLExpr) -> Self {
Sequent::new(vec![formula.clone()], vec![formula])
}
pub fn is_axiom(&self) -> bool {
for ant in &self.antecedents {
for cons in &self.consequents {
if ant == cons {
return true;
}
}
}
false
}
pub fn weaken_left(mut self, formula: TLExpr) -> Self {
self.antecedents.push(formula);
self
}
pub fn weaken_right(mut self, formula: TLExpr) -> Self {
self.consequents.push(formula);
self
}
pub fn contract_left(mut self, index: usize) -> Option<Self> {
if index >= self.antecedents.len() {
return None;
}
let formula = self.antecedents[index].clone();
for i in 0..self.antecedents.len() {
if i != index && self.antecedents[i] == formula {
self.antecedents.remove(index);
return Some(self);
}
}
None
}
pub fn contract_right(mut self, index: usize) -> Option<Self> {
if index >= self.consequents.len() {
return None;
}
let formula = self.consequents[index].clone();
for i in 0..self.consequents.len() {
if i != index && self.consequents[i] == formula {
self.consequents.remove(index);
return Some(self);
}
}
None
}
pub fn free_vars(&self) -> HashSet<String> {
let mut vars = HashSet::new();
for ant in &self.antecedents {
vars.extend(ant.free_vars());
}
for cons in &self.consequents {
vars.extend(cons.free_vars());
}
vars
}
pub fn substitute(&self, var: &str, term: &Term) -> Self {
let mut subst = Substitution::empty();
subst.bind(var.to_string(), term.clone());
let new_antecedents = self
.antecedents
.iter()
.map(|expr| substitute_in_expr(expr, &subst))
.collect();
let new_consequents = self
.consequents
.iter()
.map(|expr| substitute_in_expr(expr, &subst))
.collect();
Sequent::new(new_antecedents, new_consequents)
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum InferenceRule {
Identity,
WeakeningLeft,
WeakeningRight,
ContractionLeft { index: usize },
ContractionRight { index: usize },
Exchange,
Cut { index: usize },
AndLeft { index: usize },
AndRight { index: usize },
OrLeft { index: usize },
OrRight { index: usize },
NotLeft { index: usize },
NotRight { index: usize },
ImplyLeft { index: usize },
ImplyRight { index: usize },
ExistsLeft { index: usize, witness: Term },
ExistsRight { index: usize, witness: Term },
ForAllLeft { index: usize, term: Term },
ForAllRight { index: usize, witness: Term },
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ProofTree {
pub conclusion: Sequent,
pub rule: InferenceRule,
pub premises: Vec<ProofTree>,
}
impl ProofTree {
pub fn identity(formula: TLExpr) -> Self {
ProofTree {
conclusion: Sequent::identity(formula),
rule: InferenceRule::Identity,
premises: vec![],
}
}
pub fn new(conclusion: Sequent, rule: InferenceRule, premises: Vec<ProofTree>) -> Self {
ProofTree {
conclusion,
rule,
premises,
}
}
pub fn is_valid(&self) -> bool {
match &self.rule {
InferenceRule::Identity => {
self.premises.is_empty() && self.conclusion.is_axiom()
}
InferenceRule::WeakeningLeft | InferenceRule::WeakeningRight => {
if self.premises.len() != 1 {
return false;
}
true
}
InferenceRule::AndLeft { index } => {
if self.premises.len() != 1 {
return false;
}
if *index >= self.conclusion.antecedents.len() {
return false;
}
matches!(self.conclusion.antecedents[*index], TLExpr::And(_, _))
}
InferenceRule::AndRight { .. } => {
self.premises.len() == 2
}
InferenceRule::OrLeft { .. } => {
self.premises.len() == 2
}
InferenceRule::OrRight { index } => {
if self.premises.len() != 1 {
return false;
}
if *index >= self.conclusion.consequents.len() {
return false;
}
matches!(self.conclusion.consequents[*index], TLExpr::Or(_, _))
}
InferenceRule::NotLeft { .. } | InferenceRule::NotRight { .. } => {
self.premises.len() == 1
}
InferenceRule::ImplyLeft { .. } => {
self.premises.len() == 2
}
InferenceRule::ImplyRight { .. } => {
self.premises.len() == 1
}
InferenceRule::Cut { .. } => {
self.premises.len() == 2
}
_ => true, }
}
pub fn depth(&self) -> usize {
if self.premises.is_empty() {
1
} else {
1 + self.premises.iter().map(|p| p.depth()).max().unwrap_or(0)
}
}
pub fn size(&self) -> usize {
1 + self.premises.iter().map(|p| p.size()).sum::<usize>()
}
pub fn uses_cut(&self) -> bool {
matches!(self.rule, InferenceRule::Cut { .. }) || self.premises.iter().any(|p| p.uses_cut())
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ProofSearchStrategy {
DepthFirst { max_depth: usize },
BreadthFirst { max_depth: usize },
BestFirst { max_depth: usize },
IterativeDeepening { max_depth: usize },
}
pub struct ProofSearchEngine {
strategy: ProofSearchStrategy,
max_steps: usize,
pub stats: ProofSearchStats,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct ProofSearchStats {
pub sequents_explored: usize,
pub proofs_generated: usize,
pub backtracks: usize,
pub proof_depth: Option<usize>,
}
impl ProofSearchEngine {
pub fn new(strategy: ProofSearchStrategy, max_steps: usize) -> Self {
ProofSearchEngine {
strategy,
max_steps,
stats: ProofSearchStats::default(),
}
}
pub fn search(&mut self, sequent: &Sequent) -> Option<ProofTree> {
match self.strategy {
ProofSearchStrategy::DepthFirst { max_depth } => self.dfs_search(sequent, 0, max_depth),
ProofSearchStrategy::BreadthFirst { max_depth } => self.bfs_search(sequent, max_depth),
ProofSearchStrategy::BestFirst { max_depth } => {
self.best_first_search(sequent, max_depth)
}
ProofSearchStrategy::IterativeDeepening { max_depth } => {
self.iterative_deepening_search(sequent, max_depth)
}
}
}
fn dfs_search(
&mut self,
sequent: &Sequent,
depth: usize,
max_depth: usize,
) -> Option<ProofTree> {
self.stats.sequents_explored += 1;
if depth >= max_depth || self.stats.sequents_explored >= self.max_steps {
self.stats.backtracks += 1;
return None;
}
if sequent.is_axiom() {
for ant in &sequent.antecedents {
if sequent.consequents.contains(ant) {
self.stats.proofs_generated += 1;
let proof = ProofTree::identity(ant.clone());
self.stats.proof_depth = Some(depth);
return Some(proof);
}
}
}
for (i, ant) in sequent.antecedents.iter().enumerate() {
if let TLExpr::And(a, b) = ant {
let mut new_ant = sequent.antecedents.clone();
new_ant.remove(i);
new_ant.push((**a).clone());
new_ant.push((**b).clone());
let new_sequent = Sequent::new(new_ant, sequent.consequents.clone());
if let Some(premise) = self.dfs_search(&new_sequent, depth + 1, max_depth) {
self.stats.proofs_generated += 1;
return Some(ProofTree::new(
sequent.clone(),
InferenceRule::AndLeft { index: i },
vec![premise],
));
}
}
}
for (i, cons) in sequent.consequents.iter().enumerate() {
if let TLExpr::Or(a, b) = cons {
let mut new_cons = sequent.consequents.clone();
new_cons.remove(i);
new_cons.push((**a).clone());
new_cons.push((**b).clone());
let new_sequent = Sequent::new(sequent.antecedents.clone(), new_cons);
if let Some(premise) = self.dfs_search(&new_sequent, depth + 1, max_depth) {
self.stats.proofs_generated += 1;
return Some(ProofTree::new(
sequent.clone(),
InferenceRule::OrRight { index: i },
vec![premise],
));
}
}
}
for (i, ant) in sequent.antecedents.iter().enumerate() {
if let TLExpr::Not(a) = ant {
let mut new_ant = sequent.antecedents.clone();
new_ant.remove(i);
let mut new_cons = sequent.consequents.clone();
new_cons.push((**a).clone());
let new_sequent = Sequent::new(new_ant, new_cons);
if let Some(premise) = self.dfs_search(&new_sequent, depth + 1, max_depth) {
self.stats.proofs_generated += 1;
return Some(ProofTree::new(
sequent.clone(),
InferenceRule::NotLeft { index: i },
vec![premise],
));
}
}
}
for (i, cons) in sequent.consequents.iter().enumerate() {
if let TLExpr::Not(a) = cons {
let mut new_cons = sequent.consequents.clone();
new_cons.remove(i);
let mut new_ant = sequent.antecedents.clone();
new_ant.push((**a).clone());
let new_sequent = Sequent::new(new_ant, new_cons);
if let Some(premise) = self.dfs_search(&new_sequent, depth + 1, max_depth) {
self.stats.proofs_generated += 1;
return Some(ProofTree::new(
sequent.clone(),
InferenceRule::NotRight { index: i },
vec![premise],
));
}
}
}
self.stats.backtracks += 1;
None
}
fn bfs_search(&mut self, sequent: &Sequent, max_depth: usize) -> Option<ProofTree> {
self.dfs_search(sequent, 0, max_depth)
}
fn best_first_search(&mut self, sequent: &Sequent, max_depth: usize) -> Option<ProofTree> {
self.dfs_search(sequent, 0, max_depth)
}
fn iterative_deepening_search(
&mut self,
sequent: &Sequent,
max_depth: usize,
) -> Option<ProofTree> {
for depth in 1..=max_depth {
if let Some(proof) = self.dfs_search(sequent, 0, depth) {
return Some(proof);
}
}
None
}
}
pub struct CutElimination;
impl CutElimination {
pub fn eliminate(proof: ProofTree) -> ProofTree {
if !proof.uses_cut() {
return proof;
}
let premises: Vec<ProofTree> = proof.premises.into_iter().map(Self::eliminate).collect();
if let InferenceRule::Cut { index } = proof.rule {
if premises.len() == 2 {
return ProofTree::new(proof.conclusion, InferenceRule::Cut { index }, premises);
}
}
ProofTree::new(proof.conclusion, proof.rule, premises)
}
pub fn is_cut_free(proof: &ProofTree) -> bool {
!proof.uses_cut()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TLExpr;
#[test]
fn test_identity_sequent() {
let p = TLExpr::pred("P", vec![Term::var("x")]);
let seq = Sequent::identity(p);
assert!(seq.is_axiom());
assert_eq!(seq.antecedents.len(), 1);
assert_eq!(seq.consequents.len(), 1);
}
#[test]
fn test_weakening_left() {
let p = TLExpr::pred("P", vec![]);
let q = TLExpr::pred("Q", vec![]);
let seq = Sequent::identity(p.clone()).weaken_left(q.clone());
assert_eq!(seq.antecedents.len(), 2);
assert!(seq.antecedents.contains(&q));
}
#[test]
fn test_weakening_right() {
let p = TLExpr::pred("P", vec![]);
let q = TLExpr::pred("Q", vec![]);
let seq = Sequent::identity(p.clone()).weaken_right(q.clone());
assert_eq!(seq.consequents.len(), 2);
assert!(seq.consequents.contains(&q));
}
#[test]
fn test_contraction_left() {
let p = TLExpr::pred("P", vec![]);
let mut seq = Sequent::identity(p.clone());
seq.antecedents.push(p.clone());
assert_eq!(seq.antecedents.len(), 2);
let contracted = seq.contract_left(0);
assert!(contracted.is_some());
assert_eq!(contracted.expect("unwrap").antecedents.len(), 1);
}
#[test]
fn test_free_vars() {
let p = TLExpr::pred("P", vec![Term::var("x")]);
let q = TLExpr::pred("Q", vec![Term::var("y")]);
let seq = Sequent::new(vec![p], vec![q]);
let vars = seq.free_vars();
assert_eq!(vars.len(), 2);
assert!(vars.contains("x"));
assert!(vars.contains("y"));
}
#[test]
fn test_sequent_substitute() {
let p_x = TLExpr::pred("P", vec![Term::var("x")]);
let seq = Sequent::identity(p_x.clone());
let substituted = seq.substitute("x", &Term::constant("a"));
let p_a = TLExpr::pred("P", vec![Term::constant("a")]);
assert_eq!(substituted.antecedents[0], p_a);
assert_eq!(substituted.consequents[0], p_a);
assert_eq!(seq.antecedents[0], p_x);
}
#[test]
fn test_sequent_substitute_capture_avoiding() {
let p_x = TLExpr::pred("P", vec![Term::var("x")]);
let exists_p = TLExpr::exists("x", "Domain", p_x);
let q_x = TLExpr::pred("Q", vec![Term::var("x")]);
let seq = Sequent::new(vec![exists_p.clone()], vec![q_x]);
let substituted = seq.substitute("x", &Term::constant("a"));
assert_eq!(substituted.antecedents[0], exists_p);
let q_a = TLExpr::pred("Q", vec![Term::constant("a")]);
assert_eq!(substituted.consequents[0], q_a);
}
#[test]
fn test_sequent_substitute_multiple() {
let p_x = TLExpr::pred("P", vec![Term::var("x")]);
let q_x = TLExpr::pred("Q", vec![Term::var("x")]);
let and_pq = TLExpr::and(p_x, q_x);
let r_x = TLExpr::pred("R", vec![Term::var("x")]);
let seq = Sequent::new(vec![and_pq], vec![r_x]);
let substituted = seq.substitute("x", &Term::constant("b"));
let p_b = TLExpr::pred("P", vec![Term::constant("b")]);
let q_b = TLExpr::pred("Q", vec![Term::constant("b")]);
let and_pq_b = TLExpr::and(p_b, q_b);
let r_b = TLExpr::pred("R", vec![Term::constant("b")]);
assert_eq!(substituted.antecedents[0], and_pq_b);
assert_eq!(substituted.consequents[0], r_b);
}
#[test]
fn test_substitution() {
let p = TLExpr::pred("P", vec![Term::var("x")]);
let seq = Sequent::identity(p.clone());
let substituted = seq.substitute("x", &Term::constant("a"));
let p_a = TLExpr::pred("P", vec![Term::constant("a")]);
assert_eq!(substituted.antecedents[0], p_a);
assert_eq!(substituted.consequents[0], p_a);
}
#[test]
fn test_identity_proof_tree() {
let p = TLExpr::pred("P", vec![]);
let proof = ProofTree::identity(p);
assert!(proof.is_valid());
assert_eq!(proof.depth(), 1);
assert_eq!(proof.size(), 1);
assert!(!proof.uses_cut());
}
#[test]
fn test_and_left_proof() {
let p = TLExpr::pred("P", vec![]);
let q = TLExpr::pred("Q", vec![]);
let and_pq = TLExpr::and(p.clone(), q.clone());
let _premise_seq = Sequent::new(vec![p.clone(), q], vec![p.clone()]);
let premise = ProofTree::identity(p.clone());
let conclusion_seq = Sequent::new(vec![and_pq], vec![p]);
let proof = ProofTree::new(
conclusion_seq,
InferenceRule::AndLeft { index: 0 },
vec![premise],
);
assert!(proof.is_valid());
assert_eq!(proof.depth(), 2);
}
#[test]
fn test_proof_search_simple() {
let p = TLExpr::pred("P", vec![]);
let sequent = Sequent::identity(p);
let mut engine =
ProofSearchEngine::new(ProofSearchStrategy::DepthFirst { max_depth: 10 }, 1000);
let proof = engine.search(&sequent);
assert!(proof.is_some());
assert!(proof.expect("unwrap").is_valid());
assert!(engine.stats.proofs_generated > 0);
}
#[test]
fn test_proof_search_and() {
let p = TLExpr::pred("P", vec![]);
let q = TLExpr::pred("Q", vec![]);
let and_pq = TLExpr::and(p.clone(), q.clone());
let sequent = Sequent::new(vec![and_pq], vec![p]);
let mut engine =
ProofSearchEngine::new(ProofSearchStrategy::DepthFirst { max_depth: 10 }, 1000);
let proof = engine.search(&sequent);
assert!(proof.is_some());
let proof = proof.expect("unwrap");
assert!(proof.is_valid());
assert!(engine.stats.proofs_generated > 0);
}
#[test]
fn test_proof_search_not() {
let p = TLExpr::pred("P", vec![]);
let not_p = TLExpr::negate(p.clone());
let sequent = Sequent::identity(not_p);
let mut engine =
ProofSearchEngine::new(ProofSearchStrategy::DepthFirst { max_depth: 10 }, 1000);
let proof = engine.search(&sequent);
assert!(proof.is_some());
assert!(proof.expect("unwrap").is_valid());
}
#[test]
fn test_cut_elimination_no_cut() {
let p = TLExpr::pred("P", vec![]);
let proof = ProofTree::identity(p);
assert!(CutElimination::is_cut_free(&proof));
let eliminated = CutElimination::eliminate(proof.clone());
assert_eq!(eliminated, proof);
}
#[test]
fn test_iterative_deepening_search() {
let p = TLExpr::pred("P", vec![]);
let q = TLExpr::pred("Q", vec![]);
let and_pq = TLExpr::and(p.clone(), q.clone());
let sequent = Sequent::new(vec![and_pq], vec![p]);
let mut engine = ProofSearchEngine::new(
ProofSearchStrategy::IterativeDeepening { max_depth: 10 },
1000,
);
let proof = engine.search(&sequent);
assert!(proof.is_some());
assert!(proof.expect("unwrap").is_valid());
}
}