use crate::DType;
use numr::autograd::{DualTensor, Var, backward};
use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use super::checkpointing::CheckpointManager;
use crate::common::jacobian::vjp_with_params;
use crate::integrate::error::{IntegrateError, IntegrateResult};
use crate::integrate::impl_generic::ode::ODEResultTensor;
use crate::integrate::impl_generic::ode::stiff_client::StiffSolverClient;
use crate::integrate::ode::{BDFOptions, LSODAOptions, ODEMethod, ODEOptions, RadauOptions};
use crate::integrate::sensitivity::traits::{SensitivityOptions, SensitivityResult};
struct ForwardWrapper<'a, R, C, F>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R>,
F: Fn(&Var<R>, &Var<R>, &Var<R>, &C) -> Result<Var<R>>,
{
f: &'a F,
p: &'a Tensor<R>,
client: &'a C,
}
impl<'a, R, C, F> ForwardWrapper<'a, R, C, F>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R>,
F: Fn(&Var<R>, &Var<R>, &Var<R>, &C) -> Result<Var<R>>,
{
fn new(f: &'a F, p: &'a Tensor<R>, client: &'a C) -> Self {
Self { f, p, client }
}
fn eval(&self, t: &Tensor<R>, y: &Tensor<R>) -> Result<Tensor<R>> {
let t_var = Var::new(t.clone(), false);
let y_var = Var::new(y.clone(), false);
let p_var = Var::new(self.p.clone(), false);
let result = (self.f)(&t_var, &y_var, &p_var, self.client)?;
Ok(result.tensor().clone())
}
}
#[allow(clippy::too_many_arguments)]
pub fn adjoint_sensitivity_impl<R, C, F, G>(
client: &C,
f: F,
g: G,
t_span: [f64; 2],
y0: &Tensor<R>,
p: &Tensor<R>,
ode_opts: &ODEOptions,
sens_opts: &SensitivityOptions,
) -> IntegrateResult<SensitivityResult<R>>
where
R: Runtime<DType = DType>,
C: StiffSolverClient<R>,
R::Client: TensorOps<R>,
F: Fn(&Var<R>, &Var<R>, &Var<R>, &C) -> Result<Var<R>>,
G: Fn(&Var<R>, &C) -> Result<Var<R>>,
{
let [t0, tf] = t_span;
let _device = y0.device();
if t0 >= tf {
return Err(IntegrateError::InvalidInterval {
a: t0,
b: tf,
context: "adjoint_sensitivity".to_string(),
});
}
let mut checkpoint_manager = CheckpointManager::new(
sens_opts.n_checkpoints,
sens_opts.checkpoint_strategy,
t_span,
);
let forward_wrapper = ForwardWrapper::new(&f, p, client);
let checkpoint_tol = (tf - t0) * 1e-8;
let forward_result = forward_with_checkpoints(
client,
&forward_wrapper,
t_span,
y0,
ode_opts,
&mut checkpoint_manager,
checkpoint_tol,
)?;
let y_shape = forward_result.y.shape();
let n_steps = y_shape[0];
let y_final = forward_result
.y
.narrow(0, n_steps - 1, 1)
.map_err(|e| IntegrateError::NumericalError {
message: format!("Failed to extract final state: {}", e),
})?
.squeeze(Some(0))
.contiguous()?;
let nfev_forward = forward_result.nfev;
let (cost, lambda_t) = compute_terminal_adjoint(client, &g, &y_final)?;
let (gradient, nfev_adjoint) =
backward_adjoint_pass(client, &f, p, &checkpoint_manager, &lambda_t, sens_opts)?;
Ok(SensitivityResult {
gradient,
cost,
y_final,
nfev_forward,
nfev_adjoint,
n_checkpoints: checkpoint_manager.len(),
})
}
#[allow(clippy::needless_borrows_for_generic_args)]
fn forward_with_checkpoints<R, C, F>(
client: &C,
wrapper: &ForwardWrapper<'_, R, C, F>,
t_span: [f64; 2],
y0: &Tensor<R>,
options: &ODEOptions,
checkpoint_manager: &mut CheckpointManager<R>,
checkpoint_tol: f64,
) -> IntegrateResult<ODEResultTensor<R>>
where
R: Runtime<DType = DType>,
C: StiffSolverClient<R>,
R::Client: TensorOps<R>,
F: Fn(&Var<R>, &Var<R>, &Var<R>, &C) -> Result<Var<R>>,
{
checkpoint_manager.add_checkpoint(t_span[0], y0.clone());
let f_tensor = |t: &Tensor<R>, y: &Tensor<R>| -> Result<Tensor<R>> { wrapper.eval(t, y) };
let implicit_opts = {
let mut o = options.clone();
o.max_steps = o.max_steps.max(500_000);
o
};
let result = match options.method {
ODEMethod::RK45 => {
crate::integrate::impl_generic::ode::rk45_impl(client, &f_tensor, t_span, y0, options)?
}
ODEMethod::RK23 => {
crate::integrate::impl_generic::ode::rk23_impl(client, &f_tensor, t_span, y0, options)?
}
ODEMethod::DOP853 => crate::integrate::impl_generic::ode::dop853_impl(
client, &f_tensor, t_span, y0, options,
)?,
ODEMethod::BDF => {
let f_dual =
|t_d: &DualTensor<R>, y_d: &DualTensor<R>, c: &C| -> Result<DualTensor<R>> {
let t_primal = t_d.primal();
let y_primal = y_d.primal();
let f_primal = f_tensor(t_primal, y_primal)?;
let tangent_out = if let Some(v) = y_d.tangent() {
let eps = 1e-7_f64;
let v_eps = c.mul_scalar(v, eps)?;
let y_pert = c.add(y_primal, &v_eps)?;
let f_pert = f_tensor(t_primal, &y_pert)?;
let diff = c.sub(&f_pert, &f_primal)?;
Some(c.mul_scalar(&diff, 1.0 / eps)?)
} else {
None
};
Ok(DualTensor::new(f_primal, tangent_out))
};
let bdf_opts =
BDFOptions::default().newton_params((implicit_opts.atol * 1e-2).min(1e-10), 20);
crate::integrate::impl_generic::ode::bdf_impl(
client,
f_dual,
t_span,
y0,
&implicit_opts,
&bdf_opts,
)?
}
ODEMethod::Radau => {
let f_dual =
|t_d: &DualTensor<R>, y_d: &DualTensor<R>, c: &C| -> Result<DualTensor<R>> {
let t_primal = t_d.primal();
let y_primal = y_d.primal();
let f_primal = f_tensor(t_primal, y_primal)?;
let tangent_out = if let Some(v) = y_d.tangent() {
let eps = 1e-7_f64;
let v_eps = c.mul_scalar(v, eps)?;
let y_pert = c.add(y_primal, &v_eps)?;
let f_pert = f_tensor(t_primal, &y_pert)?;
let diff = c.sub(&f_pert, &f_primal)?;
Some(c.mul_scalar(&diff, 1.0 / eps)?)
} else {
None
};
Ok(DualTensor::new(f_primal, tangent_out))
};
let radau_opts =
RadauOptions::default().newton_params((implicit_opts.atol * 1e-2).min(1e-10), 20);
crate::integrate::impl_generic::ode::radau_impl(
client,
f_dual,
t_span,
y0,
&implicit_opts,
&radau_opts,
)?
}
ODEMethod::LSODA => {
let f_dual =
|t_d: &DualTensor<R>, y_d: &DualTensor<R>, c: &C| -> Result<DualTensor<R>> {
let t_primal = t_d.primal();
let y_primal = y_d.primal();
let f_primal = f_tensor(t_primal, y_primal)?;
let tangent_out = if let Some(v) = y_d.tangent() {
let eps = 1e-7_f64;
let v_eps = c.mul_scalar(v, eps)?;
let y_pert = c.add(y_primal, &v_eps)?;
let f_pert = f_tensor(t_primal, &y_pert)?;
let diff = c.sub(&f_pert, &f_primal)?;
Some(c.mul_scalar(&diff, 1.0 / eps)?)
} else {
None
};
Ok(DualTensor::new(f_primal, tangent_out))
};
crate::integrate::impl_generic::ode::lsoda_impl(
client,
f_dual,
t_span,
y0,
&implicit_opts,
&LSODAOptions::default(),
)?
}
ODEMethod::Verlet | ODEMethod::Leapfrog => {
return Err(IntegrateError::InvalidInput {
context: format!(
"Symplectic method {:?} cannot be used for adjoint sensitivity: \
symplectic integrators require separate position/momentum coordinates \
(q, p) and are not general ODE solvers. Use RK45, RK23, DOP853, \
BDF, Radau, or LSODA instead.",
options.method
),
});
}
};
let t_vec: Vec<f64> = result.t.to_vec();
let n_steps = t_vec.len();
let mut last_t = t_span[0]; for (idx, &t_val) in t_vec.iter().enumerate().take(n_steps) {
if (t_val - last_t).abs() < checkpoint_tol {
continue;
}
let y_checkpoint = result
.y
.narrow(0, idx, 1)
.map_err(|e| IntegrateError::NumericalError {
message: format!("Failed to extract checkpoint state at step {}: {}", idx, e),
})?
.squeeze(Some(0))
.contiguous()?;
checkpoint_manager.add_checkpoint(t_val, y_checkpoint);
last_t = t_val;
}
if n_steps > 0 {
let t_last = t_vec[n_steps - 1];
if checkpoint_manager.checkpoints().last().map(|c| c.t) != Some(t_last) {
let y_final = result
.y
.narrow(0, n_steps - 1, 1)
.map_err(|e| IntegrateError::NumericalError {
message: format!("Failed to extract final checkpoint: {}", e),
})?
.squeeze(Some(0))
.contiguous()?;
checkpoint_manager.add_checkpoint(t_last, y_final);
}
}
Ok(result)
}
fn compute_terminal_adjoint<R, C, G>(
client: &C,
g: &G,
y_final: &Tensor<R>,
) -> IntegrateResult<(f64, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R>,
G: Fn(&Var<R>, &C) -> Result<Var<R>>,
{
let y_var = Var::new(y_final.clone(), true);
let cost_var = g(&y_var, client).map_err(|e| IntegrateError::NumericalError {
message: format!("Cost function evaluation failed: {}", e),
})?;
let cost_tensor = cost_var.tensor();
let cost = cost_tensor
.item::<f64>()
.map_err(|_| IntegrateError::InvalidInput {
context: "Cost function must return a scalar".to_string(),
})?;
let grads = backward(&cost_var, client).map_err(|e| IntegrateError::NumericalError {
message: format!("Backward pass for terminal condition failed: {}", e),
})?;
let lambda_t =
grads
.get(y_var.id())
.cloned()
.ok_or_else(|| IntegrateError::NumericalError {
message: "No gradient for y_final in cost function".to_string(),
})?;
Ok((cost, lambda_t))
}
fn backward_adjoint_pass<R, C, F>(
client: &C,
f: &F,
p: &Tensor<R>,
checkpoint_manager: &CheckpointManager<R>,
lambda_t: &Tensor<R>,
sens_opts: &SensitivityOptions,
) -> IntegrateResult<(Tensor<R>, usize)>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R>,
F: Fn(&Var<R>, &Var<R>, &Var<R>, &C) -> Result<Var<R>>,
{
let device = lambda_t.device();
let dtype = lambda_t.dtype();
let n_params = p.numel();
let mut gradient = Tensor::<R>::zeros(&[n_params], dtype, device);
let mut nfev_adjoint = 0usize;
let mut lambda = lambda_t.clone();
let checkpoints = checkpoint_manager.checkpoints();
let n_checkpoints = checkpoints.len();
if n_checkpoints < 2 {
return Err(IntegrateError::NumericalError {
message: "Need at least 2 checkpoints for adjoint pass".to_string(),
});
}
for i in (0..n_checkpoints - 1).rev() {
let ck_start = &checkpoints[i];
let ck_end = &checkpoints[i + 1];
let t_start = ck_end.t; let t_end = ck_start.t;
if (t_start - t_end).abs() < 1e-14 {
continue; }
let (new_lambda, interval_gradient, interval_nfev) = integrate_adjoint_interval(
client,
f,
p,
&lambda,
&ck_end.y, &ck_start.y, t_start,
t_end,
sens_opts,
)?;
lambda = new_lambda;
gradient = client.add(&gradient, &interval_gradient).map_err(|e| {
IntegrateError::NumericalError {
message: format!("Gradient accumulation failed: {}", e),
}
})?;
nfev_adjoint += interval_nfev;
}
Ok((gradient, nfev_adjoint))
}
#[allow(clippy::too_many_arguments)]
fn integrate_adjoint_interval<R, C, F>(
client: &C,
f: &F,
p: &Tensor<R>,
lambda_start: &Tensor<R>,
y_start: &Tensor<R>, y_end: &Tensor<R>, t_start: f64,
t_end: f64,
_sens_opts: &SensitivityOptions,
) -> IntegrateResult<(Tensor<R>, Tensor<R>, usize)>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R>,
F: Fn(&Var<R>, &Var<R>, &Var<R>, &C) -> Result<Var<R>>,
{
let rhs = |t_val: f64,
y_cur: &Tensor<R>,
lam: &Tensor<R>|
-> IntegrateResult<(Tensor<R>, Tensor<R>)> {
let (_f_val, vjp_y, vjp_p) =
vjp_with_params(client, f, t_val, y_cur, p, lam).map_err(|e| {
IntegrateError::NumericalError {
message: format!("VJP computation failed at t={}: {}", t_val, e),
}
})?;
let dlambda_dt =
client
.mul_scalar(&vjp_y, -1.0)
.map_err(|e| IntegrateError::NumericalError {
message: format!("Scalar multiply failed: {}", e),
})?;
Ok((dlambda_dt, vjp_p))
};
let dt = t_end - t_start;
let map_err = |ctx: &'static str| {
move |e: numr::error::Error| IntegrateError::NumericalError {
message: format!("{}: {}", ctx, e),
}
};
let (k1_lambda, vjp_p1) = rhs(t_start, y_start, lambda_start)?;
let lambda_pred = client
.add(
lambda_start,
&client
.mul_scalar(&k1_lambda, dt)
.map_err(map_err("predictor scale"))?,
)
.map_err(map_err("predictor add"))?;
let (k2_lambda, vjp_p2) = rhs(t_end, y_end, &lambda_pred)?;
let k_sum = client
.add(&k1_lambda, &k2_lambda)
.map_err(map_err("corrector sum"))?;
let lambda_end = client
.add(
lambda_start,
&client
.mul_scalar(&k_sum, dt * 0.5)
.map_err(map_err("corrector scale"))?,
)
.map_err(map_err("corrector add"))?;
let vjp_p_sum = client
.add(&vjp_p1, &vjp_p2)
.map_err(map_err("gradient sum"))?;
let gradient = client
.mul_scalar(&vjp_p_sum, dt.abs() * 0.5)
.map_err(map_err("gradient scale"))?;
Ok((lambda_end, gradient, 2))
}
#[cfg(test)]
mod tests {
use super::*;
use numr::autograd::{var_mul, var_mul_scalar};
use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
(device, client)
}
fn run_exponential_decay_adjoint(ode_opts: ODEOptions) -> (f64, f64) {
let (device, client) = setup();
let t_span = [0.0, 1.0];
let k_val = 0.5f64;
let y0 = Tensor::<CpuRuntime>::from_slice(&[1.0f64], &[1], &device);
let k = Tensor::<CpuRuntime>::from_slice(&[k_val], &[1], &device);
let f = |_t: &Var<CpuRuntime>,
y: &Var<CpuRuntime>,
p: &Var<CpuRuntime>,
c: &CpuClient|
-> Result<Var<CpuRuntime>> {
let ky = var_mul(p, y, c)?;
var_mul_scalar(&ky, -1.0, c)
};
let g =
|y: &Var<CpuRuntime>, c: &CpuClient| -> Result<Var<CpuRuntime>> { var_mul(y, y, c) };
let sens_opts = SensitivityOptions::default()
.with_checkpoints(10)
.with_adjoint_tolerances(1e-6, 1e-8);
let result =
adjoint_sensitivity_impl(&client, f, g, t_span, &y0, &k, &ode_opts, &sens_opts)
.expect("adjoint_sensitivity_impl should not return Err");
let grad_val = result.gradient.to_vec::<f64>()[0];
let t_final = 1.0f64;
let y_analytical = (-k_val * t_final).exp();
let grad_analytical = -2.0 * t_final * y_analytical * y_analytical;
(grad_val, grad_analytical)
}
#[test]
fn test_adjoint_exponential_decay() {
let (device, client) = setup();
let t_span = [0.0, 1.0];
let y0 = Tensor::<CpuRuntime>::from_slice(&[1.0f64], &[1], &device);
let k = Tensor::<CpuRuntime>::from_slice(&[0.5f64], &[1], &device);
let f = |_t: &Var<CpuRuntime>,
y: &Var<CpuRuntime>,
p: &Var<CpuRuntime>,
c: &CpuClient|
-> Result<Var<CpuRuntime>> {
let ky = var_mul(p, y, c)?;
var_mul_scalar(&ky, -1.0, c)
};
let g =
|y: &Var<CpuRuntime>, c: &CpuClient| -> Result<Var<CpuRuntime>> { var_mul(y, y, c) };
let ode_opts = ODEOptions::with_tolerances(1e-8, 1e-10);
let sens_opts = SensitivityOptions::default()
.with_checkpoints(10)
.with_adjoint_tolerances(1e-6, 1e-8);
let result =
adjoint_sensitivity_impl(&client, f, g, t_span, &y0, &k, &ode_opts, &sens_opts)
.unwrap();
let k_val: f64 = 0.5;
let t_final: f64 = 1.0;
let y_analytical = (-k_val * t_final).exp();
let cost_analytical = y_analytical * y_analytical;
let grad_analytical = -2.0 * t_final * cost_analytical;
let y_final_val = result.y_final.to_vec::<f64>()[0];
let grad_val = result.gradient.to_vec::<f64>()[0];
assert!(
(y_final_val - y_analytical).abs() < 1e-5,
"y_final: expected {}, got {}",
y_analytical,
y_final_val
);
assert!(
(result.cost - cost_analytical).abs() < 1e-5,
"cost: expected {}, got {}",
cost_analytical,
result.cost
);
assert!(
(grad_val - grad_analytical).abs() < 0.05 * grad_analytical.abs(),
"gradient: expected {}, got {} (error = {}%)",
grad_analytical,
grad_val,
100.0 * (grad_val - grad_analytical).abs() / grad_analytical.abs()
);
}
#[test]
fn test_bdf_no_longer_returns_invalid_input() {
let ode_opts = ODEOptions::with_tolerances(1e-6, 1e-8).method(ODEMethod::BDF);
let (grad_val, _) = run_exponential_decay_adjoint(ode_opts);
assert!(
grad_val.is_finite(),
"BDF adjoint gradient should be finite, got {}",
grad_val
);
}
#[test]
fn test_radau_no_longer_returns_invalid_input() {
let ode_opts = ODEOptions::with_tolerances(1e-6, 1e-8).method(ODEMethod::Radau);
let (grad_val, _) = run_exponential_decay_adjoint(ode_opts);
assert!(
grad_val.is_finite(),
"Radau adjoint gradient should be finite, got {}",
grad_val
);
}
#[test]
fn test_lsoda_no_longer_returns_invalid_input() {
let ode_opts = ODEOptions::with_tolerances(1e-6, 1e-8).method(ODEMethod::LSODA);
let (grad_val, _) = run_exponential_decay_adjoint(ode_opts);
assert!(
grad_val.is_finite(),
"LSODA adjoint gradient should be finite, got {}",
grad_val
);
}
#[test]
fn test_bdf_adjoint_stiff_linear_ode() {
let (device, client) = setup();
let k_val = 50.0f64;
let t_span = [0.0, 0.1];
let y0 = Tensor::<CpuRuntime>::from_slice(&[1.0f64], &[1], &device);
let k = Tensor::<CpuRuntime>::from_slice(&[k_val], &[1], &device);
let f = |_t: &Var<CpuRuntime>,
y: &Var<CpuRuntime>,
p: &Var<CpuRuntime>,
c: &CpuClient|
-> Result<Var<CpuRuntime>> {
let ky = var_mul(p, y, c)?;
var_mul_scalar(&ky, -1.0, c)
};
let g =
|y: &Var<CpuRuntime>, c: &CpuClient| -> Result<Var<CpuRuntime>> { var_mul(y, y, c) };
let ode_opts = ODEOptions::with_tolerances(1e-8, 1e-10).method(ODEMethod::BDF);
let sens_opts = SensitivityOptions::default()
.with_checkpoints(20)
.with_adjoint_tolerances(1e-6, 1e-8);
let result =
adjoint_sensitivity_impl(&client, f, g, t_span, &y0, &k, &ode_opts, &sens_opts)
.expect("BDF adjoint should succeed on stiff linear ODE");
let adjoint_grad = result.gradient.to_vec::<f64>()[0];
let t_final = t_span[1];
let y_analytical = (-k_val * t_final).exp();
let grad_analytical = -2.0 * t_final * y_analytical * y_analytical;
let rel_err = (adjoint_grad - grad_analytical).abs() / grad_analytical.abs();
assert!(
rel_err < 0.05,
"BDF adjoint gradient: expected {:.6e}, got {:.6e} (rel error {:.2}%)",
grad_analytical,
adjoint_grad,
rel_err * 100.0
);
let eps = 1e-4;
let k_plus = Tensor::<CpuRuntime>::from_slice(&[k_val + eps], &[1], &device);
let k_minus = Tensor::<CpuRuntime>::from_slice(&[k_val - eps], &[1], &device);
let f_for_fd = |_t: &Var<CpuRuntime>,
y: &Var<CpuRuntime>,
p: &Var<CpuRuntime>,
c: &CpuClient|
-> Result<Var<CpuRuntime>> {
let ky = var_mul(p, y, c)?;
var_mul_scalar(&ky, -1.0, c)
};
let g_for_fd =
|y: &Var<CpuRuntime>, c: &CpuClient| -> Result<Var<CpuRuntime>> { var_mul(y, y, c) };
let ode_opts_fd = ODEOptions::with_tolerances(1e-10, 1e-12).method(ODEMethod::BDF);
let sens_opts_fd = SensitivityOptions::default().with_checkpoints(5);
let res_plus = adjoint_sensitivity_impl(
&client,
f_for_fd,
g_for_fd,
t_span,
&y0,
&k_plus,
&ode_opts_fd,
&sens_opts_fd,
)
.expect("BDF adjoint (k+eps) should succeed");
let f_for_fd2 = |_t: &Var<CpuRuntime>,
y: &Var<CpuRuntime>,
p: &Var<CpuRuntime>,
c: &CpuClient|
-> Result<Var<CpuRuntime>> {
let ky = var_mul(p, y, c)?;
var_mul_scalar(&ky, -1.0, c)
};
let g_for_fd2 =
|y: &Var<CpuRuntime>, c: &CpuClient| -> Result<Var<CpuRuntime>> { var_mul(y, y, c) };
let res_minus = adjoint_sensitivity_impl(
&client,
f_for_fd2,
g_for_fd2,
t_span,
&y0,
&k_minus,
&ode_opts_fd,
&sens_opts_fd,
)
.expect("BDF adjoint (k-eps) should succeed");
let fd_grad = (res_plus.cost - res_minus.cost) / (2.0 * eps);
let fd_err = (fd_grad - grad_analytical).abs() / grad_analytical.abs();
assert!(
fd_err < 1e-3,
"BDF finite-difference gradient: expected {:.6e}, got {:.6e} (rel error {:.4}%)",
grad_analytical,
fd_grad,
fd_err * 100.0
);
let adj_fd_err = (adjoint_grad - fd_grad).abs() / fd_grad.abs();
assert!(
adj_fd_err < 0.05,
"BDF adjoint vs FD: adjoint = {:.6e}, fd = {:.6e} (rel error {:.2}%)",
adjoint_grad,
fd_grad,
adj_fd_err * 100.0
);
}
fn adjoint_with_method(method: ODEMethod) -> bool {
let (device, client) = setup();
let t_span = [0.0, 1.0];
let y0 = Tensor::<CpuRuntime>::from_slice(&[1.0f64], &[1], &device);
let k = Tensor::<CpuRuntime>::from_slice(&[0.5f64], &[1], &device);
let f = |_t: &Var<CpuRuntime>,
y: &Var<CpuRuntime>,
p: &Var<CpuRuntime>,
c: &CpuClient|
-> Result<Var<CpuRuntime>> {
let ky = var_mul(p, y, c)?;
var_mul_scalar(&ky, -1.0, c)
};
let g =
|y: &Var<CpuRuntime>, c: &CpuClient| -> Result<Var<CpuRuntime>> { var_mul(y, y, c) };
let ode_opts = ODEOptions::with_method(method);
let sens_opts = SensitivityOptions::default().with_checkpoints(5);
adjoint_sensitivity_impl(&client, f, g, t_span, &y0, &k, &ode_opts, &sens_opts).is_err()
}
fn adjoint_error_msg(method: ODEMethod) -> String {
let (device, client) = setup();
let t_span = [0.0, 1.0];
let y0 = Tensor::<CpuRuntime>::from_slice(&[1.0f64], &[1], &device);
let k = Tensor::<CpuRuntime>::from_slice(&[0.5f64], &[1], &device);
let f = |_t: &Var<CpuRuntime>,
y: &Var<CpuRuntime>,
p: &Var<CpuRuntime>,
c: &CpuClient|
-> Result<Var<CpuRuntime>> {
let ky = var_mul(p, y, c)?;
var_mul_scalar(&ky, -1.0, c)
};
let g =
|y: &Var<CpuRuntime>, c: &CpuClient| -> Result<Var<CpuRuntime>> { var_mul(y, y, c) };
let ode_opts = ODEOptions::with_method(method);
let sens_opts = SensitivityOptions::default().with_checkpoints(5);
format!(
"{:?}",
adjoint_sensitivity_impl(&client, f, g, t_span, &y0, &k, &ode_opts, &sens_opts)
.unwrap_err()
)
}
#[test]
fn test_verlet_returns_meaningful_error() {
assert!(
adjoint_with_method(ODEMethod::Verlet),
"Verlet should return Err for adjoint sensitivity"
);
let msg = adjoint_error_msg(ODEMethod::Verlet);
assert!(
msg.contains("symplectic") || msg.contains("Verlet"),
"Error should mention symplectic nature: {}",
msg
);
}
#[test]
fn test_leapfrog_returns_meaningful_error() {
assert!(
adjoint_with_method(ODEMethod::Leapfrog),
"Leapfrog should return Err for adjoint sensitivity"
);
let msg = adjoint_error_msg(ODEMethod::Leapfrog);
assert!(
msg.contains("symplectic") || msg.contains("Leapfrog"),
"Error should mention symplectic nature: {}",
msg
);
}
}