use crate::kernel::{Domain, ExprData, ExprId, ExprPool};
use crate::matrix::{jacobian, Matrix};
use crate::ode::{OdeError, ODE};
use crate::simplify::engine::simplify;
pub fn sensitivity_system(
ode: &ODE,
params: &[ExprId],
pool: &ExprPool,
) -> Result<SensitivitySystem, OdeError> {
let m = ode.order();
let n_params = params.len();
let jac_y = jacobian(&ode.rhs, &ode.state_vars, pool)
.map_err(|e| OdeError::DiffError(e.to_string()))?;
let jac_p = jacobian(&ode.rhs, params, pool).map_err(|e| OdeError::DiffError(e.to_string()))?;
let mut sens_vars: Vec<Vec<ExprId>> = Vec::new(); let mut sens_derivs: Vec<Vec<ExprId>> = Vec::new();
for (j, ¶m) in params.iter().enumerate().take(n_params) {
let col_vars: Vec<ExprId> = (0..m)
.map(|i| {
let pname = pool.with(param, |d| match d {
ExprData::Symbol { name, .. } => name.clone(),
_ => format!("p{j}"),
});
let yname = pool.with(ode.state_vars[i], |d| match d {
ExprData::Symbol { name, .. } => name.clone(),
_ => format!("y{i}"),
});
pool.symbol(format!("dS_{yname}_{pname}"), Domain::Real)
})
.collect();
let col_derivs: Vec<ExprId> = col_vars
.iter()
.map(|&v| {
let name = pool.with(v, |d| match d {
ExprData::Symbol { name, .. } => format!("d{name}/dt"),
_ => "d?/dt".to_string(),
});
pool.symbol(name, Domain::Real)
})
.collect();
sens_vars.push(col_vars);
sens_derivs.push(col_derivs);
}
let mut extended_vars: Vec<ExprId> = ode.state_vars.clone();
let mut extended_derivs: Vec<ExprId> = ode.derivatives.clone();
let mut extended_rhs: Vec<ExprId> = ode.rhs.clone();
let mut sens_rhs_matrix: Vec<Vec<ExprId>> = Vec::new();
for j in 0..n_params {
let s_j = Matrix::new(sens_vars[j].iter().map(|&v| vec![v]).collect())
.expect("single-column matrix");
let jac_sj = jac_y.mul(&s_j, pool).expect("compatible shapes");
let df_dpj: Vec<ExprId> = (0..m).map(|i| jac_p.get(i, j)).collect();
let col_rhs: Vec<ExprId> = (0..m)
.map(|i| {
let jac_term = jac_sj.get(i, 0);
let param_term = df_dpj[i];
simplify(pool.add(vec![jac_term, param_term]), pool).value
})
.collect();
sens_rhs_matrix.push(col_rhs.clone());
for i in 0..m {
extended_vars.push(sens_vars[j][i]);
extended_derivs.push(sens_derivs[j][i]);
extended_rhs.push(col_rhs[i]);
}
}
Ok(SensitivitySystem {
extended_ode: ODE {
state_vars: extended_vars,
derivatives: extended_derivs,
rhs: extended_rhs,
time_var: ode.time_var,
initial_conditions: ode.initial_conditions.clone(),
},
original_dim: m,
n_params,
param_vars: params.to_vec(),
sensitivity_vars: sens_vars,
})
}
pub fn adjoint_system(
ode: &ODE,
objective_grad: &[ExprId], pool: &ExprPool,
) -> Result<AdjointSystem, OdeError> {
let m = ode.order();
let jac_y = jacobian(&ode.rhs, &ode.state_vars, pool)
.map_err(|e| OdeError::DiffError(e.to_string()))?;
let lambda: Vec<ExprId> = (0..m)
.map(|i| {
let yname = pool.with(ode.state_vars[i], |d| match d {
ExprData::Symbol { name, .. } => name.clone(),
_ => format!("y{i}"),
});
pool.symbol(format!("lambda_{yname}"), Domain::Real)
})
.collect();
let lambda_derivs: Vec<ExprId> = lambda
.iter()
.map(|&v| {
let name = pool.with(v, |d| match d {
ExprData::Symbol { name, .. } => format!("d{name}/dt"),
_ => "d?/dt".to_string(),
});
pool.symbol(&name, Domain::Real)
})
.collect();
let jac_y_t = jac_y.transpose();
let lam_mat = Matrix::new(lambda.iter().map(|&v| vec![v]).collect()).expect("column matrix");
let jac_lam = jac_y_t.mul(&lam_mat, pool).expect("compatible shapes");
let neg_one = pool.integer(-1_i32);
let adjoint_rhs: Vec<ExprId> = (0..m)
.map(|i| simplify(pool.mul(vec![neg_one, jac_lam.get(i, 0)]), pool).value)
.collect();
let terminal_conditions: Vec<(ExprId, ExprId)> = lambda
.iter()
.zip(objective_grad.iter())
.map(|(&l, &g)| (l, g))
.collect();
let adjoint_ode = ODE {
state_vars: lambda.clone(),
derivatives: lambda_derivs,
rhs: adjoint_rhs,
time_var: ode.time_var,
initial_conditions: terminal_conditions.clone(),
};
Ok(AdjointSystem {
adjoint_ode,
lambda_vars: lambda,
terminal_conditions,
})
}
#[derive(Clone, Debug)]
pub struct SensitivitySystem {
pub extended_ode: ODE,
pub original_dim: usize,
pub n_params: usize,
pub param_vars: Vec<ExprId>,
pub sensitivity_vars: Vec<Vec<ExprId>>,
}
impl SensitivitySystem {
pub fn get_sensitivity(&self, state_idx: usize, param_idx: usize) -> ExprId {
self.sensitivity_vars[param_idx][state_idx]
}
pub fn display(&self, pool: &ExprPool) -> String {
self.extended_ode.display(pool)
}
}
#[derive(Clone, Debug)]
pub struct AdjointSystem {
pub adjoint_ode: ODE,
pub lambda_vars: Vec<ExprId>,
pub terminal_conditions: Vec<(ExprId, ExprId)>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::{Domain, ExprPool};
use crate::ode::ODE;
fn p() -> ExprPool {
ExprPool::new()
}
#[test]
fn sensitivity_linear_ode() {
let pool = p();
let y = pool.symbol("y", Domain::Real);
let a = pool.symbol("a", Domain::Real);
let t = pool.symbol("t", Domain::Real);
let rhs = pool.mul(vec![a, y]);
let ode = ODE::new(vec![y], vec![rhs], t, &pool).unwrap();
let sys = sensitivity_system(&ode, &[a], &pool).unwrap();
assert_eq!(sys.extended_ode.order(), 2);
assert_eq!(sys.original_dim, 1);
assert_eq!(sys.n_params, 1);
}
#[test]
fn sensitivity_constant_ode() {
let pool = p();
let y = pool.symbol("y", Domain::Real);
let p_sym = pool.symbol("p", Domain::Real);
let t = pool.symbol("t", Domain::Real);
let ode = ODE::new(vec![y], vec![p_sym], t, &pool).unwrap();
let sys = sensitivity_system(&ode, &[p_sym], &pool).unwrap();
assert_eq!(sys.extended_ode.order(), 2);
let s_rhs = sys.extended_ode.rhs[1];
assert_eq!(s_rhs, pool.integer(1_i32));
}
#[test]
fn adjoint_system_basic() {
let pool = p();
let y = pool.symbol("y", Domain::Real);
let t = pool.symbol("t", Domain::Real);
let neg_y = pool.mul(vec![pool.integer(-1_i32), y]);
let ode = ODE::new(vec![y], vec![neg_y], t, &pool).unwrap();
let obj_grad = vec![pool.integer(1_i32)];
let adj = adjoint_system(&ode, &obj_grad, &pool).unwrap();
assert_eq!(adj.adjoint_ode.order(), 1);
let lam = adj.lambda_vars[0];
let rhs = adj.adjoint_ode.rhs[0];
assert_eq!(rhs, lam);
}
#[test]
fn sensitivity_two_params() {
let pool = p();
let y = pool.symbol("y", Domain::Real);
let a = pool.symbol("a", Domain::Real);
let b = pool.symbol("b", Domain::Real);
let t = pool.symbol("t", Domain::Real);
let rhs = pool.add(vec![pool.mul(vec![a, y]), b]);
let ode = ODE::new(vec![y], vec![rhs], t, &pool).unwrap();
let sys = sensitivity_system(&ode, &[a, b], &pool).unwrap();
assert_eq!(sys.extended_ode.order(), 3);
assert_eq!(sys.n_params, 2);
}
}