use tensorlogic_ir::TLExpr;
#[derive(Debug, Clone)]
pub struct CostWeights {
pub add_sub: f64,
pub mul: f64,
pub div: f64,
pub pow: f64,
pub exp: f64,
pub log: f64,
pub sqrt: f64,
pub abs: f64,
pub cmp: f64,
pub reduction: f64,
}
impl Default for CostWeights {
fn default() -> Self {
Self {
add_sub: 1.0,
mul: 2.0,
div: 4.0,
pow: 8.0,
exp: 10.0,
log: 10.0,
sqrt: 4.0,
abs: 1.0,
cmp: 1.0,
reduction: 5.0,
}
}
}
impl CostWeights {
pub fn gpu_optimized() -> Self {
Self {
add_sub: 1.0,
mul: 1.0,
div: 2.0,
pow: 4.0,
exp: 3.0,
log: 3.0,
sqrt: 2.0,
abs: 1.0,
cmp: 1.0,
reduction: 10.0, }
}
pub fn simd_optimized() -> Self {
Self {
add_sub: 1.0,
mul: 1.0,
div: 3.0,
pow: 6.0,
exp: 8.0,
log: 8.0,
sqrt: 3.0,
abs: 1.0,
cmp: 1.0,
reduction: 3.0,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ExpressionComplexity {
pub additions: usize,
pub subtractions: usize,
pub multiplications: usize,
pub divisions: usize,
pub powers: usize,
pub exponentials: usize,
pub logarithms: usize,
pub square_roots: usize,
pub absolute_values: usize,
pub negations: usize,
pub comparisons: usize,
pub logical_ands: usize,
pub logical_ors: usize,
pub logical_nots: usize,
pub existential_quantifiers: usize,
pub universal_quantifiers: usize,
pub conditionals: usize,
pub predicates: usize,
pub constants: usize,
pub variables: usize,
pub min_operations: usize,
pub max_operations: usize,
pub max_depth: usize,
pub unique_variables: usize,
pub unique_predicates: usize,
}
impl ExpressionComplexity {
pub fn arithmetic_operations(&self) -> usize {
self.additions
+ self.subtractions
+ self.multiplications
+ self.divisions
+ self.powers
+ self.exponentials
+ self.logarithms
+ self.square_roots
+ self.absolute_values
+ self.negations
}
pub fn logical_operations(&self) -> usize {
self.logical_ands + self.logical_ors + self.logical_nots
}
pub fn total_operations(&self) -> usize {
self.arithmetic_operations()
+ self.logical_operations()
+ self.comparisons
+ self.conditionals
+ self.min_operations
+ self.max_operations
}
pub fn total_cost(&self) -> f64 {
self.total_cost_with_weights(&CostWeights::default())
}
pub fn total_cost_with_weights(&self, weights: &CostWeights) -> f64 {
let mut cost = 0.0;
cost += (self.additions + self.subtractions) as f64 * weights.add_sub;
cost += self.multiplications as f64 * weights.mul;
cost += self.divisions as f64 * weights.div;
cost += self.powers as f64 * weights.pow;
cost += self.exponentials as f64 * weights.exp;
cost += self.logarithms as f64 * weights.log;
cost += self.square_roots as f64 * weights.sqrt;
cost += self.absolute_values as f64 * weights.abs;
cost += self.comparisons as f64 * weights.cmp;
cost +=
(self.existential_quantifiers + self.universal_quantifiers) as f64 * weights.reduction;
cost += self.min_operations as f64 * weights.cmp;
cost += self.max_operations as f64 * weights.cmp;
cost
}
pub fn leaf_count(&self) -> usize {
self.constants + self.variables + self.predicates
}
pub fn cse_potential(&self) -> bool {
self.total_operations() > 5 && self.max_depth > 3
}
pub fn strength_reduction_potential(&self) -> bool {
self.powers > 0 || self.divisions > 2 || self.exponentials + self.logarithms > 0
}
pub fn complexity_level(&self) -> &'static str {
let total = self.total_operations();
if total <= 3 {
"trivial"
} else if total <= 10 {
"simple"
} else if total <= 30 {
"moderate"
} else if total <= 100 {
"complex"
} else {
"very_complex"
}
}
}
impl std::fmt::Display for ExpressionComplexity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Expression Complexity Analysis:")?;
writeln!(f, " Total operations: {}", self.total_operations())?;
writeln!(
f,
" Arithmetic operations: {}",
self.arithmetic_operations()
)?;
writeln!(f, " Logical operations: {}", self.logical_operations())?;
writeln!(f, " Maximum depth: {}", self.max_depth)?;
writeln!(f, " Estimated cost: {:.2}", self.total_cost())?;
writeln!(f, " Complexity level: {}", self.complexity_level())?;
Ok(())
}
}
pub fn analyze_complexity(expr: &TLExpr) -> ExpressionComplexity {
let mut complexity = ExpressionComplexity::default();
let mut var_names = std::collections::HashSet::new();
let mut pred_names = std::collections::HashSet::new();
analyze_complexity_impl(expr, &mut complexity, 0, &mut var_names, &mut pred_names);
complexity.unique_variables = var_names.len();
complexity.unique_predicates = pred_names.len();
complexity
}
fn analyze_complexity_impl(
expr: &TLExpr,
complexity: &mut ExpressionComplexity,
depth: usize,
var_names: &mut std::collections::HashSet<String>,
pred_names: &mut std::collections::HashSet<String>,
) {
complexity.max_depth = complexity.max_depth.max(depth);
match expr {
TLExpr::Add(lhs, rhs) => {
complexity.additions += 1;
analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Sub(lhs, rhs) => {
complexity.subtractions += 1;
analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Mul(lhs, rhs) => {
complexity.multiplications += 1;
analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Div(lhs, rhs) => {
complexity.divisions += 1;
analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Pow(base, exp) => {
complexity.powers += 1;
analyze_complexity_impl(base, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(exp, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Exp(inner) => {
complexity.exponentials += 1;
analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Log(inner) => {
complexity.logarithms += 1;
analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Sqrt(inner) => {
complexity.square_roots += 1;
analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Abs(inner) => {
complexity.absolute_values += 1;
analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
}
TLExpr::And(lhs, rhs) => {
complexity.logical_ands += 1;
analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Or(lhs, rhs) => {
complexity.logical_ors += 1;
analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Not(inner) => {
complexity.logical_nots += 1;
analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Imply(lhs, rhs) => {
complexity.logical_nots += 1;
complexity.logical_ors += 1;
analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Eq(lhs, rhs)
| TLExpr::Lt(lhs, rhs)
| TLExpr::Lte(lhs, rhs)
| TLExpr::Gt(lhs, rhs)
| TLExpr::Gte(lhs, rhs) => {
complexity.comparisons += 1;
analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Min(lhs, rhs) => {
complexity.min_operations += 1;
analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Max(lhs, rhs) => {
complexity.max_operations += 1;
analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Exists { var, body, .. } => {
complexity.existential_quantifiers += 1;
var_names.insert(var.clone());
analyze_complexity_impl(body, complexity, depth + 1, var_names, pred_names);
}
TLExpr::ForAll { var, body, .. } => {
complexity.universal_quantifiers += 1;
var_names.insert(var.clone());
analyze_complexity_impl(body, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Let {
var, value, body, ..
} => {
var_names.insert(var.clone());
analyze_complexity_impl(value, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(body, complexity, depth + 1, var_names, pred_names);
}
TLExpr::IfThenElse {
condition,
then_branch,
else_branch,
} => {
complexity.conditionals += 1;
analyze_complexity_impl(condition, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(then_branch, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(else_branch, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Pred { name, args } => {
complexity.predicates += 1;
pred_names.insert(name.clone());
for arg in args {
if let tensorlogic_ir::Term::Var(v) = arg {
var_names.insert(v.clone());
}
}
}
TLExpr::Constant(_) => {
complexity.constants += 1;
}
TLExpr::Box(inner) | TLExpr::Diamond(inner) => {
complexity.universal_quantifiers += 1; analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Next(inner) | TLExpr::Eventually(inner) | TLExpr::Always(inner) => {
complexity.existential_quantifiers += 1; analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Until { before, after } => {
complexity.existential_quantifiers += 1;
analyze_complexity_impl(before, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(after, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Score(inner)
| TLExpr::Floor(inner)
| TLExpr::Ceil(inner)
| TLExpr::Round(inner)
| TLExpr::Sin(inner)
| TLExpr::Cos(inner)
| TLExpr::Tan(inner)
| TLExpr::FuzzyNot { expr: inner, .. } => {
analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Mod(lhs, rhs)
| TLExpr::TNorm {
left: lhs,
right: rhs,
..
}
| TLExpr::TCoNorm {
left: lhs,
right: rhs,
..
}
| TLExpr::FuzzyImplication {
premise: lhs,
conclusion: rhs,
..
}
| TLExpr::Release {
released: lhs,
releaser: rhs,
}
| TLExpr::WeakUntil {
before: lhs,
after: rhs,
}
| TLExpr::StrongRelease {
released: lhs,
releaser: rhs,
} => {
analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
}
TLExpr::Aggregate { body, .. }
| TLExpr::SoftExists { body, .. }
| TLExpr::SoftForAll { body, .. }
| TLExpr::WeightedRule { rule: body, .. } => {
complexity.existential_quantifiers += 1;
analyze_complexity_impl(body, complexity, depth + 1, var_names, pred_names);
}
TLExpr::ProbabilisticChoice { alternatives } => {
for (_, expr) in alternatives {
analyze_complexity_impl(expr, complexity, depth + 1, var_names, pred_names);
}
}
_ => {}
}
}
pub fn compare_complexity(expr1: &TLExpr, expr2: &TLExpr) -> std::cmp::Ordering {
let c1 = analyze_complexity(expr1);
let c2 = analyze_complexity(expr2);
let cost1 = c1.total_cost();
let cost2 = c2.total_cost();
cost1
.partial_cmp(&cost2)
.unwrap_or(std::cmp::Ordering::Equal)
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::Term;
#[test]
fn test_simple_addition() {
let expr = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
let complexity = analyze_complexity(&expr);
assert_eq!(complexity.additions, 1);
assert_eq!(complexity.constants, 2);
assert_eq!(complexity.total_operations(), 1);
}
#[test]
fn test_nested_operations() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let expr = TLExpr::mul(
TLExpr::add(x.clone(), TLExpr::Constant(1.0)),
TLExpr::sub(x, TLExpr::Constant(2.0)),
);
let complexity = analyze_complexity(&expr);
assert_eq!(complexity.additions, 1);
assert_eq!(complexity.subtractions, 1);
assert_eq!(complexity.multiplications, 1);
assert_eq!(complexity.predicates, 2);
assert_eq!(complexity.constants, 2);
}
#[test]
fn test_logical_operations() {
let a = TLExpr::pred("a", vec![Term::var("x")]);
let b = TLExpr::pred("b", vec![Term::var("y")]);
let expr = TLExpr::and(a, TLExpr::negate(b));
let complexity = analyze_complexity(&expr);
assert_eq!(complexity.logical_ands, 1);
assert_eq!(complexity.logical_nots, 1);
assert_eq!(complexity.predicates, 2);
}
#[test]
fn test_quantifiers() {
let pred = TLExpr::pred("p", vec![Term::var("x"), Term::var("y")]);
let expr = TLExpr::exists("x", "D1", TLExpr::forall("y", "D2", pred));
let complexity = analyze_complexity(&expr);
assert_eq!(complexity.existential_quantifiers, 1);
assert_eq!(complexity.universal_quantifiers, 1);
assert_eq!(complexity.predicates, 1);
assert_eq!(complexity.unique_variables, 2);
}
#[test]
fn test_depth_calculation() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let expr = TLExpr::add(TLExpr::mul(x, TLExpr::Constant(2.0)), TLExpr::Constant(3.0));
let complexity = analyze_complexity(&expr);
assert_eq!(complexity.max_depth, 2);
}
#[test]
fn test_cost_calculation() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let expr = TLExpr::add(TLExpr::mul(x, TLExpr::Constant(2.0)), TLExpr::Constant(3.0));
let complexity = analyze_complexity(&expr);
let cost = complexity.total_cost();
assert!(cost > 0.0);
assert_eq!(cost, 3.0);
}
#[test]
fn test_gpu_weights() {
let weights = CostWeights::gpu_optimized();
assert!(weights.reduction > weights.mul);
}
#[test]
fn test_complexity_level() {
let simple = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
let complex = {
let mut expr = TLExpr::pred("x", vec![Term::var("i")]);
for _ in 0..20 {
expr = TLExpr::add(expr, TLExpr::Constant(1.0));
}
expr
};
let simple_c = analyze_complexity(&simple);
let complex_c = analyze_complexity(&complex);
assert_eq!(simple_c.complexity_level(), "trivial");
assert!(
complex_c.complexity_level() == "moderate" || complex_c.complexity_level() == "complex"
);
}
#[test]
fn test_cse_potential() {
let simple = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
let simple_c = analyze_complexity(&simple);
assert!(!simple_c.cse_potential());
let x = TLExpr::pred("x", vec![Term::var("i")]);
let complex = TLExpr::mul(
TLExpr::exp(TLExpr::add(
TLExpr::mul(x.clone(), TLExpr::Constant(2.0)),
TLExpr::Constant(1.0),
)),
TLExpr::log(TLExpr::sub(
TLExpr::div(x, TLExpr::Constant(3.0)),
TLExpr::Constant(4.0),
)),
);
let complex_c = analyze_complexity(&complex);
assert!(complex_c.cse_potential());
}
#[test]
fn test_strength_reduction_potential() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let expr = TLExpr::pow(x.clone(), TLExpr::Constant(2.0));
let c = analyze_complexity(&expr);
assert!(c.strength_reduction_potential());
let simple = TLExpr::add(x, TLExpr::Constant(1.0));
let simple_c = analyze_complexity(&simple);
assert!(!simple_c.strength_reduction_potential());
}
#[test]
fn test_compare_complexity() {
let simple = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
let x = TLExpr::pred("x", vec![Term::var("i")]);
let complex = TLExpr::mul(
TLExpr::add(x.clone(), TLExpr::Constant(1.0)),
TLExpr::sub(x, TLExpr::Constant(2.0)),
);
let ordering = compare_complexity(&simple, &complex);
assert_eq!(ordering, std::cmp::Ordering::Less);
}
#[test]
fn test_display() {
let expr = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
let complexity = analyze_complexity(&expr);
let display = format!("{}", complexity);
assert!(display.contains("Expression Complexity Analysis"));
assert!(display.contains("Total operations:"));
}
#[test]
fn test_arithmetic_vs_logical() {
let arith = TLExpr::mul(
TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0)),
TLExpr::Constant(3.0),
);
let logic = TLExpr::and(
TLExpr::or(TLExpr::pred("a", vec![]), TLExpr::pred("b", vec![])),
TLExpr::pred("c", vec![]),
);
let arith_c = analyze_complexity(&arith);
let logic_c = analyze_complexity(&logic);
assert!(arith_c.arithmetic_operations() > 0);
assert_eq!(arith_c.logical_operations(), 0);
assert_eq!(logic_c.arithmetic_operations(), 0);
assert!(logic_c.logical_operations() > 0);
}
#[test]
fn test_unique_variables() {
let expr = TLExpr::exists(
"x",
"D",
TLExpr::forall(
"y",
"D",
TLExpr::pred("p", vec![Term::var("x"), Term::var("y"), Term::var("z")]),
),
);
let c = analyze_complexity(&expr);
assert_eq!(c.unique_variables, 3); }
#[test]
fn test_unique_predicates() {
let expr = TLExpr::and(
TLExpr::pred("foo", vec![Term::var("x")]),
TLExpr::or(
TLExpr::pred("bar", vec![Term::var("y")]),
TLExpr::pred("foo", vec![Term::var("z")]), ),
);
let c = analyze_complexity(&expr);
assert_eq!(c.unique_predicates, 2); assert_eq!(c.predicates, 3); }
}