use dslcompile::final_tagless::{ASTEval, ASTMathExpr};
use dslcompile::symbolic::symbolic_ad::convenience;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("🎯 DSLCompile: Comprehensive Gradient Computation Demo");
println!("===================================================\n");
println!("1️⃣ Basic Multivariate Gradients");
println!("--------------------------------");
let multivar_func = ASTEval::add(
ASTEval::add(
ASTEval::add(
ASTEval::add(
ASTEval::add(
ASTEval::pow(ASTEval::var(0), ASTEval::constant(2.0)), ASTEval::pow(ASTEval::var(1), ASTEval::constant(2.0)), ),
ASTEval::pow(ASTEval::var(2), ASTEval::constant(2.0)), ),
ASTEval::mul(
ASTEval::constant(2.0),
ASTEval::mul(ASTEval::var(0), ASTEval::var(1)), ),
),
ASTEval::mul(
ASTEval::constant(3.0),
ASTEval::mul(ASTEval::var(0), ASTEval::var(2)), ),
),
ASTEval::mul(ASTEval::var(1), ASTEval::var(2)), );
println!("Function: f(x,y,z) = x² + y² + z² + 2xy + 3xz + yz");
println!("Using index-based variables: x=var(0), y=var(1), z=var(2)");
println!("Expected gradient:");
println!(" ∂f/∂x = 2x + 2y + 3z");
println!(" ∂f/∂y = 2y + 2x + z");
println!(" ∂f/∂z = 2z + 3x + y");
let gradient = convenience::gradient(&multivar_func, &["0", "1", "2"])?;
let test_point = [1.0, 2.0, 3.0];
let f_val = multivar_func.eval_with_vars(&test_point);
println!(
"\nAt point ({}, {}, {}):",
test_point[0], test_point[1], test_point[2]
);
let df_dx_val = gradient["0"].eval_with_vars(&test_point);
let df_dy_val = gradient["1"].eval_with_vars(&test_point);
let df_dz_val = gradient["2"].eval_with_vars(&test_point);
println!(" ∂f/∂x = {df_dx_val:.3}");
println!(" ∂f/∂y = {df_dy_val:.3}");
println!(" ∂f/∂z = {df_dz_val:.3}");
println!(" Expected ∂f/∂x = 15.0, ∂f/∂y = 9.0, ∂f/∂z = 11.0");
println!();
println!("2️⃣ Machine Learning Loss Function Gradients");
println!("--------------------------------------------");
let x_input = 2.0; let y_target = 5.0;
let prediction = ASTEval::add(
ASTEval::mul(ASTEval::var(0), ASTEval::constant(x_input)), ASTEval::var(1), );
let error = ASTEval::sub(ASTEval::constant(y_target), prediction);
let mse_loss = ASTEval::pow(error, ASTEval::constant(2.0));
println!("MSE Loss: L(w,b) = (y - (wx + b))²");
println!("With x = {x_input}, y = {y_target}");
println!("Using index-based variables: w=var(0), b=var(1)");
println!("L(w,b) = ({y_target} - (w*{x_input} + b))²");
println!("Expected gradients:");
println!(" ∂L/∂w = -2x(y - (wx + b)) = -2*{x_input}*({y_target} - (w*{x_input} + b))");
println!(" ∂L/∂b = -2(y - (wx + b)) = -2*({y_target} - (w*{x_input} + b))");
let mse_gradient = convenience::gradient(&mse_loss, &["0", "1"])?;
let wb_vals = [1.0, 0.5];
let loss_val = mse_loss.eval_with_vars(&wb_vals);
let dl_dw = mse_gradient["0"].eval_with_vars(&wb_vals);
let dl_db = mse_gradient["1"].eval_with_vars(&wb_vals);
println!("\nAt w = {}, b = {}:", wb_vals[0], wb_vals[1]);
println!(" Loss = {loss_val:.3}");
println!(" ∂L/∂w = {dl_dw:.3}");
println!(" ∂L/∂b = {dl_db:.3}");
println!(" Expected Loss = 6.25");
println!(" Expected ∂L/∂w = -10.0");
println!(" Expected ∂L/∂b = -5.0");
println!();
println!("3️⃣ Optimization Problem: Rosenbrock Function");
println!("---------------------------------------------");
let a = 1.0;
let b = 100.0;
let term1 = ASTEval::pow(
ASTEval::sub(ASTEval::constant(a), ASTEval::var(0)), ASTEval::constant(2.0),
);
let x_squared = ASTEval::pow(ASTEval::var(0), ASTEval::constant(2.0)); let term2 = ASTEval::mul(
ASTEval::constant(b),
ASTEval::pow(
ASTEval::sub(ASTEval::var(1), x_squared), ASTEval::constant(2.0),
),
);
let rosenbrock = ASTEval::add(term1, term2);
println!("Rosenbrock function: f(x,y) = (1-x)² + 100(y-x²)²");
println!("Using index-based variables: x=var(0), y=var(1)");
println!("This is a classic optimization test function with global minimum at (1,1)");
println!("Expected gradients:");
println!(" ∂f/∂x = -2(1-x) + 100*2(y-x²)*(-2x) = -2(1-x) - 400x(y-x²)");
println!(" ∂f/∂y = 100*2(y-x²) = 200(y-x²)");
let rosenbrock_grad = convenience::gradient(&rosenbrock, &["0", "1"])?;
let test_points = [[0.0, 0.0], [0.5, 0.25], [1.0, 1.0], [1.5, 2.0]];
for point in test_points {
let f_val = rosenbrock.eval_with_vars(&point);
let df_dx = rosenbrock_grad["0"].eval_with_vars(&point);
let df_dy = rosenbrock_grad["1"].eval_with_vars(&point);
println!("\nAt ({:.1}, {:.2}):", point[0], point[1]);
println!(" f = {f_val:.3}");
println!(" ∇f = [{df_dx:.3}, {df_dy:.3}]");
let grad_magnitude = (df_dx * df_dx + df_dy * df_dy).sqrt();
println!(" |∇f| = {grad_magnitude:.3}");
if grad_magnitude < 0.1 {
println!(" → Near critical point! 🎯");
}
}
println!();
println!("4️⃣ Logistic Regression Gradient");
println!("--------------------------------");
let x_data = 1.5;
let y_label = 1.0;
let linear_output = ASTEval::add(
ASTEval::mul(ASTEval::var(0), ASTEval::constant(x_data)), ASTEval::var(1), );
let logistic_loss = ASTEval::pow(
ASTEval::sub(linear_output, ASTEval::constant(y_label)),
ASTEval::constant(2.0),
);
println!("Simplified logistic loss: L(w,b) = (wx + b - y)²");
println!("Using index-based variables: w=var(0), b=var(1)");
println!("With x = {x_data}, y = {y_label}");
let logistic_grad = convenience::gradient(&logistic_loss, &["0", "1"])?;
let wb_test = [0.8, 0.2];
let loss_val = logistic_loss.eval_with_vars(&wb_test);
let dl_dw = logistic_grad["0"].eval_with_vars(&wb_test);
let dl_db = logistic_grad["1"].eval_with_vars(&wb_test);
println!("\nAt w = {}, b = {}:", wb_test[0], wb_test[1]);
println!(" Loss = {loss_val:.3}");
println!(" ∂L/∂w = {dl_dw:.3}");
println!(" ∂L/∂b = {dl_db:.3}");
println!(" Expected Loss ≈ 0.16");
println!(" Expected ∂L/∂w ≈ 1.2");
println!(" Expected ∂L/∂b ≈ 0.8");
println!();
println!("5️⃣ Higher-Dimensional Example");
println!("------------------------------");
let high_dim_func = ASTEval::add(
ASTEval::add(
ASTEval::add(
ASTEval::mul(ASTEval::var(0), ASTEval::var(1)), ASTEval::mul(ASTEval::var(1), ASTEval::var(2)), ),
ASTEval::mul(ASTEval::var(2), ASTEval::var(3)), ),
ASTEval::mul(ASTEval::var(3), ASTEval::var(0)), );
println!("Function: f(x₁,x₂,x₃,x₄) = x₁x₂ + x₂x₃ + x₃x₄ + x₄x₁");
println!("Using index-based variables: x₁=var(0), x₂=var(1), x₃=var(2), x₄=var(3)");
println!("Expected gradient:");
println!(" ∂f/∂x₁ = x₂ + x₄");
println!(" ∂f/∂x₂ = x₁ + x₃");
println!(" ∂f/∂x₃ = x₂ + x₄");
println!(" ∂f/∂x₄ = x₃ + x₁");
let high_dim_grad = convenience::gradient(&high_dim_func, &["0", "1", "2", "3"])?;
let test_4d = [1.0, 2.0, 3.0, 4.0];
let f_val = high_dim_func.eval_with_vars(&test_4d);
println!(
"\nAt point ({}, {}, {}, {}):",
test_4d[0], test_4d[1], test_4d[2], test_4d[3]
);
println!(" f = {f_val:.3}");
for i in 0..4 {
let grad_val = high_dim_grad[&i.to_string()].eval_with_vars(&test_4d);
println!(" ∂f/∂x₊{} = {grad_val:.3}", i + 1);
}
println!(" Expected: [6, 4, 6, 4]");
println!();
println!("=== Demo Complete ===");
println!("✅ Successfully demonstrated index-based gradient computation");
println!("✅ All gradient calculations use modern variable indexing");
println!("✅ No string-based variable lookups required");
Ok(())
}