mathcompile 0.1.2

High-performance symbolic mathematics with final tagless design, egglog optimization, and Rust hot-loading compilation
Documentation
//! Gradient Computation Demo
//!
//! This example demonstrates comprehensive gradient computation capabilities
//! of the `MathCompile` symbolic AD system, including:
//! - Multivariate function gradients
//! - Machine learning loss function gradients
//! - Optimization problem gradients
//! - Higher-dimensional gradient examples

use mathcompile::final_tagless::{ASTEval, ASTMathExpr, DirectEval};
use mathcompile::symbolic_ad::convenience;

fn main() -> Result<(), Box<dyn std::error::Error>> {
    println!("🎯 MathCompile: Comprehensive Gradient Computation Demo");
    println!("===================================================\n");

    // 1. Basic Multivariate Gradients
    println!("1️⃣  Basic Multivariate Gradients");
    println!("--------------------------------");

    // f(x,y,z) = x² + y² + z² + 2xy + 3xz + yz
    let multivar_func = ASTEval::add(
        ASTEval::add(
            ASTEval::add(
                ASTEval::add(
                    ASTEval::add(
                        ASTEval::pow(ASTEval::var_by_name("x"), ASTEval::constant(2.0)),
                        ASTEval::pow(ASTEval::var_by_name("y"), ASTEval::constant(2.0)),
                    ),
                    ASTEval::pow(ASTEval::var_by_name("z"), ASTEval::constant(2.0)),
                ),
                ASTEval::mul(
                    ASTEval::constant(2.0),
                    ASTEval::mul(ASTEval::var_by_name("x"), ASTEval::var_by_name("y")),
                ),
            ),
            ASTEval::mul(
                ASTEval::constant(3.0),
                ASTEval::mul(ASTEval::var_by_name("x"), ASTEval::var_by_name("z")),
            ),
        ),
        ASTEval::mul(ASTEval::var_by_name("y"), ASTEval::var_by_name("z")),
    );

    println!("Function: f(x,y,z) = x² + y² + z² + 2xy + 3xz + yz");
    println!("Expected gradient:");
    println!("  ∂f/∂x = 2x + 2y + 3z");
    println!("  ∂f/∂y = 2y + 2x + z");
    println!("  ∂f/∂z = 2z + 3x + y");

    let gradient = convenience::gradient(&multivar_func, &["x", "y", "z"])?;

    // Test at point (1, 2, 3)
    let x_val = 1.0;
    let y_val = 2.0;
    let z_val = 3.0;

    let f_val = DirectEval::eval_two_vars(&multivar_func, x_val, y_val); // Note: eval_two_vars only handles x,y
    println!("\nAt point ({x_val}, {y_val}, {z_val}):");

    // For now, we'll evaluate at (x,y) = (1,2) and treat z as a parameter
    // This is a limitation of the current DirectEval::eval_two_vars function
    let df_dx_val = DirectEval::eval_two_vars(&gradient["x"], x_val, y_val);
    let df_dy_val = DirectEval::eval_two_vars(&gradient["y"], x_val, y_val);

    println!("  ∂f/∂x = {df_dx_val:.3}");
    println!("  ∂f/∂y = {df_dy_val:.3}");

    // Expected: ∂f/∂x = 2(1) + 2(2) + 3(3) = 2 + 4 + 9 = 15
    // Expected: ∂f/∂y = 2(2) + 2(1) + 3 = 4 + 2 + 3 = 9
    println!("  Expected ∂f/∂x ≈ 15.0 (with z=3)");
    println!("  Expected ∂f/∂y ≈ 9.0 (with z=3)");
    println!();

    // 2. Machine Learning Loss Function Gradients
    println!("2️⃣  Machine Learning Loss Function Gradients");
    println!("--------------------------------------------");

    // Mean Squared Error: L(w,b) = (y - (wx + b))²
    // where y is target, x is input, w is weight, b is bias
    let x_input = 2.0; // Input value
    let y_target = 5.0; // Target value

    // Create the loss function: L(w,b) = (5 - (w*2 + b))²
    let prediction = ASTEval::add(
        ASTEval::mul(ASTEval::var_by_name("w"), ASTEval::constant(x_input)),
        ASTEval::var_by_name("b"),
    );
    let error = ASTEval::sub(ASTEval::constant(y_target), prediction);
    let mse_loss = ASTEval::pow(error, ASTEval::constant(2.0));

    println!("MSE Loss: L(w,b) = (y - (wx + b))²");
    println!("With x = {x_input}, y = {y_target}");
    println!("L(w,b) = ({y_target} - (w*{x_input} + b))²");
    println!("Expected gradients:");
    println!("  ∂L/∂w = -2x(y - (wx + b)) = -2*{x_input}*({y_target} - (w*{x_input} + b))");
    println!("  ∂L/∂b = -2(y - (wx + b)) = -2*({y_target} - (w*{x_input} + b))");

    let mse_gradient = convenience::gradient(&mse_loss, &["w", "b"])?;

    // Test at w=1.0, b=0.5
    let w_val = 1.0;
    let b_val = 0.5;

    let loss_val = DirectEval::eval_two_vars(&mse_loss, w_val, b_val);
    let dl_dw = DirectEval::eval_two_vars(&mse_gradient["w"], w_val, b_val);
    let dl_db = DirectEval::eval_two_vars(&mse_gradient["b"], w_val, b_val);

    println!("\nAt w = {w_val}, b = {b_val}:");
    println!("  Loss = {loss_val:.3}");
    println!("  ∂L/∂w = {dl_dw:.3}");
    println!("  ∂L/∂b = {dl_db:.3}");

    // Manual calculation: prediction = 1.0*2.0 + 0.5 = 2.5, error = 5.0 - 2.5 = 2.5
    // Loss = 2.5² = 6.25
    // ∂L/∂w = -2*2.0*2.5 = -10.0
    // ∂L/∂b = -2*2.5 = -5.0
    println!("  Expected Loss = 6.25");
    println!("  Expected ∂L/∂w = -10.0");
    println!("  Expected ∂L/∂b = -5.0");
    println!();

    // 3. Optimization Problem: Rosenbrock Function
    println!("3️⃣  Optimization Problem: Rosenbrock Function");
    println!("---------------------------------------------");

    // Rosenbrock function: f(x,y) = (a-x)² + b(y-x²)²
    // Standard form: f(x,y) = (1-x)² + 100(y-x²)²
    let a = 1.0;
    let b = 100.0;

    let term1 = ASTEval::pow(
        ASTEval::sub(ASTEval::constant(a), ASTEval::var_by_name("x")),
        ASTEval::constant(2.0),
    );
    let x_squared = ASTEval::pow(ASTEval::var_by_name("x"), ASTEval::constant(2.0));
    let term2 = ASTEval::mul(
        ASTEval::constant(b),
        ASTEval::pow(
            ASTEval::sub(ASTEval::var_by_name("y"), x_squared),
            ASTEval::constant(2.0),
        ),
    );
    let rosenbrock = ASTEval::add(term1, term2);

    println!("Rosenbrock function: f(x,y) = (1-x)² + 100(y-x²)²");
    println!("This is a classic optimization test function with global minimum at (1,1)");
    println!("Expected gradients:");
    println!("  ∂f/∂x = -2(1-x) + 100*2(y-x²)*(-2x) = -2(1-x) - 400x(y-x²)");
    println!("  ∂f/∂y = 100*2(y-x²) = 200(y-x²)");

    let rosenbrock_grad = convenience::gradient(&rosenbrock, &["x", "y"])?;

    // Test at several points
    let test_points = [(0.0, 0.0), (0.5, 0.25), (1.0, 1.0), (1.5, 2.0)];

    for (x_test, y_test) in test_points {
        let f_val = DirectEval::eval_two_vars(&rosenbrock, x_test, y_test);
        let df_dx = DirectEval::eval_two_vars(&rosenbrock_grad["x"], x_test, y_test);
        let df_dy = DirectEval::eval_two_vars(&rosenbrock_grad["y"], x_test, y_test);

        println!("\nAt ({x_test:.1}, {y_test:.2}):");
        println!("  f = {f_val:.3}");
        println!("  ∇f = [{df_dx:.3}, {df_dy:.3}]");

        // Check if we're at the minimum (gradient should be near zero)
        let grad_magnitude = (df_dx * df_dx + df_dy * df_dy).sqrt();
        println!("  |∇f| = {grad_magnitude:.3}");

        if grad_magnitude < 0.1 {
            println!("  → Near critical point! 🎯");
        }
    }
    println!();

    // 4. Logistic Regression Gradient
    println!("4️⃣  Logistic Regression Gradient");
    println!("--------------------------------");

    // Logistic loss: L(w,b) = -y*log(σ(wx+b)) - (1-y)*log(1-σ(wx+b))
    // where σ(z) = 1/(1+exp(-z)) is the sigmoid function
    // For simplicity, we'll use a linear approximation or focus on the linear part

    // Simplified version: L(w,b) = (σ(wx+b) - y)² where σ(z) ≈ z for small z
    let x_data = 1.5;
    let y_label = 1.0;

    let linear_output = ASTEval::add(
        ASTEval::mul(ASTEval::var_by_name("w"), ASTEval::constant(x_data)),
        ASTEval::var_by_name("b"),
    );

    // For demonstration, use a quadratic loss: (wx + b - y)²
    let logistic_loss = ASTEval::pow(
        ASTEval::sub(linear_output, ASTEval::constant(y_label)),
        ASTEval::constant(2.0),
    );

    println!("Simplified logistic loss: L(w,b) = (wx + b - y)²");
    println!("With x = {x_data}, y = {y_label}");
    println!("Expected gradients:");
    println!("  ∂L/∂w = 2x(wx + b - y)");
    println!("  ∂L/∂b = 2(wx + b - y)");

    let logistic_grad = convenience::gradient(&logistic_loss, &["w", "b"])?;

    let w_val = 0.8;
    let b_val = 0.2;

    let loss_val = DirectEval::eval_two_vars(&logistic_loss, w_val, b_val);
    let dl_dw = DirectEval::eval_two_vars(&logistic_grad["w"], w_val, b_val);
    let dl_db = DirectEval::eval_two_vars(&logistic_grad["b"], w_val, b_val);

    println!("\nAt w = {w_val}, b = {b_val}:");
    println!("  Loss = {loss_val:.3}");
    println!("  ∂L/∂w = {dl_dw:.3}");
    println!("  ∂L/∂b = {dl_db:.3}");

    // Manual: prediction = 0.8*1.5 + 0.2 = 1.4, error = 1.4 - 1.0 = 0.4
    // Loss = 0.4² = 0.16
    // ∂L/∂w = 2*1.5*0.4 = 1.2
    // ∂L/∂b = 2*0.4 = 0.8
    println!("  Expected Loss = 0.16");
    println!("  Expected ∂L/∂w = 1.2");
    println!("  Expected ∂L/∂b = 0.8");
    println!();

    // 5. Performance Analysis
    println!("5️⃣  Gradient Computation Performance");
    println!("------------------------------------");

    // Test gradient computation timing for different numbers of variables
    let dimensions = [2, 3, 5, 8];

    for &dim in &dimensions {
        // Create a polynomial with `dim` variables
        let mut poly = ASTEval::constant(0.0);
        let mut var_names = Vec::new();

        for i in 0..dim {
            let var_name = format!("x{i}");
            var_names.push(var_name.clone());

            // Add x_i² term
            poly = ASTEval::add(
                poly,
                ASTEval::pow(ASTEval::var_by_name(&var_name), ASTEval::constant(2.0)),
            );

            // Add cross terms x_i * x_j for j > i
            for j in (i + 1)..dim {
                let var_j = format!("x{j}");
                poly = ASTEval::add(
                    poly,
                    ASTEval::mul(
                        ASTEval::var_by_name(&var_name),
                        ASTEval::var_by_name(&var_j),
                    ),
                );
            }
        }

        let var_refs: Vec<&str> = var_names.iter().map(std::string::String::as_str).collect();

        let start_time = std::time::Instant::now();
        let grad_result = convenience::gradient(&poly, &var_refs);
        let computation_time = start_time.elapsed();

        match grad_result {
            Ok(grad) => {
                println!(
                    "  {dim}D gradient: {} variables, {} μs",
                    grad.len(),
                    computation_time.as_micros()
                );
            }
            Err(e) => {
                println!("  {dim}D gradient: Error - {e}");
            }
        }
    }
    println!();

    // 6. Summary
    println!("6️⃣  Gradient Capabilities Summary");
    println!("---------------------------------");
    println!("✅ Multivariate function gradients (∇f for f: ℝⁿ → ℝ)");
    println!("✅ Machine learning loss function gradients");
    println!("✅ Optimization problem gradients (Rosenbrock, etc.)");
    println!("✅ Symbolic computation (exact derivatives)");
    println!("✅ Arbitrary number of variables");
    println!("✅ Integration with optimization pipeline");
    println!("✅ Caching for repeated computations");
    println!();

    println!("🎯 Perfect for:");
    println!("• Gradient descent optimization");
    println!("• Machine learning backpropagation");
    println!("• Scientific computing");
    println!("• Numerical optimization algorithms");
    println!("• Sensitivity analysis");

    Ok(())
}