mathexpr 0.1.1

A fast, safe mathematical expression parser and evaluator with bytecode compilation
Documentation
//! Tests for mathexpr.

use crate::{eval, Expression};

fn eval_expr(input: &str) -> f64 {
    Expression::parse(input)
        .unwrap()
        .compile_no_vars()
        .unwrap()
        .eval(&[])
        .unwrap()
}

fn eval_with_vars(input: &str, var_names: &[&str], vars: &[f64]) -> f64 {
    Expression::parse(input)
        .unwrap()
        .compile(var_names)
        .unwrap()
        .eval(vars)
        .unwrap()
}

// ===== Basic parsing tests =====

#[test]
fn test_parse_number() {
    assert_eq!(eval_expr("42"), 42.0);
    assert_eq!(eval_expr("3.14"), 3.14);
}

#[test]
fn test_parse_negative_number() {
    assert_eq!(eval_expr("-42"), -42.0);
}

#[test]
fn test_parse_variable() {
    assert_eq!(eval_with_vars("foo", &["foo"], &[10.0]), 10.0);
}

#[test]
fn test_parse_current_value() {
    let expr = Expression::parse("_ * 2")
        .unwrap()
        .compile_no_vars()
        .unwrap();
    assert_eq!(expr.eval_with_current(5.0, &[]).unwrap(), 10.0);
}

// ===== Operator precedence tests =====

#[test]
fn test_operator_precedence() {
    // 2 + 3 * 4 should be 2 + (3 * 4) = 14
    assert_eq!(eval_expr("2 + 3 * 4"), 14.0);
}

#[test]
fn test_parentheses() {
    // (2 + 3) * 4 = 20
    assert_eq!(eval_expr("(2 + 3) * 4"), 20.0);
}

#[test]
fn test_power_precedence() {
    // 2 ^ 3 ^ 2 = 2 ^ 9 = 512 (right-associative)
    assert_eq!(eval_expr("2 ^ 3 ^ 2"), 512.0);
    // 2 * 3 ^ 2 = 2 * 9 = 18 (power before multiply)
    assert_eq!(eval_expr("2 * 3 ^ 2"), 18.0);
}

#[test]
fn test_modulo() {
    assert_eq!(eval_expr("10 % 3"), 1.0);
    assert_eq!(eval_expr("17 % 5"), 2.0);
}

// ===== Variable tests =====

#[test]
fn test_eval_with_variables() {
    assert_eq!(eval_with_vars("a + b", &["a", "b"], &[10.0, 20.0]), 30.0);
}

// ===== Error tests =====

#[test]
fn test_division_by_zero() {
    let result = Expression::parse("1 / 0")
        .unwrap()
        .compile_no_vars()
        .unwrap()
        .eval(&[]);
    assert!(result.is_err());
}

#[test]
fn test_unknown_variable() {
    let result = Expression::parse("unknown").unwrap().compile(&[]);
    assert!(result.is_err());
}

// ===== Tier 1: Core function tests =====

#[test]
fn test_abs() {
    assert_eq!(eval_expr("abs(-5)"), 5.0);
    assert_eq!(eval_expr("abs(5)"), 5.0);
}

#[test]
fn test_sqrt() {
    assert_eq!(eval_expr("sqrt(16)"), 4.0);
}

#[test]
fn test_log() {
    assert!((eval_expr("log(2.718281828)") - 1.0).abs() < 0.0001);
    assert_eq!(eval_expr("ln(1)"), 0.0);
}

#[test]
fn test_log10() {
    assert_eq!(eval_expr("log10(100)"), 2.0);
}

#[test]
fn test_exp() {
    assert_eq!(eval_expr("exp(0)"), 1.0);
}

#[test]
fn test_min_max() {
    assert_eq!(eval_expr("min(3, 7)"), 3.0);
    assert_eq!(eval_expr("max(3, 7)"), 7.0);
}

#[test]
fn test_pow_function() {
    assert_eq!(eval_expr("pow(2, 3)"), 8.0);
}

#[test]
fn test_mod_function() {
    assert_eq!(eval_expr("mod(10, 3)"), 1.0);
}

#[test]
fn test_nested_functions() {
    assert_eq!(eval_expr("sqrt(abs(-16))"), 4.0);
}

// ===== Tier 2: Trigonometric function tests =====

