Skip to main content

alkahest_cas/dae/
mod.rs

1//! Phase 17 — DAE structural analysis.
2//!
3//! Represents a Differential-Algebraic Equation (DAE) system and implements
4//! the Pantelides algorithm for structural index reduction.
5//!
6//! A DAE is a system `g_i(t, y, y') = 0` where some equations may be purely
7//! algebraic (not involving any derivative).  The *structural index* measures
8//! how many times the system must be differentiated to convert it to an ODE.
9//!
10//! # Pantelides Algorithm
11//!
12//! The algorithm finds which equations need to be differentiated and creates
13//! new equations by differentiating them, until a perfect matching between
14//! equations and variables exists.
15//!
16//! References:
17//! - Pantelides (1988) "The consistent initialization of differential-algebraic systems"
18//! - Mattsson & Söderlind (1993) "Index reduction in differential-algebraic equations"
19
20use crate::diff::diff;
21use crate::kernel::{Domain, ExprData, ExprId, ExprPool};
22use crate::simplify::engine::simplify;
23use std::collections::HashSet;
24use std::fmt;
25
26/// Extend `(variables, derivatives)` when an equation references a derivative
27/// symbol that has not yet been promoted to the *state* vector — same rule as
28/// the Pantelides inner loop (for higher-order derivative algebra).
29pub fn extend_derivative_state_vectors(
30    variables: &mut Vec<ExprId>,
31    derivatives: &mut Vec<ExprId>,
32    new_eq: ExprId,
33    pool: &ExprPool,
34) {
35    for (j, _) in variables.clone().iter().enumerate() {
36        let deriv = derivatives[j];
37        if structurally_depends(new_eq, deriv, pool) && !variables.contains(&deriv) {
38            let d2_name = pool.with(deriv, |d| match d {
39                ExprData::Symbol { name, .. } => format!("d{name}/dt"),
40                _ => "d?/dt".to_string(),
41            });
42            let d2 = pool.symbol(&d2_name, Domain::Real);
43            variables.push(deriv);
44            derivatives.push(d2);
45        }
46    }
47}
48
49/// [`extend_derivative_state_vectors`] on [`DAE::variables`] / [`DAE::derivatives`].
50pub fn extend_dae_for_derivative_symbols(dae: &mut DAE, new_eq: ExprId, pool: &ExprPool) {
51    extend_derivative_state_vectors(&mut dae.variables, &mut dae.derivatives, new_eq, pool);
52}
53
54// ---------------------------------------------------------------------------
55// DAE type
56// ---------------------------------------------------------------------------
57
58/// A DAE system `g_i(t, y, y') = 0`.
59///
60/// Equations are in implicit form: `g_i = 0`.
61/// Variables are split into:
62/// - `alg_vars`: purely algebraic variables (not differentiated anywhere)
63/// - `diff_vars`: differential variables with corresponding `derivatives`
64#[derive(Clone, Debug)]
65pub struct DAE {
66    /// Implicit equations `g_i(t, y, y') = 0`
67    pub equations: Vec<ExprId>,
68    /// Algebraic + differential variables
69    pub variables: Vec<ExprId>,
70    /// Derivative symbols `dy_i/dt` (for `diff_vars[i]`)
71    pub derivatives: Vec<ExprId>,
72    /// The independent variable
73    pub time_var: ExprId,
74    /// Differentiation index (None = not yet computed)
75    pub index: Option<usize>,
76}
77
78#[derive(Debug, Clone)]
79pub enum DaeError {
80    DiffError(String),
81    IndexTooHigh,
82    StructurallyInconsistent,
83}
84
85impl fmt::Display for DaeError {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        match self {
88            DaeError::DiffError(s) => write!(f, "differentiation error: {s}"),
89            DaeError::IndexTooHigh => write!(f, "DAE structural index too high (> 10)"),
90            DaeError::StructurallyInconsistent => write!(f, "DAE is structurally inconsistent"),
91        }
92    }
93}
94
95impl std::error::Error for DaeError {}
96
97impl crate::errors::AlkahestError for DaeError {
98    fn code(&self) -> &'static str {
99        match self {
100            DaeError::DiffError(_) => "E-DAE-001",
101            DaeError::IndexTooHigh => "E-DAE-002",
102            DaeError::StructurallyInconsistent => "E-DAE-003",
103        }
104    }
105
106    fn remediation(&self) -> Option<&'static str> {
107        match self {
108            DaeError::DiffError(_) => Some(
109                "ensure all functions in the DAE are differentiable before calling pantelides()",
110            ),
111            DaeError::IndexTooHigh => {
112                Some("DAE index exceeds 10; reformulate the model or use an index-reduction tool")
113            }
114            DaeError::StructurallyInconsistent => Some(
115                "the DAE is structurally inconsistent; check constraint count vs. variable count",
116            ),
117        }
118    }
119}
120
121impl DAE {
122    /// Create a new DAE.
123    ///
124    /// `equations` — implicit equations `g_i = 0`
125    /// `variables` — all variables (algebraic + differential)
126    /// `derivatives` — derivative symbols for each variable
127    pub fn new(
128        equations: Vec<ExprId>,
129        variables: Vec<ExprId>,
130        derivatives: Vec<ExprId>,
131        time_var: ExprId,
132    ) -> Self {
133        DAE {
134            equations,
135            variables,
136            derivatives,
137            time_var,
138            index: None,
139        }
140    }
141
142    /// Number of equations.
143    pub fn n_equations(&self) -> usize {
144        self.equations.len()
145    }
146
147    /// Number of variables.
148    pub fn n_variables(&self) -> usize {
149        self.variables.len()
150    }
151
152    /// Build the structural incidence matrix.
153    ///
154    /// `incidence[i][j]` is `true` if equation `i` structurally depends on
155    /// variable `j` or its derivative.
156    pub fn incidence_matrix(&self, pool: &ExprPool) -> Vec<Vec<bool>> {
157        let m = self.equations.len();
158        let n = self.variables.len();
159        let mut inc = vec![vec![false; n]; m];
160        for (i, &eq) in self.equations.iter().enumerate() {
161            for (j, &var) in self.variables.iter().enumerate() {
162                let deriv = self.derivatives[j];
163                if structurally_depends(eq, var, pool) || structurally_depends(eq, deriv, pool) {
164                    inc[i][j] = true;
165                }
166            }
167        }
168        inc
169    }
170
171    /// Display the DAE.
172    pub fn display(&self, pool: &ExprPool) -> String {
173        self.equations
174            .iter()
175            .map(|&eq| format!("  {} = 0", pool.display(eq)))
176            .collect::<Vec<_>>()
177            .join("\n")
178    }
179}
180
181// ---------------------------------------------------------------------------
182// Pantelides algorithm for structural index reduction
183// ---------------------------------------------------------------------------
184
185/// Result of applying the Pantelides algorithm.
186#[derive(Clone, Debug)]
187pub struct PantelidesResult {
188    /// The index-reduced DAE (index ≤ 1)
189    pub reduced_dae: DAE,
190    /// Number of differentiation steps applied
191    pub differentiation_steps: usize,
192    /// Which original equations were differentiated (and how many times)
193    pub sigma: Vec<usize>, // sigma[i] = number of times equation i was differentiated
194}
195
196/// Apply the Pantelides algorithm to reduce a DAE to differentiation index ≤ 1.
197///
198/// Returns the reduced DAE together with bookkeeping information.
199pub fn pantelides(dae: &DAE, pool: &ExprPool) -> Result<PantelidesResult, DaeError> {
200    let max_iter = 10;
201
202    let mut equations = dae.equations.clone();
203    let mut variables = dae.variables.clone();
204    let mut derivatives = dae.derivatives.clone();
205    let mut sigma = vec![0usize; equations.len()];
206    let mut total_steps = 0;
207
208    for iteration in 0..max_iter {
209        // Build incidence structure
210        let n_eq = equations.len();
211        let n_var = variables.len();
212        let inc = incidence(&equations, &variables, &derivatives, pool);
213
214        // Find maximum matching using Hopcroft-Karp
215        let matching = maximum_matching(&inc, n_eq, n_var);
216
217        // Check if perfect matching exists
218        let unmatched_eqs: Vec<usize> = (0..n_eq)
219            .filter(|&i| matching.eq_to_var[i].is_none())
220            .collect();
221
222        if unmatched_eqs.is_empty() {
223            // Perfect matching found → index ≤ 1
224            let mut reduced = DAE::new(equations, variables, derivatives, dae.time_var);
225            reduced.index = Some(iteration);
226            return Ok(PantelidesResult {
227                reduced_dae: reduced,
228                differentiation_steps: total_steps,
229                sigma,
230            });
231        }
232
233        // Differentiate unmatched equations
234        for &eq_idx in &unmatched_eqs {
235            let new_eq = differentiate_equation(
236                equations[eq_idx],
237                &variables,
238                &derivatives,
239                dae.time_var,
240                pool,
241            )
242            .map_err(|e| DaeError::DiffError(e.to_string()))?;
243            equations.push(new_eq);
244            sigma.push(sigma[eq_idx] + 1);
245            total_steps += 1;
246
247            extend_derivative_state_vectors(&mut variables, &mut derivatives, new_eq, pool);
248        }
249    }
250
251    Err(DaeError::IndexTooHigh)
252}
253
254// ---------------------------------------------------------------------------
255// Helpers
256// ---------------------------------------------------------------------------
257
258struct Matching {
259    eq_to_var: Vec<Option<usize>>,
260    #[allow(dead_code)]
261    var_to_eq: Vec<Option<usize>>,
262}
263
264/// Build an incidence list: `result[i]` = set of variable indices that equation `i` depends on.
265fn incidence(
266    equations: &[ExprId],
267    variables: &[ExprId],
268    derivatives: &[ExprId],
269    pool: &ExprPool,
270) -> Vec<Vec<usize>> {
271    equations
272        .iter()
273        .map(|&eq| {
274            variables
275                .iter()
276                .zip(derivatives.iter())
277                .enumerate()
278                .filter(|(_, (&var, &deriv))| {
279                    structurally_depends(eq, var, pool) || structurally_depends(eq, deriv, pool)
280                })
281                .map(|(j, _)| j)
282                .collect()
283        })
284        .collect()
285}
286
287/// Augmenting path search for maximum bipartite matching (DFS).
288fn augment(
289    eq: usize,
290    adj: &[Vec<usize>],
291    var_to_eq: &mut Vec<Option<usize>>,
292    visited: &mut HashSet<usize>,
293) -> bool {
294    for &var in &adj[eq] {
295        if visited.contains(&var) {
296            continue;
297        }
298        visited.insert(var);
299        if var_to_eq[var].is_none() || augment(var_to_eq[var].unwrap(), adj, var_to_eq, visited) {
300            var_to_eq[var] = Some(eq);
301            return true;
302        }
303    }
304    false
305}
306
307fn maximum_matching(adj: &[Vec<usize>], n_eq: usize, n_var: usize) -> Matching {
308    let mut var_to_eq: Vec<Option<usize>> = vec![None; n_var];
309    for eq in 0..n_eq {
310        let mut visited = HashSet::new();
311        augment(eq, adj, &mut var_to_eq, &mut visited);
312    }
313    let mut eq_to_var = vec![None; n_eq];
314    for (var, &opt_eq) in var_to_eq.iter().enumerate() {
315        if let Some(eq) = opt_eq {
316            eq_to_var[eq] = Some(var);
317        }
318    }
319    Matching {
320        eq_to_var,
321        var_to_eq,
322    }
323}
324
325/// Differentiate `equation` with respect to time, using `d(var)/dt = deriv`.
326pub(crate) fn differentiate_equation(
327    equation: ExprId,
328    variables: &[ExprId],
329    derivatives: &[ExprId],
330    time_var: ExprId,
331    pool: &ExprPool,
332) -> Result<ExprId, crate::diff::diff_impl::DiffError> {
333    // d(g)/dt = Σ_i (∂g/∂y_i) * (dy_i/dt)  +  ∂g/∂t
334    // Use chain rule symbolically
335    let mut terms: Vec<ExprId> = Vec::new();
336
337    // ∂g/∂t
338    let dg_dt = diff(equation, time_var, pool)?.value;
339    if dg_dt != pool.integer(0_i32) {
340        terms.push(dg_dt);
341    }
342
343    // For each variable y_i: (∂g/∂y_i) * (dy_i/dt)
344    for (&var, &deriv) in variables.iter().zip(derivatives.iter()) {
345        let dg_dyi = diff(equation, var, pool)?.value;
346        if dg_dyi != pool.integer(0_i32) {
347            let term = pool.mul(vec![dg_dyi, deriv]);
348            terms.push(term);
349        }
350        // Also differentiate w.r.t. the derivative (for higher-index terms)
351        let dg_ddyi = diff(equation, deriv, pool)?.value;
352        if dg_ddyi != pool.integer(0_i32) {
353            // d(dy_i/dt)/dt is a new symbol — use the naming convention
354            let d2_name = pool.with(deriv, |d| match d {
355                ExprData::Symbol { name, .. } => format!("d{name}/dt"),
356                _ => "d?/dt".to_string(),
357            });
358            let d2 = pool.symbol(&d2_name, Domain::Real);
359            let term = pool.mul(vec![dg_ddyi, d2]);
360            terms.push(term);
361        }
362    }
363
364    let result = match terms.len() {
365        0 => pool.integer(0_i32),
366        1 => terms[0],
367        _ => pool.add(terms),
368    };
369    Ok(simplify(result, pool).value)
370}
371
372/// True if `expr` structurally contains `var` as a sub-expression.
373pub fn structurally_depends(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
374    if expr == var {
375        return true;
376    }
377    let children = pool.with(expr, |data| match data {
378        ExprData::Add(args) | ExprData::Mul(args) | ExprData::Func { args, .. } => args.clone(),
379        ExprData::Pow { base, exp } => vec![*base, *exp],
380        ExprData::BigO(inner) => vec![*inner],
381        _ => vec![],
382    });
383    children
384        .into_iter()
385        .any(|c| structurally_depends(c, var, pool))
386}
387
388// ---------------------------------------------------------------------------
389// Tests
390// ---------------------------------------------------------------------------
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use crate::kernel::{Domain, ExprPool};
396
397    fn p() -> ExprPool {
398        ExprPool::new()
399    }
400
401    #[test]
402    fn ode_is_index_0() {
403        // An explicit ODE y' - f(y) = 0  has differentiation index 0 (or 1 in some conventions)
404        let pool = p();
405        let y = pool.symbol("y", Domain::Real);
406        let dy = pool.symbol("dy/dt", Domain::Real);
407        let t = pool.symbol("t", Domain::Real);
408        let neg_y = pool.mul(vec![pool.integer(-1_i32), y]);
409        // Equation: dy/dt - y = 0  →  dy - y
410        let eq = pool.add(vec![dy, neg_y]);
411        let dae = DAE::new(vec![eq], vec![y], vec![dy], t);
412        let result = pantelides(&dae, &pool).unwrap();
413        assert_eq!(result.differentiation_steps, 0);
414    }
415
416    #[test]
417    fn incidence_matrix_correct() {
418        let pool = p();
419        let x = pool.symbol("x", Domain::Real);
420        let y = pool.symbol("y", Domain::Real);
421        let dx = pool.symbol("dx/dt", Domain::Real);
422        let dy = pool.symbol("dy/dt", Domain::Real);
423        let t = pool.symbol("t", Domain::Real);
424        // g1 = x + y,  g2 = dx + y
425        let g1 = pool.add(vec![x, y]);
426        let g2 = pool.add(vec![dx, y]);
427        let dae = DAE::new(vec![g1, g2], vec![x, y], vec![dx, dy], t);
428        let inc = dae.incidence_matrix(&pool);
429        // g1 depends on x (j=0) and y (j=1)
430        assert!(inc[0][0]);
431        assert!(inc[0][1]);
432        // g2 depends on dx (structurally related to j=0) and y (j=1)
433        assert!(inc[1][0]); // dx is deriv of x
434        assert!(inc[1][1]); // y
435    }
436
437    #[test]
438    fn structurally_depends_nested() {
439        let pool = p();
440        let x = pool.symbol("x", Domain::Real);
441        let y = pool.symbol("y", Domain::Real);
442        let sin_x = pool.func("sin", vec![x]);
443        let expr = pool.add(vec![sin_x, y]);
444        assert!(structurally_depends(expr, x, &pool));
445        assert!(structurally_depends(expr, y, &pool));
446    }
447
448    #[test]
449    fn differentiate_equation_linear() {
450        // g(x, y) = x + y,  variables = [x, y], derivatives = [dx, dy]
451        // dg/dt = dx + dy
452        let pool = p();
453        let x = pool.symbol("x", Domain::Real);
454        let y = pool.symbol("y", Domain::Real);
455        let dx = pool.symbol("dx/dt", Domain::Real);
456        let dy = pool.symbol("dy/dt", Domain::Real);
457        let t = pool.symbol("t", Domain::Real);
458        let eq = pool.add(vec![x, y]);
459        let result = differentiate_equation(eq, &[x, y], &[dx, dy], t, &pool).unwrap();
460        // Should give dx + dy (in some order)
461        let s = pool.display(result).to_string();
462        assert!(s.contains("dx") || s.contains("dy"), "got: {s}");
463    }
464}