use mathcompile::final_tagless::{ASTEval, ASTMathExpr, DirectEval};
use mathcompile::symbolic_ad::convenience;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("🎯 MathCompile: Comprehensive Gradient Computation Demo");
println!("===================================================\n");
println!("1️⃣ Basic Multivariate Gradients");
println!("--------------------------------");
let multivar_func = ASTEval::add(
ASTEval::add(
ASTEval::add(
ASTEval::add(
ASTEval::add(
ASTEval::pow(ASTEval::var_by_name("x"), ASTEval::constant(2.0)),
ASTEval::pow(ASTEval::var_by_name("y"), ASTEval::constant(2.0)),
),
ASTEval::pow(ASTEval::var_by_name("z"), ASTEval::constant(2.0)),
),
ASTEval::mul(
ASTEval::constant(2.0),
ASTEval::mul(ASTEval::var_by_name("x"), ASTEval::var_by_name("y")),
),
),
ASTEval::mul(
ASTEval::constant(3.0),
ASTEval::mul(ASTEval::var_by_name("x"), ASTEval::var_by_name("z")),
),
),
ASTEval::mul(ASTEval::var_by_name("y"), ASTEval::var_by_name("z")),
);
println!("Function: f(x,y,z) = x² + y² + z² + 2xy + 3xz + yz");
println!("Expected gradient:");
println!(" ∂f/∂x = 2x + 2y + 3z");
println!(" ∂f/∂y = 2y + 2x + z");
println!(" ∂f/∂z = 2z + 3x + y");
let gradient = convenience::gradient(&multivar_func, &["x", "y", "z"])?;
let x_val = 1.0;
let y_val = 2.0;
let z_val = 3.0;
let f_val = DirectEval::eval_two_vars(&multivar_func, x_val, y_val); println!("\nAt point ({x_val}, {y_val}, {z_val}):");
let df_dx_val = DirectEval::eval_two_vars(&gradient["x"], x_val, y_val);
let df_dy_val = DirectEval::eval_two_vars(&gradient["y"], x_val, y_val);
println!(" ∂f/∂x = {df_dx_val:.3}");
println!(" ∂f/∂y = {df_dy_val:.3}");
println!(" Expected ∂f/∂x ≈ 15.0 (with z=3)");
println!(" Expected ∂f/∂y ≈ 9.0 (with z=3)");
println!();
println!("2️⃣ Machine Learning Loss Function Gradients");
println!("--------------------------------------------");
let x_input = 2.0; let y_target = 5.0;
let prediction = ASTEval::add(
ASTEval::mul(ASTEval::var_by_name("w"), ASTEval::constant(x_input)),
ASTEval::var_by_name("b"),
);
let error = ASTEval::sub(ASTEval::constant(y_target), prediction);
let mse_loss = ASTEval::pow(error, ASTEval::constant(2.0));
println!("MSE Loss: L(w,b) = (y - (wx + b))²");
println!("With x = {x_input}, y = {y_target}");
println!("L(w,b) = ({y_target} - (w*{x_input} + b))²");
println!("Expected gradients:");
println!(" ∂L/∂w = -2x(y - (wx + b)) = -2*{x_input}*({y_target} - (w*{x_input} + b))");
println!(" ∂L/∂b = -2(y - (wx + b)) = -2*({y_target} - (w*{x_input} + b))");
let mse_gradient = convenience::gradient(&mse_loss, &["w", "b"])?;
let w_val = 1.0;
let b_val = 0.5;
let loss_val = DirectEval::eval_two_vars(&mse_loss, w_val, b_val);
let dl_dw = DirectEval::eval_two_vars(&mse_gradient["w"], w_val, b_val);
let dl_db = DirectEval::eval_two_vars(&mse_gradient["b"], w_val, b_val);
println!("\nAt w = {w_val}, b = {b_val}:");
println!(" Loss = {loss_val:.3}");
println!(" ∂L/∂w = {dl_dw:.3}");
println!(" ∂L/∂b = {dl_db:.3}");
println!(" Expected Loss = 6.25");
println!(" Expected ∂L/∂w = -10.0");
println!(" Expected ∂L/∂b = -5.0");
println!();
println!("3️⃣ Optimization Problem: Rosenbrock Function");
println!("---------------------------------------------");
let a = 1.0;
let b = 100.0;
let term1 = ASTEval::pow(
ASTEval::sub(ASTEval::constant(a), ASTEval::var_by_name("x")),
ASTEval::constant(2.0),
);
let x_squared = ASTEval::pow(ASTEval::var_by_name("x"), ASTEval::constant(2.0));
let term2 = ASTEval::mul(
ASTEval::constant(b),
ASTEval::pow(
ASTEval::sub(ASTEval::var_by_name("y"), x_squared),
ASTEval::constant(2.0),
),
);
let rosenbrock = ASTEval::add(term1, term2);
println!("Rosenbrock function: f(x,y) = (1-x)² + 100(y-x²)²");
println!("This is a classic optimization test function with global minimum at (1,1)");
println!("Expected gradients:");
println!(" ∂f/∂x = -2(1-x) + 100*2(y-x²)*(-2x) = -2(1-x) - 400x(y-x²)");
println!(" ∂f/∂y = 100*2(y-x²) = 200(y-x²)");
let rosenbrock_grad = convenience::gradient(&rosenbrock, &["x", "y"])?;
let test_points = [(0.0, 0.0), (0.5, 0.25), (1.0, 1.0), (1.5, 2.0)];
for (x_test, y_test) in test_points {
let f_val = DirectEval::eval_two_vars(&rosenbrock, x_test, y_test);
let df_dx = DirectEval::eval_two_vars(&rosenbrock_grad["x"], x_test, y_test);
let df_dy = DirectEval::eval_two_vars(&rosenbrock_grad["y"], x_test, y_test);
println!("\nAt ({x_test:.1}, {y_test:.2}):");
println!(" f = {f_val:.3}");
println!(" ∇f = [{df_dx:.3}, {df_dy:.3}]");
let grad_magnitude = (df_dx * df_dx + df_dy * df_dy).sqrt();
println!(" |∇f| = {grad_magnitude:.3}");
if grad_magnitude < 0.1 {
println!(" → Near critical point! 🎯");
}
}
println!();
println!("4️⃣ Logistic Regression Gradient");
println!("--------------------------------");
let x_data = 1.5;
let y_label = 1.0;
let linear_output = ASTEval::add(
ASTEval::mul(ASTEval::var_by_name("w"), ASTEval::constant(x_data)),
ASTEval::var_by_name("b"),
);
let logistic_loss = ASTEval::pow(
ASTEval::sub(linear_output, ASTEval::constant(y_label)),
ASTEval::constant(2.0),
);
println!("Simplified logistic loss: L(w,b) = (wx + b - y)²");
println!("With x = {x_data}, y = {y_label}");
println!("Expected gradients:");
println!(" ∂L/∂w = 2x(wx + b - y)");
println!(" ∂L/∂b = 2(wx + b - y)");
let logistic_grad = convenience::gradient(&logistic_loss, &["w", "b"])?;
let w_val = 0.8;
let b_val = 0.2;
let loss_val = DirectEval::eval_two_vars(&logistic_loss, w_val, b_val);
let dl_dw = DirectEval::eval_two_vars(&logistic_grad["w"], w_val, b_val);
let dl_db = DirectEval::eval_two_vars(&logistic_grad["b"], w_val, b_val);
println!("\nAt w = {w_val}, b = {b_val}:");
println!(" Loss = {loss_val:.3}");
println!(" ∂L/∂w = {dl_dw:.3}");
println!(" ∂L/∂b = {dl_db:.3}");
println!(" Expected Loss = 0.16");
println!(" Expected ∂L/∂w = 1.2");
println!(" Expected ∂L/∂b = 0.8");
println!();
println!("5️⃣ Gradient Computation Performance");
println!("------------------------------------");
let dimensions = [2, 3, 5, 8];
for &dim in &dimensions {
let mut poly = ASTEval::constant(0.0);
let mut var_names = Vec::new();
for i in 0..dim {
let var_name = format!("x{i}");
var_names.push(var_name.clone());
poly = ASTEval::add(
poly,
ASTEval::pow(ASTEval::var_by_name(&var_name), ASTEval::constant(2.0)),
);
for j in (i + 1)..dim {
let var_j = format!("x{j}");
poly = ASTEval::add(
poly,
ASTEval::mul(
ASTEval::var_by_name(&var_name),
ASTEval::var_by_name(&var_j),
),
);
}
}
let var_refs: Vec<&str> = var_names.iter().map(std::string::String::as_str).collect();
let start_time = std::time::Instant::now();
let grad_result = convenience::gradient(&poly, &var_refs);
let computation_time = start_time.elapsed();
match grad_result {
Ok(grad) => {
println!(
" {dim}D gradient: {} variables, {} μs",
grad.len(),
computation_time.as_micros()
);
}
Err(e) => {
println!(" {dim}D gradient: Error - {e}");
}
}
}
println!();
println!("6️⃣ Gradient Capabilities Summary");
println!("---------------------------------");
println!("✅ Multivariate function gradients (∇f for f: ℝⁿ → ℝ)");
println!("✅ Machine learning loss function gradients");
println!("✅ Optimization problem gradients (Rosenbrock, etc.)");
println!("✅ Symbolic computation (exact derivatives)");
println!("✅ Arbitrary number of variables");
println!("✅ Integration with optimization pipeline");
println!("✅ Caching for repeated computations");
println!();
println!("🎯 Perfect for:");
println!("• Gradient descent optimization");
println!("• Machine learning backpropagation");
println!("• Scientific computing");
println!("• Numerical optimization algorithms");
println!("• Sensitivity analysis");
Ok(())
}