1use crate::diff::diff;
21use crate::kernel::{Domain, ExprData, ExprId, ExprPool};
22use crate::simplify::engine::simplify;
23use std::collections::HashSet;
24use std::fmt;
25
26pub fn extend_derivative_state_vectors(
30 variables: &mut Vec<ExprId>,
31 derivatives: &mut Vec<ExprId>,
32 new_eq: ExprId,
33 pool: &ExprPool,
34) {
35 for (j, _) in variables.clone().iter().enumerate() {
36 let deriv = derivatives[j];
37 if structurally_depends(new_eq, deriv, pool) && !variables.contains(&deriv) {
38 let d2_name = pool.with(deriv, |d| match d {
39 ExprData::Symbol { name, .. } => format!("d{name}/dt"),
40 _ => "d?/dt".to_string(),
41 });
42 let d2 = pool.symbol(&d2_name, Domain::Real);
43 variables.push(deriv);
44 derivatives.push(d2);
45 }
46 }
47}
48
49pub fn extend_dae_for_derivative_symbols(dae: &mut DAE, new_eq: ExprId, pool: &ExprPool) {
51 extend_derivative_state_vectors(&mut dae.variables, &mut dae.derivatives, new_eq, pool);
52}
53
54#[derive(Clone, Debug)]
65pub struct DAE {
66 pub equations: Vec<ExprId>,
68 pub variables: Vec<ExprId>,
70 pub derivatives: Vec<ExprId>,
72 pub time_var: ExprId,
74 pub index: Option<usize>,
76}
77
78#[derive(Debug, Clone)]
79pub enum DaeError {
80 DiffError(String),
81 IndexTooHigh,
82 StructurallyInconsistent,
83}
84
85impl fmt::Display for DaeError {
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 match self {
88 DaeError::DiffError(s) => write!(f, "differentiation error: {s}"),
89 DaeError::IndexTooHigh => write!(f, "DAE structural index too high (> 10)"),
90 DaeError::StructurallyInconsistent => write!(f, "DAE is structurally inconsistent"),
91 }
92 }
93}
94
95impl std::error::Error for DaeError {}
96
97impl crate::errors::AlkahestError for DaeError {
98 fn code(&self) -> &'static str {
99 match self {
100 DaeError::DiffError(_) => "E-DAE-001",
101 DaeError::IndexTooHigh => "E-DAE-002",
102 DaeError::StructurallyInconsistent => "E-DAE-003",
103 }
104 }
105
106 fn remediation(&self) -> Option<&'static str> {
107 match self {
108 DaeError::DiffError(_) => Some(
109 "ensure all functions in the DAE are differentiable before calling pantelides()",
110 ),
111 DaeError::IndexTooHigh => {
112 Some("DAE index exceeds 10; reformulate the model or use an index-reduction tool")
113 }
114 DaeError::StructurallyInconsistent => Some(
115 "the DAE is structurally inconsistent; check constraint count vs. variable count",
116 ),
117 }
118 }
119}
120
121impl DAE {
122 pub fn new(
128 equations: Vec<ExprId>,
129 variables: Vec<ExprId>,
130 derivatives: Vec<ExprId>,
131 time_var: ExprId,
132 ) -> Self {
133 DAE {
134 equations,
135 variables,
136 derivatives,
137 time_var,
138 index: None,
139 }
140 }
141
142 pub fn n_equations(&self) -> usize {
144 self.equations.len()
145 }
146
147 pub fn n_variables(&self) -> usize {
149 self.variables.len()
150 }
151
152 pub fn incidence_matrix(&self, pool: &ExprPool) -> Vec<Vec<bool>> {
157 let m = self.equations.len();
158 let n = self.variables.len();
159 let mut inc = vec![vec![false; n]; m];
160 for (i, &eq) in self.equations.iter().enumerate() {
161 for (j, &var) in self.variables.iter().enumerate() {
162 let deriv = self.derivatives[j];
163 if structurally_depends(eq, var, pool) || structurally_depends(eq, deriv, pool) {
164 inc[i][j] = true;
165 }
166 }
167 }
168 inc
169 }
170
171 pub fn display(&self, pool: &ExprPool) -> String {
173 self.equations
174 .iter()
175 .map(|&eq| format!(" {} = 0", pool.display(eq)))
176 .collect::<Vec<_>>()
177 .join("\n")
178 }
179}
180
181#[derive(Clone, Debug)]
187pub struct PantelidesResult {
188 pub reduced_dae: DAE,
190 pub differentiation_steps: usize,
192 pub sigma: Vec<usize>, }
195
196pub fn pantelides(dae: &DAE, pool: &ExprPool) -> Result<PantelidesResult, DaeError> {
200 let max_iter = 10;
201
202 let mut equations = dae.equations.clone();
203 let mut variables = dae.variables.clone();
204 let mut derivatives = dae.derivatives.clone();
205 let mut sigma = vec![0usize; equations.len()];
206 let mut total_steps = 0;
207
208 for iteration in 0..max_iter {
209 let n_eq = equations.len();
211 let n_var = variables.len();
212 let inc = incidence(&equations, &variables, &derivatives, pool);
213
214 let matching = maximum_matching(&inc, n_eq, n_var);
216
217 let unmatched_eqs: Vec<usize> = (0..n_eq)
219 .filter(|&i| matching.eq_to_var[i].is_none())
220 .collect();
221
222 if unmatched_eqs.is_empty() {
223 let mut reduced = DAE::new(equations, variables, derivatives, dae.time_var);
225 reduced.index = Some(iteration);
226 return Ok(PantelidesResult {
227 reduced_dae: reduced,
228 differentiation_steps: total_steps,
229 sigma,
230 });
231 }
232
233 for &eq_idx in &unmatched_eqs {
235 let new_eq = differentiate_equation(
236 equations[eq_idx],
237 &variables,
238 &derivatives,
239 dae.time_var,
240 pool,
241 )
242 .map_err(|e| DaeError::DiffError(e.to_string()))?;
243 equations.push(new_eq);
244 sigma.push(sigma[eq_idx] + 1);
245 total_steps += 1;
246
247 extend_derivative_state_vectors(&mut variables, &mut derivatives, new_eq, pool);
248 }
249 }
250
251 Err(DaeError::IndexTooHigh)
252}
253
254struct Matching {
259 eq_to_var: Vec<Option<usize>>,
260 #[allow(dead_code)]
261 var_to_eq: Vec<Option<usize>>,
262}
263
264fn incidence(
266 equations: &[ExprId],
267 variables: &[ExprId],
268 derivatives: &[ExprId],
269 pool: &ExprPool,
270) -> Vec<Vec<usize>> {
271 equations
272 .iter()
273 .map(|&eq| {
274 variables
275 .iter()
276 .zip(derivatives.iter())
277 .enumerate()
278 .filter(|(_, (&var, &deriv))| {
279 structurally_depends(eq, var, pool) || structurally_depends(eq, deriv, pool)
280 })
281 .map(|(j, _)| j)
282 .collect()
283 })
284 .collect()
285}
286
287fn augment(
289 eq: usize,
290 adj: &[Vec<usize>],
291 var_to_eq: &mut Vec<Option<usize>>,
292 visited: &mut HashSet<usize>,
293) -> bool {
294 for &var in &adj[eq] {
295 if visited.contains(&var) {
296 continue;
297 }
298 visited.insert(var);
299 if var_to_eq[var].is_none() || augment(var_to_eq[var].unwrap(), adj, var_to_eq, visited) {
300 var_to_eq[var] = Some(eq);
301 return true;
302 }
303 }
304 false
305}
306
307fn maximum_matching(adj: &[Vec<usize>], n_eq: usize, n_var: usize) -> Matching {
308 let mut var_to_eq: Vec<Option<usize>> = vec![None; n_var];
309 for eq in 0..n_eq {
310 let mut visited = HashSet::new();
311 augment(eq, adj, &mut var_to_eq, &mut visited);
312 }
313 let mut eq_to_var = vec![None; n_eq];
314 for (var, &opt_eq) in var_to_eq.iter().enumerate() {
315 if let Some(eq) = opt_eq {
316 eq_to_var[eq] = Some(var);
317 }
318 }
319 Matching {
320 eq_to_var,
321 var_to_eq,
322 }
323}
324
325pub(crate) fn differentiate_equation(
327 equation: ExprId,
328 variables: &[ExprId],
329 derivatives: &[ExprId],
330 time_var: ExprId,
331 pool: &ExprPool,
332) -> Result<ExprId, crate::diff::diff_impl::DiffError> {
333 let mut terms: Vec<ExprId> = Vec::new();
336
337 let dg_dt = diff(equation, time_var, pool)?.value;
339 if dg_dt != pool.integer(0_i32) {
340 terms.push(dg_dt);
341 }
342
343 for (&var, &deriv) in variables.iter().zip(derivatives.iter()) {
345 let dg_dyi = diff(equation, var, pool)?.value;
346 if dg_dyi != pool.integer(0_i32) {
347 let term = pool.mul(vec![dg_dyi, deriv]);
348 terms.push(term);
349 }
350 let dg_ddyi = diff(equation, deriv, pool)?.value;
352 if dg_ddyi != pool.integer(0_i32) {
353 let d2_name = pool.with(deriv, |d| match d {
355 ExprData::Symbol { name, .. } => format!("d{name}/dt"),
356 _ => "d?/dt".to_string(),
357 });
358 let d2 = pool.symbol(&d2_name, Domain::Real);
359 let term = pool.mul(vec![dg_ddyi, d2]);
360 terms.push(term);
361 }
362 }
363
364 let result = match terms.len() {
365 0 => pool.integer(0_i32),
366 1 => terms[0],
367 _ => pool.add(terms),
368 };
369 Ok(simplify(result, pool).value)
370}
371
372pub fn structurally_depends(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
374 if expr == var {
375 return true;
376 }
377 let children = pool.with(expr, |data| match data {
378 ExprData::Add(args) | ExprData::Mul(args) | ExprData::Func { args, .. } => args.clone(),
379 ExprData::Pow { base, exp } => vec![*base, *exp],
380 ExprData::BigO(inner) => vec![*inner],
381 _ => vec![],
382 });
383 children
384 .into_iter()
385 .any(|c| structurally_depends(c, var, pool))
386}
387
388#[cfg(test)]
393mod tests {
394 use super::*;
395 use crate::kernel::{Domain, ExprPool};
396
397 fn p() -> ExprPool {
398 ExprPool::new()
399 }
400
401 #[test]
402 fn ode_is_index_0() {
403 let pool = p();
405 let y = pool.symbol("y", Domain::Real);
406 let dy = pool.symbol("dy/dt", Domain::Real);
407 let t = pool.symbol("t", Domain::Real);
408 let neg_y = pool.mul(vec![pool.integer(-1_i32), y]);
409 let eq = pool.add(vec![dy, neg_y]);
411 let dae = DAE::new(vec![eq], vec![y], vec![dy], t);
412 let result = pantelides(&dae, &pool).unwrap();
413 assert_eq!(result.differentiation_steps, 0);
414 }
415
416 #[test]
417 fn incidence_matrix_correct() {
418 let pool = p();
419 let x = pool.symbol("x", Domain::Real);
420 let y = pool.symbol("y", Domain::Real);
421 let dx = pool.symbol("dx/dt", Domain::Real);
422 let dy = pool.symbol("dy/dt", Domain::Real);
423 let t = pool.symbol("t", Domain::Real);
424 let g1 = pool.add(vec![x, y]);
426 let g2 = pool.add(vec![dx, y]);
427 let dae = DAE::new(vec![g1, g2], vec![x, y], vec![dx, dy], t);
428 let inc = dae.incidence_matrix(&pool);
429 assert!(inc[0][0]);
431 assert!(inc[0][1]);
432 assert!(inc[1][0]); assert!(inc[1][1]); }
436
437 #[test]
438 fn structurally_depends_nested() {
439 let pool = p();
440 let x = pool.symbol("x", Domain::Real);
441 let y = pool.symbol("y", Domain::Real);
442 let sin_x = pool.func("sin", vec![x]);
443 let expr = pool.add(vec![sin_x, y]);
444 assert!(structurally_depends(expr, x, &pool));
445 assert!(structurally_depends(expr, y, &pool));
446 }
447
448 #[test]
449 fn differentiate_equation_linear() {
450 let pool = p();
453 let x = pool.symbol("x", Domain::Real);
454 let y = pool.symbol("y", Domain::Real);
455 let dx = pool.symbol("dx/dt", Domain::Real);
456 let dy = pool.symbol("dy/dt", Domain::Real);
457 let t = pool.symbol("t", Domain::Real);
458 let eq = pool.add(vec![x, y]);
459 let result = differentiate_equation(eq, &[x, y], &[dx, dy], t, &pool).unwrap();
460 let s = pool.display(result).to_string();
462 assert!(s.contains("dx") || s.contains("dy"), "got: {s}");
463 }
464}