Skip to main content

alkahest_cas/ode/
mod.rs

1//! Phase 16 — ODE representation and manipulation.
2//!
3//! Provides the `ODE` type for first-order systems dy/dt = f(t, y) and
4//! helpers to lower higher-order ODEs to first-order systems.
5//!
6//! Phase 19 sensitivity analysis is also implemented here as
7//! `sensitivity_system`.
8
9pub mod sensitivity;
10
11use crate::kernel::{Domain, ExprData, ExprId, ExprPool};
12use crate::simplify::engine::simplify;
13use std::fmt;
14
15// ---------------------------------------------------------------------------
16// Error type
17// ---------------------------------------------------------------------------
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum OdeError {
21    VariableCountMismatch,
22    NotFirstOrder,
23    DiffError(String),
24}
25
26impl fmt::Display for OdeError {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        match self {
29            OdeError::VariableCountMismatch => write!(f, "variable and RHS count mismatch"),
30            OdeError::NotFirstOrder => write!(f, "ODE is not first-order"),
31            OdeError::DiffError(msg) => write!(f, "differentiation error: {msg}"),
32        }
33    }
34}
35
36impl std::error::Error for OdeError {}
37
38impl crate::errors::AlkahestError for OdeError {
39    fn code(&self) -> &'static str {
40        match self {
41            OdeError::VariableCountMismatch => "E-ODE-001",
42            OdeError::NotFirstOrder => "E-ODE-002",
43            OdeError::DiffError(_) => "E-ODE-003",
44        }
45    }
46
47    fn remediation(&self) -> Option<&'static str> {
48        match self {
49            OdeError::VariableCountMismatch => Some(
50                "the number of state variables must equal the number of right-hand-side expressions",
51            ),
52            OdeError::NotFirstOrder => Some(
53                "use lower_to_first_order() to reduce higher-order ODEs to first-order form",
54            ),
55            OdeError::DiffError(_) => Some(
56                "check that all functions in the ODE are differentiable; unknown functions block lowering",
57            ),
58        }
59    }
60}
61
62// ---------------------------------------------------------------------------
63// ODE: first-order system dy/dt = f(t, y)
64// ---------------------------------------------------------------------------
65
66/// A first-order ODE system `dy_i/dt = rhs_i(t, y)`.
67///
68/// Invariants:
69/// - `state_vars.len() == derivatives.len() == rhs.len()`
70/// - `derivatives[i]` is a `Symbol` representing `d(state_vars[i])/dt`
71/// - All expressions live in the same pool
72#[derive(Clone, Debug)]
73pub struct ODE {
74    /// State variables `y_0, y_1, …`
75    pub state_vars: Vec<ExprId>,
76    /// Derivative symbols `dy_0/dt, dy_1/dt, …`
77    pub derivatives: Vec<ExprId>,
78    /// Right-hand-side expressions `f_i(t, y)`
79    pub rhs: Vec<ExprId>,
80    /// The independent variable (usually `t`)
81    pub time_var: ExprId,
82    /// Initial conditions: `(var, value)` pairs
83    pub initial_conditions: Vec<(ExprId, ExprId)>,
84}
85
86impl ODE {
87    /// Construct a first-order system directly.
88    ///
89    /// `state_vars` — the state `y_i`
90    /// `rhs`        — the right-hand sides `f_i(t, y)`
91    /// `time_var`   — the independent variable `t`
92    ///
93    /// Derivative symbols `d(y_i)/dt` are created automatically with the
94    /// naming convention `d{name}/dt`.
95    pub fn new(
96        state_vars: Vec<ExprId>,
97        rhs: Vec<ExprId>,
98        time_var: ExprId,
99        pool: &ExprPool,
100    ) -> Result<Self, OdeError> {
101        if state_vars.len() != rhs.len() {
102            return Err(OdeError::VariableCountMismatch);
103        }
104        let derivatives: Vec<ExprId> = state_vars
105            .iter()
106            .map(|&v| {
107                let name = pool.with(v, |d| match d {
108                    ExprData::Symbol { name, .. } => format!("d{name}/dt"),
109                    _ => "d?/dt".to_string(),
110                });
111                pool.symbol(&name, Domain::Real)
112            })
113            .collect();
114        Ok(ODE {
115            state_vars,
116            derivatives,
117            rhs,
118            time_var,
119            initial_conditions: vec![],
120        })
121    }
122
123    /// Add an initial condition `var = value`.
124    pub fn with_ic(mut self, var: ExprId, value: ExprId) -> Self {
125        self.initial_conditions.push((var, value));
126        self
127    }
128
129    /// Number of state variables.
130    pub fn order(&self) -> usize {
131        self.state_vars.len()
132    }
133
134    /// Return `true` if `t` does not appear in any RHS expression.
135    pub fn is_autonomous(&self, pool: &ExprPool) -> bool {
136        self.rhs
137            .iter()
138            .all(|&rhs| !contains(rhs, self.time_var, pool))
139    }
140
141    /// Simplify all RHS expressions in place.
142    pub fn simplify_rhs(&self, pool: &ExprPool) -> ODE {
143        let rhs: Vec<ExprId> = self.rhs.iter().map(|&r| simplify(r, pool).value).collect();
144        ODE {
145            state_vars: self.state_vars.clone(),
146            derivatives: self.derivatives.clone(),
147            rhs,
148            time_var: self.time_var,
149            initial_conditions: self.initial_conditions.clone(),
150        }
151    }
152
153    /// Display the system as a sequence of equations.
154    pub fn display(&self, pool: &ExprPool) -> String {
155        let mut lines: Vec<String> = self
156            .derivatives
157            .iter()
158            .zip(self.rhs.iter())
159            .map(|(&d, &r)| format!("  {} = {}", pool.display(d), pool.display(r)))
160            .collect();
161        for (v, val) in &self.initial_conditions {
162            lines.push(format!(
163                "  {}(0) = {}",
164                pool.display(*v),
165                pool.display(*val)
166            ));
167        }
168        lines.join("\n")
169    }
170}
171
172// ---------------------------------------------------------------------------
173// Higher-order ODE lowering
174// ---------------------------------------------------------------------------
175
176/// A higher-order scalar ODE `x^(n) = f(t, x, x', …, x^(n-1))`.
177pub struct ScalarODE {
178    /// The original variable `x`
179    pub var: ExprId,
180    /// `[x, x', x'', …, x^(n-1)]` — state symbols (created by `lower`)
181    pub aux_vars: Vec<ExprId>,
182    /// The highest-order RHS: `f(t, x, x', …, x^(n-1))`
183    pub rhs: ExprId,
184    /// Independent variable
185    pub time_var: ExprId,
186    /// Order of the ODE
187    pub order: usize,
188}
189
190/// Lower a higher-order scalar ODE to a first-order system by introducing
191/// auxiliary variables for each derivative.
192///
193/// For an `n`-th order ODE `x^(n) = f(t, x, x', …, x^(n-1))` the result is:
194///
195/// ```text
196/// dy_0/dt = y_1
197/// dy_1/dt = y_2
198/// …
199/// dy_{n-2}/dt = y_{n-1}
200/// dy_{n-1}/dt = f(t, y_0, y_1, …, y_{n-1})
201/// ```
202pub fn lower_to_first_order(scalar_ode: &ScalarODE, pool: &ExprPool) -> Result<ODE, OdeError> {
203    let n = scalar_ode.order;
204    if n == 0 {
205        return Err(OdeError::NotFirstOrder);
206    }
207    if n == 1 {
208        // Already first-order
209        return ODE::new(
210            vec![scalar_ode.var],
211            vec![scalar_ode.rhs],
212            scalar_ode.time_var,
213            pool,
214        );
215    }
216
217    // Create auxiliary variables y_0 = x, y_1 = x', …, y_{n-1} = x^{(n-1)}
218    let var_name = pool.with(scalar_ode.var, |d| match d {
219        ExprData::Symbol { name, .. } => name.clone(),
220        _ => "x".to_string(),
221    });
222    let aux: Vec<ExprId> = (0..n)
223        .map(|i| {
224            let suffix = if i == 0 {
225                var_name.clone()
226            } else {
227                format!("{var_name}_{i}")
228            };
229            pool.symbol(&suffix, Domain::Real)
230        })
231        .collect();
232
233    // Build RHS: dy_i/dt = y_{i+1} for i < n-1, and dy_{n-1}/dt = rhs
234    let mut rhs_vec: Vec<ExprId> = (0..n - 1).map(|i| aux[i + 1]).collect();
235    rhs_vec.push(scalar_ode.rhs);
236
237    ODE::new(aux, rhs_vec, scalar_ode.time_var, pool)
238}
239
240// ---------------------------------------------------------------------------
241// Helper: does `expr` contain `needle` as a sub-expression?
242// ---------------------------------------------------------------------------
243
244fn contains(expr: ExprId, needle: ExprId, pool: &ExprPool) -> bool {
245    if expr == needle {
246        return true;
247    }
248    let children = pool.with(expr, |data| match data {
249        ExprData::Add(args) | ExprData::Mul(args) | ExprData::Func { args, .. } => args.clone(),
250        ExprData::Pow { base, exp } => vec![*base, *exp],
251        _ => vec![],
252    });
253    children.into_iter().any(|c| contains(c, needle, pool))
254}
255
256// ---------------------------------------------------------------------------
257// Tests
258// ---------------------------------------------------------------------------
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use crate::kernel::ExprPool;
264
265    fn p() -> ExprPool {
266        ExprPool::new()
267    }
268
269    #[test]
270    fn ode_new_simple() {
271        let pool = p();
272        let x = pool.symbol("x", Domain::Real);
273        let t = pool.symbol("t", Domain::Real);
274        // dx/dt = x
275        let ode = ODE::new(vec![x], vec![x], t, &pool).unwrap();
276        assert_eq!(ode.order(), 1);
277        assert!(ode.is_autonomous(&pool));
278    }
279
280    #[test]
281    fn ode_is_not_autonomous_with_t() {
282        let pool = p();
283        let x = pool.symbol("x", Domain::Real);
284        let t = pool.symbol("t", Domain::Real);
285        // dx/dt = t*x (not autonomous)
286        let rhs = pool.mul(vec![t, x]);
287        let ode = ODE::new(vec![x], vec![rhs], t, &pool).unwrap();
288        assert!(!ode.is_autonomous(&pool));
289    }
290
291    #[test]
292    fn ode_mismatch_error() {
293        let pool = p();
294        let x = pool.symbol("x", Domain::Real);
295        let y = pool.symbol("y", Domain::Real);
296        let t = pool.symbol("t", Domain::Real);
297        // 2 vars, 1 rhs — error
298        let result = ODE::new(vec![x, y], vec![x], t, &pool);
299        assert!(result.is_err());
300    }
301
302    #[test]
303    fn lower_second_order() {
304        // x'' = -x  (harmonic oscillator)
305        let pool = p();
306        let x = pool.symbol("x", Domain::Real);
307        let t = pool.symbol("t", Domain::Real);
308        let rhs = pool.mul(vec![pool.integer(-1_i32), x]);
309        let scalar = ScalarODE {
310            var: x,
311            aux_vars: vec![],
312            rhs,
313            time_var: t,
314            order: 2,
315        };
316        let sys = lower_to_first_order(&scalar, &pool).unwrap();
317        assert_eq!(sys.order(), 2);
318        // First RHS should be the auxiliary variable x_1
319        let first_rhs_name = pool.with(sys.rhs[0], |d| match d {
320            ExprData::Symbol { name, .. } => name.clone(),
321            _ => "?".to_string(),
322        });
323        assert_eq!(first_rhs_name, "x_1");
324    }
325
326    #[test]
327    fn ode_display() {
328        let pool = p();
329        let x = pool.symbol("x", Domain::Real);
330        let t = pool.symbol("t", Domain::Real);
331        let ode = ODE::new(vec![x], vec![x], t, &pool).unwrap();
332        let s = ode.display(&pool);
333        assert!(s.contains("dx/dt") || s.contains("d"), "got: {s}");
334    }
335
336    #[test]
337    fn ode_with_ic() {
338        let pool = p();
339        let x = pool.symbol("x", Domain::Real);
340        let t = pool.symbol("t", Domain::Real);
341        let zero = pool.integer(0_i32);
342        let one = pool.integer(1_i32);
343        let ode = ODE::new(vec![x], vec![x], t, &pool)
344            .unwrap()
345            .with_ic(x, one);
346        assert_eq!(ode.initial_conditions.len(), 1);
347        assert_eq!(ode.initial_conditions[0], (x, one));
348        let _ = zero; // suppress warning
349    }
350
351    #[test]
352    fn ode_simplify_rhs() {
353        let pool = p();
354        let x = pool.symbol("x", Domain::Real);
355        let t = pool.symbol("t", Domain::Real);
356        let zero = pool.integer(0_i32);
357        // rhs = x + 0  → should simplify to x
358        let rhs = pool.add(vec![x, zero]);
359        let ode = ODE::new(vec![x], vec![rhs], t, &pool).unwrap();
360        let simplified = ode.simplify_rhs(&pool);
361        assert_eq!(simplified.rhs[0], x);
362    }
363}