use mathcompile::prelude::*;
use mathcompile::{DirectEval, PrettyPrint, StatisticalExpr};
fn quadratic_traditional<E: MathExpr>(x: E::Repr<f64>) -> E::Repr<f64>
where
E::Repr<f64>: Clone,
{
let a = E::constant(2.0);
let b = E::constant(3.0);
let c = E::constant(1.0);
E::add(
E::add(E::mul(a, E::pow(x.clone(), E::constant(2.0))), E::mul(b, x)),
c,
)
}
fn logistic_regression<E: StatisticalExpr>(x: E::Repr<f64>, theta: E::Repr<f64>) -> E::Repr<f64> {
E::logistic(E::mul(theta, x))
}
fn main() -> Result<()> {
println!("=== MathCompile Basic Usage Example ===\n");
println!("1. Traditional Final Tagless Approach:");
let x_val = 2.0;
let result_traditional = quadratic_traditional::<DirectEval>(DirectEval::var("x", x_val));
println!(" quadratic({x_val}) = {result_traditional}");
println!(" Expected: 2(4) + 3(2) + 1 = 15");
let pretty_traditional = quadratic_traditional::<PrettyPrint>(PrettyPrint::var("x"));
println!(" Expression: {pretty_traditional}\n");
println!("2. Modern MathBuilder Approach:");
let mut math = MathBuilder::new();
let x = math.var("x");
let quadratic_modern =
math.constant(2.0) * &x * &x + math.constant(3.0) * &x + math.constant(1.0);
let result_modern = math.eval(&quadratic_modern, &[("x", x_val)]);
println!(" quadratic({x_val}) = {result_modern}");
println!(" Expected: 2(4) + 3(2) + 1 = 15");
assert_eq!(result_traditional, result_modern);
println!(" ✓ Both approaches produce identical results!\n");
println!("3. Statistical Functions:");
let theta_val = 1.5;
let logistic_result = logistic_regression::<DirectEval>(
DirectEval::var("x", x_val),
DirectEval::var("theta", theta_val),
);
println!(" logistic_regression({x_val}, {theta_val}) = {logistic_result}");
let logistic_pretty =
logistic_regression::<PrettyPrint>(PrettyPrint::var("x"), PrettyPrint::var("theta"));
println!(" Expression: {logistic_pretty}");
let mut math = MathBuilder::new();
let x = math.var("x");
let theta = math.var("theta");
let logistic_modern = math.logistic(&(theta * &x));
let logistic_result_modern = math.eval(&logistic_modern, &[("x", x_val), ("theta", theta_val)]);
println!(" MathBuilder logistic({x_val}, {theta_val}) = {logistic_result_modern}");
assert!((logistic_result - logistic_result_modern).abs() < 1e-10);
println!(" ✓ Traditional and MathBuilder approaches match!\n");
println!("4. Complex Expressions with Natural Syntax:");
let mut math = MathBuilder::new();
let x = math.var("x");
let pi = math.math_constant("pi")?;
let gaussian =
math.exp(&(-(&x * &x) / math.constant(2.0))) / math.sqrt(&(math.constant(2.0) * &pi));
let gaussian_result = math.eval(&gaussian, &[("x", 0.0)]);
println!(" gaussian(0.0) = {gaussian_result:.6}");
println!(" Expected: ~0.398942 (1/sqrt(2π))");
assert!((gaussian_result - 0.398942).abs() < 0.001);
println!(" ✓ Gaussian calculation correct!\n");
println!("5. High-Level Mathematical Functions:");
let mut math = MathBuilder::new();
let x = math.var("x");
let poly = math.poly(&[1.0, 2.0, 3.0], &x); let poly_result = math.eval(&poly, &[("x", 2.0)]);
println!(" polynomial 1 + 2x + 3x² at x=2: {poly_result}");
println!(" Expected: 1 + 2(2) + 3(4) = 1 + 4 + 12 = 17");
assert_eq!(poly_result, 17.0);
let quad = math.quadratic(3.0, 2.0, 1.0, &x); let quad_result = math.eval(&quad, &[("x", 2.0)]);
println!(" quadratic 3x² + 2x + 1 at x=2: {quad_result}");
println!(" Expected: 3(4) + 2(2) + 1 = 12 + 4 + 1 = 17");
assert_eq!(poly_result, quad_result);
println!(" ✓ Polynomial and quadratic functions match!");
let gaussian_builtin = math.gaussian(0.0, 1.0, &x); let gaussian_builtin_result = math.eval(&gaussian_builtin, &[("x", 0.0)]);
println!(" Built-in gaussian(0.0) = {gaussian_builtin_result:.6}");
assert!((gaussian_result - gaussian_builtin_result).abs() < 1e-6);
println!(" ✓ Manual and built-in Gaussian match!\n");
println!("6. Expression Validation and Optimization:");
let mut math = MathBuilder::with_optimization()?;
let x = math.var("x");
let unoptimized = &x * math.constant(0.0) + &x * math.constant(1.0); println!(" Original: x*0 + x*1");
match math.validate(&unoptimized) {
Ok(()) => println!(" ✓ Expression is valid"),
Err(e) => println!(" ✗ Validation error: {e}"),
}
let optimized = math.optimize(&unoptimized)?;
let original_result = math.eval(&unoptimized, &[("x", 5.0)]);
let optimized_result = math.eval(&optimized, &[("x", 5.0)]);
println!(" Original result at x=5: {original_result}");
println!(" Optimized result at x=5: {optimized_result}");
assert_eq!(original_result, optimized_result);
println!(" ✓ Optimization preserves correctness!");
println!("\n=== Key Benefits of MathBuilder API ===");
println!("✓ Natural mathematical syntax with operator overloading");
println!("✓ Automatic variable management");
println!("✓ Built-in mathematical functions and constants");
println!("✓ Expression validation and optimization");
println!("✓ Named variable evaluation");
println!("✓ Type safety and helpful error messages");
println!("\n=== Example Complete ===");
Ok(())
}