mathcompile 0.1.2

High-performance symbolic mathematics with final tagless design, egglog optimization, and Rust hot-loading compilation
Documentation
//! Basic usage example for `MathCompile`
//!
//! This example demonstrates both the traditional final tagless approach and the new
//! ergonomic `MathBuilder` API:
//! - `DirectEval`: Immediate evaluation
//! - `PrettyPrint`: String representation
//! - `MathBuilder`: Ergonomic expression building with operator overloading

use mathcompile::prelude::*;
use mathcompile::{DirectEval, PrettyPrint, StatisticalExpr};

/// Define a quadratic function using traditional final tagless syntax: 2x² + 3x + 1
fn quadratic_traditional<E: MathExpr>(x: E::Repr<f64>) -> E::Repr<f64>
where
    E::Repr<f64>: Clone,
{
    let a = E::constant(2.0);
    let b = E::constant(3.0);
    let c = E::constant(1.0);

    E::add(
        E::add(E::mul(a, E::pow(x.clone(), E::constant(2.0))), E::mul(b, x)),
        c,
    )
}

/// Define a logistic function using statistical extensions
fn logistic_regression<E: StatisticalExpr>(x: E::Repr<f64>, theta: E::Repr<f64>) -> E::Repr<f64> {
    E::logistic(E::mul(theta, x))
}

fn main() -> Result<()> {
    println!("=== MathCompile Basic Usage Example ===\n");

    // 1. Traditional Final Tagless Approach
    println!("1. Traditional Final Tagless Approach:");
    let x_val = 2.0;
    let result_traditional = quadratic_traditional::<DirectEval>(DirectEval::var("x", x_val));
    println!("   quadratic({x_val}) = {result_traditional}");
    println!("   Expected: 2(4) + 3(2) + 1 = 15");

    let pretty_traditional = quadratic_traditional::<PrettyPrint>(PrettyPrint::var("x"));
    println!("   Expression: {pretty_traditional}\n");

    // 2. Modern MathBuilder Approach with Operator Overloading
    println!("2. Modern MathBuilder Approach:");
    let mut math = MathBuilder::new();
    let x = math.var("x");

    // Natural mathematical syntax using operator overloading!
    let quadratic_modern =
        math.constant(2.0) * &x * &x + math.constant(3.0) * &x + math.constant(1.0);

    let result_modern = math.eval(&quadratic_modern, &[("x", x_val)]);
    println!("   quadratic({x_val}) = {result_modern}");
    println!("   Expected: 2(4) + 3(2) + 1 = 15");

    // Verify both approaches give the same result
    assert_eq!(result_traditional, result_modern);
    println!("   ✓ Both approaches produce identical results!\n");

    // 3. Statistical Functions
    println!("3. Statistical Functions:");
    let theta_val = 1.5;
    let logistic_result = logistic_regression::<DirectEval>(
        DirectEval::var("x", x_val),
        DirectEval::var("theta", theta_val),
    );
    println!("   logistic_regression({x_val}, {theta_val}) = {logistic_result}");

    let logistic_pretty =
        logistic_regression::<PrettyPrint>(PrettyPrint::var("x"), PrettyPrint::var("theta"));
    println!("   Expression: {logistic_pretty}");

    // Using MathBuilder for logistic regression
    let mut math = MathBuilder::new();
    let x = math.var("x");
    let theta = math.var("theta");
    let logistic_modern = math.logistic(&(theta * &x));
    let logistic_result_modern = math.eval(&logistic_modern, &[("x", x_val), ("theta", theta_val)]);
    println!("   MathBuilder logistic({x_val}, {theta_val}) = {logistic_result_modern}");
    assert!((logistic_result - logistic_result_modern).abs() < 1e-10);
    println!("   ✓ Traditional and MathBuilder approaches match!\n");

    // 4. Complex Mathematical Expressions with Natural Syntax
    println!("4. Complex Expressions with Natural Syntax:");

    // Gaussian function: exp(-x²/2) / sqrt(2π) using MathBuilder
    let mut math = MathBuilder::new();
    let x = math.var("x");
    let pi = math.math_constant("pi")?;

    let gaussian =
        math.exp(&(-(&x * &x) / math.constant(2.0))) / math.sqrt(&(math.constant(2.0) * &pi));

    let gaussian_result = math.eval(&gaussian, &[("x", 0.0)]);
    println!("   gaussian(0.0) = {gaussian_result:.6}");
    println!("   Expected: ~0.398942 (1/sqrt(2π))");
    assert!((gaussian_result - 0.398942).abs() < 0.001);
    println!("   ✓ Gaussian calculation correct!\n");

    // 5. High-Level Mathematical Functions
    println!("5. High-Level Mathematical Functions:");

    let mut math = MathBuilder::new();
    let x = math.var("x");

    // Polynomial using convenience function
    let poly = math.poly(&[1.0, 2.0, 3.0], &x); // 1 + 2x + 3x² (coefficients in ascending order)
    let poly_result = math.eval(&poly, &[("x", 2.0)]);
    println!("   polynomial 1 + 2x + 3x² at x=2: {poly_result}");
    println!("   Expected: 1 + 2(2) + 3(4) = 1 + 4 + 12 = 17");
    assert_eq!(poly_result, 17.0);

    // Quadratic using convenience function
    let quad = math.quadratic(3.0, 2.0, 1.0, &x); // 3x² + 2x + 1
    let quad_result = math.eval(&quad, &[("x", 2.0)]);
    println!("   quadratic 3x² + 2x + 1 at x=2: {quad_result}");
    println!("   Expected: 3(4) + 2(2) + 1 = 12 + 4 + 1 = 17");
    assert_eq!(poly_result, quad_result);
    println!("   ✓ Polynomial and quadratic functions match!");

    // Gaussian using convenience function
    let gaussian_builtin = math.gaussian(0.0, 1.0, &x); // Standard normal
    let gaussian_builtin_result = math.eval(&gaussian_builtin, &[("x", 0.0)]);
    println!("   Built-in gaussian(0.0) = {gaussian_builtin_result:.6}");
    assert!((gaussian_result - gaussian_builtin_result).abs() < 1e-6);
    println!("   ✓ Manual and built-in Gaussian match!\n");

    // 6. Expression Validation and Optimization
    println!("6. Expression Validation and Optimization:");

    let mut math = MathBuilder::with_optimization()?;
    let x = math.var("x");

    // Create an expression that can be optimized
    let unoptimized = &x * math.constant(0.0) + &x * math.constant(1.0); // x*0 + x*1 = x
    println!("   Original: x*0 + x*1");

    // Validate the expression
    match math.validate(&unoptimized) {
        Ok(()) => println!("   ✓ Expression is valid"),
        Err(e) => println!("   ✗ Validation error: {e}"),
    }

    // Optimize the expression
    let optimized = math.optimize(&unoptimized)?;

    // Test that both give the same result
    let original_result = math.eval(&unoptimized, &[("x", 5.0)]);
    let optimized_result = math.eval(&optimized, &[("x", 5.0)]);

    println!("   Original result at x=5: {original_result}");
    println!("   Optimized result at x=5: {optimized_result}");
    assert_eq!(original_result, optimized_result);
    println!("   ✓ Optimization preserves correctness!");

    println!("\n=== Key Benefits of MathBuilder API ===");
    println!("✓ Natural mathematical syntax with operator overloading");
    println!("✓ Automatic variable management");
    println!("✓ Built-in mathematical functions and constants");
    println!("✓ Expression validation and optimization");
    println!("✓ Named variable evaluation");
    println!("✓ Type safety and helpful error messages");

    println!("\n=== Example Complete ===");
    Ok(())
}