use expr_solver::{Compiled, Linked, Number, Program, SymTable, eval, eval_with_table, num};
use indoc::indoc;
fn eval_ok(expr: &str) -> Number {
eval(expr).expect("Evaluation should be successful")
}
fn eval_err(expr: &str) -> String {
colored::control::set_override(false);
eval(expr).expect_err("Evaluation should fail")
}
fn eval_with_custom_table_ok(expr: &str, table: SymTable) -> Number {
eval_with_table(expr, table).expect("Evaluation should be successful")
}
#[cfg(feature = "f64-floats")]
fn approx_eq(a: Number, b: Number, epsilon: f64) -> bool {
(a - b).abs() < epsilon
}
#[cfg(feature = "decimal-precision")]
fn approx_eq(a: Number, b: Number, epsilon: Number) -> bool {
(a - b).abs() < epsilon
}
#[test]
fn test_arithmetic_and_precedence() {
assert_eq!(eval_ok("1 + 2"), num!(3));
assert_eq!(eval_ok("10 - 5"), num!(5));
assert_eq!(eval_ok("2 * 3"), num!(6));
assert_eq!(eval_ok("10 / 2"), num!(5));
assert_eq!(eval_ok("2 ^ 3"), num!(8));
assert_eq!(eval_ok("-5"), num!(-5));
assert_eq!(eval_ok("1 + -5"), num!(-4));
assert_eq!(eval_ok("5!"), num!(120));
assert_eq!(eval_ok("0!"), num!(1));
assert_eq!(eval_ok("1 + 2 * 3"), num!(7));
assert_eq!(eval_ok("(1 + 2) * 3"), num!(9));
assert_eq!(eval_ok("2 ^ 3 ^ 2"), num!(512)); assert_eq!(eval_ok("-2 ^ 2"), num!(-4));
assert_eq!(eval_ok("3 + 4 * 2 / (1 - 5) ^ 2"), num!(3.5));
}
#[test]
fn test_comparisons() {
assert_eq!(eval_ok("1 > 0"), num!(1));
assert_eq!(eval_ok("1 < 0"), num!(0));
assert_eq!(eval_ok("1 == 1"), num!(1));
assert_eq!(eval_ok("1 != 1"), num!(0));
assert_eq!(eval_ok("1 >= 1"), num!(1));
assert_eq!(eval_ok("1 <= 1"), num!(1));
assert_eq!(eval_ok("1 + 1 == 2"), num!(1));
}
#[test]
fn test_constants() {
let pi_result = eval_ok("pi").to_string();
assert!(pi_result.starts_with("3.14"));
let e_result = eval_ok("e").to_string();
assert!(e_result.starts_with("2.71"));
let tau_result = eval_ok("tau").to_string();
assert!(tau_result.starts_with("6.28"));
let ln2_result = eval_ok("ln2").to_string();
assert!(ln2_result.starts_with("0.69"));
}
#[test]
fn test_functions() {
assert_eq!(eval_ok("sqrt(16)"), num!(4));
assert_eq!(eval_ok("abs(-5)"), num!(5));
assert_eq!(eval_ok("pow(2, 3)"), num!(8));
assert_eq!(eval_ok("round(3.5)"), num!(4));
assert_eq!(eval_ok("floor(3.9)"), num!(3));
assert_eq!(eval_ok("ceil(3.1)"), num!(4));
assert_eq!(eval_ok("max(1, 10, 3, -5)"), num!(10));
assert_eq!(eval_ok("min(1, 10, 3, -5)"), num!(-5));
assert_eq!(eval_ok("sum(1, 2, 3, 4)"), num!(10));
assert_eq!(eval_ok("avg(1, 2, 3, 4)"), num!(2.5));
assert!(approx_eq(eval_ok("sin(pi)"), num!(0), num!(0.0001)));
assert_eq!(eval_ok("cos(0)"), num!(1));
assert!(approx_eq(eval_ok("tan(0)"), num!(0), num!(0.0001)));
assert!(approx_eq(eval_ok("log10(100)"), num!(2), num!(0.0001)));
assert_eq!(eval_ok("clamp(5, 1, 10)"), num!(5));
assert_eq!(eval_ok("clamp(0, 1, 10)"), num!(1));
assert_eq!(eval_ok("clamp(12, 1, 10)"), num!(10));
}
#[test]
fn test_decimal_native_functions() {
let log2_1024 = eval_ok("log2(1024)");
assert!(approx_eq(log2_1024, num!(10), num!(0.00001)));
let exp2_10 = eval_ok("exp2(10)");
assert!(approx_eq(exp2_10, num!(1024), num!(0.001)));
let sinh_1 = eval_ok("sinh(1)");
assert!(approx_eq(sinh_1, num!(1.175201193), num!(0.0001)));
let cosh_1 = eval_ok("cosh(1)");
assert!(approx_eq(cosh_1, num!(1.543080634), num!(0.0001)));
let tanh_1 = eval_ok("tanh(1)");
assert!(approx_eq(tanh_1, num!(0.761594156), num!(0.0001)));
assert!(eval_ok("tanh(10)") > num!(0.99));
assert_eq!(eval_ok("cbrt(27)"), num!(3));
assert_eq!(eval_ok("cbrt(-8)"), num!(-2));
let cbrt_10 = eval_ok("cbrt(10)");
assert!(approx_eq(cbrt_10, num!(2.154434690), num!(0.0001)));
assert_eq!(eval_ok("hypot(3, 4)"), num!(5));
assert_eq!(eval_ok("hypot(5, 12)"), num!(13));
}
#[test]
fn test_complex_expressions() {
assert!(approx_eq(
eval_ok("sin(pi / 2) + cos(pi)"),
num!(0),
num!(0.0001)
));
assert_eq!(eval_ok("max(sqrt(25), pow(2, 4), 10)"), num!(16));
assert_eq!(eval_ok("sum(1, 2, 3, max(4, 5))"), num!(11));
assert_eq!(eval_ok("floor(abs(-3.7)) + ceil(2.1)"), num!(6));
assert_eq!(eval_ok("-1!"), num!(-1));
assert_eq!(eval_ok("-3!^2"), num!(-36));
}
#[test]
fn test_custom_symbols() {
let mut table = SymTable::stdlib();
table.add_const("my_const", num!(123), false).unwrap();
table
.add_func("add_one", 1, false, |args| Ok(args[0] + num!(1)), false)
.unwrap();
assert_eq!(
eval_with_custom_table_ok("my_const + 10", table.clone()),
num!(133)
);
assert_eq!(
eval_with_custom_table_ok("add_one(my_const)", table),
num!(124)
);
}
#[test]
fn test_emoji_identifiers() {
let mut table = SymTable::stdlib();
table.add_const("x😀", num!(10), false).unwrap();
table
.add_func("add🚀", 2, false, |args| Ok(args[0] + args[1]), false)
.unwrap();
assert_eq!(
eval_with_custom_table_ok("x😀 + 5", table.clone()),
num!(15)
);
assert_eq!(eval_with_custom_table_ok("add🚀(x😀, 2)", table), num!(12));
}
#[test]
#[rustfmt::skip]
fn test_syntax_errors() {
assert_eq!(eval_err("1 + * 2"), indoc! {r#"
Unexpected token: unexpected token '*', expected an expression
1 | 1 + * 2
| ^"#
});
assert_eq!(eval_err("(1 + 2"), indoc! {r#"
Unexpected token: unexpected token 'EOF', expected ')'
1 | (1 + 2
| ^"#
});
assert_eq!(eval_err("1 2"), indoc! {r#"
Unexpected token: unexpected token '2', expected 'EOF'
1 | 1 2
| ^"#
});
assert_eq!(eval_err("()"), indoc! {r#"
Unexpected token: unexpected token ')', expected an expression
1 | ()
| ^"#
});
assert_eq!(eval_err("sin("), indoc! {r#"
Unexpected token: unexpected token 'EOF', expected an expression
1 | sin(
| ^"#
});
assert_eq!(eval_err("1 + "), indoc! {r#"
Unexpected token: unexpected token 'EOF', expected an expression
1 | 1 +
| ^"#
});
}
#[test]
#[rustfmt::skip]
fn test_semantic_errors() {
assert_eq!(eval_err("foo()"), "Symbol 'foo' not found");
assert_eq!(eval_err("bar"), "Symbol 'bar' not found");
assert_eq!(eval_err("sin(1, 2)"), "Link error: Type mismatch for symbol 'sin': expected exactly 1 arguments, found 2 arguments provided");
assert_eq!(eval_err("max()"), "Link error: Type mismatch for symbol 'max': expected at least 1 arguments, found 0 arguments provided");
assert_eq!(eval_err("pi()"), "Link error: Type mismatch for symbol 'pi': expected function, found constant");
assert_eq!(eval_err("1 + sin"), "Link error: Type mismatch for symbol 'sin': expected constant, found function");
}
#[test]
fn test_runtime_errors() {
assert_eq!(eval_err("1 / 0"), "Division by zero");
assert_eq!(
eval_err("1.5!"),
"Invalid factorial: 1.5 (must be a non-negative integer)"
);
#[cfg(feature = "decimal-precision")]
{
assert_eq!(
eval_err("log(-1)"),
"Function error: Domain error in function 'log': invalid input -1"
);
assert_eq!(
eval_err("sqrt(-4)"),
"Function error: Square root of negative number: -4"
);
}
#[cfg(feature = "f64-floats")]
{
let result = eval_ok("log(-1)");
assert!(result.is_nan());
let result = eval_ok("sqrt(-4)");
assert!(result.is_nan());
}
}
#[test]
fn test_if_expressions() {
assert_eq!(eval_ok("if(1, 10, 20)"), num!(10));
assert_eq!(eval_ok("if(0, 10, 20)"), num!(20));
assert_eq!(eval_ok("if(0.5, 10, 20)"), num!(10));
assert_eq!(eval_ok("if(5 > 3, 100, 200)"), num!(100));
assert_eq!(eval_ok("if(5 == 5, 1, 0)"), num!(1));
assert_eq!(eval_ok("if(5 != 3, 1, 0)"), num!(1));
assert_eq!(eval_ok("if(5 - 5, 1, 0)"), num!(0));
assert_eq!(eval_ok("if(1, 2 + 3, 4 * 5)"), num!(5));
assert_eq!(eval_ok("if(0, 2 + 3, 4 * 5) + 10"), num!(30));
assert_eq!(eval_ok("if(abs(-5), 1, 0)"), num!(1));
assert_eq!(eval_ok("if(1, abs(-10), abs(-20))"), num!(10));
assert_eq!(eval_ok("if(max(1, 2) > 0, 42, 0)"), num!(42));
assert_eq!(eval_ok("IF(1, 10, 20)"), num!(10));
}
#[test]
fn test_if_nested() {
assert_eq!(eval_ok("if(1, if(1, 10, 20), 30)"), num!(10));
assert_eq!(eval_ok("if(0, 10, if(1, 20, 30))"), num!(20));
assert_eq!(eval_ok("if(if(1, 1, 0), 100, 200)"), num!(100));
assert_eq!(eval_ok("if(1, if(1, if(1, 1, 2), 3), 4)"), num!(1));
assert_eq!(eval_ok("if(1, if(1, if(0, 1, 2), 3), 4)"), num!(2));
assert_eq!(eval_ok("if(0, if(1, if(1, 1, 2), 3), 4)"), num!(4));
}
#[test]
fn test_if_short_circuit() {
assert_eq!(eval_ok("if(1, 42, 1/0)"), num!(42));
assert_eq!(eval_ok("if(0, 1/0, 42)"), num!(42));
}
#[test]
fn test_if_error_cases() {
let err = eval_err("if(1, 2)");
assert!(err.contains("expected ')'") || err.contains("expected ','"));
let err = eval_err("if 1, 2, 3");
assert!(err.contains("expected '('"));
}
fn load_with_table(
expr: &'static str,
table: SymTable,
) -> Result<Program<'static, Linked>, String> {
let program = Program::new_from_source(expr).map_err(|err| err.to_string())?;
program.link(table).map_err(|err| err.to_string())
}
fn load(expr: &'static str) -> Result<Program<'static, Compiled>, String> {
Program::new_from_source(expr).map_err(|err| err.to_string())
}
#[test]
fn test_program_compile_link_execute() {
let mut program = load_with_table("2 + 3 * 4", SymTable::stdlib()).expect("link failed");
assert_eq!(program.execute().expect("execution failed"), num!(14));
let mut program =
load_with_table("sqrt(16) + sin(0)", SymTable::stdlib()).expect("link failed");
assert_eq!(program.execute().expect("execution failed"), num!(4));
}
#[test]
fn test_program_symtable_mutation() {
let program = load("x + y").expect("compilation failed");
let mut table = SymTable::new();
table.add_const("x", num!(10), false).unwrap();
table.add_const("y", num!(20), false).unwrap();
let mut program = program.link(table).expect("link failed");
assert_eq!(program.execute().expect("execution failed"), num!(30));
program
.symtable_mut()
.add_const("z", num!(100), false)
.unwrap();
assert_eq!(program.execute().expect("execution failed"), num!(30));
}
#[test]
#[cfg(feature = "serialization")]
fn test_program_serialization() {
let mut program = load_with_table("sqrt(pi) + 2", SymTable::stdlib()).expect("link failed");
let result1 = program.execute().expect("execution failed");
let bytes = program.to_bytecode().expect("serialization failed");
use expr_solver::Program;
let mut program2 = Program::new_from_bytecode(&bytes)
.expect("deserialization failed")
.link(SymTable::stdlib())
.expect("link failed");
let result2 = program2.execute().expect("execution failed");
assert_eq!(result1, result2);
}
#[test]
fn test_program_assembly() {
let program = load_with_table("2 + 3", SymTable::stdlib()).expect("link failed");
let assembly = program.get_assembly();
assert!(assembly.contains("PUSH"));
assert!(assembly.contains("ADD"));
}
#[test]
fn test_program_link_validation() {
let program = load("x + y").expect("compilation failed");
let empty_table = SymTable::new();
let result = program.link(empty_table);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not found"));
}
#[test]
fn test_let_simple() {
assert_eq!(eval_ok("let x = 10 then x"), num!(10));
assert_eq!(eval_ok("let x = 5 * 2 then x + 1"), num!(11));
}
#[test]
fn test_let_multiple_declarations() {
assert_eq!(eval_ok("let x = 2, y = 3 then x + y"), num!(5));
assert_eq!(eval_ok("let x = 1, y = 2, z = 3 then x + y + z"), num!(6));
}
#[test]
fn test_let_reference_previous() {
assert_eq!(eval_ok("let x = 1, y = x + 1 then y"), num!(2));
assert_eq!(eval_ok("let x = 2, y = x * 3, z = y + 1 then z"), num!(7));
}
#[test]
fn test_let_with_globals() {
assert_eq!(eval_ok("let x = pi then x"), eval_ok("pi"));
assert_eq!(eval_ok("let x = sin(pi / 2) then x"), num!(1));
assert_eq!(
eval_ok("let r = 5, area = pi * r ^ 2 then area"),
eval_ok("pi * 25")
);
}
#[test]
fn test_let_complex_expressions() {
assert_eq!(eval_ok("let x = if(1 < 2, 10, 20) then x"), num!(10));
assert_eq!(eval_ok("let x = 2, y = x ^ 3 then y * 2"), num!(16));
assert_eq!(eval_ok("let x = sqrt(16), y = x + 4 then y"), num!(8));
}
#[test]
fn test_let_error_shadowing_global() {
let err = eval_err("let pi = 3 then pi");
assert_eq!(err, "Duplicate symbol definition: 'pi'");
let err = eval_err("let e = 2 then e");
assert_eq!(err, "Duplicate symbol definition: 'e'");
}
#[test]
fn test_let_error_duplicate_names() {
let err = eval_err("let x = 1, x = 2 then x");
assert_eq!(err, "Duplicate symbol definition: 'x'");
let err = eval_err("let x = 1, y = 2, x = 3 then x + y");
assert_eq!(err, "Duplicate symbol definition: 'x'");
}
#[test]
fn test_let_error_forward_reference() {
let err = eval_err("let x = y, y = 1 then x");
assert_eq!(err, "Duplicate symbol definition: 'y'");
}
#[test]
fn test_let_error_self_reference() {
let err = eval_err("let x = x + 1 then x");
assert_eq!(err, "Symbol 'x' not found");
}
#[test]
fn test_let_with_custom_table() {
let mut table = SymTable::stdlib();
table.add_const("custom", num!(42), false).unwrap();
let result = eval_with_custom_table_ok("let x = custom * 2 then x", table);
assert_eq!(result, num!(84));
}
#[test]
fn test_let_case_insensitive_keywords() {
assert_eq!(eval_ok("LET x = 10 THEN x"), num!(10));
assert_eq!(eval_ok("Let x = 5 Then x * 2"), num!(10));
assert_eq!(eval_ok("leT x = 3, y = 7 tHeN x + y"), num!(10));
}
#[test]
#[cfg(feature = "serialization")]
fn test_let_if_serialization_roundtrip() {
use expr_solver::Program;
let source =
"let a = if(5 > 3, 10, 20), b = if(a > 15, a * 2, a + 5) then if(b < 20, b * 3, b - 10)";
let program = Program::new_from_source(source).expect("Failed to compile");
let table = SymTable::stdlib();
let mut linked = program.link(table.clone()).expect("Failed to link");
let result1 = linked.execute().expect("Failed to execute");
let bytecode = linked.to_bytecode().expect("Failed to serialize");
let loaded_program = Program::new_from_bytecode(&bytecode).expect("Failed to deserialize");
let mut relinked = loaded_program.link(table).expect("Failed to relink");
let result2 = relinked.execute().expect("Failed to execute reloaded");
assert_eq!(
result1, result2,
"Serialization roundtrip produced different result"
);
assert_eq!(result1, num!(45));
}