use mathcompile::final_tagless::{ASTRepr, DirectEval};
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("🧮 Generic Operator Overloading Demo");
println!("=====================================\n");
println!("📊 Working with f64 expressions:");
demo_f64_expressions()?;
println!("\n📊 Working with f32 expressions:");
demo_f32_expressions()?;
println!("\n🔄 Comparing f64 vs f32 precision:");
compare_precision()?;
println!("\n✅ All demonstrations completed successfully!");
Ok(())
}
fn demo_f64_expressions() -> Result<(), Box<dyn std::error::Error>> {
let x = ASTRepr::<f64>::Variable(0);
let y = ASTRepr::<f64>::Variable(1);
let z = ASTRepr::<f64>::Variable(2);
let two = ASTRepr::<f64>::Constant(2.0);
let pi = ASTRepr::<f64>::Constant(std::f64::consts::PI);
println!(" Building expressions with natural syntax:");
let linear = &two * &x + &ASTRepr::<f64>::Constant(3.0) * &y;
println!(" Linear: 2x + 3y");
let quadratic = &x * &x + &y * &y;
println!(" Quadratic: x² + y²");
let complex = pi.sin() + y.cos() - &z * &z;
println!(" Complex: sin(πx) + cos(y) - z²");
println!("\n Evaluating with x=1.0, y=2.0, z=3.0:");
let vars = vec![1.0_f64, 2.0, 3.0];
let linear_result = DirectEval::eval_with_vars(&linear, &vars);
let quadratic_result = DirectEval::eval_with_vars(&quadratic, &vars);
let complex_result = DirectEval::eval_with_vars(&complex, &vars);
println!(" Linear result: {linear_result:.6}");
println!(" Quadratic result: {quadratic_result:.6}");
println!(" Complex result: {complex_result:.6}");
Ok(())
}
fn demo_f32_expressions() -> Result<(), Box<dyn std::error::Error>> {
let x = ASTRepr::<f32>::Variable(0);
let y = ASTRepr::<f32>::Variable(1);
let two = ASTRepr::<f32>::Constant(2.0_f32);
let half = ASTRepr::<f32>::Constant(0.5_f32);
println!(" Building f32 expressions:");
let expr1 = &two * &x + ½
println!(" Expression 1: 2x + 0.5");
let expr2 = (&x + &y) * (&x - &y); println!(" Expression 2: (x + y)(x - y)");
let expr3 = x.exp() + y.ln(); println!(" Expression 3: exp(x) + ln(y)");
println!("\n Evaluating with x=2.0, y=3.0:");
let vars = vec![2.0_f32, 3.0_f32];
let result1 = DirectEval::eval_with_vars(&expr1, &vars);
let result2 = DirectEval::eval_with_vars(&expr2, &vars);
let result3 = DirectEval::eval_with_vars(&expr3, &vars);
println!(" Expression 1 result: {result1:.6}");
println!(" Expression 2 result: {result2:.6}");
println!(" Expression 3 result: {result3:.6}");
Ok(())
}
fn compare_precision() -> Result<(), Box<dyn std::error::Error>> {
let x_f64 = ASTRepr::<f64>::Variable(0);
let x_f32 = ASTRepr::<f32>::Variable(0);
let mut expr_f64 = x_f64.clone();
let mut expr_f32 = x_f32.clone();
for _ in 1..10 {
expr_f64 = &expr_f64 * &x_f64;
expr_f32 = &expr_f32 * &x_f32;
}
println!(" Comparing precision for x^10:");
let test_value_f64 = 1.1_f64;
let test_value_f32 = 1.1_f32;
let result_f64 = DirectEval::eval_with_vars(&expr_f64, &[test_value_f64]);
let result_f32 = DirectEval::eval_with_vars(&expr_f32, &[test_value_f32]);
let reference = test_value_f64.powi(10);
println!(" Input value: 1.1");
println!(" f64 result: {result_f64:.15}");
println!(" f32 result: {:.15}", f64::from(result_f32));
println!(" Reference: {reference:.15}");
println!(" f64 error: {:.2e}", (result_f64 - reference).abs());
println!(
" f32 error: {:.2e}",
(f64::from(result_f32) - reference).abs()
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generic_operators_f64() {
let x = ASTRepr::<f64>::Variable(0);
let y = ASTRepr::<f64>::Variable(1);
let c = ASTRepr::<f64>::Constant(5.0);
let add_expr = &x + &y;
let sub_expr = &x - &y;
let mul_expr = &x * &c;
let div_expr = &x / &c;
let neg_expr = -&x;
assert!(matches!(add_expr, ASTRepr::Add(_, _)));
assert!(matches!(sub_expr, ASTRepr::Sub(_, _)));
assert!(matches!(mul_expr, ASTRepr::Mul(_, _)));
assert!(matches!(div_expr, ASTRepr::Div(_, _)));
assert!(matches!(neg_expr, ASTRepr::Neg(_)));
}
#[test]
fn test_generic_operators_f32() {
let x = ASTRepr::<f32>::Variable(0);
let y = ASTRepr::<f32>::Variable(1);
let c = ASTRepr::<f32>::Constant(5.0_f32);
let add_expr = &x + &y;
let sub_expr = &x - &y;
let mul_expr = &x * &c;
let div_expr = &x / &c;
let neg_expr = -&x;
assert!(matches!(add_expr, ASTRepr::Add(_, _)));
assert!(matches!(sub_expr, ASTRepr::Sub(_, _)));
assert!(matches!(mul_expr, ASTRepr::Mul(_, _)));
assert!(matches!(div_expr, ASTRepr::Div(_, _)));
assert!(matches!(neg_expr, ASTRepr::Neg(_)));
}
#[test]
fn test_transcendental_functions() {
let x_f64 = ASTRepr::<f64>::Variable(0);
let x_f32 = ASTRepr::<f32>::Variable(0);
let sin_f64 = x_f64.sin();
let cos_f32 = x_f32.cos();
let exp_f64 = x_f64.exp();
let ln_f32 = x_f32.ln();
let sqrt_f64 = x_f64.sqrt();
assert!(matches!(sin_f64, ASTRepr::Sin(_)));
assert!(matches!(cos_f32, ASTRepr::Cos(_)));
assert!(matches!(exp_f64, ASTRepr::Exp(_)));
assert!(matches!(ln_f32, ASTRepr::Ln(_)));
assert!(matches!(sqrt_f64, ASTRepr::Sqrt(_)));
}
#[test]
fn test_mixed_operations() {
let x = ASTRepr::<f64>::Variable(0);
let y = ASTRepr::<f64>::Variable(1);
let expr = (&x + &y).sin() * (&x - &y).cos() + (&x * &y).sqrt();
let result = DirectEval::eval_with_vars(&expr, &[2.0, 3.0]);
assert!(result.is_finite());
}
}