use super::expression::{SymbolicExpression, Variable};
use crate::common::IntegrateFloat;
use crate::error::{IntegrateError, IntegrateResult};
use scirs2_core::ndarray::{Array1, ArrayView1};
use std::collections::HashMap;
use SymbolicExpression::{Add, Constant, Cos, Div, Exp, Ln, Mul, Neg, Pow, Sin, Sqrt, Sub, Var};
pub struct HigherOrderODE<F: IntegrateFloat> {
pub order: usize,
pub dependent_var: String,
pub independent_var: String,
pub expression: SymbolicExpression<F>,
}
impl<F: IntegrateFloat> HigherOrderODE<F> {
pub fn new(
order: usize,
dependent_var: impl Into<String>,
independent_var: impl Into<String>,
expression: SymbolicExpression<F>,
) -> IntegrateResult<Self> {
if order == 0 {
return Err(IntegrateError::ValueError(
"ODE order must be at least 1".to_string(),
));
}
Ok(HigherOrderODE {
order,
dependent_var: dependent_var.into(),
independent_var: independent_var.into(),
expression,
})
}
pub fn state_variables(&self) -> Vec<Variable> {
(0..self.order)
.map(|i| Variable::indexed(&self.dependent_var, i))
.collect()
}
}
pub struct FirstOrderSystem<F: IntegrateFloat> {
pub state_vars: Vec<Variable>,
pub expressions: Vec<SymbolicExpression<F>>,
pub variable_map: HashMap<String, Variable>,
}
impl<F: IntegrateFloat> FirstOrderSystem<F> {
pub fn to_function(&self) -> impl Fn(F, ArrayView1<F>) -> IntegrateResult<Array1<F>> {
let expressions = self.expressions.clone();
let state_vars = self.state_vars.clone();
move |t: F, y: ArrayView1<F>| {
if y.len() != state_vars.len() {
return Err(IntegrateError::DimensionMismatch(format!(
"Expected {} states, got {}",
state_vars.len(),
y.len()
)));
}
let mut values = HashMap::new();
for (i, var) in state_vars.iter().enumerate() {
values.insert(var.clone(), y[i]);
}
values.insert(Variable::new("t"), t);
let mut result = Array1::zeros(expressions.len());
for (i, expr) in expressions.iter().enumerate() {
result[i] = expr.evaluate(&values)?;
}
Ok(result)
}
}
}
#[allow(dead_code)]
pub fn higher_order_to_first_order<F: IntegrateFloat>(
ode: &HigherOrderODE<F>,
) -> IntegrateResult<FirstOrderSystem<F>> {
use SymbolicExpression::*;
let mut state_vars = Vec::new();
let mut expressions = Vec::new();
let mut variable_map = HashMap::new();
for i in 0..ode.order {
let var = Variable::indexed(&ode.dependent_var, i);
state_vars.push(var.clone());
let deriv_notation = match i {
0 => ode.dependent_var.clone(),
1 => format!("{}'", ode.dependent_var),
n => format!("{}^({})", ode.dependent_var, n),
};
variable_map.insert(deriv_notation, var);
}
for i in 0..ode.order - 1 {
expressions.push(Var(state_vars[i + 1].clone()));
}
let mut highest_deriv_expr = ode.expression.clone();
highest_deriv_expr = substitute_derivatives(&highest_deriv_expr, &variable_map);
expressions.push(highest_deriv_expr);
Ok(FirstOrderSystem {
state_vars,
expressions,
variable_map,
})
}
#[allow(dead_code)]
fn substitute_derivatives<F: IntegrateFloat>(
expr: &SymbolicExpression<F>,
variable_map: &HashMap<String, Variable>,
) -> SymbolicExpression<F> {
match expr {
Var(v) => {
if let Some(state_var) = variable_map.get(&v.name) {
Var(state_var.clone())
} else {
expr.clone()
}
}
Add(a, b) => Add(
Box::new(substitute_derivatives(a, variable_map)),
Box::new(substitute_derivatives(b, variable_map)),
),
Sub(a, b) => Sub(
Box::new(substitute_derivatives(a, variable_map)),
Box::new(substitute_derivatives(b, variable_map)),
),
Mul(a, b) => Mul(
Box::new(substitute_derivatives(a, variable_map)),
Box::new(substitute_derivatives(b, variable_map)),
),
Div(a, b) => Div(
Box::new(substitute_derivatives(a, variable_map)),
Box::new(substitute_derivatives(b, variable_map)),
),
Pow(a, b) => Pow(
Box::new(substitute_derivatives(a, variable_map)),
Box::new(substitute_derivatives(b, variable_map)),
),
Neg(a) => Neg(Box::new(substitute_derivatives(a, variable_map))),
Sin(a) => Sin(Box::new(substitute_derivatives(a, variable_map))),
Cos(a) => Cos(Box::new(substitute_derivatives(a, variable_map))),
Exp(a) => Exp(Box::new(substitute_derivatives(a, variable_map))),
Ln(a) => Ln(Box::new(substitute_derivatives(a, variable_map))),
Sqrt(a) => Sqrt(Box::new(substitute_derivatives(a, variable_map))),
_ => expr.clone(),
}
}
#[allow(dead_code)]
pub fn example_damped_oscillator<F: IntegrateFloat>(
omega: F,
damping: F,
) -> IntegrateResult<FirstOrderSystem<F>> {
let x = Var(Variable::new("x"));
let x_prime = Var(Variable::new("x'"));
let expression = Neg(Box::new(Add(
Box::new(Mul(
Box::new(Mul(
Box::new(Constant(
F::from(2.0).expect("Failed to convert constant to float"),
)),
Box::new(Constant(damping)),
)),
Box::new(x_prime),
)),
Box::new(Mul(
Box::new(Pow(
Box::new(Constant(omega)),
Box::new(Constant(
F::from(2.0).expect("Failed to convert constant to float"),
)),
)),
Box::new(x),
)),
)));
let ode = HigherOrderODE::new(2, "x", "t", expression)?;
higher_order_to_first_order(&ode)
}
#[allow(dead_code)]
pub fn example_driven_pendulum<F: IntegrateFloat>(
g: F, l: F, gamma: F, a: F, omega: F, ) -> IntegrateResult<FirstOrderSystem<F>> {
let theta = SymbolicExpression::var("θ");
let theta_prime = SymbolicExpression::var("θ'");
let t = SymbolicExpression::var("t");
let g_over_l = SymbolicExpression::constant(g / l);
let gamma_const = SymbolicExpression::constant(gamma);
let a_const = SymbolicExpression::constant(a);
let omega_const = SymbolicExpression::constant(omega);
let damping_term = -gamma_const * theta_prime;
let gravity_term = -g_over_l * SymbolicExpression::Sin(Box::new(theta));
let driving_term = a_const * SymbolicExpression::Cos(Box::new(omega_const * t));
let expression = damping_term + gravity_term + driving_term;
let ode = HigherOrderODE::new(2, "θ", "t", expression)?;
higher_order_to_first_order(&ode)
}
#[allow(dead_code)]
pub fn example_euler_bernoulli_beam<F: IntegrateFloat>(
ei: F, _rho_a: F, f: F, ) -> IntegrateResult<FirstOrderSystem<F>> {
let f_over_ei = SymbolicExpression::constant(f / ei);
let ode = HigherOrderODE::new(4, "w", "x", f_over_ei)?;
higher_order_to_first_order(&ode)
}
pub struct SystemConverter<F: IntegrateFloat> {
odes: Vec<HigherOrderODE<F>>,
total_states: usize,
}
impl<F: IntegrateFloat> SystemConverter<F> {
pub fn new() -> Self {
SystemConverter {
odes: Vec::new(),
total_states: 0,
}
}
pub fn add_ode(&mut self, ode: HigherOrderODE<F>) -> &mut Self {
self.total_states += ode.order;
self.odes.push(ode);
self
}
pub fn convert(&self) -> IntegrateResult<FirstOrderSystem<F>> {
let mut all_state_vars = Vec::new();
let mut all_expressions = Vec::new();
let mut all_variable_map = HashMap::new();
for ode in &self.odes {
let system = higher_order_to_first_order(ode)?;
all_state_vars.extend(system.state_vars);
all_expressions.extend(system.expressions);
all_variable_map.extend(system.variable_map);
}
Ok(FirstOrderSystem {
state_vars: all_state_vars,
expressions: all_expressions,
variable_map: all_variable_map,
})
}
}
impl<F: IntegrateFloat> Default for SystemConverter<F> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
higher_order_to_first_order, HigherOrderODE, SymbolicExpression,
SymbolicExpression::{Neg, Var},
Variable,
};
#[test]
fn test_second_order_conversion() {
let x: SymbolicExpression<f64> = Var(Variable::new("x"));
let expr = Neg(Box::new(x));
let ode = HigherOrderODE::new(2, "x", "t", expr).expect("Operation failed");
let system = higher_order_to_first_order(&ode).expect("Operation failed");
assert_eq!(system.state_vars.len(), 2);
assert_eq!(system.expressions.len(), 2);
if let Var(v) = &system.expressions[0] {
assert_eq!(v.name, "x");
assert_eq!(v.index, Some(1));
} else {
panic!(
"Expected variable expression, got {:?}",
system.expressions[0]
);
}
}
}