use crate::ir::ast::{
ComponentRefPart, ComponentReference, Equation, Expression, OpBinary, TerminalType, Token,
};
pub fn differentiate_equation(equation: &Equation) -> Option<Equation> {
if let Equation::Simple { lhs, rhs, .. } = equation {
let diff_lhs = differentiate_expression(lhs);
let diff_rhs = differentiate_expression(rhs);
Some(Equation::Simple {
lhs: diff_lhs,
rhs: diff_rhs,
})
} else {
None
}
}
pub fn differentiate_expression(expr: &Expression) -> Expression {
match expr {
Expression::ComponentReference(cref) => {
Expression::FunctionCall {
comp: ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: "der".to_string(),
..Default::default()
},
subs: None,
}],
},
args: vec![Expression::ComponentReference(cref.clone())],
}
}
Expression::Binary { lhs, op, rhs } => {
match op {
OpBinary::Add(_) | OpBinary::Sub(_) => {
Expression::Binary {
lhs: Box::new(differentiate_expression(lhs)),
op: op.clone(),
rhs: Box::new(differentiate_expression(rhs)),
}
}
OpBinary::Mul(_) => {
let da = differentiate_expression(lhs);
let db = differentiate_expression(rhs);
Expression::Binary {
lhs: Box::new(Expression::Binary {
lhs: Box::new(da),
op: op.clone(),
rhs: rhs.clone(),
}),
op: OpBinary::Add(Token::default()),
rhs: Box::new(Expression::Binary {
lhs: lhs.clone(),
op: op.clone(),
rhs: Box::new(db),
}),
}
}
OpBinary::Div(_) => {
let da = differentiate_expression(lhs);
let db = differentiate_expression(rhs);
Expression::Binary {
lhs: Box::new(Expression::Binary {
lhs: Box::new(Expression::Binary {
lhs: Box::new(da),
op: OpBinary::Mul(Token::default()),
rhs: rhs.clone(),
}),
op: OpBinary::Sub(Token::default()),
rhs: Box::new(Expression::Binary {
lhs: lhs.clone(),
op: OpBinary::Mul(Token::default()),
rhs: Box::new(db),
}),
}),
op: OpBinary::Div(Token::default()),
rhs: Box::new(Expression::Binary {
lhs: rhs.clone(),
op: OpBinary::Mul(Token::default()),
rhs: rhs.clone(),
}),
}
}
_ => {
wrap_in_der(expr)
}
}
}
Expression::Terminal { terminal_type, .. } => {
match terminal_type {
TerminalType::UnsignedInteger | TerminalType::UnsignedReal => {
Expression::Terminal {
terminal_type: TerminalType::UnsignedInteger,
token: Token {
text: "0".to_string(),
..Default::default()
},
}
}
_ => wrap_in_der(expr),
}
}
Expression::FunctionCall { comp, args } => {
if comp.to_string() == "der" {
Expression::FunctionCall {
comp: comp.clone(),
args: args.iter().map(differentiate_expression).collect(),
}
} else {
wrap_in_der(expr)
}
}
Expression::Unary { op, rhs } => {
Expression::Unary {
op: op.clone(),
rhs: Box::new(differentiate_expression(rhs)),
}
}
Expression::Array {
elements,
is_matrix,
} => {
Expression::Array {
elements: elements.iter().map(differentiate_expression).collect(),
is_matrix: *is_matrix,
}
}
Expression::Tuple { elements } => {
Expression::Tuple {
elements: elements.iter().map(differentiate_expression).collect(),
}
}
Expression::Range { .. } | Expression::If { .. } | Expression::Empty => {
wrap_in_der(expr)
}
Expression::Parenthesized { inner } => {
Expression::Parenthesized {
inner: Box::new(differentiate_expression(inner)),
}
}
Expression::ArrayComprehension { expr, indices } => {
Expression::ArrayComprehension {
expr: Box::new(differentiate_expression(expr)),
indices: indices.clone(),
}
}
}
}
fn wrap_in_der(expr: &Expression) -> Expression {
Expression::FunctionCall {
comp: ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: "der".to_string(),
..Default::default()
},
subs: None,
}],
},
args: vec![expr.clone()],
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::ast::OpUnary;
fn make_var(name: &str) -> Expression {
Expression::ComponentReference(ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: name.to_string(),
..Default::default()
},
subs: None,
}],
})
}
fn make_der(var: Expression) -> Expression {
Expression::FunctionCall {
comp: ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: "der".to_string(),
..Default::default()
},
subs: None,
}],
},
args: vec![var],
}
}
fn make_const(val: &str) -> Expression {
Expression::Terminal {
terminal_type: TerminalType::UnsignedInteger,
token: Token {
text: val.to_string(),
..Default::default()
},
}
}
#[test]
fn test_differentiate_variable() {
let expr = make_var("x");
let diff = differentiate_expression(&expr);
if let Expression::FunctionCall { comp, args } = diff {
assert_eq!(comp.to_string(), "der");
assert_eq!(args.len(), 1);
} else {
panic!("Expected function call");
}
}
#[test]
fn test_differentiate_constant() {
let expr = make_const("5");
let diff = differentiate_expression(&expr);
if let Expression::Terminal { token, .. } = diff {
assert_eq!(token.text, "0");
} else {
panic!("Expected terminal");
}
}
#[test]
fn test_differentiate_sum() {
let expr = Expression::Binary {
lhs: Box::new(make_var("x")),
op: OpBinary::Add(Token::default()),
rhs: Box::new(make_var("y")),
};
let diff = differentiate_expression(&expr);
if let Expression::Binary {
op: OpBinary::Add(_),
..
} = diff
{
} else {
panic!("Expected binary addition");
}
}
#[test]
fn test_differentiate_product() {
let expr = Expression::Binary {
lhs: Box::new(make_var("x")),
op: OpBinary::Mul(Token::default()),
rhs: Box::new(make_var("y")),
};
let diff = differentiate_expression(&expr);
if let Expression::Binary {
op: OpBinary::Add(_),
..
} = diff
{
} else {
panic!("Expected binary addition from product rule");
}
}
#[test]
fn test_differentiate_negation() {
let expr = Expression::Unary {
op: OpUnary::Minus(Token::default()),
rhs: Box::new(make_var("x")),
};
let diff = differentiate_expression(&expr);
if let Expression::Unary {
op: OpUnary::Minus(_),
rhs,
} = diff
{
if let Expression::FunctionCall { comp, .. } = *rhs {
assert_eq!(comp.to_string(), "der");
} else {
panic!("Expected der() inside negation");
}
} else {
panic!("Expected unary negation");
}
}
#[test]
fn test_differentiate_der() {
let expr = make_der(make_var("x"));
let diff = differentiate_expression(&expr);
if let Expression::FunctionCall { comp, args } = diff {
assert_eq!(comp.to_string(), "der");
if let Expression::FunctionCall {
comp: inner_comp, ..
} = &args[0]
{
assert_eq!(inner_comp.to_string(), "der");
} else {
panic!("Expected nested der()");
}
} else {
panic!("Expected function call");
}
}
#[test]
fn test_differentiate_equation() {
let eq = Equation::Simple {
lhs: make_var("x"),
rhs: make_var("y"),
};
let diff_eq = differentiate_equation(&eq);
assert!(diff_eq.is_some());
if let Some(Equation::Simple { lhs, rhs }) = diff_eq {
assert!(matches!(lhs, Expression::FunctionCall { .. }));
assert!(matches!(rhs, Expression::FunctionCall { .. }));
}
}
}