use crate::ir::ast::{
ComponentRefPart, ComponentReference, Equation, Expression, OpBinary, OpUnary, TerminalType,
Token,
};
use crate::ir::visitor::{Visitable, Visitor};
struct DerivativeFinder {
derivatives: Vec<String>,
}
impl DerivativeFinder {
fn new() -> Self {
Self {
derivatives: Vec::new(),
}
}
}
impl Visitor for DerivativeFinder {
fn enter_expression(&mut self, node: &Expression) {
if let Expression::FunctionCall { comp, args } = node
&& comp.to_string() == "der"
&& !args.is_empty()
&& let Expression::ComponentReference(cref) = &args[0]
{
self.derivatives.push(cref.to_string());
}
}
}
pub fn has_der_call(expr: &Expression) -> bool {
let mut finder = DerivativeFinder::new();
expr.accept(&mut finder);
!finder.derivatives.is_empty()
}
pub(super) fn check_if_needs_swap(lhs: &Expression, rhs: &Expression) -> bool {
let lhs_has_der = has_der_call(lhs);
let rhs_has_der = has_der_call(rhs);
!lhs_has_der && rhs_has_der
}
pub(super) fn normalize_derivative_equation(
lhs: &Expression,
rhs: &Expression,
) -> Option<Equation> {
if let Expression::Binary {
op: OpBinary::Mul(_),
lhs: mult_lhs,
rhs: mult_rhs,
} = lhs
{
if let Expression::FunctionCall { comp, args } = mult_rhs.as_ref()
&& comp.to_string() == "der"
&& args.len() == 1
{
let der_expr = mult_rhs.as_ref().clone();
let coeff = mult_lhs.as_ref().clone();
return Some(Equation::Simple {
lhs: der_expr,
rhs: Expression::Binary {
op: OpBinary::Div(Token::default()),
lhs: Box::new(rhs.clone()),
rhs: Box::new(coeff),
},
});
}
if let Expression::FunctionCall { comp, args } = mult_lhs.as_ref()
&& comp.to_string() == "der"
&& args.len() == 1
{
let der_expr = mult_lhs.as_ref().clone();
let coeff = mult_rhs.as_ref().clone();
return Some(Equation::Simple {
lhs: der_expr,
rhs: Expression::Binary {
op: OpBinary::Div(Token::default()),
lhs: Box::new(rhs.clone()),
rhs: Box::new(coeff),
},
});
}
}
None
}
pub(super) fn causalize_equation(
lhs: &Expression,
rhs: &Expression,
solve_for: &str,
) -> Option<Equation> {
if let Expression::ComponentReference(cref) = lhs
&& cref.to_string() == solve_for
{
return None; }
if let Expression::ComponentReference(cref) = rhs
&& cref.to_string() == solve_for
{
return Some(Equation::Simple {
lhs: rhs.clone(),
rhs: lhs.clone(),
});
}
let is_zero = |expr: &Expression| -> bool {
match expr {
Expression::Terminal { token, .. } => token.text == "0" || token.text == "0.0",
_ => false,
}
};
let rhs_is_zero = is_zero(rhs);
let lhs_is_zero = is_zero(lhs);
if rhs_is_zero {
if let Some((coeff, other_terms)) = extract_linear_term(lhs, solve_for) {
let new_rhs = if coeff > 0.0 {
negate_expression(&other_terms)
} else {
other_terms
};
return Some(Equation::Simple {
lhs: Expression::ComponentReference(ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: solve_for.to_string(),
..Default::default()
},
subs: None,
}],
}),
rhs: new_rhs,
});
}
}
if lhs_is_zero {
if let Some((coeff, other_terms)) = extract_linear_term(rhs, solve_for) {
let new_rhs = if coeff > 0.0 {
negate_expression(&other_terms)
} else {
other_terms
};
return Some(Equation::Simple {
lhs: Expression::ComponentReference(ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: solve_for.to_string(),
..Default::default()
},
subs: None,
}],
}),
rhs: new_rhs,
});
}
}
if let Expression::Binary {
op: OpBinary::Mul(_),
lhs: mult_lhs,
rhs: mult_rhs,
} = lhs
{
if let Expression::ComponentReference(cref) = mult_rhs.as_ref()
&& cref.to_string() == solve_for
{
return Some(Equation::Simple {
lhs: Expression::ComponentReference(ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: solve_for.to_string(),
..Default::default()
},
subs: None,
}],
}),
rhs: Expression::Binary {
op: OpBinary::Div(Token::default()),
lhs: Box::new(rhs.clone()),
rhs: mult_lhs.clone(),
},
});
}
if let Expression::ComponentReference(cref) = mult_lhs.as_ref()
&& cref.to_string() == solve_for
{
return Some(Equation::Simple {
lhs: Expression::ComponentReference(ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: solve_for.to_string(),
..Default::default()
},
subs: None,
}],
}),
rhs: Expression::Binary {
op: OpBinary::Div(Token::default()),
lhs: Box::new(rhs.clone()),
rhs: mult_rhs.clone(),
},
});
}
}
if let Some((coeff, other_terms)) = extract_linear_term(lhs, solve_for) {
let rhs_minus_other = if is_zero_expression(&other_terms) {
rhs.clone()
} else {
Expression::Binary {
op: OpBinary::Sub(Token::default()),
lhs: Box::new(rhs.clone()),
rhs: Box::new(other_terms),
}
};
let new_rhs = if (coeff - 1.0).abs() < 1e-10 {
rhs_minus_other
} else if (coeff + 1.0).abs() < 1e-10 {
negate_expression(&rhs_minus_other)
} else {
Expression::Binary {
op: OpBinary::Div(Token::default()),
lhs: Box::new(rhs_minus_other),
rhs: Box::new(Expression::Terminal {
terminal_type: TerminalType::UnsignedReal,
token: Token {
text: coeff.to_string(),
..Default::default()
},
}),
}
};
return Some(Equation::Simple {
lhs: Expression::ComponentReference(ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: solve_for.to_string(),
..Default::default()
},
subs: None,
}],
}),
rhs: new_rhs,
});
}
if let Some((coeff, other_terms)) = extract_linear_term(rhs, solve_for) {
let lhs_minus_other = if is_zero_expression(&other_terms) {
lhs.clone()
} else {
Expression::Binary {
op: OpBinary::Sub(Token::default()),
lhs: Box::new(lhs.clone()),
rhs: Box::new(other_terms),
}
};
let new_rhs = if (coeff - 1.0).abs() < 1e-10 {
lhs_minus_other
} else if (coeff + 1.0).abs() < 1e-10 {
negate_expression(&lhs_minus_other)
} else {
Expression::Binary {
op: OpBinary::Div(Token::default()),
lhs: Box::new(lhs_minus_other),
rhs: Box::new(Expression::Terminal {
terminal_type: TerminalType::UnsignedReal,
token: Token {
text: coeff.to_string(),
..Default::default()
},
}),
}
};
return Some(Equation::Simple {
lhs: Expression::ComponentReference(ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: solve_for.to_string(),
..Default::default()
},
subs: None,
}],
}),
rhs: new_rhs,
});
}
None
}
fn is_zero_expression(expr: &Expression) -> bool {
match expr {
Expression::Terminal { token, .. } => token.text == "0" || token.text == "0.0",
_ => false,
}
}
fn negate_expression(expr: &Expression) -> Expression {
match expr {
Expression::Unary {
op: OpUnary::Minus(_),
rhs,
} => (**rhs).clone(),
Expression::Binary {
op: OpBinary::Sub(_),
lhs,
rhs,
} => Expression::Binary {
op: OpBinary::Sub(Token::default()),
lhs: rhs.clone(),
rhs: lhs.clone(),
},
_ => Expression::Unary {
op: OpUnary::Minus(Token::default()),
rhs: Box::new(expr.clone()),
},
}
}
fn extract_linear_term(expr: &Expression, var_name: &str) -> Option<(f64, Expression)> {
match expr {
Expression::ComponentReference(cref) => {
if cref.to_string() == var_name {
Some((
1.0,
Expression::Terminal {
terminal_type: TerminalType::UnsignedReal,
token: Token {
text: "0".to_string(),
..Default::default()
},
},
))
} else {
None }
}
Expression::Unary {
op: OpUnary::Minus(_),
rhs,
} => {
if let Expression::ComponentReference(cref) = rhs.as_ref()
&& cref.to_string() == var_name
{
return Some((
-1.0,
Expression::Terminal {
terminal_type: TerminalType::UnsignedReal,
token: Token {
text: "0".to_string(),
..Default::default()
},
},
));
}
if let Some((coeff, other)) = extract_linear_term(rhs, var_name) {
Some((-coeff, negate_expression(&other)))
} else {
None
}
}
Expression::Binary {
op: OpBinary::Add(_),
lhs,
rhs,
} => {
if let Some((coeff, other_from_lhs)) = extract_linear_term(lhs, var_name) {
let combined_other = if is_zero_expression(&other_from_lhs) {
(**rhs).clone()
} else {
Expression::Binary {
op: OpBinary::Add(Token::default()),
lhs: Box::new(other_from_lhs),
rhs: rhs.clone(),
}
};
Some((coeff, combined_other))
} else if let Some((coeff, other_from_rhs)) = extract_linear_term(rhs, var_name) {
let combined_other = if is_zero_expression(&other_from_rhs) {
(**lhs).clone()
} else {
Expression::Binary {
op: OpBinary::Add(Token::default()),
lhs: lhs.clone(),
rhs: Box::new(other_from_rhs),
}
};
Some((coeff, combined_other))
} else {
None
}
}
Expression::Binary {
op: OpBinary::Sub(_),
lhs,
rhs,
} => {
if let Some((coeff, other_from_lhs)) = extract_linear_term(lhs, var_name) {
let combined_other = if is_zero_expression(&other_from_lhs) {
negate_expression(rhs)
} else {
Expression::Binary {
op: OpBinary::Sub(Token::default()),
lhs: Box::new(other_from_lhs),
rhs: rhs.clone(),
}
};
Some((coeff, combined_other))
} else if let Some((coeff, other_from_rhs)) = extract_linear_term(rhs, var_name) {
let combined_other = if is_zero_expression(&other_from_rhs) {
(**lhs).clone()
} else {
Expression::Binary {
op: OpBinary::Sub(Token::default()),
lhs: lhs.clone(),
rhs: Box::new(other_from_rhs),
}
};
Some((-coeff, combined_other))
} else {
None
}
}
_ => None, }
}
#[cfg(test)]
mod tests {
use super::*;
fn make_var(name: &str) -> Expression {
Expression::ComponentReference(ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: name.to_string(),
..Default::default()
},
subs: None,
}],
})
}
fn make_zero() -> Expression {
Expression::Terminal {
terminal_type: TerminalType::UnsignedReal,
token: Token {
text: "0".to_string(),
..Default::default()
},
}
}
#[test]
fn test_causalize_already_causal() {
let lhs = make_var("x");
let rhs = make_var("y");
let result = causalize_equation(&lhs, &rhs, "x");
assert!(result.is_none(), "Already causal, should return None");
}
#[test]
fn test_causalize_sum_to_zero() {
let lhs = Expression::Binary {
op: OpBinary::Add(Token::default()),
lhs: Box::new(make_var("a")),
rhs: Box::new(make_var("b")),
};
let rhs = make_zero();
let result = causalize_equation(&lhs, &rhs, "a");
assert!(result.is_some());
if let Some(Equation::Simple { lhs, rhs: _ }) = result {
assert!(
matches!(lhs, Expression::ComponentReference(_)),
"LHS should be a simple variable"
);
} else {
panic!("Expected Simple equation");
}
}
#[test]
fn test_causalize_three_term_sum() {
let inner_sum = Expression::Binary {
op: OpBinary::Add(Token::default()),
lhs: Box::new(make_var("a")),
rhs: Box::new(make_var("b")),
};
let lhs = Expression::Binary {
op: OpBinary::Add(Token::default()),
lhs: Box::new(inner_sum),
rhs: Box::new(make_var("c")),
};
let rhs = make_zero();
let result = causalize_equation(&lhs, &rhs, "a");
assert!(result.is_some(), "Should be able to solve for 'a'");
}
#[test]
fn test_causalize_zero_on_lhs() {
let lhs = make_zero();
let rhs = Expression::Binary {
op: OpBinary::Add(Token::default()),
lhs: Box::new(make_var("a")),
rhs: Box::new(make_var("b")),
};
let result = causalize_equation(&lhs, &rhs, "a");
assert!(result.is_some(), "Should handle zero on LHS");
}
}