use numra_core::Scalar;
use numra_ode::sensitivity::solve_forward_sensitivity_with;
pub use numra_ode::SensitivityResult;
use numra_ode::{AugmentedSystem, ClosureSystem, DoPri5, Solver, SolverOptions};
use crate::error::OcpError;
type ModelFn<S> = dyn Fn(S, &[S], &mut [S], &[S]);
#[allow(clippy::too_many_arguments)]
pub fn forward_sensitivity<S: Scalar>(
model: &ModelFn<S>,
y0: &[S],
params: &[S],
t0: S,
tf: S,
output_times: Option<&[S]>,
rtol: S,
atol: S,
) -> Result<SensitivityResult<S>, OcpError> {
let opts = SolverOptions::default().rtol(rtol).atol(atol);
match output_times {
None => solve_forward_sensitivity_with::<DoPri5, S, _>(
|t: S, y: &[S], p: &[S], dy: &mut [S]| model(t, y, dy, p),
y0,
params,
t0,
tf,
&opts,
)
.map_err(|e| OcpError::IntegrationFailed(e.to_string())),
Some(te) => integrate_at_output_times(model, y0, params, te, &opts),
}
}
fn integrate_at_output_times<S: Scalar>(
model: &ModelFn<S>,
y0: &[S],
params: &[S],
te: &[S],
opts: &SolverOptions<S>,
) -> Result<SensitivityResult<S>, OcpError> {
let n_states = y0.len();
let n_params = params.len();
if te.is_empty() {
return Err(OcpError::IntegrationFailed(
"output_times must contain at least one entry".to_string(),
));
}
let system = ClosureSystem::new(
|t: S, y: &[S], p: &[S], dy: &mut [S]| model(t, y, dy, p),
params.to_vec(),
n_states,
);
let aug = AugmentedSystem::new(system);
let aug_dim = aug.augmented_dim();
let mut z_cur = aug.initial_augmented(y0);
let tiny = S::from_f64(1e-15);
let mut t_out = Vec::with_capacity(te.len());
let mut y_out = Vec::with_capacity(te.len() * n_states);
let mut sens_out = Vec::with_capacity(te.len() * n_states * n_params);
t_out.push(te[0]);
y_out.extend_from_slice(&z_cur[..n_states]);
sens_out.extend_from_slice(&z_cur[n_states..aug_dim]);
let mut last_stats = numra_ode::SolverStats::new();
for seg in 0..(te.len() - 1) {
let t_start = te[seg];
let t_end = te[seg + 1];
if (t_end - t_start).abs() < tiny {
t_out.push(t_end);
y_out.extend_from_slice(&z_cur[..n_states]);
sens_out.extend_from_slice(&z_cur[n_states..aug_dim]);
continue;
}
let result = DoPri5::solve(&aug, t_start, t_end, &z_cur, opts)
.map_err(|e| OcpError::IntegrationFailed(e.to_string()))?;
if !result.success {
return Err(OcpError::IntegrationFailed(result.message));
}
z_cur = result
.y_final()
.ok_or_else(|| OcpError::IntegrationFailed("missing final state".to_string()))?;
last_stats = result.stats;
t_out.push(t_end);
y_out.extend_from_slice(&z_cur[..n_states]);
sens_out.extend_from_slice(&z_cur[n_states..aug_dim]);
}
Ok(SensitivityResult {
t: t_out,
y: y_out,
sensitivity: sens_out,
n_states,
n_params,
stats: last_stats,
success: true,
message: String::new(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use numra_ode::OdeProblem;
#[test]
fn test_exponential_decay_sensitivity() {
let k = 0.5_f64;
let y0 = [1.0];
let params = [k];
let check_times = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
let result = forward_sensitivity(
&|_t: f64, y, dydt, p| {
dydt[0] = -p[0] * y[0];
},
&y0,
¶ms,
0.0,
5.0,
Some(&check_times),
1e-10,
1e-12,
)
.expect("forward_sensitivity failed");
assert_eq!(result.n_states, 1);
assert_eq!(result.n_params, 1);
for (idx, &t) in check_times.iter().enumerate().skip(1) {
let analytical = -t * (-k * t).exp();
let computed = result.sensitivity_at(idx)[0];
assert!(
(computed - analytical).abs() < 1e-3,
"t={t}: computed={computed}, analytical={analytical}, err={}",
(computed - analytical).abs()
);
}
}
#[test]
fn test_two_param_sensitivity() {
let a = 1.0_f64;
let b = 2.0_f64;
let y0 = [1.0];
let params = [a, b];
let check_times = vec![0.0, 1.0, 2.0, 3.0];
let result = forward_sensitivity(
&|_t: f64, y, dydt, p| {
dydt[0] = -p[0] * y[0] + p[1];
},
&y0,
¶ms,
0.0,
3.0,
Some(&check_times),
1e-10,
1e-12,
)
.expect("forward_sensitivity failed");
assert_eq!(result.n_states, 1);
assert_eq!(result.n_params, 2);
for (idx, &t) in check_times.iter().enumerate().skip(1) {
let analytical_dydb = 1.0 - (-t).exp();
let computed = result.dyi_dpj(idx, 0, 1);
assert!(
(computed - analytical_dydb).abs() < 1e-3,
"t={t}: computed dy/db={computed}, analytical={analytical_dydb}, err={}",
(computed - analytical_dydb).abs()
);
}
}
#[test]
fn test_sensitivity_matches_finite_diff() {
let p_val = 0.5_f64;
let y0 = [1.0];
let t_final = 2.0;
let model = |_t: f64, y: &[f64], dydt: &mut [f64], p: &[f64]| {
dydt[0] = -p[0] * y[0] * y[0];
};
let result = forward_sensitivity(
&model,
&y0,
&[p_val],
0.0,
t_final,
Some(&[0.0, t_final]),
1e-10,
1e-12,
)
.expect("forward_sensitivity failed");
let sens_forward = result.dyi_dpj(1, 0, 0);
let h = 1e-5;
let opts = SolverOptions::default().rtol(1e-12).atol(1e-14);
let p_plus = p_val + h;
let problem_plus = OdeProblem::new(
move |_t: f64, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -p_plus * y[0] * y[0];
},
0.0,
t_final,
vec![1.0],
);
let res_plus = DoPri5::solve(&problem_plus, 0.0, t_final, &[1.0], &opts)
.expect("integration p+h failed");
let y_plus = res_plus.y_final().unwrap()[0];
let p_minus = p_val - h;
let problem_minus = OdeProblem::new(
move |_t: f64, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -p_minus * y[0] * y[0];
},
0.0,
t_final,
vec![1.0],
);
let res_minus = DoPri5::solve(&problem_minus, 0.0, t_final, &[1.0], &opts)
.expect("integration p-h failed");
let y_minus = res_minus.y_final().unwrap()[0];
let fd_sens = (y_plus - y_minus) / (2.0 * h);
assert!(
(sens_forward - fd_sens).abs() < 1e-3,
"forward sensitivity={sens_forward}, FD={fd_sens}, err={}",
(sens_forward - fd_sens).abs()
);
}
}