use dslcompile::prelude::*;
#[test]
fn test_symbolic_to_numeric_optimization() -> Result<()> {
let math = MathBuilder::new();
let x = math.var();
let expr = math.poly(&[1.0, 2.0, 3.0], &x);
let result = math.eval(&expr, &[3.0]);
assert_eq!(result, 34.0);
let ast_expr = expr.into_ast();
let codegen = RustCodeGenerator::new();
let rust_code = codegen.generate_function(&ast_expr, "test_function")?;
assert!(rust_code.contains("test_function"));
if RustCompiler::is_available() {
let compiler = RustCompiler::new();
let compiled_func = compiler.compile_and_load(&rust_code, "test_function")?;
let compiled_result = compiled_func.call(3.0)?;
assert_eq!(compiled_result, result); }
Ok(())
}
#[test]
fn test_basic_usage_example() -> Result<()> {
let math = MathBuilder::new();
let x = math.var();
let expr = &x * &x + 2.0 * &x + 1.0;
let result = math.eval(&expr, &[3.0]);
assert_eq!(result, 16.0);
let ast_expr = expr.into_ast();
let codegen = RustCodeGenerator::new();
let rust_code = codegen.generate_function(&ast_expr, "test_poly")?;
if RustCompiler::is_available() {
let compiler = RustCompiler::new();
let compiled_func = compiler.compile_and_load(&rust_code, "test_poly")?;
let compiled_result = compiled_func.call(3.0)?; assert_eq!(compiled_result, 16.0);
}
#[cfg(feature = "cranelift")]
{
let mut compiler = CraneliftCompiler::new_default()?;
let registry = VariableRegistry::for_expression(&ast_expr);
let compiled = compiler.compile_expression(&ast_expr, ®istry)?;
let fast_result = compiled.call(&[3.0]).unwrap();
assert_eq!(fast_result, 16.0);
}
Ok(())
}
#[test]
fn test_automatic_differentiation_example() -> Result<()> {
let math = MathBuilder::new();
let x = math.var();
let f = math.poly(&[1.0, 2.0, 1.0], &x);
let ast_f = f.into_ast();
let mut ad = SymbolicAD::new()?;
let result = ad.compute_with_derivatives(&ast_f)?;
let _subexpr_count = result.stats.shared_subexpressions_count; assert!(!result.first_derivatives.is_empty());
Ok(())
}
#[test]
fn test_multiple_backends_example() -> Result<()> {
let math = MathBuilder::new();
let x = math.var();
let expr = 2.0 * &x + 1.0;
let ast_expr = expr.into_ast();
#[cfg(feature = "cranelift")]
{
let mut compiler = CraneliftCompiler::new_default()?;
let registry = VariableRegistry::for_expression(&ast_expr);
let jit_func = compiler.compile_expression(&ast_expr, ®istry)?;
let fast_result = jit_func.call(&[3.0]).unwrap();
assert_eq!(fast_result, 7.0); }
let codegen = RustCodeGenerator::new();
let rust_code = codegen.generate_function(&ast_expr, "test_backends")?;
assert!(rust_code.contains("test_backends"));
if RustCompiler::is_available() {
let compiler = RustCompiler::new();
let compiled_func = compiler.compile_and_load(&rust_code, "test_backends")?;
let compiled_result = compiled_func.call(3.0)?;
assert_eq!(compiled_result, 7.0);
}
Ok(())
}
#[test]
fn test_compile_and_load_api() -> Result<()> {
let math = MathBuilder::new();
let x = math.var();
let expr = 3.0 * &x;
let ast_expr = expr.into_ast();
let codegen = RustCodeGenerator::new();
let rust_code = codegen.generate_function(&ast_expr, "test_api")?;
if RustCompiler::is_available() {
let compiler = RustCompiler::new();
let compiled_func = compiler.compile_and_load(&rust_code, "test_api")?;
let result1 = compiled_func.call(4.0)?;
assert_eq!(result1, 12.0);
assert_eq!(compiled_func.name(), "test_api");
}
Ok(())
}
#[test]
fn test_readme_basic_usage() {
let math = MathBuilder::new();
let x = math.var();
let y = math.var();
let expr = &x * &x + 2.0 * &x + &y;
let result = math.eval(&expr, &[3.0, 1.0]);
assert_eq!(result, 16.0); }
#[test]
fn test_readme_optimization() {
let math = MathBuilder::new();
let x = math.var();
let expr = &x + 0.0; let result = math.eval(&expr, &[5.0]);
assert_eq!(result, 5.0);
}
#[test]
fn test_readme_compilation() {
let math = MathBuilder::new();
let x = math.var();
let y = math.var();
let poly_expr = &x * &x + 2.0 * &x + &y;
let result = math.eval(&poly_expr, &[2.0, 3.0]);
assert_eq!(result, 11.0); }
#[test]
fn test_readme_complex_example() {
let math = MathBuilder::new();
let x = math.var();
let y = math.var();
let expr = x.sin() + y.cos();
let result = math.eval(&expr, &[0.0, 0.0]);
let expected = 0.0_f64.sin() + 0.0_f64.cos(); assert!((result - expected).abs() < 1e-10);
}
#[test]
fn test_readme_performance() {
let math = MathBuilder::new();
let x = math.var();
for _i in 0..1000 {
let expr = &x * 2.0 + 1.0;
let _result = math.eval(&expr, &[3.0]);
}
assert!(true);
}
#[test]
fn test_readme_variable_management() {
let math = MathBuilder::new();
let x = math.var();
let y = math.var();
let z = math.var();
let expr = &x * &y + &z * &z;
let result = math.eval(&expr, &[2.0, 3.0, 4.0]);
assert_eq!(result, 22.0);
let x_only = &x * 2.0;
let x_result = math.eval(&x_only, &[5.0]);
assert_eq!(x_result, 10.0);
}
#[test]
fn test_readme_operator_precedence() {
let math = MathBuilder::new();
let x = math.var();
let expr = 2.0 * &x + 1.0; let result = math.eval(&expr, &[3.0]);
assert_eq!(result, 7.0);
let expr2 = 2.0 + &x * 3.0; let result2 = math.eval(&expr2, &[2.0]);
assert_eq!(result2, 8.0); }
#[test]
fn test_readme_mathematical_functions() {
let math = MathBuilder::new();
let x = math.var();
let expr = x.exp().ln(); let result = math.eval(&expr, &[2.5]);
assert!((result - 2.5).abs() < 1e-10);
let math = MathBuilder::new();
let x = math.var();
let expr = 3.0 * &x; let result = math.eval(&expr, &[4.0]);
assert_eq!(result, 12.0); }