pub mod config;
pub mod helpers;
pub mod substitute;
pub mod traversal;
pub use config::{InlineConfig, InlineStats};
pub use traversal::LetInliner;
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::TLExpr;
fn var(name: &str) -> TLExpr {
TLExpr::pred(name, vec![])
}
fn deep_add(depth: usize) -> TLExpr {
if depth == 0 {
return TLExpr::Constant(1.0);
}
TLExpr::add(deep_add(depth - 1), TLExpr::Constant(1.0))
}
#[test]
fn test_inline_stats_default() {
let stats = InlineStats::default();
assert_eq!(stats.single_use_inlines, 0);
assert_eq!(stats.constant_inlines, 0);
assert_eq!(stats.variable_inlines, 0);
assert_eq!(stats.total(), 0);
assert_eq!(stats.nodes_before, 0);
assert_eq!(stats.nodes_after, 0);
assert_eq!(stats.passes, 0);
}
#[test]
fn test_inline_stats_summary_nonempty() {
let stats = InlineStats {
single_use_inlines: 2,
constant_inlines: 3,
variable_inlines: 1,
nodes_before: 20,
nodes_after: 14,
passes: 2,
};
let summary = stats.summary();
assert!(summary.contains("2 passes"));
assert!(summary.contains("14/20"));
assert!(summary.contains("2 single-use"));
assert!(summary.contains("3 constant"));
assert!(summary.contains("1 variable-alias"));
}
#[test]
fn test_total_inlines() {
let stats = InlineStats {
single_use_inlines: 4,
constant_inlines: 5,
variable_inlines: 3,
..Default::default()
};
assert_eq!(stats.total(), 12);
}
#[test]
fn test_reduction_pct() {
let stats = InlineStats {
nodes_before: 100,
nodes_after: 60,
..Default::default()
};
let pct = stats.reduction_pct();
assert!((pct - 40.0).abs() < 1e-9, "expected ~40%, got {pct}");
}
#[test]
fn test_inline_config_default() {
let cfg = InlineConfig::default();
assert!(cfg.inline_single_use);
assert!(cfg.inline_constants);
assert!(cfg.inline_vars);
assert_eq!(cfg.max_passes, 20);
assert_eq!(cfg.max_inline_depth, 10);
}
#[test]
fn test_inliner_with_default() {
let inliner = LetInliner::with_default();
assert!(inliner.config.inline_single_use);
}
#[test]
fn test_count_free_occurrences_zero() {
let expr = var("p");
assert_eq!(LetInliner::count_free_occurrences("z", &expr), 0);
}
#[test]
fn test_count_free_occurrences_one() {
let expr = var("x");
assert_eq!(LetInliner::count_free_occurrences("x", &expr), 1);
}
#[test]
fn test_count_free_occurrences_multi() {
let expr = TLExpr::add(var("x"), var("x"));
assert_eq!(LetInliner::count_free_occurrences("x", &expr), 2);
}
#[test]
fn test_substitute_simple() {
let body = var("x");
let result = LetInliner::substitute("x", &TLExpr::Constant(7.0), body);
assert_eq!(result, TLExpr::Constant(7.0));
}
#[test]
fn test_substitute_shadowed() {
let inner = TLExpr::exists("x", "D", var("x"));
let result = LetInliner::substitute("x", &TLExpr::Constant(7.0), inner.clone());
assert_eq!(result, inner);
}
#[test]
fn test_inline_constant_binding() {
let inliner = LetInliner::with_default();
let expr = TLExpr::let_binding("x", TLExpr::Constant(5.0), TLExpr::add(var("x"), var("x")));
let (result, stats) = inliner.run(expr);
assert_eq!(stats.constant_inlines, 1);
assert_eq!(
result,
TLExpr::add(TLExpr::Constant(5.0), TLExpr::Constant(5.0))
);
}
#[test]
fn test_inline_variable_binding() {
let inliner = LetInliner::with_default();
let expr = TLExpr::let_binding("x", var("y"), TLExpr::add(var("x"), TLExpr::Constant(1.0)));
let (result, stats) = inliner.run(expr);
assert_eq!(stats.variable_inlines, 1);
assert_eq!(result, TLExpr::add(var("y"), TLExpr::Constant(1.0)));
}
#[test]
fn test_inline_single_use() {
let inliner = LetInliner::with_default();
let binding_val = TLExpr::add(TLExpr::Constant(3.0), TLExpr::Constant(4.0));
let expr = TLExpr::let_binding("x", binding_val.clone(), TLExpr::sqrt(var("x")));
let (result, stats) = inliner.run(expr);
assert_eq!(stats.single_use_inlines, 1);
assert_eq!(result, TLExpr::sqrt(binding_val));
}
#[test]
fn test_no_inline_multi_use_by_default() {
let cfg = InlineConfig {
inline_single_use: true,
inline_constants: false,
inline_vars: false,
max_passes: 5,
max_inline_depth: 10,
};
let inliner = LetInliner::new(cfg);
let binding_val = TLExpr::add(TLExpr::Constant(3.0), TLExpr::Constant(4.0));
let expr = TLExpr::let_binding("x", binding_val.clone(), TLExpr::add(var("x"), var("x")));
let (_result, stats) = inliner.run(expr);
assert_eq!(stats.single_use_inlines, 0);
assert_eq!(stats.total(), 0);
}
#[test]
fn test_inline_depth_limit() {
let cfg = InlineConfig {
inline_single_use: true,
inline_constants: true,
inline_vars: true,
max_passes: 5,
max_inline_depth: 3,
};
let inliner = LetInliner::new(cfg);
let deep = deep_add(5);
let expr = TLExpr::let_binding("x", deep, TLExpr::sqrt(var("x")));
let (_result, stats) = inliner.run(expr);
assert_eq!(stats.total(), 0, "deep binding should not be inlined");
}
#[test]
fn test_shadowing_respected() {
let inliner = LetInliner::with_default();
let expr = TLExpr::let_binding(
"x",
TLExpr::Constant(5.0),
TLExpr::let_binding("x", TLExpr::Constant(2.0), var("x")),
);
let (result, stats) = inliner.run(expr);
assert_eq!(result, TLExpr::Constant(2.0));
assert!(stats.constant_inlines >= 2);
}
#[test]
fn test_run_fixed_point() {
let inliner = LetInliner::with_default();
let expr = TLExpr::let_binding(
"a",
TLExpr::Constant(1.0),
TLExpr::let_binding("b", var("a"), TLExpr::add(var("b"), var("b"))),
);
let (result, stats) = inliner.run(expr);
assert_eq!(
result,
TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(1.0))
);
assert!(stats.total() >= 2);
}
#[test]
fn test_run_preserves_non_let() {
let inliner = LetInliner::with_default();
let expr = TLExpr::and(TLExpr::pred("P", vec![]), TLExpr::Constant(1.0));
let (result, stats) = inliner.run(expr.clone());
assert_eq!(result, expr);
assert_eq!(stats.total(), 0);
}
#[test]
fn test_inline_disabled() {
let cfg = InlineConfig {
inline_single_use: false,
inline_constants: false,
inline_vars: false,
max_passes: 5,
max_inline_depth: 10,
};
let inliner = LetInliner::new(cfg);
let expr = TLExpr::let_binding("x", TLExpr::Constant(99.0), var("x"));
let (_result, stats) = inliner.run(expr);
assert_eq!(stats.total(), 0, "all flags disabled => no inlining");
}
#[test]
fn test_reduction_pct_after_inlining() {
let inliner = LetInliner::with_default();
let expr = TLExpr::let_binding("x", TLExpr::Constant(3.0), TLExpr::add(var("x"), var("x")));
let (_, stats) = inliner.run(expr);
assert!(
stats.nodes_after <= stats.nodes_before,
"nodes should not grow: before={}, after={}",
stats.nodes_before,
stats.nodes_after
);
assert!(stats.reduction_pct() > 0.0, "should have some reduction");
}
}