use super::differentiate::differentiate_equation;
use super::{DummyDerivative, StructuralAnalysis};
use crate::ir::ast::{ComponentReference, Equation, Expression};
use crate::ir::visitor::{Visitable, Visitor};
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone)]
struct EquationStructure {
equation: Equation,
variables: HashSet<String>,
derivatives: HashSet<String>,
is_constraint: bool,
diff_level: usize,
}
struct VariableFinder {
variables: HashSet<String>,
skip_next_cref: bool,
}
impl VariableFinder {
fn new() -> Self {
Self {
variables: HashSet::new(),
skip_next_cref: false,
}
}
}
impl Visitor for VariableFinder {
fn enter_expression(&mut self, node: &Expression) {
if matches!(node, Expression::FunctionCall { .. }) {
self.skip_next_cref = true;
}
}
fn enter_component_reference(&mut self, node: &ComponentReference) {
if self.skip_next_cref {
self.skip_next_cref = false;
} else {
self.variables.insert(node.to_string());
}
}
}
struct DerivativeCollector {
derivatives: HashSet<String>,
}
impl DerivativeCollector {
fn new() -> Self {
Self {
derivatives: HashSet::new(),
}
}
}
impl Visitor for DerivativeCollector {
fn enter_expression(&mut self, node: &Expression) {
if let Expression::FunctionCall { comp, args } = node
&& comp.to_string() == "der"
&& !args.is_empty()
&& let Expression::ComponentReference(cref) = &args[0]
{
self.derivatives.insert(cref.to_string());
}
}
}
fn analyze_equation_structure(equation: &Equation) -> EquationStructure {
let mut variables = HashSet::new();
let mut derivatives = HashSet::new();
if let Equation::Simple { lhs, rhs, .. } = equation {
let mut var_finder = VariableFinder::new();
lhs.accept(&mut var_finder);
rhs.accept(&mut var_finder);
variables = var_finder.variables;
let mut der_finder = DerivativeCollector::new();
lhs.accept(&mut der_finder);
rhs.accept(&mut der_finder);
derivatives = der_finder.derivatives;
for der_var in &derivatives {
variables.insert(format!("der({})", der_var));
}
}
let is_constraint = derivatives.is_empty();
EquationStructure {
equation: equation.clone(),
variables,
derivatives,
is_constraint,
diff_level: 0,
}
}
pub fn pantelides_index_reduction(
equations: &[Equation],
state_variables: &HashSet<String>,
algebraic_variables: Option<&HashSet<String>>,
) -> StructuralAnalysis {
let mut analysis = StructuralAnalysis::default();
let mut eq_structures: Vec<EquationStructure> =
equations.iter().map(analyze_equation_structure).collect();
let mut all_equation_vars: HashSet<String> = HashSet::new();
for eq_struct in &eq_structures {
all_equation_vars.extend(eq_struct.variables.clone());
}
let mut unknown_variables: HashSet<String> = HashSet::new();
for state in state_variables {
unknown_variables.insert(format!("der({})", state));
}
if let Some(alg_vars) = algebraic_variables {
unknown_variables.extend(alg_vars.clone());
} else {
for var in &all_equation_vars {
if !state_variables.contains(var) && !var.starts_with("der(") {
unknown_variables.insert(var.clone());
}
}
}
let mut iteration = 0;
let max_iterations = 10;
while iteration < max_iterations {
iteration += 1;
let (matching, unmatched_eqs) =
find_structural_matching(&eq_structures, &unknown_variables, state_variables);
let max_level = eq_structures
.iter()
.map(|e| e.diff_level)
.max()
.unwrap_or(0);
let highest_level_eqs: Vec<usize> = eq_structures
.iter()
.enumerate()
.filter(|(_, e)| e.diff_level == max_level)
.map(|(i, _)| i)
.collect();
let unmatched_at_highest: Vec<usize> = unmatched_eqs
.iter()
.filter(|&&idx| idx < eq_structures.len() && eq_structures[idx].diff_level == max_level)
.copied()
.collect();
if unmatched_at_highest.is_empty() && !highest_level_eqs.is_empty() {
analysis.dae_index = max_level;
break;
}
if matching.len() >= unknown_variables.len() {
analysis.dae_index = max_level;
break;
}
let mut eqs_to_diff: Vec<usize> = if !unmatched_at_highest.is_empty() {
unmatched_at_highest.clone()
} else {
find_equations_to_differentiate(&eq_structures, &unmatched_eqs)
};
if eqs_to_diff.is_empty() {
analysis.is_singular = true;
analysis.diagnostics.push(
"Structurally singular system: cannot find equations to differentiate".to_string(),
);
break;
}
let mut additional_eqs: Vec<usize> = Vec::new();
for &unmatched_idx in &eqs_to_diff {
if unmatched_idx < eq_structures.len() {
let unmatched_eq = &eq_structures[unmatched_idx];
for var in &unmatched_eq.variables {
if unknown_variables.contains(var) {
for (&eq_idx, matched_var) in &matching {
if matched_var == var
&& !eqs_to_diff.contains(&eq_idx)
&& !additional_eqs.contains(&eq_idx)
{
additional_eqs.push(eq_idx);
}
}
}
}
}
}
eqs_to_diff.extend(additional_eqs);
for eq_idx in eqs_to_diff {
if eq_idx < eq_structures.len() {
let eq_struct = &eq_structures[eq_idx];
*analysis
.equations_to_differentiate
.entry(eq_idx)
.or_insert(0) += 1;
if let Some(diff_eq) = differentiate_equation(&eq_struct.equation) {
let mut diff_struct = analyze_equation_structure(&diff_eq);
diff_struct.diff_level = eq_struct.diff_level + 1;
for var in &diff_struct.derivatives {
if state_variables.contains(var) {
let der_var = format!("der({})", var);
if unknown_variables.insert(der_var.clone()) {
analysis.dummy_derivatives.push(DummyDerivative {
name: der_var,
base_variable: var.clone(),
order: diff_struct.diff_level,
});
}
}
}
eq_structures.push(diff_struct);
}
}
}
analysis.dae_index = iteration;
}
if iteration >= max_iterations {
analysis.is_singular = true;
analysis
.diagnostics
.push("Index reduction did not converge".to_string());
}
analysis
}
fn find_structural_matching(
eq_structures: &[EquationStructure],
all_variables: &HashSet<String>,
state_variables: &HashSet<String>,
) -> (HashMap<usize, String>, Vec<usize>) {
let n_equations = eq_structures.len();
let vars: Vec<String> = all_variables.iter().cloned().collect();
let n_variables = vars.len();
let var_to_idx: HashMap<&String, usize> =
vars.iter().enumerate().map(|(i, v)| (v, i)).collect();
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n_equations];
for (eq_idx, eq_struct) in eq_structures.iter().enumerate() {
for var in &eq_struct.variables {
if state_variables.contains(var) {
continue;
}
if let Some(&var_idx) = var_to_idx.get(var) {
adj[eq_idx].push(var_idx);
}
}
for der_var in &eq_struct.derivatives {
if state_variables.contains(der_var) {
let der_var_name = format!("der({})", der_var);
if let Some(&var_idx) = var_to_idx.get(&der_var_name) {
adj[eq_idx].push(var_idx);
}
}
}
}
let mut hk = HopcroftKarp::new(n_equations, n_variables, adj);
hk.max_matching();
let mut matching = HashMap::new();
let mut unmatched = Vec::new();
for (eq_idx, var_idx) in hk.pair_eq.iter().enumerate() {
if *var_idx != NIL && *var_idx < vars.len() {
matching.insert(eq_idx, vars[*var_idx].clone());
} else {
unmatched.push(eq_idx);
}
}
(matching, unmatched)
}
fn find_equations_to_differentiate(
eq_structures: &[EquationStructure],
unmatched_eqs: &[usize],
) -> Vec<usize> {
let mut max_diff_level = 0;
let mut to_diff = Vec::new();
for &eq_idx in unmatched_eqs {
if eq_idx < eq_structures.len() {
let level = eq_structures[eq_idx].diff_level;
if level > max_diff_level {
max_diff_level = level;
to_diff.clear();
to_diff.push(eq_idx);
} else if level == max_diff_level {
to_diff.push(eq_idx);
}
}
}
if max_diff_level == 0 && !to_diff.is_empty() {
let constraints: Vec<usize> = to_diff
.iter()
.filter(|&&idx| eq_structures[idx].is_constraint)
.copied()
.collect();
if !constraints.is_empty() {
return constraints;
}
}
to_diff
}
const NIL: usize = usize::MAX;
struct HopcroftKarp {
n_equations: usize,
adj: Vec<Vec<usize>>,
pair_eq: Vec<usize>,
pair_var: Vec<usize>,
dist: Vec<usize>,
}
impl HopcroftKarp {
fn new(n_equations: usize, n_variables: usize, adj: Vec<Vec<usize>>) -> Self {
Self {
n_equations,
adj,
pair_eq: vec![NIL; n_equations],
pair_var: vec![NIL; n_variables],
dist: vec![0; n_equations + 1],
}
}
fn max_matching(&mut self) -> usize {
let mut matching = 0;
while self.bfs() {
for eq in 0..self.n_equations {
if self.pair_eq[eq] == NIL && self.dfs(eq) {
matching += 1;
}
}
}
matching
}
fn bfs(&mut self) -> bool {
let mut queue = VecDeque::new();
for eq in 0..self.n_equations {
if self.pair_eq[eq] == NIL {
self.dist[eq] = 0;
queue.push_back(eq);
} else {
self.dist[eq] = usize::MAX;
}
}
self.dist[self.n_equations] = usize::MAX;
while let Some(eq) = queue.pop_front() {
if self.dist[eq] < self.dist[self.n_equations] {
for &var in &self.adj[eq] {
let next_eq = self.pair_var[var];
let next_idx = if next_eq == NIL {
self.n_equations
} else {
next_eq
};
if self.dist[next_idx] == usize::MAX {
self.dist[next_idx] = self.dist[eq] + 1;
if next_eq != NIL {
queue.push_back(next_eq);
}
}
}
}
}
self.dist[self.n_equations] != usize::MAX
}
fn dfs(&mut self, eq: usize) -> bool {
if eq == NIL {
return true;
}
for i in 0..self.adj[eq].len() {
let var = self.adj[eq][i];
let next_eq = self.pair_var[var];
let next_idx = if next_eq == NIL {
self.n_equations
} else {
next_eq
};
if self.dist[next_idx] == self.dist[eq] + 1 && self.dfs(next_eq) {
self.pair_var[var] = eq;
self.pair_eq[eq] = var;
return true;
}
}
self.dist[eq] = usize::MAX;
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::ast::{ComponentRefPart, OpBinary, OpUnary, TerminalType, Token};
fn make_var(name: &str) -> Expression {
Expression::ComponentReference(ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: name.to_string(),
..Default::default()
},
subs: None,
}],
})
}
fn make_der(var: Expression) -> Expression {
Expression::FunctionCall {
comp: ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: "der".to_string(),
..Default::default()
},
subs: None,
}],
},
args: vec![var],
}
}
fn make_const(val: &str) -> Expression {
Expression::Terminal {
terminal_type: TerminalType::UnsignedInteger,
token: Token {
text: val.to_string(),
..Default::default()
},
}
}
fn make_mul(lhs: Expression, rhs: Expression) -> Expression {
Expression::Binary {
lhs: Box::new(lhs),
op: OpBinary::Mul(Token::default()),
rhs: Box::new(rhs),
}
}
fn make_add(lhs: Expression, rhs: Expression) -> Expression {
Expression::Binary {
lhs: Box::new(lhs),
op: OpBinary::Add(Token::default()),
rhs: Box::new(rhs),
}
}
fn make_sub(lhs: Expression, rhs: Expression) -> Expression {
Expression::Binary {
lhs: Box::new(lhs),
op: OpBinary::Sub(Token::default()),
rhs: Box::new(rhs),
}
}
#[test]
fn test_equation_structure_analysis() {
let eq = Equation::Simple {
lhs: make_der(make_var("x")),
rhs: make_var("v"),
};
let structure = analyze_equation_structure(&eq);
assert!(structure.variables.contains("x"));
assert!(structure.variables.contains("v"));
assert!(structure.derivatives.contains("x"));
assert!(!structure.is_constraint);
}
#[test]
fn test_constraint_detection() {
let eq = Equation::Simple {
lhs: make_var("x"),
rhs: make_var("y"),
};
let structure = analyze_equation_structure(&eq);
assert!(structure.is_constraint);
}
#[test]
fn test_index_reduction_ode() {
let equations = vec![Equation::Simple {
lhs: make_der(make_var("x")),
rhs: Expression::Unary {
op: OpUnary::Minus(Token::default()),
rhs: Box::new(make_var("x")),
},
}];
let states: HashSet<String> = ["x".to_string()].into_iter().collect();
let analysis = pantelides_index_reduction(&equations, &states, None);
assert_eq!(analysis.dae_index, 0);
assert!(analysis.equations_to_differentiate.is_empty());
assert!(!analysis.is_singular);
}
#[test]
fn test_pendulum_index3_dae() {
let equations = vec![
Equation::Simple {
lhs: make_der(make_var("x")),
rhs: make_var("vx"),
},
Equation::Simple {
lhs: make_der(make_var("y")),
rhs: make_var("vy"),
},
Equation::Simple {
lhs: make_der(make_var("vx")),
rhs: Expression::Unary {
op: OpUnary::Minus(Token::default()),
rhs: Box::new(make_mul(make_var("lambda"), make_var("x"))),
},
},
Equation::Simple {
lhs: make_der(make_var("vy")),
rhs: make_sub(
Expression::Unary {
op: OpUnary::Minus(Token::default()),
rhs: Box::new(make_mul(make_var("lambda"), make_var("y"))),
},
make_var("g"),
),
},
Equation::Simple {
lhs: make_add(
make_mul(make_var("x"), make_var("x")),
make_mul(make_var("y"), make_var("y")),
),
rhs: make_mul(make_var("L"), make_var("L")),
},
];
let states: HashSet<String> = ["x", "y", "vx", "vy"]
.iter()
.map(|s| s.to_string())
.collect();
let algebraic: HashSet<String> = ["lambda"].iter().map(|s| s.to_string()).collect();
let analysis = pantelides_index_reduction(&equations, &states, Some(&algebraic));
assert!(
analysis.dae_index > 0,
"Pendulum should be detected as high-index DAE (got index {})",
analysis.dae_index
);
assert!(
!analysis.equations_to_differentiate.is_empty(),
"Should identify constraint equation for differentiation"
);
assert!(
!analysis.is_singular,
"Pendulum should not be structurally singular"
);
}
#[test]
fn test_pendulum_constraint_is_detected() {
let constraint = Equation::Simple {
lhs: make_add(
make_mul(make_var("x"), make_var("x")),
make_mul(make_var("y"), make_var("y")),
),
rhs: make_mul(make_var("L"), make_var("L")),
};
let structure = analyze_equation_structure(&constraint);
assert!(
structure.is_constraint,
"x^2 + y^2 = L^2 should be detected as constraint"
);
assert!(
structure.derivatives.is_empty(),
"Constraint should have no derivatives"
);
assert!(structure.variables.contains("x"));
assert!(structure.variables.contains("y"));
assert!(structure.variables.contains("L"));
}
#[test]
fn test_index1_dae() {
let equations = vec![
Equation::Simple {
lhs: make_der(make_var("x")),
rhs: Expression::Unary {
op: OpUnary::Minus(Token::default()),
rhs: Box::new(make_var("y")),
},
},
Equation::Simple {
lhs: make_add(make_var("x"), make_var("y")),
rhs: make_const("1"),
},
];
let states: HashSet<String> = ["x".to_string()].into_iter().collect();
let algebraic: HashSet<String> = ["y".to_string()].into_iter().collect();
let analysis = pantelides_index_reduction(&equations, &states, Some(&algebraic));
assert!(
analysis.dae_index <= 1,
"Simple index-1 DAE should have index <= 1"
);
}
}