use std::collections::{HashMap, HashSet, VecDeque};
use super::rewriting::RewriteSystem;
use super::TLExpr;
#[derive(Debug, Clone)]
pub struct CriticalPair {
pub overlap: TLExpr,
pub result1: TLExpr,
pub result2: TLExpr,
pub rule1_name: String,
pub rule2_name: String,
pub joinable: Option<bool>,
}
impl CriticalPair {
pub fn new(
overlap: TLExpr,
result1: TLExpr,
result2: TLExpr,
rule1_name: String,
rule2_name: String,
) -> Self {
Self {
overlap,
result1,
result2,
rule1_name,
rule2_name,
joinable: None,
}
}
pub fn is_trivially_joinable(&self) -> bool {
self.result1 == self.result2
}
pub fn has_conflict(&self) -> bool {
!self.is_trivially_joinable()
}
}
#[derive(Debug, Clone)]
pub struct ConfluenceReport {
pub critical_pairs: Vec<CriticalPair>,
pub joinable_count: usize,
pub non_joinable_count: usize,
pub is_locally_confluent: bool,
pub is_terminating: bool,
}
impl ConfluenceReport {
pub fn new() -> Self {
Self {
critical_pairs: Vec::new(),
joinable_count: 0,
non_joinable_count: 0,
is_locally_confluent: false,
is_terminating: false,
}
}
pub fn is_confluent(&self) -> bool {
self.is_terminating && self.is_locally_confluent
}
pub fn summary(&self) -> String {
format!(
"Confluence Report:\n\
- Critical pairs: {}\n\
- Joinable: {}\n\
- Non-joinable: {}\n\
- Locally confluent: {}\n\
- Terminating: {}\n\
- Confluent: {}",
self.critical_pairs.len(),
self.joinable_count,
self.non_joinable_count,
self.is_locally_confluent,
self.is_terminating,
self.is_confluent()
)
}
}
impl Default for ConfluenceReport {
fn default() -> Self {
Self::new()
}
}
pub struct ConfluenceChecker {
max_depth: usize,
max_expr_size: usize,
joinability_cache: HashMap<(String, String), bool>,
}
impl ConfluenceChecker {
pub fn new() -> Self {
Self {
max_depth: 10,
max_expr_size: 1000,
joinability_cache: HashMap::new(),
}
}
pub fn with_max_depth(mut self, depth: usize) -> Self {
self.max_depth = depth;
self
}
pub fn with_max_expr_size(mut self, size: usize) -> Self {
self.max_expr_size = size;
self
}
pub fn check(&mut self, system: &RewriteSystem) -> ConfluenceReport {
let mut report = ConfluenceReport::new();
self.find_critical_pairs_basic(system, &mut report);
for pair in &mut report.critical_pairs {
if pair.is_trivially_joinable() {
pair.joinable = Some(true);
report.joinable_count += 1;
} else {
let joinable = self.test_joinability(&pair.result1, &pair.result2, system);
pair.joinable = Some(joinable);
if joinable {
report.joinable_count += 1;
} else {
report.non_joinable_count += 1;
}
}
}
report.is_locally_confluent = report.non_joinable_count == 0;
report.is_terminating = self.check_termination_heuristic(system);
report
}
fn find_critical_pairs_basic(&self, _system: &RewriteSystem, _report: &mut ConfluenceReport) {
}
pub fn test_joinability(
&mut self,
expr1: &TLExpr,
expr2: &TLExpr,
system: &RewriteSystem,
) -> bool {
let key = (format!("{:?}", expr1), format!("{:?}", expr2));
if let Some(&result) = self.joinability_cache.get(&key) {
return result;
}
if expr1 == expr2 {
self.joinability_cache.insert(key, true);
return true;
}
let mut visited1 = HashSet::new();
let mut visited2 = HashSet::new();
let mut queue1 = VecDeque::new();
let mut queue2 = VecDeque::new();
queue1.push_back((expr1.clone(), 0));
queue2.push_back((expr2.clone(), 0));
visited1.insert(format!("{:?}", expr1));
visited2.insert(format!("{:?}", expr2));
while !queue1.is_empty() || !queue2.is_empty() {
if let Some((current, depth)) = queue1.pop_front() {
if depth >= self.max_depth {
continue;
}
let current_key = format!("{:?}", ¤t);
if visited2.contains(¤t_key) {
self.joinability_cache.insert(key, true);
return true;
}
for rewrite in self.get_all_rewrites(¤t, system) {
let rewrite_key = format!("{:?}", &rewrite);
if !visited1.contains(&rewrite_key) {
visited1.insert(rewrite_key);
queue1.push_back((rewrite, depth + 1));
}
}
}
if let Some((current, depth)) = queue2.pop_front() {
if depth >= self.max_depth {
continue;
}
let current_key = format!("{:?}", ¤t);
if visited1.contains(¤t_key) {
self.joinability_cache.insert(key, true);
return true;
}
for rewrite in self.get_all_rewrites(¤t, system) {
let rewrite_key = format!("{:?}", &rewrite);
if !visited2.contains(&rewrite_key) {
visited2.insert(rewrite_key);
queue2.push_back((rewrite, depth + 1));
}
}
}
}
self.joinability_cache.insert(key, false);
false
}
#[allow(clippy::only_used_in_recursion)]
fn get_all_rewrites(&self, expr: &TLExpr, system: &RewriteSystem) -> Vec<TLExpr> {
let mut results = Vec::new();
if let Some(rewritten) = system.apply_once(expr) {
results.push(rewritten);
}
match expr {
TLExpr::And(l, r) => {
for l_rewrite in self.get_all_rewrites(l, system) {
results.push(TLExpr::and(l_rewrite, (**r).clone()));
}
for r_rewrite in self.get_all_rewrites(r, system) {
results.push(TLExpr::and((**l).clone(), r_rewrite));
}
}
TLExpr::Or(l, r) => {
for l_rewrite in self.get_all_rewrites(l, system) {
results.push(TLExpr::or(l_rewrite, (**r).clone()));
}
for r_rewrite in self.get_all_rewrites(r, system) {
results.push(TLExpr::or((**l).clone(), r_rewrite));
}
}
TLExpr::Not(e) => {
for e_rewrite in self.get_all_rewrites(e, system) {
results.push(TLExpr::negate(e_rewrite));
}
}
_ => {}
}
results
}
fn check_termination_heuristic(&self, _system: &RewriteSystem) -> bool {
true
}
}
impl Default for ConfluenceChecker {
fn default() -> Self {
Self::new()
}
}
pub fn are_joinable(expr1: &TLExpr, expr2: &TLExpr, system: &RewriteSystem) -> bool {
let mut checker = ConfluenceChecker::new();
checker.test_joinability(expr1, expr2, system)
}
pub fn normalize(expr: &TLExpr, system: &RewriteSystem, max_steps: usize) -> Option<TLExpr> {
let mut current = expr.clone();
let mut steps = 0;
while steps < max_steps {
if let Some(next) = system.apply_once(¤t) {
current = next;
steps += 1;
} else {
return Some(current); }
}
None }
#[cfg(test)]
mod tests {
use super::*;
use crate::{Pattern, RewriteRule, Term};
#[test]
fn test_critical_pair_trivial_joinable() {
let overlap = TLExpr::pred("P", vec![Term::var("x")]);
let result = TLExpr::pred("Q", vec![Term::var("x")]);
let pair = CriticalPair::new(
overlap,
result.clone(),
result,
"rule1".to_string(),
"rule2".to_string(),
);
assert!(pair.is_trivially_joinable());
assert!(!pair.has_conflict());
}
#[test]
fn test_critical_pair_with_conflict() {
let overlap = TLExpr::pred("P", vec![Term::var("x")]);
let result1 = TLExpr::pred("Q", vec![Term::var("x")]);
let result2 = TLExpr::pred("R", vec![Term::var("x")]);
let pair = CriticalPair::new(
overlap,
result1,
result2,
"rule1".to_string(),
"rule2".to_string(),
);
assert!(!pair.is_trivially_joinable());
assert!(pair.has_conflict());
}
#[test]
fn test_joinability_identical() {
let system = RewriteSystem::new();
let expr = TLExpr::pred("P", vec![Term::var("x")]);
let mut checker = ConfluenceChecker::new();
assert!(checker.test_joinability(&expr, &expr, &system));
}
#[test]
fn test_joinability_via_rewriting() {
let system = RewriteSystem::new().add_rule(RewriteRule::new(
Pattern::negation(Pattern::negation(Pattern::var("A"))),
|bindings| bindings.get("A").expect("unwrap").clone(),
));
let expr1 = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
let expr2 = TLExpr::pred("P", vec![Term::var("x")]);
let mut checker = ConfluenceChecker::new();
assert!(checker.test_joinability(&expr1, &expr2, &system));
}
#[test]
fn test_normalize_to_normal_form() {
let system = RewriteSystem::new().add_rule(RewriteRule::new(
Pattern::negation(Pattern::negation(Pattern::var("A"))),
|bindings| bindings.get("A").expect("unwrap").clone(),
));
let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
let normal_form = normalize(&expr, &system, 100).expect("unwrap");
assert!(matches!(normal_form, TLExpr::Pred { .. }));
}
#[test]
fn test_confluence_report_summary() {
let mut report = ConfluenceReport::new();
report.joinable_count = 5;
report.non_joinable_count = 2;
report.is_locally_confluent = false;
report.is_terminating = true;
let summary = report.summary();
assert!(summary.contains("Joinable: 5"));
assert!(summary.contains("Non-joinable: 2"));
assert!(summary.contains("Confluent: false"));
}
#[test]
fn test_confluence_via_newmans_lemma() {
let mut report = ConfluenceReport::new();
report.is_terminating = true;
report.is_locally_confluent = true;
assert!(report.is_confluent());
report.is_terminating = false;
report.is_locally_confluent = true;
assert!(!report.is_confluent());
report.is_terminating = true;
report.is_locally_confluent = false;
assert!(!report.is_confluent());
}
}