mod causalize;
pub mod create_dae;
mod differentiate;
pub mod location;
mod matching;
mod pantelides;
mod scc;
mod tearing;
use crate::ir::ast::{ComponentReference, Equation, Expression};
use crate::ir::visitor::{Visitable, Visitor};
use causalize::{causalize_equation, check_if_needs_swap, normalize_derivative_equation};
use matching::find_maximum_matching;
use scc::tarjan_scc;
use std::collections::{HashMap, HashSet};
pub use causalize::has_der_call;
pub use differentiate::{differentiate_equation, differentiate_expression};
pub use pantelides::pantelides_index_reduction;
pub use tearing::{analyze_algebraic_loops, tear_algebraic_loop};
struct VariableFinder {
variables: HashSet<String>,
skip_next_cref: bool,
}
impl VariableFinder {
fn new() -> Self {
Self {
variables: HashSet::new(),
skip_next_cref: false,
}
}
}
impl Visitor for VariableFinder {
fn enter_expression(&mut self, node: &Expression) {
if matches!(node, Expression::FunctionCall { .. }) {
self.skip_next_cref = true;
}
}
fn enter_component_reference(&mut self, comp: &ComponentReference) {
if self.skip_next_cref {
self.skip_next_cref = false;
} else {
self.variables.insert(comp.to_string());
}
}
}
struct DerivativeFinder {
derivatives: Vec<String>,
}
impl DerivativeFinder {
fn new() -> Self {
Self {
derivatives: Vec::new(),
}
}
}
impl Visitor for DerivativeFinder {
fn enter_expression(&mut self, node: &Expression) {
if let Expression::FunctionCall { comp, args } = node
&& comp.to_string() == "der"
&& !args.is_empty()
&& let Expression::ComponentReference(cref) = &args[0]
{
self.derivatives.push(cref.to_string());
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct EquationInfo {
pub equation: Equation,
pub all_variables: HashSet<String>,
pub lhs_variable: Option<String>,
pub is_derivative: bool,
pub matched_variable: Option<String>,
}
#[derive(Debug, Clone)]
pub struct AlgebraicLoop {
pub equation_indices: Vec<usize>,
pub variables: HashSet<String>,
pub tearing_variables: Vec<String>,
pub residual_variables: Vec<String>,
pub size: usize,
}
#[derive(Debug, Clone)]
pub struct DummyDerivative {
pub name: String,
pub base_variable: String,
pub order: usize,
}
#[derive(Debug, Clone, Default)]
pub struct StructuralAnalysis {
pub dae_index: usize,
pub equations_to_differentiate: HashMap<usize, usize>,
pub dummy_derivatives: Vec<DummyDerivative>,
pub algebraic_loops: Vec<AlgebraicLoop>,
pub is_singular: bool,
pub diagnostics: Vec<String>,
}
#[derive(Debug, Clone, Default)]
pub struct BltResult {
pub equations: Vec<Equation>,
pub sccs: Vec<Vec<usize>>,
pub matching: HashMap<usize, String>,
pub is_complete_matching: bool,
pub algebraic_loops: Vec<AlgebraicLoop>,
}
pub fn blt_transform(
equations: Vec<Equation>,
exclude_from_matching: &HashSet<String>,
) -> Vec<Equation> {
blt_transform_with_info(equations, exclude_from_matching).equations
}
pub fn blt_transform_with_info(
equations: Vec<Equation>,
exclude_from_matching: &HashSet<String>,
) -> BltResult {
let mut eq_infos: Vec<EquationInfo> = Vec::new();
let mut all_variables_set: HashSet<String> = HashSet::new();
for eq in equations.iter() {
if let Equation::Simple { lhs, rhs, .. } = eq {
let mut info = EquationInfo {
equation: eq.clone(),
all_variables: HashSet::new(),
lhs_variable: None,
is_derivative: false,
matched_variable: None,
};
match lhs {
Expression::ComponentReference(cref) => {
let var_name = cref.to_string();
info.lhs_variable = Some(var_name.clone());
info.all_variables.insert(var_name.clone());
all_variables_set.insert(var_name);
}
Expression::FunctionCall { comp, args } => {
if comp.to_string() == "der"
&& !args.is_empty()
&& let Expression::ComponentReference(cref) = &args[0]
{
let var_name = format!("der({})", cref);
info.lhs_variable = Some(var_name.clone());
info.all_variables.insert(var_name.clone());
all_variables_set.insert(var_name);
info.is_derivative = true;
}
}
_ => {
let mut lhs_finder = VariableFinder::new();
lhs.accept(&mut lhs_finder);
for var in lhs_finder.variables {
info.all_variables.insert(var.clone());
all_variables_set.insert(var);
}
}
}
let mut var_finder = VariableFinder::new();
rhs.accept(&mut var_finder);
for var in var_finder.variables {
info.all_variables.insert(var.clone());
all_variables_set.insert(var);
}
let mut der_finder = DerivativeFinder::new();
lhs.accept(&mut der_finder);
rhs.accept(&mut der_finder);
for der_var in &der_finder.derivatives {
let var_name = format!("der({})", der_var);
info.all_variables.insert(var_name.clone());
all_variables_set.insert(var_name);
}
eq_infos.push(info);
} else {
eq_infos.push(EquationInfo {
equation: eq.clone(),
all_variables: HashSet::new(),
lhs_variable: None,
is_derivative: false,
matched_variable: None,
});
}
}
let all_variables: Vec<String> = {
let mut vars: Vec<_> = all_variables_set
.into_iter()
.filter(|v| !exclude_from_matching.contains(v))
.collect();
vars.sort();
vars
};
let matching = find_maximum_matching(&eq_infos, &all_variables, exclude_from_matching);
for (eq_idx, var_name) in &matching {
eq_infos[*eq_idx].matched_variable = Some(var_name.clone());
}
let tarjan_result = tarjan_scc(&eq_infos);
let mut result_equations = Vec::new();
for idx in &tarjan_result.ordered_indices {
let info = &eq_infos[*idx];
if let Equation::Simple { lhs, rhs, .. } = &info.equation {
let needs_swap = check_if_needs_swap(lhs, rhs);
if needs_swap {
result_equations.push(Equation::Simple {
lhs: rhs.clone(),
rhs: lhs.clone(),
});
} else if let Some(normalized) = normalize_derivative_equation(lhs, rhs) {
result_equations.push(normalized);
} else {
if let Some(matched_var) = &info.matched_variable {
if let Some(causalized) = causalize_equation(lhs, rhs, matched_var) {
result_equations.push(causalized);
} else {
result_equations.push(info.equation.clone());
}
} else {
result_equations.push(info.equation.clone());
}
}
} else {
result_equations.push(info.equation.clone());
}
}
let is_complete_matching = matching.len() == eq_infos.len();
let algebraic_loops = tearing::analyze_algebraic_loops(&result_equations, &tarjan_result.sccs);
BltResult {
equations: result_equations,
sccs: tarjan_result.sccs,
matching,
is_complete_matching,
algebraic_loops,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::ast::{ComponentRefPart, OpBinary, OpUnary, TerminalType, Token};
fn make_var(name: &str) -> Expression {
Expression::ComponentReference(ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: name.to_string(),
..Default::default()
},
subs: None,
}],
})
}
fn make_der(var: Expression) -> Expression {
Expression::FunctionCall {
comp: ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: "der".to_string(),
..Default::default()
},
subs: None,
}],
},
args: vec![var],
}
}
fn make_zero() -> Expression {
Expression::Terminal {
terminal_type: TerminalType::UnsignedReal,
token: Token {
text: "0".to_string(),
..Default::default()
},
}
}
#[test]
fn test_swap_derivative_equation() {
let equations = vec![Equation::Simple {
lhs: make_var("v"),
rhs: make_der(make_var("h")),
}];
let result = blt_transform(equations, &HashSet::new());
assert_eq!(result.len(), 1);
if let Equation::Simple { lhs, rhs, .. } = &result[0] {
assert!(has_der_call(lhs), "LHS should have der()");
assert!(!has_der_call(rhs), "RHS should not have der()");
} else {
panic!("Expected Simple equation");
}
}
#[test]
fn test_blt_with_chain_dependencies() {
let equations = vec![
Equation::Simple {
lhs: make_var("x"),
rhs: make_var("y"),
},
Equation::Simple {
lhs: make_var("y"),
rhs: make_var("z"),
},
Equation::Simple {
lhs: make_var("z"),
rhs: Expression::Terminal {
terminal_type: TerminalType::UnsignedInteger,
token: Token {
text: "1".to_string(),
..Default::default()
},
},
},
];
let result = blt_transform(equations, &HashSet::new());
assert_eq!(result.len(), 3);
let order: Vec<String> = result
.iter()
.filter_map(|eq| {
if let Equation::Simple {
lhs: Expression::ComponentReference(cref),
..
} = eq
{
return Some(cref.to_string());
}
None
})
.collect();
let z_pos = order.iter().position(|s| s == "z").unwrap();
let y_pos = order.iter().position(|s| s == "y").unwrap();
let x_pos = order.iter().position(|s| s == "x").unwrap();
assert!(
z_pos < y_pos,
"z should be computed before y (z at {}, y at {})",
z_pos,
y_pos
);
assert!(
y_pos < x_pos,
"y should be computed before x (y at {}, x at {})",
y_pos,
x_pos
);
}
#[test]
fn test_blt_algebraic_loop_detection() {
let one = Expression::Terminal {
terminal_type: TerminalType::UnsignedInteger,
token: Token {
text: "1".to_string(),
..Default::default()
},
};
let equations = vec![
Equation::Simple {
lhs: make_var("x"),
rhs: Expression::Binary {
lhs: Box::new(make_var("y")),
op: OpBinary::Add(Token::default()),
rhs: Box::new(one.clone()),
},
},
Equation::Simple {
lhs: make_var("y"),
rhs: Expression::Binary {
lhs: Box::new(make_var("x")),
op: OpBinary::Add(Token::default()),
rhs: Box::new(one),
},
},
];
let result = blt_transform(equations, &HashSet::new());
assert_eq!(result.len(), 2);
}
#[test]
fn test_causalize_already_causal() {
let lhs = make_var("a");
let rhs = make_var("b");
let result = causalize_equation(&lhs, &rhs, "a");
assert!(
result.is_none(),
"Should return None for already causal equation"
);
}
#[test]
fn test_causalize_sum_to_zero() {
let lhs = Expression::Binary {
op: OpBinary::Add(Token::default()),
lhs: Box::new(make_var("a")),
rhs: Box::new(make_var("b")),
};
let rhs = make_zero();
let result = causalize_equation(&lhs, &rhs, "a");
assert!(
result.is_some(),
"Should be able to causalize a + b = 0 for a"
);
if let Some(Equation::Simple { lhs, rhs }) = result {
if let Expression::ComponentReference(cref) = lhs {
assert_eq!(cref.to_string(), "a");
} else {
panic!("LHS should be ComponentReference");
}
if let Expression::Unary {
op: OpUnary::Minus(_),
rhs,
} = rhs
{
if let Expression::ComponentReference(cref) = *rhs {
assert_eq!(cref.to_string(), "b");
} else {
panic!("RHS of negation should be ComponentReference");
}
} else {
panic!("RHS should be Unary negation, got: {:?}", rhs);
}
}
}
#[test]
fn test_causalize_three_term_sum() {
let inner = Expression::Binary {
op: OpBinary::Add(Token::default()),
lhs: Box::new(make_var("a")),
rhs: Box::new(make_var("b")),
};
let lhs = Expression::Binary {
op: OpBinary::Add(Token::default()),
lhs: Box::new(inner),
rhs: Box::new(make_var("c")),
};
let rhs = make_zero();
let result = causalize_equation(&lhs, &rhs, "a");
assert!(
result.is_some(),
"Should be able to causalize a + b + c = 0 for a"
);
if let Some(Equation::Simple { lhs, .. }) = result {
if let Expression::ComponentReference(cref) = lhs {
assert_eq!(cref.to_string(), "a");
} else {
panic!("LHS should be ComponentReference");
}
}
}
#[test]
fn test_causalize_zero_on_lhs() {
let lhs = make_zero();
let rhs = Expression::Binary {
op: OpBinary::Add(Token::default()),
lhs: Box::new(make_var("a")),
rhs: Box::new(make_var("b")),
};
let result = causalize_equation(&lhs, &rhs, "a");
assert!(
result.is_some(),
"Should be able to causalize 0 = a + b for a"
);
if let Some(Equation::Simple { lhs, rhs }) = result {
if let Expression::ComponentReference(cref) = lhs {
assert_eq!(cref.to_string(), "a");
} else {
panic!("LHS should be ComponentReference");
}
if let Expression::Unary {
op: OpUnary::Minus(_),
rhs,
} = rhs
{
if let Expression::ComponentReference(cref) = *rhs {
assert_eq!(cref.to_string(), "b");
} else {
panic!("RHS of negation should be ComponentReference");
}
} else {
panic!("RHS should be Unary negation, got: {:?}", rhs);
}
}
}
#[test]
fn test_kcl_style_equation() {
let equations = vec![Equation::Simple {
lhs: Expression::Binary {
op: OpBinary::Add(Token::default()),
lhs: Box::new(make_var("R2_n_i")),
rhs: Box::new(make_var("L1_p_i")),
},
rhs: make_zero(),
}];
let result = blt_transform(equations, &HashSet::new());
assert_eq!(result.len(), 1);
if let Equation::Simple { lhs, .. } = &result[0] {
assert!(
matches!(lhs, Expression::ComponentReference(_)),
"LHS should be a simple variable after causalization, got: {:?}",
lhs
);
} else {
panic!("Expected Simple equation");
}
}
#[test]
fn test_kcl_style_equation_zero_on_lhs() {
let equations = vec![Equation::Simple {
lhs: make_zero(),
rhs: Expression::Binary {
op: OpBinary::Add(Token::default()),
lhs: Box::new(make_var("R2_n_i")),
rhs: Box::new(make_var("L1_p_i")),
},
}];
let result = blt_transform(equations, &HashSet::new());
assert_eq!(result.len(), 1);
if let Equation::Simple { lhs, .. } = &result[0] {
assert!(
matches!(lhs, Expression::ComponentReference(_)),
"LHS should be a simple variable after causalization, got: {:?}",
lhs
);
} else {
panic!("Expected Simple equation");
}
}
}