1use crate::kernel::{Domain, ExprData, ExprId, ExprPool};
15use crate::matrix::{jacobian, Matrix};
16use crate::ode::{OdeError, ODE};
17use crate::simplify::engine::simplify;
18
19pub fn sensitivity_system(
33 ode: &ODE,
34 params: &[ExprId],
35 pool: &ExprPool,
36) -> Result<SensitivitySystem, OdeError> {
37 let m = ode.order();
38 let n_params = params.len();
39
40 let jac_y = jacobian(&ode.rhs, &ode.state_vars, pool)
42 .map_err(|e| OdeError::DiffError(e.to_string()))?;
43
44 let jac_p = jacobian(&ode.rhs, params, pool).map_err(|e| OdeError::DiffError(e.to_string()))?;
46
47 let mut sens_vars: Vec<Vec<ExprId>> = Vec::new(); let mut sens_derivs: Vec<Vec<ExprId>> = Vec::new();
51 for (j, ¶m) in params.iter().enumerate().take(n_params) {
52 let col_vars: Vec<ExprId> = (0..m)
53 .map(|i| {
54 let pname = pool.with(param, |d| match d {
55 ExprData::Symbol { name, .. } => name.clone(),
56 _ => format!("p{j}"),
57 });
58 let yname = pool.with(ode.state_vars[i], |d| match d {
59 ExprData::Symbol { name, .. } => name.clone(),
60 _ => format!("y{i}"),
61 });
62 pool.symbol(format!("dS_{yname}_{pname}"), Domain::Real)
63 })
64 .collect();
65 let col_derivs: Vec<ExprId> = col_vars
66 .iter()
67 .map(|&v| {
68 let name = pool.with(v, |d| match d {
69 ExprData::Symbol { name, .. } => format!("d{name}/dt"),
70 _ => "d?/dt".to_string(),
71 });
72 pool.symbol(name, Domain::Real)
73 })
74 .collect();
75 sens_vars.push(col_vars);
76 sens_derivs.push(col_derivs);
77 }
78
79 let mut extended_vars: Vec<ExprId> = ode.state_vars.clone();
81 let mut extended_derivs: Vec<ExprId> = ode.derivatives.clone();
82 let mut extended_rhs: Vec<ExprId> = ode.rhs.clone();
83
84 let mut sens_rhs_matrix: Vec<Vec<ExprId>> = Vec::new(); for j in 0..n_params {
87 let s_j = Matrix::new(sens_vars[j].iter().map(|&v| vec![v]).collect())
89 .expect("single-column matrix");
90
91 let jac_sj = jac_y.mul(&s_j, pool).expect("compatible shapes");
93
94 let df_dpj: Vec<ExprId> = (0..m).map(|i| jac_p.get(i, j)).collect();
96
97 let col_rhs: Vec<ExprId> = (0..m)
99 .map(|i| {
100 let jac_term = jac_sj.get(i, 0);
101 let param_term = df_dpj[i];
102 simplify(pool.add(vec![jac_term, param_term]), pool).value
103 })
104 .collect();
105
106 sens_rhs_matrix.push(col_rhs.clone());
107
108 for i in 0..m {
110 extended_vars.push(sens_vars[j][i]);
111 extended_derivs.push(sens_derivs[j][i]);
112 extended_rhs.push(col_rhs[i]);
113 }
114 }
115
116 Ok(SensitivitySystem {
117 extended_ode: ODE {
118 state_vars: extended_vars,
119 derivatives: extended_derivs,
120 rhs: extended_rhs,
121 time_var: ode.time_var,
122 initial_conditions: ode.initial_conditions.clone(),
123 },
124 original_dim: m,
125 n_params,
126 param_vars: params.to_vec(),
127 sensitivity_vars: sens_vars,
128 })
129}
130
131pub fn adjoint_system(
153 ode: &ODE,
154 objective_grad: &[ExprId], pool: &ExprPool,
156) -> Result<AdjointSystem, OdeError> {
157 let m = ode.order();
158
159 let jac_y = jacobian(&ode.rhs, &ode.state_vars, pool)
161 .map_err(|e| OdeError::DiffError(e.to_string()))?;
162
163 let lambda: Vec<ExprId> = (0..m)
165 .map(|i| {
166 let yname = pool.with(ode.state_vars[i], |d| match d {
167 ExprData::Symbol { name, .. } => name.clone(),
168 _ => format!("y{i}"),
169 });
170 pool.symbol(format!("lambda_{yname}"), Domain::Real)
171 })
172 .collect();
173
174 let lambda_derivs: Vec<ExprId> = lambda
175 .iter()
176 .map(|&v| {
177 let name = pool.with(v, |d| match d {
178 ExprData::Symbol { name, .. } => format!("d{name}/dt"),
179 _ => "d?/dt".to_string(),
180 });
181 pool.symbol(&name, Domain::Real)
182 })
183 .collect();
184
185 let jac_y_t = jac_y.transpose();
187 let lam_mat = Matrix::new(lambda.iter().map(|&v| vec![v]).collect()).expect("column matrix");
188 let jac_lam = jac_y_t.mul(&lam_mat, pool).expect("compatible shapes");
189
190 let neg_one = pool.integer(-1_i32);
191 let adjoint_rhs: Vec<ExprId> = (0..m)
192 .map(|i| simplify(pool.mul(vec![neg_one, jac_lam.get(i, 0)]), pool).value)
193 .collect();
194
195 let terminal_conditions: Vec<(ExprId, ExprId)> = lambda
197 .iter()
198 .zip(objective_grad.iter())
199 .map(|(&l, &g)| (l, g))
200 .collect();
201
202 let adjoint_ode = ODE {
203 state_vars: lambda.clone(),
204 derivatives: lambda_derivs,
205 rhs: adjoint_rhs,
206 time_var: ode.time_var,
207 initial_conditions: terminal_conditions.clone(),
208 };
209
210 Ok(AdjointSystem {
211 adjoint_ode,
212 lambda_vars: lambda,
213 terminal_conditions,
214 })
215}
216
217#[derive(Clone, Debug)]
223pub struct SensitivitySystem {
224 pub extended_ode: ODE,
226 pub original_dim: usize,
228 pub n_params: usize,
230 pub param_vars: Vec<ExprId>,
232 pub sensitivity_vars: Vec<Vec<ExprId>>,
234}
235
236impl SensitivitySystem {
237 pub fn get_sensitivity(&self, state_idx: usize, param_idx: usize) -> ExprId {
239 self.sensitivity_vars[param_idx][state_idx]
240 }
241
242 pub fn display(&self, pool: &ExprPool) -> String {
244 self.extended_ode.display(pool)
245 }
246}
247
248#[derive(Clone, Debug)]
250pub struct AdjointSystem {
251 pub adjoint_ode: ODE,
253 pub lambda_vars: Vec<ExprId>,
255 pub terminal_conditions: Vec<(ExprId, ExprId)>,
257}
258
259#[cfg(test)]
264mod tests {
265 use super::*;
266 use crate::kernel::{Domain, ExprPool};
267 use crate::ode::ODE;
268
269 fn p() -> ExprPool {
270 ExprPool::new()
271 }
272
273 #[test]
274 fn sensitivity_linear_ode() {
275 let pool = p();
278 let y = pool.symbol("y", Domain::Real);
279 let a = pool.symbol("a", Domain::Real);
280 let t = pool.symbol("t", Domain::Real);
281 let rhs = pool.mul(vec![a, y]);
282 let ode = ODE::new(vec![y], vec![rhs], t, &pool).unwrap();
283 let sys = sensitivity_system(&ode, &[a], &pool).unwrap();
284 assert_eq!(sys.extended_ode.order(), 2);
286 assert_eq!(sys.original_dim, 1);
287 assert_eq!(sys.n_params, 1);
288 }
289
290 #[test]
291 fn sensitivity_constant_ode() {
292 let pool = p();
296 let y = pool.symbol("y", Domain::Real);
297 let p_sym = pool.symbol("p", Domain::Real);
298 let t = pool.symbol("t", Domain::Real);
299 let ode = ODE::new(vec![y], vec![p_sym], t, &pool).unwrap();
300 let sys = sensitivity_system(&ode, &[p_sym], &pool).unwrap();
301 assert_eq!(sys.extended_ode.order(), 2);
302 let s_rhs = sys.extended_ode.rhs[1];
304 assert_eq!(s_rhs, pool.integer(1_i32));
305 }
306
307 #[test]
308 fn adjoint_system_basic() {
309 let pool = p();
312 let y = pool.symbol("y", Domain::Real);
313 let t = pool.symbol("t", Domain::Real);
314 let neg_y = pool.mul(vec![pool.integer(-1_i32), y]);
315 let ode = ODE::new(vec![y], vec![neg_y], t, &pool).unwrap();
316 let obj_grad = vec![pool.integer(1_i32)];
317 let adj = adjoint_system(&ode, &obj_grad, &pool).unwrap();
318 assert_eq!(adj.adjoint_ode.order(), 1);
319 let lam = adj.lambda_vars[0];
321 let rhs = adj.adjoint_ode.rhs[0];
322 assert_eq!(rhs, lam);
323 }
324
325 #[test]
326 fn sensitivity_two_params() {
327 let pool = p();
329 let y = pool.symbol("y", Domain::Real);
330 let a = pool.symbol("a", Domain::Real);
331 let b = pool.symbol("b", Domain::Real);
332 let t = pool.symbol("t", Domain::Real);
333 let rhs = pool.add(vec![pool.mul(vec![a, y]), b]);
334 let ode = ODE::new(vec![y], vec![rhs], t, &pool).unwrap();
335 let sys = sensitivity_system(&ode, &[a, b], &pool).unwrap();
336 assert_eq!(sys.extended_ode.order(), 3);
338 assert_eq!(sys.n_params, 2);
339 }
340}