use std::collections::HashMap;
use std::fmt;
use std::time::{Duration, Instant};
use tensorlogic_ir::TLExpr;
use crate::const_prop::{ConstPropConfig, ConstantPropagator};
use crate::dead_code::{DceConfig, DeadCodeEliminator};
use crate::inline::{InlineConfig, LetInliner};
use crate::optimize::OptimizationPipeline;
use crate::rewrite::RewriteEngine;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum CompilerPassId {
ConstProp,
DeadCode,
Inline,
Algebraic,
Rewrite,
}
impl fmt::Display for CompilerPassId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CompilerPassId::ConstProp => write!(f, "ConstProp"),
CompilerPassId::DeadCode => write!(f, "DeadCode"),
CompilerPassId::Inline => write!(f, "Inline"),
CompilerPassId::Algebraic => write!(f, "Algebraic"),
CompilerPassId::Rewrite => write!(f, "Rewrite"),
}
}
}
#[derive(Debug, Clone)]
pub enum CompilerPassOrder {
CanonicalOrder,
InlineFirst,
AggressiveFold,
Custom(Vec<CompilerPassId>),
}
impl CompilerPassOrder {
pub fn to_pass_list(&self) -> Vec<CompilerPassId> {
match self {
CompilerPassOrder::CanonicalOrder => vec![
CompilerPassId::ConstProp,
CompilerPassId::DeadCode,
CompilerPassId::Inline,
CompilerPassId::Algebraic,
CompilerPassId::Rewrite,
],
CompilerPassOrder::InlineFirst => vec![
CompilerPassId::Inline,
CompilerPassId::ConstProp,
CompilerPassId::DeadCode,
CompilerPassId::Algebraic,
CompilerPassId::Rewrite,
],
CompilerPassOrder::AggressiveFold => vec![
CompilerPassId::ConstProp,
CompilerPassId::ConstProp,
CompilerPassId::DeadCode,
CompilerPassId::Inline,
CompilerPassId::ConstProp,
CompilerPassId::DeadCode,
CompilerPassId::Rewrite,
],
CompilerPassOrder::Custom(order) => order.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct CompilerPipelineConfig {
pub enable_const_prop: bool,
pub enable_dead_code: bool,
pub enable_inline: bool,
pub enable_algebraic: bool,
pub enable_rewrite: bool,
pub pass_order: CompilerPassOrder,
pub max_outer_iterations: u32,
pub const_prop_config: ConstPropConfig,
pub dce_config: DceConfig,
pub inline_config: InlineConfig,
}
impl Default for CompilerPipelineConfig {
fn default() -> Self {
Self {
enable_const_prop: true,
enable_dead_code: true,
enable_inline: true,
enable_algebraic: true,
enable_rewrite: true,
pass_order: CompilerPassOrder::CanonicalOrder,
max_outer_iterations: 3,
const_prop_config: ConstPropConfig::default(),
dce_config: DceConfig::default(),
inline_config: InlineConfig::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct CompilerPassStats {
pub pass_id: CompilerPassId,
pub wall_time: Duration,
pub nodes_before: u64,
pub nodes_after: u64,
pub reductions: u64,
}
impl CompilerPassStats {
pub fn reduction_pct(&self) -> f64 {
if self.nodes_before == 0 {
return 0.0;
}
let before = self.nodes_before as f64;
let after = self.nodes_after as f64;
((before - after) / before * 100.0).max(0.0)
}
pub fn summary(&self) -> String {
format!(
"{:<12} {:>8.3}ms nodes: {:>6} → {:>6} ({:>5.1}%) reductions: {}",
self.pass_id.to_string(),
self.wall_time.as_secs_f64() * 1_000.0,
self.nodes_before,
self.nodes_after,
self.reduction_pct(),
self.reductions,
)
}
}
#[derive(Debug, Clone)]
pub struct CompilerPipelineStats {
pub pass_stats: Vec<CompilerPassStats>,
pub total_wall_time: Duration,
pub outer_iterations: u32,
pub total_node_reduction: i64,
pub initial_node_count: u64,
pub final_node_count: u64,
}
impl CompilerPipelineStats {
pub fn overall_reduction_pct(&self) -> f64 {
if self.initial_node_count == 0 {
return 0.0;
}
let before = self.initial_node_count as f64;
let after = self.final_node_count as f64;
((before - after) / before * 100.0).max(0.0)
}
pub fn slowest_pass(&self) -> Option<&CompilerPassStats> {
self.pass_stats.iter().max_by_key(|s| s.wall_time)
}
pub fn format_table(&self) -> String {
let mut out = String::new();
out.push_str("┌──────────────────────────────────────────────────────────────────┐\n");
out.push_str("│ Pass Time(ms) Nodes Before → After Pct Reductions│\n");
out.push_str("├──────────────────────────────────────────────────────────────────┤\n");
for s in &self.pass_stats {
out.push_str(&format!("│ {}\n", s.summary()));
}
out.push_str("├──────────────────────────────────────────────────────────────────┤\n");
out.push_str(&format!(
"│ TOTAL {:>8.3}ms {:>6} nodes → {:>6} ({:>5.1}% overall) │\n",
self.total_wall_time.as_secs_f64() * 1_000.0,
self.initial_node_count,
self.final_node_count,
self.overall_reduction_pct(),
));
out.push_str("└──────────────────────────────────────────────────────────────────┘\n");
out
}
pub fn summary(&self) -> String {
format!(
"Pipeline: {} outer iterations, {:.3}ms total, {} → {} nodes ({:.1}% reduction)",
self.outer_iterations,
self.total_wall_time.as_secs_f64() * 1_000.0,
self.initial_node_count,
self.final_node_count,
self.overall_reduction_pct(),
)
}
}
#[derive(Debug, Clone)]
pub struct CompilerPipelineResult {
pub expr: TLExpr,
pub stats: CompilerPipelineStats,
}
#[derive(Debug, Clone)]
pub struct PassBenchmark {
pub pass_id: CompilerPassId,
pub runs: usize,
pub min_ns: u64,
pub max_ns: u64,
pub mean_ns: u64,
pub total_reductions: u64,
}
impl PassBenchmark {
pub fn summary(&self) -> String {
format!(
"{:<12} runs={:>4} min={:.3}ms mean={:.3}ms max={:.3}ms reductions={}",
self.pass_id.to_string(),
self.runs,
self.min_ns as f64 / 1_000_000.0,
self.mean_ns as f64 / 1_000_000.0,
self.max_ns as f64 / 1_000_000.0,
self.total_reductions,
)
}
}
pub struct CompilerPipeline {
config: CompilerPipelineConfig,
}
impl Default for CompilerPipeline {
fn default() -> Self {
Self::with_default()
}
}
impl CompilerPipeline {
pub fn new(config: CompilerPipelineConfig) -> Self {
Self { config }
}
pub fn with_default() -> Self {
Self::new(CompilerPipelineConfig::default())
}
pub fn all_passes() -> Self {
Self::with_default()
}
pub fn no_passes() -> Self {
Self::new(CompilerPipelineConfig {
enable_const_prop: false,
enable_dead_code: false,
enable_inline: false,
enable_algebraic: false,
enable_rewrite: false,
..CompilerPipelineConfig::default()
})
}
pub fn run(&self, expr: TLExpr) -> CompilerPipelineResult {
let pipeline_start = Instant::now();
let initial_node_count = Self::count_nodes(&expr);
let mut stats = CompilerPipelineStats {
pass_stats: Vec::new(),
total_wall_time: Duration::ZERO,
outer_iterations: 0,
total_node_reduction: 0,
initial_node_count,
final_node_count: initial_node_count,
};
let order = self.config.pass_order.to_pass_list();
let mut current = expr;
let max_iters = self.config.max_outer_iterations.max(1);
for _ in 0..max_iters {
let nodes_before_iter = Self::count_nodes(¤t);
current = self.run_sequence(current, &order, &mut stats);
stats.outer_iterations += 1;
let nodes_after_iter = Self::count_nodes(¤t);
if nodes_after_iter >= nodes_before_iter {
break;
}
}
let final_node_count = Self::count_nodes(¤t);
stats.final_node_count = final_node_count;
stats.total_node_reduction = initial_node_count as i64 - final_node_count as i64;
stats.total_wall_time = pipeline_start.elapsed();
CompilerPipelineResult {
expr: current,
stats,
}
}
pub fn benchmark(&self, expr: TLExpr, runs: usize) -> Vec<PassBenchmark> {
let mut timings: HashMap<String, (u64, u64, u64, u64, u64)> = HashMap::new();
let effective_runs = runs.max(1);
for _ in 0..effective_runs {
let result = self.run(expr.clone());
for ps in &result.stats.pass_stats {
let ns = ps.wall_time.as_nanos() as u64;
let key = ps.pass_id.to_string();
let entry = timings.entry(key).or_insert((0, u64::MAX, 0, 0, 0));
entry.0 += 1;
entry.1 = entry.1.min(ns);
entry.2 = entry.2.max(ns);
entry.3 = entry.3.saturating_add(ns);
entry.4 = entry.4.saturating_add(ps.reductions);
}
}
let order = self.config.pass_order.to_pass_list();
let mut seen: Vec<String> = Vec::new();
let mut benchmarks: Vec<PassBenchmark> = Vec::new();
for pass_id in &order {
let key = pass_id.to_string();
if seen.contains(&key) {
continue;
}
seen.push(key.clone());
if let Some(&(count, min_ns, max_ns, sum_ns, total_reductions)) = timings.get(&key) {
let mean_ns = sum_ns.checked_div(count).unwrap_or(0);
benchmarks.push(PassBenchmark {
pass_id: pass_id.clone(),
runs: count as usize,
min_ns,
max_ns,
mean_ns,
total_reductions,
});
}
}
benchmarks
}
fn run_sequence(
&self,
mut expr: TLExpr,
order: &[CompilerPassId],
stats: &mut CompilerPipelineStats,
) -> TLExpr {
for pass_id in order {
expr = self.run_single_pass(pass_id, expr, stats);
}
expr
}
fn run_single_pass(
&self,
pass_id: &CompilerPassId,
expr: TLExpr,
stats: &mut CompilerPipelineStats,
) -> TLExpr {
let nodes_before = Self::count_nodes(&expr);
let t0 = Instant::now();
let (new_expr, reductions) = match pass_id {
CompilerPassId::ConstProp => {
if !self.config.enable_const_prop {
return expr;
}
let propagator = ConstantPropagator::new(self.config.const_prop_config.clone());
let (out, s) = propagator.run(expr);
let r = s.total_folds();
(out, r)
}
CompilerPassId::DeadCode => {
if !self.config.enable_dead_code {
return expr;
}
let eliminator = DeadCodeEliminator::new(self.config.dce_config.clone());
let (out, s) = eliminator.run(expr);
let r = s.total_eliminations();
(out, r)
}
CompilerPassId::Inline => {
if !self.config.enable_inline {
return expr;
}
let inliner = LetInliner::new(self.config.inline_config.clone());
let (out, s) = inliner.run(expr);
let r = s.total();
(out, r)
}
CompilerPassId::Algebraic => {
if !self.config.enable_algebraic {
return expr;
}
let alg_pipeline = OptimizationPipeline::new();
let (out, s) = alg_pipeline.optimize(&expr);
let r = s.total_optimizations() as u64;
(out, r)
}
CompilerPassId::Rewrite => {
if !self.config.enable_rewrite {
return expr;
}
let engine = RewriteEngine::new().add_all_builtin_rules();
let (out, s) = engine.rewrite(expr);
let r = s.total_rewrites;
(out, r)
}
};
let wall_time = t0.elapsed();
let nodes_after = Self::count_nodes(&new_expr);
stats.pass_stats.push(CompilerPassStats {
pass_id: pass_id.clone(),
wall_time,
nodes_before,
nodes_after,
reductions,
});
new_expr
}
fn count_nodes(expr: &TLExpr) -> u64 {
DeadCodeEliminator::count_nodes(expr)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::{TLExpr, Term};
fn simple_constant_expr() -> TLExpr {
TLExpr::add(
TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
TLExpr::Constant(4.0),
)
}
fn dead_branch_expr() -> TLExpr {
TLExpr::and(
TLExpr::Constant(1.0),
TLExpr::pred("p", vec![Term::var("x")]),
)
}
fn let_binding_expr() -> TLExpr {
TLExpr::let_binding("y", TLExpr::Constant(5.0), TLExpr::pred("y", vec![]))
}
fn non_trivial_expr() -> TLExpr {
TLExpr::and(
TLExpr::Constant(1.0),
TLExpr::negate(TLExpr::negate(TLExpr::pred("p", vec![Term::var("x")]))),
)
}
#[test]
fn test_compiler_pipeline_config_default() {
let cfg = CompilerPipelineConfig::default();
assert!(cfg.enable_const_prop);
assert!(cfg.enable_dead_code);
assert!(cfg.enable_inline);
assert!(cfg.enable_algebraic);
assert!(cfg.enable_rewrite);
assert_eq!(cfg.max_outer_iterations, 3);
}
#[test]
fn test_compiler_pipeline_no_passes() {
let pipeline = CompilerPipeline::no_passes();
let expr = simple_constant_expr();
let result = pipeline.run(expr.clone());
assert_eq!(format!("{:?}", result.expr), format!("{:?}", expr),);
}
#[test]
fn test_compiler_pipeline_const_prop_only() {
let cfg = CompilerPipelineConfig {
enable_const_prop: true,
enable_dead_code: false,
enable_inline: false,
enable_algebraic: false,
enable_rewrite: false,
max_outer_iterations: 1,
..CompilerPipelineConfig::default()
};
let pipeline = CompilerPipeline::new(cfg);
let expr = simple_constant_expr();
let result = pipeline.run(expr);
assert!(matches!(result.expr, TLExpr::Constant(_)));
}
#[test]
fn test_compiler_pipeline_dead_code_only() {
let cfg = CompilerPipelineConfig {
enable_const_prop: false,
enable_dead_code: true,
enable_inline: false,
enable_algebraic: false,
enable_rewrite: false,
max_outer_iterations: 1,
..CompilerPipelineConfig::default()
};
let pipeline = CompilerPipeline::new(cfg);
let expr = dead_branch_expr();
let result = pipeline.run(expr);
assert!(matches!(result.expr, TLExpr::Pred { .. }));
}
#[test]
fn test_compiler_pipeline_inline_only() {
let cfg = CompilerPipelineConfig {
enable_const_prop: false,
enable_dead_code: false,
enable_inline: true,
enable_algebraic: false,
enable_rewrite: false,
max_outer_iterations: 1,
..CompilerPipelineConfig::default()
};
let pipeline = CompilerPipeline::new(cfg);
let expr = let_binding_expr();
let result = pipeline.run(expr);
assert!(matches!(result.expr, TLExpr::Constant(v) if (v - 5.0).abs() < 1e-12));
}
#[test]
fn test_compiler_pipeline_all_passes() {
let pipeline = CompilerPipeline::all_passes();
let expr = non_trivial_expr();
let result = pipeline.run(expr);
assert!(result.stats.outer_iterations > 0);
}
#[test]
fn test_compiler_pipeline_result_has_stats() {
let pipeline = CompilerPipeline::with_default();
let expr = simple_constant_expr();
let result = pipeline.run(expr);
assert!(result.stats.initial_node_count > 0);
assert!(!result.stats.pass_stats.is_empty());
}
#[test]
fn test_pass_stats_reduction_pct() {
let s = CompilerPassStats {
pass_id: CompilerPassId::ConstProp,
wall_time: Duration::from_millis(1),
nodes_before: 100,
nodes_after: 80,
reductions: 5,
};
let pct = s.reduction_pct();
assert!((pct - 20.0).abs() < 1e-6, "expected 20%, got {pct}");
}
#[test]
fn test_pass_stats_reduction_pct_zero_before() {
let s = CompilerPassStats {
pass_id: CompilerPassId::DeadCode,
wall_time: Duration::ZERO,
nodes_before: 0,
nodes_after: 0,
reductions: 0,
};
assert_eq!(s.reduction_pct(), 0.0);
}
#[test]
fn test_pass_stats_summary_nonempty() {
let s = CompilerPassStats {
pass_id: CompilerPassId::Inline,
wall_time: Duration::from_micros(500),
nodes_before: 10,
nodes_after: 8,
reductions: 2,
};
let summary = s.summary();
assert!(!summary.is_empty());
assert!(summary.contains("Inline"));
}
#[test]
fn test_pipeline_stats_overall_reduction() {
let pipeline = CompilerPipeline::with_default();
let expr = simple_constant_expr();
let result = pipeline.run(expr);
let initial = result.stats.initial_node_count;
let final_count = result.stats.final_node_count;
assert!(
initial >= final_count,
"pipeline should not increase node count"
);
let pct = result.stats.overall_reduction_pct();
assert!(pct >= 0.0);
}
#[test]
fn test_pipeline_stats_format_table() {
let pipeline = CompilerPipeline::with_default();
let expr = simple_constant_expr();
let result = pipeline.run(expr);
let table = result.stats.format_table();
assert!(
table.contains("Pass") || table.contains("TOTAL"),
"table should contain headers, got: {table}"
);
}
#[test]
fn test_pipeline_stats_summary_nonempty() {
let pipeline = CompilerPipeline::with_default();
let expr = simple_constant_expr();
let result = pipeline.run(expr);
let summary = result.stats.summary();
assert!(!summary.is_empty());
assert!(summary.contains("Pipeline"));
}
#[test]
fn test_pipeline_stats_slowest_pass() {
let pipeline = CompilerPipeline::with_default();
let expr = simple_constant_expr();
let result = pipeline.run(expr);
assert!(result.stats.slowest_pass().is_some());
}
#[test]
fn test_compiler_pipeline_canonical_order() {
let cfg = CompilerPipelineConfig {
pass_order: CompilerPassOrder::CanonicalOrder,
..CompilerPipelineConfig::default()
};
let pipeline = CompilerPipeline::new(cfg);
let result = pipeline.run(non_trivial_expr());
assert!(result.stats.outer_iterations >= 1);
}
#[test]
fn test_compiler_pipeline_inline_first() {
let cfg = CompilerPipelineConfig {
pass_order: CompilerPassOrder::InlineFirst,
..CompilerPipelineConfig::default()
};
let pipeline = CompilerPipeline::new(cfg);
let result = pipeline.run(let_binding_expr());
assert!(result.stats.outer_iterations >= 1);
}
#[test]
fn test_compiler_pipeline_custom_order() {
let cfg = CompilerPipelineConfig {
pass_order: CompilerPassOrder::Custom(vec![
CompilerPassId::ConstProp,
CompilerPassId::DeadCode,
]),
max_outer_iterations: 1,
..CompilerPipelineConfig::default()
};
let pipeline = CompilerPipeline::new(cfg);
let result = pipeline.run(simple_constant_expr());
assert_eq!(result.stats.pass_stats.len(), 2);
}
#[test]
fn test_compiler_pipeline_outer_iterations() {
let cfg = CompilerPipelineConfig {
max_outer_iterations: 5,
..CompilerPipelineConfig::default()
};
let pipeline = CompilerPipeline::new(cfg);
let result = pipeline.run(simple_constant_expr());
assert!(result.stats.outer_iterations <= 5);
assert!(result.stats.outer_iterations >= 1);
}
#[test]
fn test_benchmark_runs_n_times() {
let pipeline = CompilerPipeline::with_default();
let expr = simple_constant_expr();
let benchmarks = pipeline.benchmark(expr, 4);
let order_len = CompilerPassOrder::CanonicalOrder.to_pass_list().len();
assert_eq!(benchmarks.len(), order_len);
for b in &benchmarks {
assert!(
b.runs >= 4,
"expected >=4 runs for {}, got {}",
b.pass_id,
b.runs
);
}
}
#[test]
fn test_pass_benchmark_summary_nonempty() {
let pipeline = CompilerPipeline::with_default();
let benchmarks = pipeline.benchmark(simple_constant_expr(), 2);
for b in &benchmarks {
let summary = b.summary();
assert!(!summary.is_empty());
}
}
#[test]
fn test_pipeline_idempotent() {
let pipeline = CompilerPipeline::with_default();
let expr = non_trivial_expr();
let first = pipeline.run(expr);
let second = pipeline.run(first.expr.clone());
assert!(
second.stats.final_node_count <= first.stats.final_node_count,
"second run produced more nodes than first"
);
}
#[test]
fn test_compiler_pipeline_aggressive_fold() {
let cfg = CompilerPipelineConfig {
pass_order: CompilerPassOrder::AggressiveFold,
max_outer_iterations: 2,
..CompilerPipelineConfig::default()
};
let pipeline = CompilerPipeline::new(cfg);
let result = pipeline.run(simple_constant_expr());
assert!(result.stats.outer_iterations >= 1);
}
#[test]
fn test_benchmark_timing_invariants() {
let pipeline = CompilerPipeline::with_default();
let benchmarks = pipeline.benchmark(simple_constant_expr(), 3);
for b in &benchmarks {
assert!(b.min_ns <= b.mean_ns, "min_ns > mean_ns for {}", b.pass_id);
assert!(b.mean_ns <= b.max_ns, "mean_ns > max_ns for {}", b.pass_id);
}
}
}