use mathcompile::prelude::*;
#[test]
fn test_symbolic_to_numeric_optimization() -> Result<()> {
let mut math = MathBuilder::new();
let x = math.var("x");
let expr = math.poly(&[1.0, 2.0, 3.0], &x);
let optimized = math.optimize(&expr)?;
let result = DirectEval::eval_with_vars(&optimized, &[3.0]); assert_eq!(result, 34.0);
let codegen = RustCodeGenerator::new();
let rust_code = codegen.generate_function(&optimized, "test_function")?;
assert!(rust_code.contains("test_function"));
if RustCompiler::is_available() {
let compiler = RustCompiler::new();
let compiled_func = compiler.compile_and_load(&rust_code, "test_function")?;
let compiled_result = compiled_func.call(3.0)?;
assert_eq!(compiled_result, result); }
Ok(())
}
#[test]
fn test_basic_usage_example() -> Result<()> {
let mut math = MathBuilder::new();
let x = math.var("x");
let expr = math.add(
&math.add(&math.mul(&x, &x), &math.mul(&math.constant(2.0), &x)),
&math.constant(1.0),
);
let optimized = math.optimize(&expr)?;
let result = DirectEval::eval_with_vars(&optimized, &[3.0]); assert_eq!(result, 16.0);
let codegen = RustCodeGenerator::new();
let rust_code = codegen.generate_function(&optimized, "test_quadratic")?;
if RustCompiler::is_available() {
let compiler = RustCompiler::new();
let compiled_func = compiler.compile_and_load(&rust_code, "test_quadratic")?;
let compiled_result = compiled_func.call(3.0)?; assert_eq!(compiled_result, 16.0);
}
#[cfg(feature = "cranelift")]
{
let compiler = JITCompiler::new()?;
let compiled = compiler.compile_single_var(&optimized, "x")?;
let fast_result = compiled.call_single(3.0);
assert_eq!(fast_result, 16.0);
}
Ok(())
}
#[test]
fn test_automatic_differentiation_example() -> Result<()> {
let mut math = MathBuilder::new();
let x = math.var("x");
let f = math.poly(&[1.0, 2.0, 1.0], &x);
let optimized_f = math.optimize(&f)?;
let mut ad = SymbolicAD::new()?;
let result = ad.compute_with_derivatives(&optimized_f)?;
let _subexpr_count = result.stats.shared_subexpressions_count; assert!(!result.first_derivatives.is_empty());
Ok(())
}
#[test]
fn test_multiple_backends_example() -> Result<()> {
let mut math = MathBuilder::new();
let x = math.var("x");
let expr = math.add(&math.mul(&math.constant(2.0), &x), &math.constant(1.0));
let optimized = math.optimize(&expr)?;
#[cfg(feature = "cranelift")]
{
let compiler = JITCompiler::new()?;
let jit_func = compiler.compile_single_var(&optimized, "x")?;
let fast_result = jit_func.call_single(3.0);
assert_eq!(fast_result, 7.0); }
let codegen = RustCodeGenerator::new();
let rust_code = codegen.generate_function(&optimized, "test_backends")?;
assert!(rust_code.contains("test_backends"));
if RustCompiler::is_available() {
let compiler = RustCompiler::new();
let compiled_func = compiler.compile_and_load(&rust_code, "test_backends")?;
let compiled_result = compiled_func.call(3.0)?;
assert_eq!(compiled_result, 7.0);
}
Ok(())
}
#[test]
fn test_compile_and_load_api() -> Result<()> {
let mut math = MathBuilder::new();
let x = math.var("x");
let expr = math.mul(&math.constant(3.0), &x);
let optimized = math.optimize(&expr)?;
let codegen = RustCodeGenerator::new();
let rust_code = codegen.generate_function(&optimized, "test_api")?;
if RustCompiler::is_available() {
let compiler = RustCompiler::new();
let compiled_func = compiler.compile_and_load(&rust_code, "test_api")?;
let result1 = compiled_func.call(4.0)?;
assert_eq!(result1, 12.0);
assert_eq!(compiled_func.name(), "test_api");
}
Ok(())
}