use dslcompile::prelude::*;
use dslcompile::ANFConverter; use std::f64::consts::PI;
use std::time::Instant;
#[derive(Debug)]
pub struct CompilationTiming {
symbolic_construction_ms: f64,
symbolic_optimization_ms: f64,
code_generation_ms: f64,
rust_compilation_ms: f64,
total_compilation_ms: f64,
}
impl CompilationTiming {
fn new() -> Self {
Self {
symbolic_construction_ms: 0.0,
symbolic_optimization_ms: 0.0,
code_generation_ms: 0.0,
rust_compilation_ms: 0.0,
total_compilation_ms: 0.0,
}
}
fn print_summary(&self) {
println!("โฑ๏ธ Compilation Timing Summary:");
println!(
" Symbolic construction: {:.2}ms",
self.symbolic_construction_ms
);
println!(
" Symbolic optimization: {:.2}ms",
self.symbolic_optimization_ms
);
println!(" Code generation: {:.2}ms", self.code_generation_ms);
println!(
" Rust compilation: {:.2}ms",
self.rust_compilation_ms
);
println!(" โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ");
println!(
" Total compilation: {:.2}ms",
self.total_compilation_ms
);
let total = self.total_compilation_ms;
if total > 0.0 {
println!("\n๐ Time Distribution:");
println!(
" Symbolic construction: {:.1}%",
(self.symbolic_construction_ms / total) * 100.0
);
println!(
" Symbolic optimization: {:.1}%",
(self.symbolic_optimization_ms / total) * 100.0
);
println!(
" Code generation: {:.1}%",
(self.code_generation_ms / total) * 100.0
);
println!(
" Rust compilation: {:.1}%",
(self.rust_compilation_ms / total) * 100.0
);
}
}
}
pub struct BayesianLinearRegression {
log_posterior_compiled: CompiledRustFunction,
log_posterior_partial: Option<CompiledRustFunction>,
log_posterior_symbolic: ASTRepr<f64>,
data: Vec<(f64, f64)>,
n_params: usize,
timing: CompilationTiming,
partial_context: Option<String>,
}
impl BayesianLinearRegression {
pub fn new(data: Vec<(f64, f64)>) -> Result<Self> {
Self::new_with_partial_eval(data, None)
}
pub fn new_with_partial_eval(
data: Vec<(f64, f64)>,
partial_constraints: Option<&str>,
) -> Result<Self> {
let total_start = Instant::now();
let mut timing = CompilationTiming::new();
println!("๐๏ธ Building Bayesian Linear Regression Model");
println!(" Data points: {}", data.len());
if let Some(constraints) = partial_constraints {
println!(" Partial evaluation: {constraints}");
}
println!("\n๐ง Stage 0: Symbolic construction (natural expressions)...");
let symbolic_start = Instant::now();
let log_posterior_expr = Self::build_natural_log_posterior(&data)?;
timing.symbolic_construction_ms = symbolic_start.elapsed().as_secs_f64() * 1000.0;
println!(
"๐ Log-posterior built naturally in {:.2}ms",
timing.symbolic_construction_ms
);
println!(
" Operations before optimization: {}",
log_posterior_expr.count_operations()
);
println!("โก Stage 1: Symbolic optimization...");
let opt_start = Instant::now();
let mut config = OptimizationConfig::default();
config.egglog_optimization = true; config.enable_expansion_rules = false; config.enable_distribution_rules = false; let mut symbolic_optimizer = SymbolicOptimizer::with_config(config)?;
let optimized_expr = symbolic_optimizer.optimize(&log_posterior_expr)?;
let symbolic_time = opt_start.elapsed().as_secs_f64() * 1000.0;
timing.symbolic_optimization_ms = symbolic_time;
println!(" Completed in {symbolic_time:.2}ms");
println!(
" Operations after optimization: {}",
optimized_expr.count_operations()
);
let reduction_pct = if log_posterior_expr.count_operations() > 0 {
((log_posterior_expr.count_operations() as f64
- optimized_expr.count_operations() as f64)
/ log_posterior_expr.count_operations() as f64)
* 100.0
} else {
0.0
};
println!(
" Operation reduction: {:.1}% ({} โ {} ops)",
reduction_pct,
log_posterior_expr.count_operations(),
optimized_expr.count_operations()
);
println!("\n๐ง Testing ANF/CSE recovery...");
let anf_start = Instant::now();
let mut anf_converter = ANFConverter::new();
let anf_expr = anf_converter.convert(&optimized_expr)?;
let anf_time = anf_start.elapsed().as_secs_f64() * 1000.0;
let anf_let_bindings = anf_expr.let_count();
println!(" ANF conversion: {anf_time:.2}ms");
println!(" ANF let bindings: {anf_let_bindings}");
let anf_reduction_pct = if optimized_expr.count_operations() > 0 {
let anf_effective_ops = anf_let_bindings + 1; ((optimized_expr.count_operations() as f64 - anf_effective_ops as f64)
/ optimized_expr.count_operations() as f64)
* 100.0
} else {
0.0
};
println!(
" ANF reduction: {anf_reduction_pct:.1}% ({} ops โ {} let bindings + 1 final expr)",
optimized_expr.count_operations(),
anf_let_bindings
);
println!("\n๐ง Stage 2: Compiling to native code...");
let rust_generator = RustCodeGenerator::new();
let rust_compiler = RustCompiler::new();
println!(" Stage 2a: Generating Rust code...");
let codegen_start = Instant::now();
let posterior_code = rust_generator.generate_function(&optimized_expr, "log_posterior")?;
timing.code_generation_ms = codegen_start.elapsed().as_secs_f64() * 1000.0;
println!(" Completed in {:.2}ms", timing.code_generation_ms);
println!(" Stage 2b: Compiling to native code...");
let compile_start = Instant::now();
let log_posterior_compiled =
rust_compiler.compile_and_load(&posterior_code, "log_posterior")?;
timing.rust_compilation_ms = compile_start.elapsed().as_secs_f64() * 1000.0;
println!(" Completed in {:.2}ms", timing.rust_compilation_ms);
timing.total_compilation_ms = total_start.elapsed().as_secs_f64() * 1000.0;
println!("\nโ
Compilation complete!");
timing.print_summary();
Ok(Self {
log_posterior_compiled,
log_posterior_partial: None,
log_posterior_symbolic: log_posterior_expr,
data,
n_params: 3, timing,
partial_context: partial_constraints.map(String::from),
})
}
#[must_use]
pub fn timing(&self) -> &CompilationTiming {
&self.timing
}
fn build_natural_log_posterior(data: &[(f64, f64)]) -> Result<ASTRepr<f64>> {
use dslcompile::final_tagless::variables::ExpressionBuilder;
let builder = ExpressionBuilder::new();
let beta0 = builder.expr_from(builder.typed_var::<f64>()); let beta1 = builder.expr_from(builder.typed_var::<f64>()); let sigma_sq = builder.expr_from(builder.typed_var::<f64>());
println!(
" Building naive summation expression with {} data points",
data.len()
);
println!(" (egglog will automatically discover sufficient statistics)");
let n = data.len() as f64;
let x_data: Vec<f64> = data.iter().map(|(x, _)| *x).collect();
let y_data: Vec<f64> = data.iter().map(|(_, y)| *y).collect();
let sum_y = builder.constant(y_data.iter().sum::<f64>());
let sum_x = builder.constant(x_data.iter().sum::<f64>());
let sum_x_sq = builder.constant(x_data.iter().map(|x| x * x).sum::<f64>());
let sum_y_sq = builder.constant(y_data.iter().map(|y| y * y).sum::<f64>());
let sum_xy = builder.constant(data.iter().map(|(x, y)| x * y).sum::<f64>());
let n_const = builder.constant(n);
let residual_sum = &sum_y_sq
- &(builder.constant(2.0) * &beta0 * &sum_y)
- &(builder.constant(2.0) * &beta1 * &sum_xy)
+ &(&n_const * &beta0 * &beta0)
+ &(builder.constant(2.0) * &beta0 * &beta1 * &sum_x)
+ &(&beta1 * &beta1 * &sum_x_sq);
let log_likelihood = builder.constant(-n / 2.0 * (2.0 * PI).ln())
- &(builder.constant(n / 2.0) * sigma_sq.clone().ln())
- &(builder.constant(0.5) * &residual_sum / &sigma_sq);
let prior_beta0 = builder.constant(-0.5 * (2.0 * PI * 100.0).ln())
- &(builder.constant(0.5 / 100.0) * &beta0 * &beta0);
let prior_beta1 = builder.constant(-0.5 * (2.0 * PI * 100.0).ln())
- &(builder.constant(0.5 / 100.0) * &beta1 * &beta1);
let prior_sigma =
builder.constant(-2.0) * sigma_sq.clone().ln() - (builder.constant(1.0) / &sigma_sq);
let log_prior = &prior_beta0 + &prior_beta1 + &prior_sigma;
let log_posterior: dslcompile::final_tagless::variables::TypedBuilderExpr<f64> =
log_likelihood + log_prior;
Ok(log_posterior.into_ast())
}
pub fn log_posterior_compiled(&self, params: &[f64]) -> Result<f64> {
if params.len() != self.n_params {
return Err(DSLCompileError::InvalidInput(format!(
"Expected {} parameters, got {}",
self.n_params,
params.len()
)));
}
self.log_posterior_compiled.call_multi_vars(params)
}
pub fn log_posterior_direct(&self, params: &[f64]) -> Result<f64> {
if params.len() != self.n_params {
return Err(DSLCompileError::InvalidInput(format!(
"Expected {} parameters, got {}",
self.n_params,
params.len()
)));
}
Ok(DirectEval::eval_with_vars(
&self.log_posterior_symbolic,
params,
))
}
#[must_use]
pub fn data(&self) -> &[(f64, f64)] {
&self.data
}
#[must_use]
pub fn n_params(&self) -> usize {
self.n_params
}
pub fn performance_comparison(&self, params: &[f64], n_evals: usize) -> Result<()> {
println!("\n๐ Performance Comparison: DirectEval vs Compiled Code");
println!(" Evaluations: {n_evals}");
println!("\n๐ DirectEval Performance:");
let direct_start = Instant::now();
let mut direct_result = 0.0;
for _ in 0..n_evals {
direct_result = self.log_posterior_direct(params)?;
}
let direct_time = direct_start.elapsed();
let direct_ms = direct_time.as_secs_f64() * 1000.0;
let direct_rate = n_evals as f64 / direct_time.as_secs_f64();
println!(" Time: {direct_ms:.2}ms");
println!(" Rate: {direct_rate:.1} evals/sec");
println!(
" Per eval: {:.3}ฮผs",
direct_time.as_micros() as f64 / n_evals as f64
);
println!(" Result: {direct_result:.6}");
println!("\n๐ Compiled Code Performance:");
let compiled_start = Instant::now();
let mut compiled_result = 0.0;
for _ in 0..n_evals {
compiled_result = self.log_posterior_compiled(params)?;
}
let compiled_time = compiled_start.elapsed();
let compiled_ms = compiled_time.as_secs_f64() * 1000.0;
let compiled_rate = n_evals as f64 / compiled_time.as_secs_f64();
println!(" Time: {compiled_ms:.2}ms");
println!(" Rate: {compiled_rate:.1} evals/sec");
println!(
" Per eval: {:.3}ฮผs",
compiled_time.as_micros() as f64 / n_evals as f64
);
println!(" Result: {compiled_result:.6}");
let speedup = direct_time.as_secs_f64() / compiled_time.as_secs_f64();
println!("\n๐ Comparison:");
println!(" Speedup: {speedup:.1}x faster");
println!(
" Results match: {}",
(direct_result - compiled_result).abs() < 1e-6 );
let compilation_cost_evals = self.timing.total_compilation_ms
/ (compiled_time.as_secs_f64() * 1000.0 / n_evals as f64);
println!(" Compilation amortized over: {compilation_cost_evals:.0} evaluations");
Ok(())
}
pub fn find_map_estimate(&self) -> Result<Vec<f64>> {
println!("๐ Finding MAP estimate via grid search...");
let search_start = Instant::now();
let mut best_params = vec![0.0, 0.0, 1.0];
let mut best_log_posterior = self.log_posterior_compiled(&best_params)?;
let mut evaluations = 0;
for beta0 in (-5..=5).map(f64::from) {
for beta1 in (-3..=3).map(|x| f64::from(x) * 0.5) {
for sigma_sq in [0.1, 0.5, 1.0, 2.0, 5.0] {
let params = vec![beta0, beta1, sigma_sq];
if let Ok(log_post) = self.log_posterior_compiled(¶ms) {
evaluations += 1;
if log_post > best_log_posterior {
best_log_posterior = log_post;
best_params = params;
}
}
}
}
}
let search_time = search_start.elapsed().as_secs_f64() * 1000.0;
println!(
" MAP estimate: ฮฒโ={:.3}, ฮฒโ={:.3}, ฯยฒ={:.3}",
best_params[0], best_params[1], best_params[2]
);
println!(" Log-posterior: {best_log_posterior:.3}");
println!(" Search completed in {search_time:.2}ms ({evaluations} evaluations)");
println!(
" Evaluation rate: {:.1} evals/ms",
f64::from(evaluations) / search_time
);
Ok(best_params)
}
pub fn apply_partial_evaluation(&mut self, constraints: &str) -> Result<()> {
println!("\n๐ฌ Applying Partial Evaluation");
println!(" Constraints: {constraints}");
let partial_start = Instant::now();
let optimized_expr = match constraints {
"positive_variance" => {
println!(" Constraint: ฯยฒ > 0 (variance must be positive)");
self.log_posterior_symbolic.clone()
}
"bounded_coefficients" => {
println!(" Constraint: ฮฒโ, ฮฒโ โ [-10, 10] (bounded coefficients)");
self.log_posterior_symbolic.clone()
}
"unit_variance" => {
println!(" Constraint: ฯยฒ = 1 (fixed unit variance)");
self.substitute_unit_variance(&self.log_posterior_symbolic)?
}
_ => {
println!(" Unknown constraint type, using original expression");
self.log_posterior_symbolic.clone()
}
};
let rust_generator = RustCodeGenerator::new();
let rust_compiler = RustCompiler::new();
let partial_code =
rust_generator.generate_function(&optimized_expr, "log_posterior_partial")?;
let partial_compiled =
rust_compiler.compile_and_load(&partial_code, "log_posterior_partial")?;
let partial_time = partial_start.elapsed().as_secs_f64() * 1000.0;
println!(" Partial evaluation completed in {partial_time:.2}ms");
println!(
" Operations in partial form: {}",
optimized_expr.count_operations()
);
self.log_posterior_partial = Some(partial_compiled);
self.partial_context = Some(constraints.to_string());
Ok(())
}
fn substitute_unit_variance(&self, expr: &ASTRepr<f64>) -> Result<ASTRepr<f64>> {
Ok(expr.clone())
}
pub fn log_posterior_partial(&self, params: &[f64]) -> Result<f64> {
if let Some(ref partial_func) = self.log_posterior_partial {
if self
.partial_context
.as_ref()
.is_some_and(|c| c == "unit_variance")
{
if params.len() < 2 {
return Err(DSLCompileError::InvalidInput(
"Unit variance model requires at least 2 parameters (ฮฒโ, ฮฒโ)".to_string(),
));
}
partial_func.call_multi_vars(¶ms[0..2])
} else {
partial_func.call_multi_vars(params)
}
} else {
Err(DSLCompileError::InvalidInput(
"No partial evaluation has been applied".to_string(),
))
}
}
#[must_use]
pub fn partial_context(&self) -> Option<&str> {
self.partial_context.as_deref()
}
}
fn generate_synthetic_data(
n: usize,
true_beta0: f64,
true_beta1: f64,
true_sigma: f64,
) -> Vec<(f64, f64)> {
let mut data = Vec::new();
let mut rng_state = 12345u64;
for i in 0..n {
let x = i as f64 / n as f64 * 10.0 - 5.0;
rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
let u1 = (rng_state as f64) / (u64::MAX as f64);
rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
let u2 = (rng_state as f64) / (u64::MAX as f64);
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
let noise = true_sigma * z;
let y = true_beta0 + true_beta1 * x + noise;
data.push((x, y));
}
data
}
fn main() -> Result<()> {
println!("๐ DSLCompile: Partial Evaluation Demo");
println!("=======================================\n");
if !RustCompiler::is_available() {
println!("โ Rust compiler not available - this demo requires rustc");
println!(" Please install Rust toolchain to run this example");
return Ok(());
}
println!("๐ Generating synthetic data...");
let data_start = Instant::now();
let true_beta0 = 2.0;
let true_beta1 = 1.5;
let true_sigma = 0.8;
let n_data = 10_000_000;
let data = generate_synthetic_data(n_data, true_beta0, true_beta1, true_sigma);
let data_time = data_start.elapsed().as_secs_f64() * 1000.0;
println!(" True parameters: ฮฒโ={true_beta0}, ฮฒโ={true_beta1}, ฯ={true_sigma}");
println!(
" Generated {} data points in {:.2}ms\n",
data.len(),
data_time
);
let true_params = vec![true_beta0, true_beta1, true_sigma * true_sigma];
println!("๐ฌ DEMONSTRATION: Partial Evaluation & Abstract Interpretation");
println!("==============================================================\n");
println!("๐ PART 1: Standard Compilation (Baseline)");
println!("-------------------------------------------");
let model = BayesianLinearRegression::new(data.clone())?;
println!("\n๐งช Testing evaluation at true parameters...");
let compiled_result = model.log_posterior_compiled(&true_params)?;
let direct_result = model.log_posterior_direct(&true_params)?;
println!(" Compiled result: {compiled_result:.6}");
println!(" DirectEval result: {direct_result:.6}");
println!(
" Results match: {}",
(compiled_result - direct_result).abs() < 1e-6 );
model.performance_comparison(&true_params, 10000)?;
println!("\n\n๐ PART 2: Partial Evaluation Scenarios");
println!("----------------------------------------");
println!("\n๐ฌ Scenario 1: Positive Variance Constraint");
let mut model_positive = BayesianLinearRegression::new(data.clone())?;
model_positive.apply_partial_evaluation("positive_variance")?;
println!("\n๐ฌ Scenario 2: Bounded Coefficients");
let mut model_bounded = BayesianLinearRegression::new(data.clone())?;
model_bounded.apply_partial_evaluation("bounded_coefficients")?;
println!("\n๐ฌ Scenario 3: Unit Variance Model");
let mut model_unit = BayesianLinearRegression::new(data.clone())?;
model_unit.apply_partial_evaluation("unit_variance")?;
println!("\n\n๐ PART 3: Performance Analysis");
println!("===============================");
println!("\nโฑ๏ธ Compilation Timing Comparison:");
println!(" โ Standard Model โ Notes");
println!("โโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโ");
let timing = model.timing();
println!(
"Symbolic construction โ {:>13.2}ms โ Efficient sufficient stats",
timing.symbolic_construction_ms
);
println!(
"Symbolic optimization โ {:>13.2}ms โ Basic algebraic rules",
timing.symbolic_optimization_ms
);
println!(
"Code generation โ {:>13.2}ms โ Rust code generation",
timing.code_generation_ms
);
println!(
"Rust compilation โ {:>13.2}ms โ LLVM optimization",
timing.rust_compilation_ms
);
println!("โโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโ");
println!(
"TOTAL โ {:>13.2}ms โ",
timing.total_compilation_ms
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_natural_bayesian_linear_regression() -> Result<()> {
if !RustCompiler::is_available() {
return Ok(());
}
let data = vec![(0.0, 1.0), (1.0, 3.0), (2.0, 5.0)]; let model = BayesianLinearRegression::new(data)?;
let params = vec![1.0, 2.0, 1.0]; let compiled_result = model.log_posterior_compiled(¶ms)?;
let direct_result = model.log_posterior_direct(¶ms)?;
assert!(compiled_result.is_finite());
assert!(direct_result.is_finite());
assert!((compiled_result - direct_result).abs() < 1e-10);
let timing = model.timing();
assert!(timing.total_compilation_ms > 0.0);
Ok(())
}
#[test]
fn test_partial_evaluation() -> Result<()> {
if !RustCompiler::is_available() {
return Ok(());
}
let data = vec![(0.0, 1.0), (1.0, 3.0), (2.0, 5.0)];
let mut model = BayesianLinearRegression::new(data)?;
model.apply_partial_evaluation("positive_variance")?;
assert!(model.partial_context().is_some());
assert_eq!(model.partial_context().unwrap(), "positive_variance");
Ok(())
}
}