use crate::math::jacobian::{DifferenceMethod, PerturbationStrategy};
use nalgebra::{DMatrix, DVector, SMatrix, SVector};
type SSensitivityFn<const S: usize, const P: usize> =
Box<dyn Fn(f64, &SVector<f64, S>, &SVector<f64, P>) -> SMatrix<f64, S, P> + Send + Sync>;
type DSensitivityFn = Box<dyn Fn(f64, &DVector<f64>, &DVector<f64>) -> DMatrix<f64> + Send + Sync>;
type SDynamicsWithParams<const S: usize, const P: usize> =
Box<dyn Fn(f64, &SVector<f64, S>, &SVector<f64, P>) -> SVector<f64, S> + Send + Sync>;
type DDynamicsWithParams =
Box<dyn Fn(f64, &DVector<f64>, &DVector<f64>) -> DVector<f64> + Send + Sync>;
pub trait SSensitivityProvider<const S: usize, const P: usize>: Send + Sync {
fn compute(
&self,
t: f64,
state: &SVector<f64, S>,
params: &SVector<f64, P>,
) -> SMatrix<f64, S, P>;
}
pub trait DSensitivityProvider: Send + Sync {
fn compute(&self, t: f64, state: &DVector<f64>, params: &DVector<f64>) -> DMatrix<f64>;
}
pub struct SAnalyticSensitivity<const S: usize, const P: usize> {
sensitivity_fn: SSensitivityFn<S, P>,
}
impl<const S: usize, const P: usize> SAnalyticSensitivity<S, P> {
pub fn new(sensitivity_fn: SSensitivityFn<S, P>) -> Self {
Self { sensitivity_fn }
}
}
impl<const S: usize, const P: usize> SSensitivityProvider<S, P> for SAnalyticSensitivity<S, P> {
fn compute(
&self,
t: f64,
state: &SVector<f64, S>,
params: &SVector<f64, P>,
) -> SMatrix<f64, S, P> {
(self.sensitivity_fn)(t, state, params)
}
}
pub struct DAnalyticSensitivity {
sensitivity_fn: DSensitivityFn,
}
impl DAnalyticSensitivity {
pub fn new(sensitivity_fn: DSensitivityFn) -> Self {
Self { sensitivity_fn }
}
}
impl DSensitivityProvider for DAnalyticSensitivity {
fn compute(&self, t: f64, state: &DVector<f64>, params: &DVector<f64>) -> DMatrix<f64> {
(self.sensitivity_fn)(t, state, params)
}
}
pub struct SNumericalSensitivity<const S: usize, const P: usize> {
dynamics_fn: SDynamicsWithParams<S, P>,
method: DifferenceMethod,
strategy: PerturbationStrategy,
}
impl<const S: usize, const P: usize> SNumericalSensitivity<S, P> {
pub fn new(dynamics_fn: SDynamicsWithParams<S, P>) -> Self {
Self {
dynamics_fn,
method: DifferenceMethod::Central,
strategy: PerturbationStrategy::Adaptive {
scale_factor: 1.0,
min_value: 1.0,
},
}
}
pub fn central(dynamics_fn: SDynamicsWithParams<S, P>) -> Self {
Self::new(dynamics_fn)
}
pub fn forward(dynamics_fn: SDynamicsWithParams<S, P>) -> Self {
Self {
dynamics_fn,
method: DifferenceMethod::Forward,
strategy: PerturbationStrategy::Adaptive {
scale_factor: 1.0,
min_value: 1.0,
},
}
}
pub fn backward(dynamics_fn: SDynamicsWithParams<S, P>) -> Self {
Self {
dynamics_fn,
method: DifferenceMethod::Backward,
strategy: PerturbationStrategy::Adaptive {
scale_factor: 1.0,
min_value: 1.0,
},
}
}
pub fn with_strategy(mut self, strategy: PerturbationStrategy) -> Self {
self.strategy = strategy;
self
}
fn compute_perturbation(&self, value: f64) -> f64 {
match self.strategy {
PerturbationStrategy::Adaptive {
scale_factor,
min_value,
} => {
let eps = f64::EPSILON;
scale_factor * eps.sqrt() * value.abs().max(min_value)
}
PerturbationStrategy::Fixed(h) => h,
PerturbationStrategy::Percentage(pct) => value.abs() * pct,
}
}
}
impl<const S: usize, const P: usize> SSensitivityProvider<S, P> for SNumericalSensitivity<S, P> {
fn compute(
&self,
t: f64,
state: &SVector<f64, S>,
params: &SVector<f64, P>,
) -> SMatrix<f64, S, P> {
let mut sensitivity = SMatrix::<f64, S, P>::zeros();
for j in 0..P {
let h = self.compute_perturbation(params[j]);
let column = match self.method {
DifferenceMethod::Forward => {
let mut params_plus = *params;
params_plus[j] += h;
let f_plus = (self.dynamics_fn)(t, state, ¶ms_plus);
let f_0 = (self.dynamics_fn)(t, state, params);
(f_plus - f_0) / h
}
DifferenceMethod::Central => {
let mut params_plus = *params;
let mut params_minus = *params;
params_plus[j] += h;
params_minus[j] -= h;
let f_plus = (self.dynamics_fn)(t, state, ¶ms_plus);
let f_minus = (self.dynamics_fn)(t, state, ¶ms_minus);
(f_plus - f_minus) / (2.0 * h)
}
DifferenceMethod::Backward => {
let mut params_minus = *params;
params_minus[j] -= h;
let f_0 = (self.dynamics_fn)(t, state, params);
let f_minus = (self.dynamics_fn)(t, state, ¶ms_minus);
(f_0 - f_minus) / h
}
};
sensitivity.set_column(j, &column);
}
sensitivity
}
}
pub struct DNumericalSensitivity {
dynamics_fn: DDynamicsWithParams,
method: DifferenceMethod,
strategy: PerturbationStrategy,
}
impl DNumericalSensitivity {
pub fn new(dynamics_fn: DDynamicsWithParams) -> Self {
Self {
dynamics_fn,
method: DifferenceMethod::Central,
strategy: PerturbationStrategy::Adaptive {
scale_factor: 1.0,
min_value: 1.0,
},
}
}
pub fn central(dynamics_fn: DDynamicsWithParams) -> Self {
Self::new(dynamics_fn)
}
pub fn forward(dynamics_fn: DDynamicsWithParams) -> Self {
Self {
dynamics_fn,
method: DifferenceMethod::Forward,
strategy: PerturbationStrategy::Adaptive {
scale_factor: 1.0,
min_value: 1.0,
},
}
}
pub fn backward(dynamics_fn: DDynamicsWithParams) -> Self {
Self {
dynamics_fn,
method: DifferenceMethod::Backward,
strategy: PerturbationStrategy::Adaptive {
scale_factor: 1.0,
min_value: 1.0,
},
}
}
pub fn with_strategy(mut self, strategy: PerturbationStrategy) -> Self {
self.strategy = strategy;
self
}
fn compute_perturbation(&self, value: f64) -> f64 {
match self.strategy {
PerturbationStrategy::Adaptive {
scale_factor,
min_value,
} => {
let eps = f64::EPSILON;
scale_factor * eps.sqrt() * value.abs().max(min_value)
}
PerturbationStrategy::Fixed(h) => h,
PerturbationStrategy::Percentage(pct) => value.abs() * pct,
}
}
}
impl DSensitivityProvider for DNumericalSensitivity {
fn compute(&self, t: f64, state: &DVector<f64>, params: &DVector<f64>) -> DMatrix<f64> {
let s = state.len();
let p = params.len();
let mut sensitivity = DMatrix::<f64>::zeros(s, p);
for j in 0..p {
let h = self.compute_perturbation(params[j]);
let column = match self.method {
DifferenceMethod::Forward => {
let mut params_plus = params.clone();
params_plus[j] += h;
let f_plus = (self.dynamics_fn)(t, state, ¶ms_plus);
let f_0 = (self.dynamics_fn)(t, state, params);
(f_plus - f_0) / h
}
DifferenceMethod::Central => {
let mut params_plus = params.clone();
let mut params_minus = params.clone();
params_plus[j] += h;
params_minus[j] -= h;
let f_plus = (self.dynamics_fn)(t, state, ¶ms_plus);
let f_minus = (self.dynamics_fn)(t, state, ¶ms_minus);
(f_plus - f_minus) / (2.0 * h)
}
DifferenceMethod::Backward => {
let mut params_minus = params.clone();
params_minus[j] -= h;
let f_0 = (self.dynamics_fn)(t, state, params);
let f_minus = (self.dynamics_fn)(t, state, ¶ms_minus);
(f_0 - f_minus) / h
}
};
sensitivity.set_column(j, &column);
}
sensitivity
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_dynamic_numerical_sensitivity() {
let dynamics = |_t: f64, state: &DVector<f64>, params: &DVector<f64>| -> DVector<f64> {
params[0] * state
};
let provider = DNumericalSensitivity::central(Box::new(dynamics));
let state = DVector::from_vec(vec![1.0, 2.0]);
let params = DVector::from_vec(vec![3.0]);
let sens = provider.compute(0.0, &state, ¶ms);
assert_eq!(sens.nrows(), 2);
assert_eq!(sens.ncols(), 1);
assert_abs_diff_eq!(sens[(0, 0)], 1.0, epsilon = 1e-6);
assert_abs_diff_eq!(sens[(1, 0)], 2.0, epsilon = 1e-6);
}
#[test]
fn test_dynamic_analytical_sensitivity() {
let sensitivity_fn =
|_t: f64, state: &DVector<f64>, _params: &DVector<f64>| -> DMatrix<f64> {
DMatrix::from_column_slice(state.len(), 1, state.as_slice())
};
let provider = DAnalyticSensitivity::new(Box::new(sensitivity_fn));
let state = DVector::from_vec(vec![1.0, 2.0]);
let params = DVector::from_vec(vec![3.0]);
let sens = provider.compute(0.0, &state, ¶ms);
assert_eq!(sens.nrows(), 2);
assert_eq!(sens.ncols(), 1);
assert_abs_diff_eq!(sens[(0, 0)], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(sens[(1, 0)], 2.0, epsilon = 1e-10);
}
#[test]
fn test_static_numerical_sensitivity() {
let dynamics =
|_t: f64, state: &SVector<f64, 2>, params: &SVector<f64, 2>| -> SVector<f64, 2> {
SVector::<f64, 2>::new(params[0] * state[0], params[1] * state[1])
};
let provider = SNumericalSensitivity::central(Box::new(dynamics));
let state = SVector::<f64, 2>::new(1.0, 2.0);
let params = SVector::<f64, 2>::new(3.0, 4.0);
let sens = provider.compute(0.0, &state, ¶ms);
assert_abs_diff_eq!(sens[(0, 0)], 1.0, epsilon = 1e-6);
assert_abs_diff_eq!(sens[(0, 1)], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(sens[(1, 0)], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(sens[(1, 1)], 2.0, epsilon = 1e-6);
}
#[test]
fn test_static_analytical_sensitivity() {
let sensitivity_fn =
|_t: f64, state: &SVector<f64, 2>, _params: &SVector<f64, 2>| -> SMatrix<f64, 2, 2> {
SMatrix::<f64, 2, 2>::new(state[0], 0.0, 0.0, state[1])
};
let provider = SAnalyticSensitivity::new(Box::new(sensitivity_fn));
let state = SVector::<f64, 2>::new(1.0, 2.0);
let params = SVector::<f64, 2>::new(3.0, 4.0);
let sens = provider.compute(0.0, &state, ¶ms);
assert_abs_diff_eq!(sens[(0, 0)], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(sens[(0, 1)], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(sens[(1, 0)], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(sens[(1, 1)], 2.0, epsilon = 1e-10);
}
#[test]
fn test_forward_vs_central_difference() {
let dynamics = |_t: f64, state: &DVector<f64>, params: &DVector<f64>| -> DVector<f64> {
DVector::from_vec(vec![params[0] * params[0] * state[0]])
};
let state = DVector::from_vec(vec![1.0]);
let params = DVector::from_vec(vec![2.0]);
let forward = DNumericalSensitivity::forward(Box::new(dynamics));
let central = DNumericalSensitivity::central(Box::new(dynamics));
let sens_forward = forward.compute(0.0, &state, ¶ms);
let sens_central = central.compute(0.0, &state, ¶ms);
assert_abs_diff_eq!(sens_central[(0, 0)], 4.0, epsilon = 1e-8);
assert_abs_diff_eq!(sens_forward[(0, 0)], 4.0, epsilon = 1e-4);
}
#[test]
fn test_static_numerical_sensitivity_forward() {
let dynamics =
|_t: f64, state: &SVector<f64, 2>, params: &SVector<f64, 2>| -> SVector<f64, 2> {
SVector::<f64, 2>::new(params[0] * state[0], params[1] * state[1])
};
let provider = SNumericalSensitivity::forward(Box::new(dynamics));
let state = SVector::<f64, 2>::new(1.0, 2.0);
let params = SVector::<f64, 2>::new(3.0, 4.0);
let sens = provider.compute(0.0, &state, ¶ms);
assert_abs_diff_eq!(sens[(0, 0)], 1.0, epsilon = 1e-6);
assert_abs_diff_eq!(sens[(0, 1)], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(sens[(1, 0)], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(sens[(1, 1)], 2.0, epsilon = 1e-6);
}
#[test]
fn test_static_numerical_sensitivity_backward() {
let dynamics =
|_t: f64, state: &SVector<f64, 2>, params: &SVector<f64, 2>| -> SVector<f64, 2> {
SVector::<f64, 2>::new(params[0] * state[0], params[1] * state[1])
};
let provider = SNumericalSensitivity::backward(Box::new(dynamics));
let state = SVector::<f64, 2>::new(1.0, 2.0);
let params = SVector::<f64, 2>::new(3.0, 4.0);
let sens = provider.compute(0.0, &state, ¶ms);
assert_abs_diff_eq!(sens[(0, 0)], 1.0, epsilon = 1e-6);
assert_abs_diff_eq!(sens[(0, 1)], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(sens[(1, 0)], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(sens[(1, 1)], 2.0, epsilon = 1e-6);
}
#[test]
fn test_static_numerical_sensitivity_with_strategy_fixed() {
let dynamics =
|_t: f64, state: &SVector<f64, 2>, params: &SVector<f64, 2>| -> SVector<f64, 2> {
SVector::<f64, 2>::new(params[0] * state[0], params[1] * state[1])
};
let provider = SNumericalSensitivity::new(Box::new(dynamics))
.with_strategy(PerturbationStrategy::Fixed(1e-6));
let state = SVector::<f64, 2>::new(1.0, 2.0);
let params = SVector::<f64, 2>::new(3.0, 4.0);
let sens = provider.compute(0.0, &state, ¶ms);
assert_abs_diff_eq!(sens[(0, 0)], 1.0, epsilon = 1e-5);
assert_abs_diff_eq!(sens[(0, 1)], 0.0, epsilon = 1e-5);
assert_abs_diff_eq!(sens[(1, 0)], 0.0, epsilon = 1e-5);
assert_abs_diff_eq!(sens[(1, 1)], 2.0, epsilon = 1e-5);
}
#[test]
fn test_static_numerical_sensitivity_with_strategy_percentage() {
let dynamics =
|_t: f64, state: &SVector<f64, 2>, params: &SVector<f64, 2>| -> SVector<f64, 2> {
SVector::<f64, 2>::new(params[0] * state[0], params[1] * state[1])
};
let provider = SNumericalSensitivity::new(Box::new(dynamics))
.with_strategy(PerturbationStrategy::Percentage(1e-6));
let state = SVector::<f64, 2>::new(1.0, 2.0);
let params = SVector::<f64, 2>::new(3.0, 4.0);
let sens = provider.compute(0.0, &state, ¶ms);
assert_abs_diff_eq!(sens[(0, 0)], 1.0, epsilon = 1e-5);
assert_abs_diff_eq!(sens[(0, 1)], 0.0, epsilon = 1e-5);
assert_abs_diff_eq!(sens[(1, 0)], 0.0, epsilon = 1e-5);
assert_abs_diff_eq!(sens[(1, 1)], 2.0, epsilon = 1e-5);
}
#[test]
fn test_dynamic_numerical_sensitivity_backward() {
let dynamics = |_t: f64, state: &DVector<f64>, params: &DVector<f64>| -> DVector<f64> {
params[0] * state
};
let provider = DNumericalSensitivity::backward(Box::new(dynamics));
let state = DVector::from_vec(vec![1.0, 2.0]);
let params = DVector::from_vec(vec![3.0]);
let sens = provider.compute(0.0, &state, ¶ms);
assert_eq!(sens.nrows(), 2);
assert_eq!(sens.ncols(), 1);
assert_abs_diff_eq!(sens[(0, 0)], 1.0, epsilon = 1e-6);
assert_abs_diff_eq!(sens[(1, 0)], 2.0, epsilon = 1e-6);
}
#[test]
fn test_dynamic_numerical_sensitivity_with_strategy() {
let dynamics = |_t: f64, state: &DVector<f64>, params: &DVector<f64>| -> DVector<f64> {
params[0] * state
};
let provider = DNumericalSensitivity::new(Box::new(dynamics))
.with_strategy(PerturbationStrategy::Fixed(1e-6));
let state = DVector::from_vec(vec![1.0, 2.0]);
let params = DVector::from_vec(vec![3.0]);
let sens = provider.compute(0.0, &state, ¶ms);
assert_eq!(sens.nrows(), 2);
assert_eq!(sens.ncols(), 1);
assert_abs_diff_eq!(sens[(0, 0)], 1.0, epsilon = 1e-5);
assert_abs_diff_eq!(sens[(1, 0)], 2.0, epsilon = 1e-5);
}
#[test]
fn test_dynamic_numerical_sensitivity_fixed_perturbation() {
let dynamics = |_t: f64, state: &DVector<f64>, params: &DVector<f64>| -> DVector<f64> {
params[0] * params[0] * state
};
let provider = DNumericalSensitivity::central(Box::new(dynamics))
.with_strategy(PerturbationStrategy::Fixed(1e-5));
let state = DVector::from_vec(vec![1.0]);
let params = DVector::from_vec(vec![2.0]);
let sens = provider.compute(0.0, &state, ¶ms);
assert_abs_diff_eq!(sens[(0, 0)], 4.0, epsilon = 1e-4);
}
#[test]
fn test_dynamic_numerical_sensitivity_percentage_perturbation() {
let dynamics = |_t: f64, state: &DVector<f64>, params: &DVector<f64>| -> DVector<f64> {
params[0] * params[0] * state
};
let provider = DNumericalSensitivity::central(Box::new(dynamics))
.with_strategy(PerturbationStrategy::Percentage(1e-6));
let state = DVector::from_vec(vec![1.0]);
let params = DVector::from_vec(vec![2.0]);
let sens = provider.compute(0.0, &state, ¶ms);
assert_abs_diff_eq!(sens[(0, 0)], 4.0, epsilon = 1e-4);
}
#[test]
fn test_static_forward_vs_central_vs_backward() {
let dynamics =
|_t: f64, state: &SVector<f64, 1>, params: &SVector<f64, 1>| -> SVector<f64, 1> {
SVector::<f64, 1>::new(params[0] * params[0] * state[0])
};
let state = SVector::<f64, 1>::new(1.0);
let params = SVector::<f64, 1>::new(2.0);
let forward = SNumericalSensitivity::forward(Box::new(dynamics));
let central = SNumericalSensitivity::central(Box::new(dynamics));
let backward = SNumericalSensitivity::backward(Box::new(dynamics));
let sens_forward = forward.compute(0.0, &state, ¶ms);
let sens_central = central.compute(0.0, &state, ¶ms);
let sens_backward = backward.compute(0.0, &state, ¶ms);
assert_abs_diff_eq!(sens_central[(0, 0)], 4.0, epsilon = 1e-8);
assert_abs_diff_eq!(sens_forward[(0, 0)], 4.0, epsilon = 1e-4);
assert_abs_diff_eq!(sens_backward[(0, 0)], 4.0, epsilon = 1e-4);
}
#[test]
fn test_dynamic_backward_vs_central() {
let dynamics = |_t: f64, state: &DVector<f64>, params: &DVector<f64>| -> DVector<f64> {
DVector::from_vec(vec![params[0] * params[0] * state[0]])
};
let state = DVector::from_vec(vec![1.0]);
let params = DVector::from_vec(vec![2.0]);
let backward = DNumericalSensitivity::backward(Box::new(dynamics));
let central = DNumericalSensitivity::central(Box::new(dynamics));
let sens_backward = backward.compute(0.0, &state, ¶ms);
let sens_central = central.compute(0.0, &state, ¶ms);
assert_abs_diff_eq!(sens_central[(0, 0)], 4.0, epsilon = 1e-8);
assert_abs_diff_eq!(sens_backward[(0, 0)], 4.0, epsilon = 1e-4);
}
}