use crate::tensor::Tensor;
use cjc_repro::kahan_sum_f64;
pub fn ode_step_euler(y: &Tensor, dydt: &Tensor, h: f64) -> Tensor {
let y_data = y.to_vec();
let dy_data = dydt.to_vec();
assert_eq!(y_data.len(), dy_data.len(), "ode_step_euler: y and dydt must have same length");
let result: Vec<f64> = y_data.iter().zip(dy_data.iter())
.map(|(&yi, &dyi)| yi + h * dyi)
.collect();
Tensor::from_vec_unchecked(result, y.shape())
}
pub fn ode_step_rk4(y: &Tensor, k1: &Tensor, k2: &Tensor, k3: &Tensor, k4: &Tensor, h: f64) -> Tensor {
let y_data = y.to_vec();
let k1_data = k1.to_vec();
let k2_data = k2.to_vec();
let k3_data = k3.to_vec();
let k4_data = k4.to_vec();
let n = y_data.len();
assert_eq!(k1_data.len(), n);
assert_eq!(k2_data.len(), n);
assert_eq!(k3_data.len(), n);
assert_eq!(k4_data.len(), n);
let h6 = h / 6.0;
let result: Vec<f64> = (0..n)
.map(|i| {
let terms = [
k1_data[i],
2.0 * k2_data[i],
2.0 * k3_data[i],
k4_data[i],
];
y_data[i] + h6 * kahan_sum_f64(&terms)
})
.collect();
Tensor::from_vec_unchecked(result, y.shape())
}
pub fn pde_laplacian_1d(u: &Tensor, dx: f64) -> Tensor {
let data = u.to_vec();
let n = data.len();
let dx2_inv = 1.0 / (dx * dx);
let mut lap = vec![0.0_f64; n];
for i in 1..n - 1 {
lap[i] = (data[i - 1] - 2.0 * data[i] + data[i + 1]) * dx2_inv;
}
Tensor::from_vec_unchecked(lap, u.shape())
}
pub fn pde_step_diffusion(u: &Tensor, alpha: f64, dt: f64, dx: f64) -> Tensor {
let lap = pde_laplacian_1d(u, dx);
let u_data = u.to_vec();
let lap_data = lap.to_vec();
let result: Vec<f64> = u_data.iter().zip(lap_data.iter())
.map(|(&ui, &li)| ui + dt * alpha * li)
.collect();
Tensor::from_vec_unchecked(result, u.shape())
}
#[derive(Debug, Clone, PartialEq)]
pub enum SymExpr {
Const(f64),
Var(String),
Add(Box<SymExpr>, Box<SymExpr>),
Mul(Box<SymExpr>, Box<SymExpr>),
Pow(Box<SymExpr>, f64),
Sin(Box<SymExpr>),
Cos(Box<SymExpr>),
Exp(Box<SymExpr>),
Ln(Box<SymExpr>),
Neg(Box<SymExpr>),
}
impl SymExpr {
pub fn differentiate(&self, var: &str) -> SymExpr {
match self {
SymExpr::Const(_) => SymExpr::Const(0.0),
SymExpr::Var(name) => {
if name == var {
SymExpr::Const(1.0)
} else {
SymExpr::Const(0.0)
}
}
SymExpr::Add(a, b) => SymExpr::Add(
Box::new(a.differentiate(var)),
Box::new(b.differentiate(var)),
),
SymExpr::Mul(a, b) => {
SymExpr::Add(
Box::new(SymExpr::Mul(
Box::new(a.differentiate(var)),
b.clone(),
)),
Box::new(SymExpr::Mul(
a.clone(),
Box::new(b.differentiate(var)),
)),
)
}
SymExpr::Pow(base, exp) => {
SymExpr::Mul(
Box::new(SymExpr::Mul(
Box::new(SymExpr::Const(*exp)),
Box::new(SymExpr::Pow(base.clone(), exp - 1.0)),
)),
Box::new(base.differentiate(var)),
)
}
SymExpr::Sin(inner) => {
SymExpr::Mul(
Box::new(SymExpr::Cos(inner.clone())),
Box::new(inner.differentiate(var)),
)
}
SymExpr::Cos(inner) => {
SymExpr::Mul(
Box::new(SymExpr::Neg(Box::new(SymExpr::Sin(inner.clone())))),
Box::new(inner.differentiate(var)),
)
}
SymExpr::Exp(inner) => {
SymExpr::Mul(
Box::new(SymExpr::Exp(inner.clone())),
Box::new(inner.differentiate(var)),
)
}
SymExpr::Ln(inner) => {
SymExpr::Mul(
Box::new(SymExpr::Pow(inner.clone(), -1.0)),
Box::new(inner.differentiate(var)),
)
}
SymExpr::Neg(inner) => {
SymExpr::Neg(Box::new(inner.differentiate(var)))
}
}
}
pub fn eval(&self, bindings: &std::collections::BTreeMap<String, f64>) -> f64 {
match self {
SymExpr::Const(c) => *c,
SymExpr::Var(name) => *bindings.get(name).unwrap_or(&0.0),
SymExpr::Add(a, b) => a.eval(bindings) + b.eval(bindings),
SymExpr::Mul(a, b) => a.eval(bindings) * b.eval(bindings),
SymExpr::Pow(base, exp) => base.eval(bindings).powf(*exp),
SymExpr::Sin(inner) => inner.eval(bindings).sin(),
SymExpr::Cos(inner) => inner.eval(bindings).cos(),
SymExpr::Exp(inner) => inner.eval(bindings).exp(),
SymExpr::Ln(inner) => inner.eval(bindings).ln(),
SymExpr::Neg(inner) => -inner.eval(bindings),
}
}
}
#[allow(dead_code)]
fn tensor_add(a: &Tensor, b: &Tensor) -> Tensor {
let a_data = a.to_vec();
let b_data = b.to_vec();
debug_assert_eq!(a_data.len(), b_data.len());
let result: Vec<f64> = a_data.iter().zip(b_data.iter()).map(|(&ai, &bi)| ai + bi).collect();
Tensor::from_vec_unchecked(result, a.shape())
}
fn tensor_scale(a: &Tensor, scalar: f64) -> Tensor {
let a_data = a.to_vec();
let result: Vec<f64> = a_data.iter().map(|&ai| scalar * ai).collect();
Tensor::from_vec_unchecked(result, a.shape())
}
fn tensor_add_scaled(a: &Tensor, b: &Tensor, scalar: f64) -> Tensor {
let a_data = a.to_vec();
let b_data = b.to_vec();
debug_assert_eq!(a_data.len(), b_data.len());
let result: Vec<f64> = a_data.iter().zip(b_data.iter()).map(|(&ai, &bi)| ai + scalar * bi).collect();
Tensor::from_vec_unchecked(result, a.shape())
}
fn tensor_norm(a: &Tensor) -> f64 {
let data = a.to_vec();
let terms: Vec<f64> = data.iter().map(|&x| x * x).collect();
kahan_sum_f64(&terms).sqrt()
}
pub fn ode_solve_rk4<F>(
mut f: F,
y0: &Tensor,
t_span: (f64, f64),
n_steps: usize,
) -> (Vec<f64>, Vec<Tensor>)
where
F: FnMut(f64, &Tensor) -> Tensor,
{
assert!(n_steps > 0, "ode_solve_rk4: n_steps must be > 0");
let (t0, t1) = t_span;
let h = (t1 - t0) / n_steps as f64;
let mut ts = Vec::with_capacity(n_steps + 1);
let mut ys = Vec::with_capacity(n_steps + 1);
ts.push(t0);
ys.push(y0.clone());
let mut t = t0;
let mut y = y0.clone();
for _ in 0..n_steps {
let k1 = f(t, &y);
let y2 = tensor_add_scaled(&y, &k1, h * 0.5);
let k2 = f(t + h * 0.5, &y2);
let y3 = tensor_add_scaled(&y, &k2, h * 0.5);
let k3 = f(t + h * 0.5, &y3);
let y4 = tensor_add_scaled(&y, &k3, h);
let k4 = f(t + h, &y4);
y = ode_step_rk4(&y, &k1, &k2, &k3, &k4, h);
t += h;
ts.push(t);
ys.push(y.clone());
}
(ts, ys)
}
mod dp5 {
pub const C2: f64 = 1.0 / 5.0;
pub const C3: f64 = 3.0 / 10.0;
pub const C4: f64 = 4.0 / 5.0;
pub const C5: f64 = 8.0 / 9.0;
pub const A21: f64 = 1.0 / 5.0;
pub const A31: f64 = 3.0 / 40.0;
pub const A32: f64 = 9.0 / 40.0;
pub const A41: f64 = 44.0 / 45.0;
pub const A42: f64 = -56.0 / 15.0;
pub const A43: f64 = 32.0 / 9.0;
pub const A51: f64 = 19372.0 / 6561.0;
pub const A52: f64 = -25360.0 / 2187.0;
pub const A53: f64 = 64448.0 / 6561.0;
pub const A54: f64 = -212.0 / 729.0;
pub const A61: f64 = 9017.0 / 3168.0;
pub const A62: f64 = -355.0 / 33.0;
pub const A63: f64 = 46732.0 / 5247.0;
pub const A64: f64 = 49.0 / 176.0;
pub const A65: f64 = -5103.0 / 18656.0;
pub const B1: f64 = 35.0 / 384.0;
pub const B3: f64 = 500.0 / 1113.0;
pub const B4: f64 = 125.0 / 192.0;
pub const B5: f64 = -2187.0 / 6784.0;
pub const B6: f64 = 11.0 / 84.0;
pub const E1: f64 = 71.0 / 57600.0;
pub const E3: f64 = -71.0 / 16695.0;
pub const E4: f64 = 71.0 / 1920.0;
pub const E5: f64 = -17253.0 / 339200.0;
pub const E6: f64 = 22.0 / 525.0;
pub const E7: f64 = -1.0 / 40.0;
}
pub fn ode_solve_rk45<F>(
mut f: F,
y0: &Tensor,
t_span: (f64, f64),
rtol: f64,
atol: f64,
) -> (Vec<f64>, Vec<Tensor>)
where
F: FnMut(f64, &Tensor) -> Tensor,
{
let (t0, t1) = t_span;
assert!(t1 > t0, "ode_solve_rk45: t1 must be > t0");
let mut ts = Vec::new();
let mut ys = Vec::new();
ts.push(t0);
ys.push(y0.clone());
let n = y0.to_vec().len();
let f0 = f(t0, y0);
let f0_norm = tensor_norm(&f0).max(1e-300);
let mut h = (0.01 * (t1 - t0)).min(0.1 / f0_norm);
h = h.max(1e-12);
let mut t = t0;
let mut y = y0.clone();
let safety = 0.9_f64;
let max_factor = 10.0_f64;
let min_factor = 0.2_f64;
let max_steps = 1_000_000_usize;
let mut step_count = 0;
while t < t1 && step_count < max_steps {
if t + h > t1 {
h = t1 - t;
}
let k1 = f(t, &y);
let y2 = tensor_add_scaled(&y, &k1, h * dp5::A21);
let k2 = f(t + dp5::C2 * h, &y2);
let mut y3_data = y.to_vec();
let k1d = k1.to_vec(); let k2d = k2.to_vec();
for i in 0..n {
y3_data[i] += h * (dp5::A31 * k1d[i] + dp5::A32 * k2d[i]);
}
let y3 = Tensor::from_vec_unchecked(y3_data, y.shape());
let k3 = f(t + dp5::C3 * h, &y3);
let k3d = k3.to_vec();
let mut y4_data = y.to_vec();
for i in 0..n {
y4_data[i] += h * (dp5::A41 * k1d[i] + dp5::A42 * k2d[i] + dp5::A43 * k3d[i]);
}
let y4 = Tensor::from_vec_unchecked(y4_data, y.shape());
let k4 = f(t + dp5::C4 * h, &y4);
let k4d = k4.to_vec();
let mut y5_data = y.to_vec();
for i in 0..n {
y5_data[i] += h * (dp5::A51 * k1d[i] + dp5::A52 * k2d[i] + dp5::A53 * k3d[i] + dp5::A54 * k4d[i]);
}
let y5 = Tensor::from_vec_unchecked(y5_data, y.shape());
let k5 = f(t + dp5::C5 * h, &y5);
let k5d = k5.to_vec();
let mut y6_data = y.to_vec();
for i in 0..n {
y6_data[i] += h * (dp5::A61 * k1d[i] + dp5::A62 * k2d[i] + dp5::A63 * k3d[i] + dp5::A64 * k4d[i] + dp5::A65 * k5d[i]);
}
let y6 = Tensor::from_vec_unchecked(y6_data, y.shape());
let k6 = f(t + h, &y6);
let k6d = k6.to_vec();
let y_data = y.to_vec();
let mut y5th_data = vec![0.0_f64; n];
for i in 0..n {
let terms = [
dp5::B1 * k1d[i],
dp5::B3 * k3d[i],
dp5::B4 * k4d[i],
dp5::B5 * k5d[i],
dp5::B6 * k6d[i],
];
y5th_data[i] = y_data[i] + h * kahan_sum_f64(&terms);
}
let y5th = Tensor::from_vec_unchecked(y5th_data.clone(), y.shape());
let k7 = f(t + h, &y5th);
let k7d = k7.to_vec();
let mut err_sq_acc = 0.0_f64;
for i in 0..n {
let e_terms = [
dp5::E1 * k1d[i],
dp5::E3 * k3d[i],
dp5::E4 * k4d[i],
dp5::E5 * k5d[i],
dp5::E6 * k6d[i],
dp5::E7 * k7d[i],
];
let e_i = h * kahan_sum_f64(&e_terms);
let sc = atol + rtol * y5th_data[i].abs().max(y_data[i].abs());
err_sq_acc += (e_i / sc) * (e_i / sc);
}
let err_norm = (err_sq_acc / n as f64).sqrt();
if err_norm <= 1.0 {
t += h;
y = y5th;
ts.push(t);
ys.push(y.clone());
step_count += 1;
let factor = safety * err_norm.powf(-0.2).min(max_factor).max(min_factor);
h = (h * factor).min(t1 - t);
if h < 1e-14 {
break;
}
} else {
let factor = (safety * err_norm.powf(-0.25)).max(min_factor);
h *= factor;
if h < 1e-14 {
break;
}
}
}
(ts, ys)
}
pub fn adjoint_solve<F, G>(
mut f: F,
mut grad_f: G,
y_final: &Tensor,
t_span: (f64, f64),
n_steps: usize,
) -> (Tensor, Tensor)
where
F: FnMut(f64, &Tensor) -> Tensor,
G: FnMut(f64, &Tensor, &Tensor) -> (Tensor, Tensor),
{
assert!(n_steps > 0, "adjoint_solve: n_steps must be > 0");
let (t0, t1) = t_span;
let h = (t1 - t0) / n_steps as f64;
let n = y_final.to_vec().len();
let a0 = Tensor::from_vec_unchecked(vec![0.0_f64; n], y_final.shape());
let mut t = t1;
let mut y = y_final.clone();
let mut a = a0;
for _ in 0..n_steps {
let t_prev = t - h;
let ky1 = tensor_scale(&f(t, &y), -1.0);
let (ka1, _) = grad_f(t, &y, &a);
let y2 = tensor_add_scaled(&y, &ky1, h * 0.5);
let a2 = tensor_add_scaled(&a, &ka1, h * 0.5);
let ky2 = tensor_scale(&f(t - h * 0.5, &y2), -1.0);
let (ka2, _) = grad_f(t - h * 0.5, &y2, &a2);
let y3 = tensor_add_scaled(&y, &ky2, h * 0.5);
let a3 = tensor_add_scaled(&a, &ka2, h * 0.5);
let ky3 = tensor_scale(&f(t - h * 0.5, &y3), -1.0);
let (ka3, _) = grad_f(t - h * 0.5, &y3, &a3);
let y4 = tensor_add_scaled(&y, &ky3, h);
let a4 = tensor_add_scaled(&a, &ka3, h);
let ky4 = tensor_scale(&f(t_prev, &y4), -1.0);
let (ka4, _) = grad_f(t_prev, &y4, &a4);
y = ode_step_rk4(&y, &ky1, &ky2, &ky3, &ky4, h);
a = ode_step_rk4(&a, &ka1, &ka2, &ka3, &ka4, h);
t = t_prev;
}
(y, a)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::BTreeMap;
#[test]
fn test_euler_step() {
let y = Tensor::from_vec_unchecked(vec![1.0, 0.0], &[2]);
let dydt = Tensor::from_vec_unchecked(vec![0.0, 1.0], &[2]);
let y1 = ode_step_euler(&y, &dydt, 0.1);
let result = y1.to_vec();
assert!((result[0] - 1.0).abs() < 1e-15);
assert!((result[1] - 0.1).abs() < 1e-15);
}
#[test]
fn test_rk4_step_constant() {
let y = Tensor::from_vec_unchecked(vec![1.0], &[1]);
let k = Tensor::from_vec_unchecked(vec![2.0], &[1]);
let y1 = ode_step_rk4(&y, &k, &k, &k, &k, 0.1);
assert!((y1.to_vec()[0] - 1.2).abs() < 1e-14);
}
#[test]
fn test_laplacian_1d() {
let u = Tensor::from_vec_unchecked(vec![0.0, 1.0, 4.0, 9.0, 16.0], &[5]);
let lap = pde_laplacian_1d(&u, 1.0);
let data = lap.to_vec();
assert!((data[0] - 0.0).abs() < 1e-14); assert!((data[1] - 2.0).abs() < 1e-14);
assert!((data[2] - 2.0).abs() < 1e-14);
assert!((data[3] - 2.0).abs() < 1e-14);
assert!((data[4] - 0.0).abs() < 1e-14); }
#[test]
fn test_symbolic_diff_polynomial() {
let expr = SymExpr::Pow(Box::new(SymExpr::Var("x".into())), 3.0);
let deriv = expr.differentiate("x");
let mut bindings = BTreeMap::new();
bindings.insert("x".into(), 2.0);
let val = deriv.eval(&bindings);
assert!((val - 12.0).abs() < 1e-12); }
#[test]
fn test_symbolic_diff_sin() {
let expr = SymExpr::Sin(Box::new(SymExpr::Var("x".into())));
let deriv = expr.differentiate("x");
let mut bindings = BTreeMap::new();
bindings.insert("x".into(), 0.0);
let val = deriv.eval(&bindings);
assert!((val - 1.0).abs() < 1e-12); }
#[test]
fn test_rk4_exponential_decay() {
let y0 = Tensor::from_vec_unchecked(vec![1.0], &[1]);
let f = |_t: f64, y: &Tensor| -> Tensor {
tensor_scale(y, -1.0)
};
let (ts, ys) = ode_solve_rk4(f, &y0, (0.0, 1.0), 100);
assert_eq!(ts.len(), 101);
assert_eq!(ys.len(), 101);
assert!((ts[0] - 0.0).abs() < 1e-15);
assert!((ts[100] - 1.0).abs() < 1e-12);
let y_final = ys[100].to_vec()[0];
let exact = (-1.0_f64).exp();
assert!(
(y_final - exact).abs() < 1e-8,
"RK4 decay: got {}, expected {}",
y_final, exact
);
}
#[test]
fn test_rk4_harmonic_oscillator() {
let y0 = Tensor::from_vec_unchecked(vec![1.0, 0.0], &[2]);
let f = |_t: f64, y: &Tensor| -> Tensor {
let d = y.to_vec();
Tensor::from_vec_unchecked(vec![d[1], -d[0]], &[2])
};
let t_end = std::f64::consts::PI / 2.0;
let (ts, ys) = ode_solve_rk4(f, &y0, (0.0, t_end), 1000);
let y_end = ys.last().unwrap().to_vec();
assert!(
y_end[0].abs() < 1e-7,
"harmonic y(pi/2) should be ~0, got {}",
y_end[0]
);
assert!(
(y_end[1] - (-1.0)).abs() < 1e-7,
"harmonic v(pi/2) should be ~-1, got {}",
y_end[1]
);
let _ = ts;
}
#[test]
fn test_rk45_exponential_decay() {
let y0 = Tensor::from_vec_unchecked(vec![1.0], &[1]);
let f = |_t: f64, y: &Tensor| -> Tensor {
tensor_scale(y, -1.0)
};
let (ts, ys) = ode_solve_rk45(f, &y0, (0.0, 1.0), 1e-8, 1e-10);
assert!(!ts.is_empty(), "RK45 should produce at least one step");
let y_final = ys.last().unwrap().to_vec()[0];
let t_final = *ts.last().unwrap();
let exact = (-t_final).exp();
assert!(
(y_final - exact).abs() < 1e-6,
"RK45 decay: got {} at t={}, expected {}",
y_final, t_final, exact
);
}
#[test]
fn test_rk45_fewer_steps_than_rk4_fixed() {
let y0 = Tensor::from_vec_unchecked(vec![1.0], &[1]);
let f_adaptive = |_t: f64, y: &Tensor| -> Tensor { tensor_scale(y, -1.0) };
let f_fixed = |_t: f64, y: &Tensor| -> Tensor { tensor_scale(y, -1.0) };
let (ts_adaptive, _) = ode_solve_rk45(f_adaptive, &y0, (0.0, 1.0), 1e-6, 1e-8);
let (ts_fixed, _) = ode_solve_rk4(f_fixed, &y0, (0.0, 1.0), 1000);
assert!(
ts_adaptive.len() < ts_fixed.len(),
"RK45 adaptive ({} steps) should take fewer steps than RK4 fixed ({} steps)",
ts_adaptive.len() - 1,
ts_fixed.len() - 1
);
}
#[test]
fn test_rk4_determinism() {
let y0 = Tensor::from_vec_unchecked(vec![1.0, 0.5], &[2]);
let f = |_t: f64, y: &Tensor| -> Tensor {
let d = y.to_vec();
Tensor::from_vec_unchecked(vec![-0.5 * d[0], -0.3 * d[1]], &[2])
};
let (ts1, ys1) = ode_solve_rk4(|t, y| { let d = y.to_vec(); Tensor::from_vec_unchecked(vec![-0.5*d[0], -0.3*d[1]], &[2]) }, &y0, (0.0, 1.0), 50);
let (ts2, ys2) = ode_solve_rk4(|t, y| { let d = y.to_vec(); Tensor::from_vec_unchecked(vec![-0.5*d[0], -0.3*d[1]], &[2]) }, &y0, (0.0, 1.0), 50);
assert_eq!(ts1, ts2, "RK4 time points must be bit-identical");
for (y1, y2) in ys1.iter().zip(ys2.iter()) {
assert_eq!(y1.to_vec(), y2.to_vec(), "RK4 solutions must be bit-identical");
}
let _ = f;
}
#[test]
fn test_rk45_determinism() {
let y0 = Tensor::from_vec_unchecked(vec![1.0], &[1]);
let run = || ode_solve_rk45(
|_t, y| tensor_scale(y, -1.0),
&y0,
(0.0, 2.0),
1e-6,
1e-9,
);
let (ts1, ys1) = run();
let (ts2, ys2) = run();
assert_eq!(ts1, ts2, "RK45 time points must be bit-identical");
for (y1, y2) in ys1.iter().zip(ys2.iter()) {
assert_eq!(y1.to_vec(), y2.to_vec(), "RK45 solutions must be bit-identical");
}
}
#[test]
fn test_adjoint_linear_ode() {
let t0 = 0.0_f64;
let t1 = 1.0_f64;
let y_final = Tensor::from_vec_unchecked(vec![(-t1).exp()], &[1]);
let (y0_rec, _adj) = adjoint_solve(
|_t, y| tensor_scale(y, -1.0),
|_t, y, a| {
let adj_y = tensor_scale(a, 1.0); let adj_theta = Tensor::from_vec_unchecked(vec![0.0], &[1]);
(adj_y, adj_theta)
},
&y_final,
(t0, t1),
1000,
);
let y0_val = y0_rec.to_vec()[0];
assert!(
(y0_val - 1.0).abs() < 1e-6,
"adjoint_solve should recover y(0)=1.0, got {}",
y0_val
);
}
#[test]
fn test_adjoint_gradient_vs_finite_diff() {
let t1 = 0.5_f64;
let alpha = 1.0_f64;
let y_final_val = (alpha * t1).exp();
let y_final = Tensor::from_vec_unchecked(vec![y_final_val], &[1]);
let eps = 1e-5;
let l_plus = ((alpha + eps) * t1).exp();
let l_minus = ((alpha - eps) * t1).exp();
let fd_grad = (l_plus - l_minus) / (2.0 * eps);
let a_terminal = Tensor::from_vec_unchecked(vec![1.0_f64], &[1]);
let n_steps = 500;
let h = t1 / n_steps as f64;
let mut t = t1;
let mut y = y_final.clone();
let mut a = a_terminal;
let mut grad_alpha_acc = 0.0_f64;
for _ in 0..n_steps {
let t_prev = t - h;
let ky1 = tensor_scale(&tensor_scale(&y, alpha), -1.0);
let ka1 = tensor_scale(&a, -(-alpha));
let y2 = tensor_add_scaled(&y, &ky1, h * 0.5);
let a2 = tensor_add_scaled(&a, &ka1, h * 0.5);
let ky2 = tensor_scale(&tensor_scale(&y2, alpha), -1.0);
let ka2 = tensor_scale(&a2, alpha);
let y3 = tensor_add_scaled(&y, &ky2, h * 0.5);
let a3 = tensor_add_scaled(&a, &ka2, h * 0.5);
let ky3 = tensor_scale(&tensor_scale(&y3, alpha), -1.0);
let ka3 = tensor_scale(&a3, alpha);
let y4 = tensor_add_scaled(&y, &ky3, h);
let a4 = tensor_add_scaled(&a, &ka3, h);
let ky4 = tensor_scale(&tensor_scale(&y4, alpha), -1.0);
let ka4 = tensor_scale(&a4, alpha);
let ay = a.to_vec()[0] * y.to_vec()[0];
grad_alpha_acc += h * ay;
y = ode_step_rk4(&y, &ky1, &ky2, &ky3, &ky4, h);
a = ode_step_rk4(&a, &ka1, &ka2, &ka3, &ka4, h);
t = t_prev;
}
assert!(
(grad_alpha_acc - fd_grad).abs() / fd_grad.abs() < 1e-4,
"adjoint gradient {} should match finite diff {} (rel err = {})",
grad_alpha_acc, fd_grad,
(grad_alpha_acc - fd_grad).abs() / fd_grad.abs()
);
}
#[test]
fn test_adjoint_determinism() {
let y_final = Tensor::from_vec_unchecked(vec![(-1.0_f64).exp()], &[1]);
let run = || adjoint_solve(
|_t, y| tensor_scale(y, -1.0),
|_t, _y, a| (tensor_scale(a, 1.0), Tensor::from_vec_unchecked(vec![0.0], &[1])),
&y_final,
(0.0, 1.0),
100,
);
let (y1, a1) = run();
let (y2, a2) = run();
assert_eq!(y1.to_vec(), y2.to_vec(), "adjoint_solve y0 must be bit-identical");
assert_eq!(a1.to_vec(), a2.to_vec(), "adjoint_solve adjoint must be bit-identical");
}
}