use xad_rs::Jet1Vec;
use xad_rs::NamedForwardTape;
#[test]
fn test_expr1_arithmetic_only() {
let xp = Jet1Vec::variable(2.0_f64, 0, 2);
let yp = Jet1Vec::variable(3.0_f64, 1, 2);
let fp = &(&xp * &yp) + &xp - 2.0 * &yp;
let mut ft = NamedForwardTape::<f64>::new();
let x_h = ft.declare_jet1vec("x", 2.0);
let y_h = ft.declare_jet1vec("y", 3.0);
let scope = ft.into_scope();
let xl = scope.jet1vec(x_h);
let yl = scope.jet1vec(y_h);
let fl = &(xl * yl) + xl - 2.0 * yl;
assert_eq!(fl.real(), fp.real);
assert_eq!(fl.partial("x"), fp.partial(0));
assert_eq!(fl.partial("y"), fp.partial(1));
}
#[test]
fn test_expr2_division() {
let xp = Jet1Vec::variable(5.0, 0, 2);
let yp = Jet1Vec::variable(7.0, 1, 2);
let fp = &(&xp + 1.0) / &(&yp - 3.0);
let mut ft = NamedForwardTape::<f64>::new();
let x_h = ft.declare_jet1vec("x", 5.0);
let y_h = ft.declare_jet1vec("y", 7.0);
let scope = ft.into_scope();
let xl = scope.jet1vec(x_h);
let yl = scope.jet1vec(y_h);
let fl = &(xl + 1.0) / &(yl - 3.0);
assert_eq!(fl.real(), fp.real);
assert_eq!(fl.partial("x"), fp.partial(0));
assert_eq!(fl.partial("y"), fp.partial(1));
}
#[test]
fn test_expr3_trig_exp() {
let xp = Jet1Vec::variable(0.5, 0, 2);
let yp = Jet1Vec::variable(0.3, 1, 2);
let fp = &xp.sin() * &yp.exp();
let mut ft = NamedForwardTape::<f64>::new();
let x_h = ft.declare_jet1vec("x", 0.5);
let y_h = ft.declare_jet1vec("y", 0.3);
let scope = ft.into_scope();
let xl = scope.jet1vec(x_h);
let yl = scope.jet1vec(y_h);
let fl = &xl.sin() * &yl.exp();
assert_eq!(fl.real(), fp.real);
assert_eq!(fl.partial("x"), fp.partial(0));
assert_eq!(fl.partial("y"), fp.partial(1));
}
#[test]
fn test_expr4_sqrt_compound() {
let xp = Jet1Vec::variable(3.0, 0, 2);
let yp = Jet1Vec::variable(4.0, 1, 2);
let sum_p = &(&xp * &xp) + &(&yp * &yp);
let fp = sum_p.sqrt();
let mut ft = NamedForwardTape::<f64>::new();
let x_h = ft.declare_jet1vec("x", 3.0);
let y_h = ft.declare_jet1vec("y", 4.0);
let scope = ft.into_scope();
let xl = scope.jet1vec(x_h);
let yl = scope.jet1vec(y_h);
let sum_l = &(xl * xl) + &(yl * yl);
let fl = sum_l.sqrt();
assert_eq!(fl.real(), fp.real);
assert_eq!(fl.partial("x"), fp.partial(0));
assert_eq!(fl.partial("y"), fp.partial(1));
}
#[test]
fn test_gradient_returns_insertion_order() {
let mut ft = NamedForwardTape::<f64>::new();
let z_h = ft.declare_jet1vec("z", 1.0);
let a_h = ft.declare_jet1vec("a", 2.0);
let m_h = ft.declare_jet1vec("m", 3.0);
let scope = ft.into_scope();
let z = scope.jet1vec(z_h);
let a = scope.jet1vec(a_h);
let m = scope.jet1vec(m_h);
let f = z + a + m;
let grad = f.gradient();
assert_eq!(grad.len(), 3);
assert_eq!(grad[0].0, "z");
assert_eq!(grad[1].0, "a");
assert_eq!(grad[2].0, "m");
assert_eq!(grad[0].1, 1.0);
assert_eq!(grad[1].1, 1.0);
assert_eq!(grad[2].1, 1.0);
}
#[test]
fn test_real_and_constant() {
let mut ft = NamedForwardTape::<f64>::new();
let _x_h = ft.declare_jet1vec("x", 0.0);
let _y_h = ft.declare_jet1vec("y", 0.0);
let scope = ft.into_scope();
let c = scope.constant_jet1vec(42.0);
assert_eq!(c.real(), 42.0);
assert_eq!(c.partial("x"), 0.0);
assert_eq!(c.partial("y"), 0.0);
}
#[test]
#[should_panic(expected = "not present in registry")]
fn test_partial_unknown_name_panics() {
let mut ft = NamedForwardTape::<f64>::new();
let x_h = ft.declare_jet1vec("x", 1.0);
let scope = ft.into_scope();
let x = scope.jet1vec(x_h);
let _ = x.partial("missing");
}
#[cfg(debug_assertions)]
#[test]
#[should_panic(expected = "cross-registry forward-mode op detected")]
fn test_cross_registry_add_panics_in_debug() {
let mut ft_a = NamedForwardTape::<f64>::new();
let xa_h = ft_a.declare_jet1vec("x", 2.0);
let scope_a = ft_a.into_scope();
let mut ft_b = NamedForwardTape::<f64>::new();
let xb_h = ft_b.declare_jet1vec("x", 3.0);
let scope_b = ft_b.into_scope();
let xa = scope_a.jet1vec(xa_h);
let xb = scope_b.jet1vec(xb_h);
let _ = xa + xb;
}
#[test]
fn test_scalar_on_lhs_ops() {
let mut ft = NamedForwardTape::<f64>::new();
let x_h = ft.declare_jet1vec("x", 2.0);
let y_h = ft.declare_jet1vec("y", 3.0);
let scope = ft.into_scope();
let x = scope.jet1vec(x_h);
let y = scope.jet1vec(y_h);
let a = 1.0 + x.clone();
assert_eq!(a.real(), 3.0);
assert_eq!(a.partial("x"), 1.0);
let b = 10.0 - y;
assert_eq!(b.real(), 7.0);
assert_eq!(b.partial("y"), -1.0);
let c = 3.0 * x.clone();
assert_eq!(c.real(), 6.0);
assert_eq!(c.partial("x"), 3.0);
let d = 12.0 / y;
assert_eq!(d.real(), 4.0);
assert_eq!(d.partial("y"), -12.0 / 9.0);
}