use std::cmp::Reverse;
use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use super::TLExpr;
use crate::util::ExprStats;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub enum RulePriority {
Critical = 100,
High = 75,
#[default]
Normal = 50,
Low = 25,
Minimal = 0,
}
pub type GuardPredicate = fn(&HashMap<String, TLExpr>) -> bool;
pub type TransformFn = fn(&TLExpr) -> Option<TLExpr>;
#[derive(Clone)]
pub struct ConditionalRule {
pub name: String,
pub transform: TransformFn,
pub guard: GuardPredicate,
pub priority: RulePriority,
pub description: Option<String>,
applications: usize,
}
impl ConditionalRule {
pub fn new(name: impl Into<String>, transform: TransformFn, guard: GuardPredicate) -> Self {
Self {
name: name.into(),
transform,
guard,
priority: RulePriority::default(),
description: None,
applications: 0,
}
}
pub fn with_priority(mut self, priority: RulePriority) -> Self {
self.priority = priority;
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn apply(&mut self, expr: &TLExpr) -> Option<TLExpr> {
let bindings = HashMap::new(); if (self.guard)(&bindings) {
if let Some(result) = (self.transform)(expr) {
self.applications += 1;
return Some(result);
}
}
None
}
pub fn application_count(&self) -> usize {
self.applications
}
pub fn reset_counter(&mut self) {
self.applications = 0;
}
}
impl std::fmt::Debug for ConditionalRule {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConditionalRule")
.field("name", &self.name)
.field("priority", &self.priority)
.field("description", &self.description)
.field("applications", &self.applications)
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum RewriteStrategy {
Innermost,
Outermost,
#[default]
BottomUp,
TopDown,
FixpointPerNode,
GlobalFixpoint,
}
#[derive(Debug, Clone)]
pub struct RewriteConfig {
pub max_steps: usize,
pub strategy: RewriteStrategy,
pub detect_cycles: bool,
pub trace: bool,
pub max_expr_size: Option<usize>,
}
impl Default for RewriteConfig {
fn default() -> Self {
Self {
max_steps: 10000,
strategy: RewriteStrategy::default(),
detect_cycles: true,
trace: false,
max_expr_size: Some(100000), }
}
}
#[derive(Debug, Clone, Default)]
pub struct RewriteStats {
pub steps: usize,
pub rule_applications: usize,
pub rule_counts: HashMap<String, usize>,
pub reached_fixpoint: bool,
pub cycle_detected: bool,
pub size_limit_exceeded: bool,
pub initial_size: usize,
pub final_size: usize,
}
impl RewriteStats {
pub fn reduction_percentage(&self) -> f64 {
if self.initial_size == 0 {
return 0.0;
}
100.0 * (1.0 - (self.final_size as f64 / self.initial_size as f64))
}
pub fn is_successful(&self) -> bool {
self.reached_fixpoint && !self.cycle_detected && !self.size_limit_exceeded
}
}
fn expr_hash(expr: &TLExpr) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
format!("{:?}", expr).hash(&mut hasher);
hasher.finish()
}
pub struct AdvancedRewriteSystem {
rules: Vec<ConditionalRule>,
config: RewriteConfig,
seen_hashes: HashSet<u64>,
}
impl AdvancedRewriteSystem {
pub fn new() -> Self {
Self {
rules: Vec::new(),
config: RewriteConfig::default(),
seen_hashes: HashSet::new(),
}
}
pub fn with_config(config: RewriteConfig) -> Self {
Self {
rules: Vec::new(),
config,
seen_hashes: HashSet::new(),
}
}
pub fn add_rule(mut self, rule: ConditionalRule) -> Self {
self.rules.push(rule);
self.rules.sort_by_key(|r| Reverse(r.priority));
self
}
pub fn apply(&mut self, expr: &TLExpr) -> (TLExpr, RewriteStats) {
let initial_stats = ExprStats::compute(expr);
let mut stats = RewriteStats {
initial_size: initial_stats.node_count,
..Default::default()
};
self.seen_hashes.clear();
let result = match self.config.strategy {
RewriteStrategy::Innermost => self.apply_innermost(expr, &mut stats),
RewriteStrategy::Outermost => self.apply_outermost(expr, &mut stats),
RewriteStrategy::BottomUp => self.apply_bottom_up(expr, &mut stats),
RewriteStrategy::TopDown => self.apply_top_down(expr, &mut stats),
RewriteStrategy::FixpointPerNode => self.apply_fixpoint_per_node(expr, &mut stats),
RewriteStrategy::GlobalFixpoint => self.apply_global_fixpoint(expr, &mut stats),
};
let final_stats = ExprStats::compute(&result);
stats.final_size = final_stats.node_count;
if stats.steps < self.config.max_steps && !stats.cycle_detected {
stats.reached_fixpoint = true;
}
(result, stats)
}
fn try_apply_at_node(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> Option<TLExpr> {
for rule in &mut self.rules {
if let Some(result) = rule.apply(expr) {
stats.rule_applications += 1;
*stats.rule_counts.entry(rule.name.clone()).or_insert(0) += 1;
if self.config.trace {
eprintln!("Applied rule '{}' at step {}", rule.name, stats.steps);
}
return Some(result);
}
}
None
}
fn check_constraints(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> bool {
if self.config.detect_cycles {
let hash = expr_hash(expr);
if self.seen_hashes.contains(&hash) {
stats.cycle_detected = true;
return false;
}
self.seen_hashes.insert(hash);
}
if let Some(max_size) = self.config.max_expr_size {
let current_stats = ExprStats::compute(expr);
if current_stats.node_count > max_size {
stats.size_limit_exceeded = true;
return false;
}
}
if stats.steps >= self.config.max_steps {
return false;
}
true
}
fn apply_innermost(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
let mut current = expr.clone();
while stats.steps < self.config.max_steps {
stats.steps += 1;
if !self.check_constraints(¤t, stats) {
break;
}
if let Some(rewritten) = self.rewrite_innermost(¤t, stats) {
current = rewritten;
} else {
break; }
}
current
}
fn rewrite_innermost(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> Option<TLExpr> {
let children_rewritten = self.rewrite_children(expr, stats);
if let Some(new_expr) = children_rewritten {
return Some(new_expr);
}
self.try_apply_at_node(expr, stats)
}
fn apply_outermost(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
let mut current = expr.clone();
while stats.steps < self.config.max_steps {
stats.steps += 1;
if !self.check_constraints(¤t, stats) {
break;
}
if let Some(rewritten) = self.try_apply_at_node(¤t, stats) {
current = rewritten;
continue;
}
if let Some(rewritten) = self.rewrite_children(¤t, stats) {
current = rewritten;
} else {
break;
}
}
current
}
fn apply_bottom_up(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
stats.steps += 1;
if !self.check_constraints(expr, stats) {
return expr.clone();
}
let with_transformed_children = self.transform_children_bottom_up(expr, stats);
if let Some(result) = self.try_apply_at_node(&with_transformed_children, stats) {
result
} else {
with_transformed_children
}
}
fn apply_top_down(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
stats.steps += 1;
if !self.check_constraints(expr, stats) {
return expr.clone();
}
let current = if let Some(result) = self.try_apply_at_node(expr, stats) {
result
} else {
expr.clone()
};
self.transform_children_top_down(¤t, stats)
}
fn apply_fixpoint_per_node(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
let mut current = expr.clone();
while let Some(rewritten) = self.try_apply_at_node(¤t, stats) {
current = rewritten;
stats.steps += 1;
if !self.check_constraints(¤t, stats) {
return current;
}
}
self.transform_children_fixpoint_per_node(¤t, stats)
}
fn apply_global_fixpoint(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
let mut current = expr.clone();
loop {
stats.steps += 1;
if !self.check_constraints(¤t, stats) {
break;
}
let next = self.apply_bottom_up(¤t, stats);
if next == current {
break; }
current = next;
}
current
}
fn rewrite_children(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> Option<TLExpr> {
match expr {
TLExpr::And(l, r) => {
let l_new = self.rewrite_innermost(l, stats);
let r_new = self.rewrite_innermost(r, stats);
if l_new.is_some() || r_new.is_some() {
Some(TLExpr::and(
l_new.unwrap_or_else(|| (**l).clone()),
r_new.unwrap_or_else(|| (**r).clone()),
))
} else {
None
}
}
TLExpr::Or(l, r) => {
let l_new = self.rewrite_innermost(l, stats);
let r_new = self.rewrite_innermost(r, stats);
if l_new.is_some() || r_new.is_some() {
Some(TLExpr::or(
l_new.unwrap_or_else(|| (**l).clone()),
r_new.unwrap_or_else(|| (**r).clone()),
))
} else {
None
}
}
TLExpr::Not(e) => self.rewrite_innermost(e, stats).map(TLExpr::negate),
_ => None,
}
}
fn transform_children_bottom_up(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
match expr {
TLExpr::And(l, r) => TLExpr::and(
self.apply_bottom_up(l, stats),
self.apply_bottom_up(r, stats),
),
TLExpr::Or(l, r) => TLExpr::or(
self.apply_bottom_up(l, stats),
self.apply_bottom_up(r, stats),
),
TLExpr::Not(e) => TLExpr::negate(self.apply_bottom_up(e, stats)),
TLExpr::Imply(l, r) => TLExpr::imply(
self.apply_bottom_up(l, stats),
self.apply_bottom_up(r, stats),
),
_ => expr.clone(),
}
}
fn transform_children_top_down(&mut self, expr: &TLExpr, stats: &mut RewriteStats) -> TLExpr {
match expr {
TLExpr::And(l, r) => {
TLExpr::and(self.apply_top_down(l, stats), self.apply_top_down(r, stats))
}
TLExpr::Or(l, r) => {
TLExpr::or(self.apply_top_down(l, stats), self.apply_top_down(r, stats))
}
TLExpr::Not(e) => TLExpr::negate(self.apply_top_down(e, stats)),
TLExpr::Imply(l, r) => {
TLExpr::imply(self.apply_top_down(l, stats), self.apply_top_down(r, stats))
}
_ => expr.clone(),
}
}
fn transform_children_fixpoint_per_node(
&mut self,
expr: &TLExpr,
stats: &mut RewriteStats,
) -> TLExpr {
match expr {
TLExpr::And(l, r) => TLExpr::and(
self.apply_fixpoint_per_node(l, stats),
self.apply_fixpoint_per_node(r, stats),
),
TLExpr::Or(l, r) => TLExpr::or(
self.apply_fixpoint_per_node(l, stats),
self.apply_fixpoint_per_node(r, stats),
),
TLExpr::Not(e) => TLExpr::negate(self.apply_fixpoint_per_node(e, stats)),
TLExpr::Imply(l, r) => TLExpr::imply(
self.apply_fixpoint_per_node(l, stats),
self.apply_fixpoint_per_node(r, stats),
),
_ => expr.clone(),
}
}
pub fn rule_statistics(&self) -> Vec<(&str, usize)> {
self.rules
.iter()
.map(|r| (r.name.as_str(), r.application_count()))
.collect()
}
pub fn reset_statistics(&mut self) {
for rule in &mut self.rules {
rule.reset_counter();
}
}
}
impl Default for AdvancedRewriteSystem {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{TLExpr, Term};
#[test]
fn test_conditional_rule_basic() {
let mut rule = ConditionalRule::new(
"remove_double_neg",
|expr| {
if let TLExpr::Not(inner) = expr {
if let TLExpr::Not(inner_inner) = &**inner {
return Some((**inner_inner).clone());
}
}
None
},
|_| true,
);
let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
let result = rule.apply(&expr).expect("unwrap");
assert!(matches!(result, TLExpr::Pred { .. }));
assert_eq!(rule.application_count(), 1);
}
#[test]
fn test_priority_ordering() {
let mut system = AdvancedRewriteSystem::new();
system = system.add_rule(
ConditionalRule::new("low", |_| None, |_| true).with_priority(RulePriority::Low),
);
system = system.add_rule(
ConditionalRule::new("high", |_| None, |_| true).with_priority(RulePriority::High),
);
system = system.add_rule(
ConditionalRule::new("critical", |_| None, |_| true)
.with_priority(RulePriority::Critical),
);
assert_eq!(system.rules[0].priority, RulePriority::Critical);
assert_eq!(system.rules[1].priority, RulePriority::High);
assert_eq!(system.rules[2].priority, RulePriority::Low);
}
#[test]
fn test_bottom_up_strategy() {
let mut system = AdvancedRewriteSystem::with_config(RewriteConfig {
strategy: RewriteStrategy::BottomUp,
max_steps: 100,
..Default::default()
});
system = system.add_rule(ConditionalRule::new(
"double_neg",
|expr| {
if let TLExpr::Not(inner) = expr {
if let TLExpr::Not(inner_inner) = &**inner {
return Some((**inner_inner).clone());
}
}
None
},
|_| true,
));
let expr = TLExpr::negate(TLExpr::negate(TLExpr::negate(TLExpr::negate(
TLExpr::pred("P", vec![Term::var("x")]),
))));
let (result, stats) = system.apply(&expr);
assert!(matches!(result, TLExpr::Pred { .. }));
assert_eq!(stats.rule_applications, 2); }
#[test]
fn test_cycle_detection() {
let mut system = AdvancedRewriteSystem::with_config(RewriteConfig {
strategy: RewriteStrategy::GlobalFixpoint,
detect_cycles: true,
max_steps: 1000,
..Default::default()
});
system = system.add_rule(ConditionalRule::new(
"add_double_neg",
|expr| {
if let TLExpr::Pred { .. } = expr {
return Some(TLExpr::negate(TLExpr::negate(expr.clone())));
}
None
},
|_| true,
));
system = system.add_rule(ConditionalRule::new(
"remove_double_neg",
|expr| {
if let TLExpr::Not(inner) = expr {
if let TLExpr::Not(inner_inner) = &**inner {
return Some((**inner_inner).clone());
}
}
None
},
|_| true,
));
let expr = TLExpr::pred("P", vec![Term::var("x")]);
let (_result, stats) = system.apply(&expr);
assert!(stats.cycle_detected || stats.steps >= 1000);
}
#[test]
fn test_size_limit() {
let mut system = AdvancedRewriteSystem::with_config(RewriteConfig {
strategy: RewriteStrategy::Innermost, max_expr_size: Some(10),
detect_cycles: false, ..Default::default()
});
system = system.add_rule(ConditionalRule::new(
"duplicate",
|expr| {
if let TLExpr::Pred { .. } = expr {
return Some(TLExpr::and(expr.clone(), expr.clone()));
}
None
},
|_| true,
));
let expr = TLExpr::pred("P", vec![Term::var("x")]);
let (_result, stats) = system.apply(&expr);
assert!(stats.size_limit_exceeded || stats.steps >= system.config.max_steps);
}
#[test]
fn test_rewrite_stats() {
let mut system = AdvancedRewriteSystem::new();
system = system.add_rule(ConditionalRule::new(
"test_rule",
|expr| {
if let TLExpr::Not(inner) = expr {
if let TLExpr::Not(inner_inner) = &**inner {
return Some((**inner_inner).clone());
}
}
None
},
|_| true,
));
let expr = TLExpr::negate(TLExpr::negate(TLExpr::pred("P", vec![Term::var("x")])));
let (_result, stats) = system.apply(&expr);
assert!(stats.is_successful());
assert!(stats.reduction_percentage() > 0.0);
assert_eq!(stats.rule_applications, 1);
}
}