Skip to main content

alkahest_cas/ode/
sensitivity.rs

1//! Phase 19 — Forward sensitivity analysis.
2//!
3//! Given an ODE `dy/dt = f(t, y, p)` and a parameter vector `p`, the
4//! forward sensitivity equations are:
5//!
6//! ```text
7//! dS_j/dt = (∂f/∂y) · S_j + ∂f/∂p_j
8//! ```
9//!
10//! where `S_j = ∂y/∂p_j` is an m-vector of sensitivities.
11//!
12//! `sensitivity_system` returns an extended ODE whose state is `(y, S)`.
13
14use crate::kernel::{Domain, ExprData, ExprId, ExprPool};
15use crate::matrix::{jacobian, Matrix};
16use crate::ode::{OdeError, ODE};
17use crate::simplify::engine::simplify;
18
19// ---------------------------------------------------------------------------
20// Sensitivity system
21// ---------------------------------------------------------------------------
22
23/// Build the forward-sensitivity ODE for `ode` with respect to `params`.
24///
25/// Returns an extended `ODE` whose state vector is `[y_0…y_{m-1}, S_{0,0}…]`
26/// where `S_{i,j} = ∂y_i/∂params[j]`.
27///
28/// # Errors
29///
30/// Returns an `OdeError` if differentiation of any RHS fails (e.g. an unknown
31/// function that cannot be differentiated).
32pub fn sensitivity_system(
33    ode: &ODE,
34    params: &[ExprId],
35    pool: &ExprPool,
36) -> Result<SensitivitySystem, OdeError> {
37    let m = ode.order();
38    let n_params = params.len();
39
40    // Jacobian of f w.r.t. state variables: ∂f_i/∂y_j  (m × m)
41    let jac_y = jacobian(&ode.rhs, &ode.state_vars, pool)
42        .map_err(|e| OdeError::DiffError(e.to_string()))?;
43
44    // Jacobian of f w.r.t. parameters: ∂f_i/∂p_j  (m × n_params)
45    let jac_p = jacobian(&ode.rhs, params, pool).map_err(|e| OdeError::DiffError(e.to_string()))?;
46
47    // Sensitivity variables: S_{i,j} = ∂y_i/∂p_j  (m × n_params matrix)
48    // Stored column-major as separate ODE states
49    let mut sens_vars: Vec<Vec<ExprId>> = Vec::new(); // sens_vars[j][i] = S_{i,j}
50    let mut sens_derivs: Vec<Vec<ExprId>> = Vec::new();
51    for (j, &param) in params.iter().enumerate().take(n_params) {
52        let col_vars: Vec<ExprId> = (0..m)
53            .map(|i| {
54                let pname = pool.with(param, |d| match d {
55                    ExprData::Symbol { name, .. } => name.clone(),
56                    _ => format!("p{j}"),
57                });
58                let yname = pool.with(ode.state_vars[i], |d| match d {
59                    ExprData::Symbol { name, .. } => name.clone(),
60                    _ => format!("y{i}"),
61                });
62                pool.symbol(format!("dS_{yname}_{pname}"), Domain::Real)
63            })
64            .collect();
65        let col_derivs: Vec<ExprId> = col_vars
66            .iter()
67            .map(|&v| {
68                let name = pool.with(v, |d| match d {
69                    ExprData::Symbol { name, .. } => format!("d{name}/dt"),
70                    _ => "d?/dt".to_string(),
71                });
72                pool.symbol(name, Domain::Real)
73            })
74            .collect();
75        sens_vars.push(col_vars);
76        sens_derivs.push(col_derivs);
77    }
78
79    // Build sensitivity RHS: dS_j/dt = J_y · S_j + ∂f/∂p_j
80    let mut extended_vars: Vec<ExprId> = ode.state_vars.clone();
81    let mut extended_derivs: Vec<ExprId> = ode.derivatives.clone();
82    let mut extended_rhs: Vec<ExprId> = ode.rhs.clone();
83
84    let mut sens_rhs_matrix: Vec<Vec<ExprId>> = Vec::new(); // [j][i]
85
86    for j in 0..n_params {
87        // S_j is the j-th column of the sensitivity matrix, as a vector
88        let s_j = Matrix::new(sens_vars[j].iter().map(|&v| vec![v]).collect())
89            .expect("single-column matrix");
90
91        // J_y · S_j  (m×m times m×1 = m×1)
92        let jac_sj = jac_y.mul(&s_j, pool).expect("compatible shapes");
93
94        // ∂f/∂p_j  (column j of jac_p, shape m×1)
95        let df_dpj: Vec<ExprId> = (0..m).map(|i| jac_p.get(i, j)).collect();
96
97        // dS_j/dt = J_y * S_j + ∂f/∂p_j
98        let col_rhs: Vec<ExprId> = (0..m)
99            .map(|i| {
100                let jac_term = jac_sj.get(i, 0);
101                let param_term = df_dpj[i];
102                simplify(pool.add(vec![jac_term, param_term]), pool).value
103            })
104            .collect();
105
106        sens_rhs_matrix.push(col_rhs.clone());
107
108        // Append to extended system
109        for i in 0..m {
110            extended_vars.push(sens_vars[j][i]);
111            extended_derivs.push(sens_derivs[j][i]);
112            extended_rhs.push(col_rhs[i]);
113        }
114    }
115
116    Ok(SensitivitySystem {
117        extended_ode: ODE {
118            state_vars: extended_vars,
119            derivatives: extended_derivs,
120            rhs: extended_rhs,
121            time_var: ode.time_var,
122            initial_conditions: ode.initial_conditions.clone(),
123        },
124        original_dim: m,
125        n_params,
126        param_vars: params.to_vec(),
127        sensitivity_vars: sens_vars,
128    })
129}
130
131// ---------------------------------------------------------------------------
132// Adjoint sensitivity (reverse-mode)
133// ---------------------------------------------------------------------------
134
135/// Build the adjoint (reverse) sensitivity system for a scalar objective.
136///
137/// Given `ode` and a scalar objective `obj = g(y(T))`, the adjoint equations
138/// are:
139///
140/// ```text
141/// dλ/dt = -(∂f/∂y)ᵀ · λ
142/// λ(T)  = ∂g/∂y(T)
143/// ```
144///
145/// The gradient w.r.t. parameters is then:
146///
147/// ```text
148/// ∂J/∂p = ∫₀ᵀ (∂f/∂p)ᵀ · λ dt
149/// ```
150///
151/// This function returns the adjoint ODE (to integrate backward in time).
152pub fn adjoint_system(
153    ode: &ODE,
154    objective_grad: &[ExprId], // ∂g/∂y_i  at terminal time
155    pool: &ExprPool,
156) -> Result<AdjointSystem, OdeError> {
157    let m = ode.order();
158
159    // Jacobian ∂f/∂y  (m × m)
160    let jac_y = jacobian(&ode.rhs, &ode.state_vars, pool)
161        .map_err(|e| OdeError::DiffError(e.to_string()))?;
162
163    // Adjoint variables λ_i
164    let lambda: Vec<ExprId> = (0..m)
165        .map(|i| {
166            let yname = pool.with(ode.state_vars[i], |d| match d {
167                ExprData::Symbol { name, .. } => name.clone(),
168                _ => format!("y{i}"),
169            });
170            pool.symbol(format!("lambda_{yname}"), Domain::Real)
171        })
172        .collect();
173
174    let lambda_derivs: Vec<ExprId> = lambda
175        .iter()
176        .map(|&v| {
177            let name = pool.with(v, |d| match d {
178                ExprData::Symbol { name, .. } => format!("d{name}/dt"),
179                _ => "d?/dt".to_string(),
180            });
181            pool.symbol(&name, Domain::Real)
182        })
183        .collect();
184
185    // Adjoint RHS: dλ/dt = -(J_y)ᵀ · λ  (backward in time)
186    let jac_y_t = jac_y.transpose();
187    let lam_mat = Matrix::new(lambda.iter().map(|&v| vec![v]).collect()).expect("column matrix");
188    let jac_lam = jac_y_t.mul(&lam_mat, pool).expect("compatible shapes");
189
190    let neg_one = pool.integer(-1_i32);
191    let adjoint_rhs: Vec<ExprId> = (0..m)
192        .map(|i| simplify(pool.mul(vec![neg_one, jac_lam.get(i, 0)]), pool).value)
193        .collect();
194
195    // Terminal conditions: λ(T) = ∂g/∂y
196    let terminal_conditions: Vec<(ExprId, ExprId)> = lambda
197        .iter()
198        .zip(objective_grad.iter())
199        .map(|(&l, &g)| (l, g))
200        .collect();
201
202    let adjoint_ode = ODE {
203        state_vars: lambda.clone(),
204        derivatives: lambda_derivs,
205        rhs: adjoint_rhs,
206        time_var: ode.time_var,
207        initial_conditions: terminal_conditions.clone(),
208    };
209
210    Ok(AdjointSystem {
211        adjoint_ode,
212        lambda_vars: lambda,
213        terminal_conditions,
214    })
215}
216
217// ---------------------------------------------------------------------------
218// Result types
219// ---------------------------------------------------------------------------
220
221/// The extended ODE system for forward sensitivity analysis.
222#[derive(Clone, Debug)]
223pub struct SensitivitySystem {
224    /// Extended ODE: state = [y, S_0, S_1, …, S_{n-1}]
225    pub extended_ode: ODE,
226    /// Dimension of the original state
227    pub original_dim: usize,
228    /// Number of parameters
229    pub n_params: usize,
230    /// The parameter variables
231    pub param_vars: Vec<ExprId>,
232    /// `sensitivity_vars[j][i]` = ExprId for S_{i,j} = ∂y_i/∂p_j
233    pub sensitivity_vars: Vec<Vec<ExprId>>,
234}
235
236impl SensitivitySystem {
237    /// Get the sensitivity variable S_{i,j} = ∂y_i/∂p_j.
238    pub fn get_sensitivity(&self, state_idx: usize, param_idx: usize) -> ExprId {
239        self.sensitivity_vars[param_idx][state_idx]
240    }
241
242    /// Display the sensitivity system.
243    pub fn display(&self, pool: &ExprPool) -> String {
244        self.extended_ode.display(pool)
245    }
246}
247
248/// The adjoint ODE system for reverse-mode sensitivity.
249#[derive(Clone, Debug)]
250pub struct AdjointSystem {
251    /// Adjoint ODE: dλ/dt = -(∂f/∂y)ᵀ · λ, integrated backward
252    pub adjoint_ode: ODE,
253    /// Adjoint variables λ_i
254    pub lambda_vars: Vec<ExprId>,
255    /// Terminal conditions λ(T) = ∂g/∂y
256    pub terminal_conditions: Vec<(ExprId, ExprId)>,
257}
258
259// ---------------------------------------------------------------------------
260// Tests
261// ---------------------------------------------------------------------------
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::kernel::{Domain, ExprPool};
267    use crate::ode::ODE;
268
269    fn p() -> ExprPool {
270        ExprPool::new()
271    }
272
273    #[test]
274    fn sensitivity_linear_ode() {
275        // dy/dt = a*y,  param = a
276        // Sensitivity: dS/dt = a*S + y,  S(0) = 0
277        let pool = p();
278        let y = pool.symbol("y", Domain::Real);
279        let a = pool.symbol("a", Domain::Real);
280        let t = pool.symbol("t", Domain::Real);
281        let rhs = pool.mul(vec![a, y]);
282        let ode = ODE::new(vec![y], vec![rhs], t, &pool).unwrap();
283        let sys = sensitivity_system(&ode, &[a], &pool).unwrap();
284        // Extended state: [y, S_{y,a}]
285        assert_eq!(sys.extended_ode.order(), 2);
286        assert_eq!(sys.original_dim, 1);
287        assert_eq!(sys.n_params, 1);
288    }
289
290    #[test]
291    fn sensitivity_constant_ode() {
292        // dy/dt = p  (constant RHS), param = p
293        // ∂f/∂y = 0, ∂f/∂p = 1
294        // dS/dt = 0 * S + 1 = 1
295        let pool = p();
296        let y = pool.symbol("y", Domain::Real);
297        let p_sym = pool.symbol("p", Domain::Real);
298        let t = pool.symbol("t", Domain::Real);
299        let ode = ODE::new(vec![y], vec![p_sym], t, &pool).unwrap();
300        let sys = sensitivity_system(&ode, &[p_sym], &pool).unwrap();
301        assert_eq!(sys.extended_ode.order(), 2);
302        // dS/dt should simplify to 1
303        let s_rhs = sys.extended_ode.rhs[1];
304        assert_eq!(s_rhs, pool.integer(1_i32));
305    }
306
307    #[test]
308    fn adjoint_system_basic() {
309        // dy/dt = -y, objective ∂g/∂y = 1
310        // Adjoint: dλ/dt = -(-1)*λ = λ
311        let pool = p();
312        let y = pool.symbol("y", Domain::Real);
313        let t = pool.symbol("t", Domain::Real);
314        let neg_y = pool.mul(vec![pool.integer(-1_i32), y]);
315        let ode = ODE::new(vec![y], vec![neg_y], t, &pool).unwrap();
316        let obj_grad = vec![pool.integer(1_i32)];
317        let adj = adjoint_system(&ode, &obj_grad, &pool).unwrap();
318        assert_eq!(adj.adjoint_ode.order(), 1);
319        // dλ/dt = λ  (Jacobian is -1, negated → 1 * λ)
320        let lam = adj.lambda_vars[0];
321        let rhs = adj.adjoint_ode.rhs[0];
322        assert_eq!(rhs, lam);
323    }
324
325    #[test]
326    fn sensitivity_two_params() {
327        // dy/dt = a*y + b,  params = [a, b]
328        let pool = p();
329        let y = pool.symbol("y", Domain::Real);
330        let a = pool.symbol("a", Domain::Real);
331        let b = pool.symbol("b", Domain::Real);
332        let t = pool.symbol("t", Domain::Real);
333        let rhs = pool.add(vec![pool.mul(vec![a, y]), b]);
334        let ode = ODE::new(vec![y], vec![rhs], t, &pool).unwrap();
335        let sys = sensitivity_system(&ode, &[a, b], &pool).unwrap();
336        // Extended state: [y, S_{y,a}, S_{y,b}]
337        assert_eq!(sys.extended_ode.order(), 3);
338        assert_eq!(sys.n_params, 2);
339    }
340}