#[test]
fn test_sin_cos_tan() {
    assert!(eval_expr("sin(0)").abs() < 1e-10);
    assert!((eval_expr("cos(0)") - 1.0).abs() < 1e-10);
    assert!(eval_expr("tan(0)").abs() < 1e-10);
}

#[test]
fn test_asin_acos_atan() {
    assert!(eval_expr("asin(0)").abs() < 1e-10);
    assert!(eval_expr("acos(1)").abs() < 1e-10);
    assert!(eval_expr("atan(0)").abs() < 1e-10);
}

#[test]
fn test_asin_acos_domain_nan() {
    assert!(eval_expr("asin(2)").is_nan());
    assert!(eval_expr("acos(1.5)").is_nan());
}

#[test]
fn test_sinh_cosh_tanh() {
    assert!(eval_expr("sinh(0)").abs() < 1e-10);
    assert!((eval_expr("cosh(0)") - 1.0).abs() < 1e-10);
    assert!(eval_expr("tanh(0)").abs() < 1e-10);
}

// ===== Tier 3: Rounding & utility function tests =====

#[test]
fn test_floor_ceil_round_trunc() {
    assert_eq!(eval_expr("floor(3.7)"), 3.0);
    assert_eq!(eval_expr("ceil(3.2)"), 4.0);
    assert_eq!(eval_expr("round(3.5)"), 4.0);
    assert_eq!(eval_expr("trunc(3.7)"), 3.0);
}

#[test]
fn test_signum() {
    assert_eq!(eval_expr("signum(5)"), 1.0);
    assert_eq!(eval_expr("signum(-5)"), -1.0);
}

#[test]
fn test_cbrt() {
    assert_eq!(eval_expr("cbrt(8)"), 2.0);
    assert_eq!(eval_expr("cbrt(-8)"), -2.0);
}

#[test]
fn test_log2() {
    assert_eq!(eval_expr("log2(8)"), 3.0);
    assert!(eval_expr("log2(0)").is_nan());
}

#[test]
fn test_clamp() {
    assert_eq!(eval_expr("clamp(5, 0, 10)"), 5.0);
    assert_eq!(eval_expr("clamp(-5, 0, 10)"), 0.0);
    assert_eq!(eval_expr("clamp(15, 0, 10)"), 10.0);
}

// ===== Constants tests =====

#[test]
fn test_pi_constant() {
    use core::f64::consts::PI;
    assert!((eval_expr("pi()") - PI).abs() < 1e-15);
    assert!((eval_expr("pi") - PI).abs() < 1e-15);
}

#[test]
fn test_e_constant() {
    use core::f64::consts::E;
    assert!((eval_expr("e()") - E).abs() < 1e-15);
    assert!((eval_expr("e") - E).abs() < 1e-15);
}

#[test]
fn test_domain_errors_return_nan() {
    assert!(eval_expr("sqrt(-1)").is_nan());
    assert!(eval_expr("log(0)").is_nan());
    assert!(eval_expr("log(-1)").is_nan());
}

// ===== Builder API tests =====

#[test]
fn test_builder_api() {
    let result = Expression::parse("sqrt(x^2 + y^2)")
        .unwrap()
        .compile(&["x", "y"])
        .unwrap()
        .eval(&[3.0, 4.0])
        .unwrap();
    assert_eq!(result, 5.0);
}

#[test]
fn test_eval_convenience() {
    let result = eval("a + b", &["a", "b"], &[1.0, 2.0]).unwrap();
    assert_eq!(result, 3.0);
}

#[test]
fn test_complex_expression() {
    let result = eval_with_vars(
        "(a + b) * (c - d) / 2",
        &["a", "b", "c", "d"],
        &[1.0, 2.0, 10.0, 4.0],
    );
    // (1 + 2) * (10 - 4) / 2 = 3 * 6 / 2 = 9
    assert_eq!(result, 9.0);
}

#[test]
fn test_reusable_expression() {
    let f = Expression::parse("x^2").unwrap().compile(&["x"]).unwrap();

    assert_eq!(f.eval(&[2.0]).unwrap(), 4.0);
    assert_eq!(f.eval(&[3.0]).unwrap(), 9.0);
    assert_eq!(f.eval(&[4.0]).unwrap(), 16.0);
}

#[test]
fn test_uses_current_value() {
    let f = Expression::parse("_ * 2 + x")
        .unwrap()
        .compile(&["x"])
        .unwrap();

    assert!(f.uses_current_value());
    assert_eq!(f.eval_with_current(5.0, &[10.0]).unwrap(), 20.0);
}