use arael_sym::{E, symbol};
#[arael::function(f_mut, derivs = [g_mut(x)])]
fn f_mut_eval(x: f64) -> f64 { x.sin() }
#[arael::function(g_mut, derivs = [-f_mut(x)])]
fn g_mut_eval(x: f64) -> f64 { x.cos() }
#[test]
fn extern_cross_ref_runtime() {
let x = symbol("y");
let e = f_mut(x.clone());
let de = e.diff("y");
let s = format!("{de}");
assert!(s.contains("g_mut(y)"),
"expected f_mut.diff to reference g_mut(y), got {s:?}");
let dg = g_mut(x).diff("y");
let sg = format!("{dg}");
assert!(sg.contains("f_mut(y)"),
"expected g_mut.diff to reference f_mut(y), got {sg:?}");
let mut env = std::collections::HashMap::new();
env.insert("y", 0.5);
let v = de.eval(&env).expect("f_mut derivative must evaluate");
assert!((v - 0.5_f64.cos()).abs() < 1e-12,
"f_mut'(0.5) = {v}, expected {}", 0.5_f64.cos());
let v = dg.eval(&env).expect("g_mut derivative must evaluate");
assert!((v - (-0.5_f64.sin())).abs() < 1e-12,
"g_mut'(0.5) = {v}, expected {}", -0.5_f64.sin());
}
#[arael::function]
fn applied_later(z: E) -> E { later_fn(z) * 2.0 }
#[arael::function(later_fn, derivs = [1.0])]
fn later_fn_eval(x: f64) -> f64 { x + 1.0 }
#[test]
fn symbolic_forward_ref_runtime() {
let y = symbol("y");
let e = applied_later(y);
let de = e.diff("y");
let mut env = std::collections::HashMap::new();
env.insert("y", 3.0);
let v = de.eval(&env).expect("derivative must evaluate");
assert!((v - 2.0).abs() < 1e-12,
"applied_later'(3.0) = {v}, expected 2.0");
let v = e.eval(&env).expect("forward must evaluate");
assert!((v - 8.0).abs() < 1e-12,
"applied_later(3.0) = {v}, expected 8.0");
}
#[arael::function(derivs = [later_fn(a)])]
fn with_xref_deriv(a: E) -> E { a * a }
#[test]
fn symbolic_explicit_deriv_xref() {
let y = symbol("y");
let de = with_xref_deriv(y).diff("y");
let mut env = std::collections::HashMap::new();
env.insert("y", 4.0);
let v = de.eval(&env).expect("explicit deriv must evaluate");
assert!((v - 5.0).abs() < 1e-12,
"with_xref_deriv'(4.0) = {v}, expected 5.0");
}
#[arael::function]
fn square2(x: E) -> E { x * x }
#[test]
fn symbolic_arg_substituted() {
let y = symbol("y");
let e = square2(y);
let s = format!("{e}");
assert!(s.contains("y") && !s.contains('x'),
"square2(y) must substitute y for x, got {s:?}");
let mut env = std::collections::HashMap::new();
env.insert("y", 3.0);
let v = e.eval(&env).expect("forward must evaluate");
assert!((v - 9.0).abs() < 1e-12, "square2(3) = {v}, expected 9");
}