scirs2_integrate/symbolic/
conversion.rs1use super::expression::{SymbolicExpression, Variable};
8use crate::common::IntegrateFloat;
9use crate::error::{IntegrateError, IntegrateResult};
10use scirs2_core::ndarray::{Array1, ArrayView1};
11use std::collections::HashMap;
12use SymbolicExpression::{Add, Constant, Cos, Div, Exp, Ln, Mul, Neg, Pow, Sin, Sqrt, Sub, Var};
13
14pub struct HigherOrderODE<F: IntegrateFloat> {
16 pub order: usize,
18 pub dependent_var: String,
20 pub independent_var: String,
22 pub expression: SymbolicExpression<F>,
25}
26
27impl<F: IntegrateFloat> HigherOrderODE<F> {
28 pub fn new(
30 order: usize,
31 dependent_var: impl Into<String>,
32 independent_var: impl Into<String>,
33 expression: SymbolicExpression<F>,
34 ) -> IntegrateResult<Self> {
35 if order == 0 {
36 return Err(IntegrateError::ValueError(
37 "ODE order must be at least 1".to_string(),
38 ));
39 }
40
41 Ok(HigherOrderODE {
42 order,
43 dependent_var: dependent_var.into(),
44 independent_var: independent_var.into(),
45 expression,
46 })
47 }
48
49 pub fn state_variables(&self) -> Vec<Variable> {
51 (0..self.order)
52 .map(|i| Variable::indexed(&self.dependent_var, i))
53 .collect()
54 }
55}
56
57pub struct FirstOrderSystem<F: IntegrateFloat> {
59 pub state_vars: Vec<Variable>,
61 pub expressions: Vec<SymbolicExpression<F>>,
63 pub variable_map: HashMap<String, Variable>,
65}
66
67impl<F: IntegrateFloat> FirstOrderSystem<F> {
68 pub fn to_function(&self) -> impl Fn(F, ArrayView1<F>) -> IntegrateResult<Array1<F>> {
70 let expressions = self.expressions.clone();
71 let state_vars = self.state_vars.clone();
72
73 move |t: F, y: ArrayView1<F>| {
74 if y.len() != state_vars.len() {
75 return Err(IntegrateError::DimensionMismatch(format!(
76 "Expected {} states, got {}",
77 state_vars.len(),
78 y.len()
79 )));
80 }
81
82 let mut values = HashMap::new();
84 for (i, var) in state_vars.iter().enumerate() {
85 values.insert(var.clone(), y[i]);
86 }
87 values.insert(Variable::new("t"), t);
88
89 let mut result = Array1::zeros(expressions.len());
91 for (i, expr) in expressions.iter().enumerate() {
92 result[i] = expr.evaluate(&values)?;
93 }
94
95 Ok(result)
96 }
97 }
98}
99
100#[allow(dead_code)]
116pub fn higher_order_to_first_order<F: IntegrateFloat>(
117 ode: &HigherOrderODE<F>,
118) -> IntegrateResult<FirstOrderSystem<F>> {
119 use SymbolicExpression::*;
120
121 let mut state_vars = Vec::new();
122 let mut expressions = Vec::new();
123 let mut variable_map = HashMap::new();
124
125 for i in 0..ode.order {
127 let var = Variable::indexed(&ode.dependent_var, i);
128 state_vars.push(var.clone());
129
130 let deriv_notation = match i {
132 0 => ode.dependent_var.clone(),
133 1 => format!("{}'", ode.dependent_var),
134 n => format!("{}^({})", ode.dependent_var, n),
135 };
136 variable_map.insert(deriv_notation, var);
137 }
138
139 for i in 0..ode.order - 1 {
142 expressions.push(Var(state_vars[i + 1].clone()));
143 }
144
145 let mut highest_deriv_expr = ode.expression.clone();
147 highest_deriv_expr = substitute_derivatives(&highest_deriv_expr, &variable_map);
148 expressions.push(highest_deriv_expr);
149
150 Ok(FirstOrderSystem {
151 state_vars,
152 expressions,
153 variable_map,
154 })
155}
156
157#[allow(dead_code)]
159fn substitute_derivatives<F: IntegrateFloat>(
160 expr: &SymbolicExpression<F>,
161 variable_map: &HashMap<String, Variable>,
162) -> SymbolicExpression<F> {
163 match expr {
164 Var(v) => {
165 if let Some(state_var) = variable_map.get(&v.name) {
167 Var(state_var.clone())
168 } else {
169 expr.clone()
170 }
171 }
172 Add(a, b) => Add(
173 Box::new(substitute_derivatives(a, variable_map)),
174 Box::new(substitute_derivatives(b, variable_map)),
175 ),
176 Sub(a, b) => Sub(
177 Box::new(substitute_derivatives(a, variable_map)),
178 Box::new(substitute_derivatives(b, variable_map)),
179 ),
180 Mul(a, b) => Mul(
181 Box::new(substitute_derivatives(a, variable_map)),
182 Box::new(substitute_derivatives(b, variable_map)),
183 ),
184 Div(a, b) => Div(
185 Box::new(substitute_derivatives(a, variable_map)),
186 Box::new(substitute_derivatives(b, variable_map)),
187 ),
188 Pow(a, b) => Pow(
189 Box::new(substitute_derivatives(a, variable_map)),
190 Box::new(substitute_derivatives(b, variable_map)),
191 ),
192 Neg(a) => Neg(Box::new(substitute_derivatives(a, variable_map))),
193 Sin(a) => Sin(Box::new(substitute_derivatives(a, variable_map))),
194 Cos(a) => Cos(Box::new(substitute_derivatives(a, variable_map))),
195 Exp(a) => Exp(Box::new(substitute_derivatives(a, variable_map))),
196 Ln(a) => Ln(Box::new(substitute_derivatives(a, variable_map))),
197 Sqrt(a) => Sqrt(Box::new(substitute_derivatives(a, variable_map))),
198 _ => expr.clone(),
199 }
200}
201
202#[allow(dead_code)]
204pub fn example_damped_oscillator<F: IntegrateFloat>(
205 omega: F,
206 damping: F,
207) -> IntegrateResult<FirstOrderSystem<F>> {
208 let x = Var(Variable::new("x"));
212 let x_prime = Var(Variable::new("x'"));
213
214 let expression = Neg(Box::new(Add(
215 Box::new(Mul(
216 Box::new(Mul(
217 Box::new(Constant(F::from(2.0).unwrap())),
218 Box::new(Constant(damping)),
219 )),
220 Box::new(x_prime),
221 )),
222 Box::new(Mul(
223 Box::new(Pow(
224 Box::new(Constant(omega)),
225 Box::new(Constant(F::from(2.0).unwrap())),
226 )),
227 Box::new(x),
228 )),
229 )));
230
231 let ode = HigherOrderODE::new(2, "x", "t", expression)?;
232 higher_order_to_first_order(&ode)
233}
234
235#[allow(dead_code)]
237pub fn example_driven_pendulum<F: IntegrateFloat>(
238 g: F, l: F, gamma: F, a: F, omega: F, ) -> IntegrateResult<FirstOrderSystem<F>> {
244 let theta = SymbolicExpression::var("θ");
248 let theta_prime = SymbolicExpression::var("θ'");
249 let t = SymbolicExpression::var("t");
250
251 let g_over_l = SymbolicExpression::constant(g / l);
252 let gamma_const = SymbolicExpression::constant(gamma);
253 let a_const = SymbolicExpression::constant(a);
254 let omega_const = SymbolicExpression::constant(omega);
255
256 let damping_term = -gamma_const * theta_prime;
258 let gravity_term = -g_over_l * SymbolicExpression::Sin(Box::new(theta));
259 let driving_term = a_const * SymbolicExpression::Cos(Box::new(omega_const * t));
260
261 let expression = damping_term + gravity_term + driving_term;
262
263 let ode = HigherOrderODE::new(2, "θ", "t", expression)?;
264 higher_order_to_first_order(&ode)
265}
266
267#[allow(dead_code)]
269pub fn example_euler_bernoulli_beam<F: IntegrateFloat>(
270 ei: F, _rho_a: F, f: F, ) -> IntegrateResult<FirstOrderSystem<F>> {
274 let f_over_ei = SymbolicExpression::constant(f / ei);
279
280 let ode = HigherOrderODE::new(4, "w", "x", f_over_ei)?;
281 higher_order_to_first_order(&ode)
282}
283
284pub struct SystemConverter<F: IntegrateFloat> {
286 odes: Vec<HigherOrderODE<F>>,
287 total_states: usize,
288}
289
290impl<F: IntegrateFloat> SystemConverter<F> {
291 pub fn new() -> Self {
293 SystemConverter {
294 odes: Vec::new(),
295 total_states: 0,
296 }
297 }
298
299 pub fn add_ode(&mut self, ode: HigherOrderODE<F>) -> &mut Self {
301 self.total_states += ode.order;
302 self.odes.push(ode);
303 self
304 }
305
306 pub fn convert(&self) -> IntegrateResult<FirstOrderSystem<F>> {
308 let mut all_state_vars = Vec::new();
309 let mut all_expressions = Vec::new();
310 let mut all_variable_map = HashMap::new();
311
312 for ode in &self.odes {
313 let system = higher_order_to_first_order(ode)?;
314 all_state_vars.extend(system.state_vars);
315 all_expressions.extend(system.expressions);
316 all_variable_map.extend(system.variable_map);
317 }
318
319 Ok(FirstOrderSystem {
320 state_vars: all_state_vars,
321 expressions: all_expressions,
322 variable_map: all_variable_map,
323 })
324 }
325}
326
327impl<F: IntegrateFloat> Default for SystemConverter<F> {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use crate::{
337 higher_order_to_first_order, HigherOrderODE, SymbolicExpression,
338 SymbolicExpression::{Neg, Var},
339 Variable,
340 };
341
342 #[test]
343 fn test_second_order_conversion() {
344 let x: SymbolicExpression<f64> = Var(Variable::new("x"));
346 let expr = Neg(Box::new(x));
347
348 let ode = HigherOrderODE::new(2, "x", "t", expr).unwrap();
349 let system = higher_order_to_first_order(&ode).unwrap();
350
351 assert_eq!(system.state_vars.len(), 2);
352 assert_eq!(system.expressions.len(), 2);
353
354 if let Var(v) = &system.expressions[0] {
356 assert_eq!(v.name, "x");
357 assert_eq!(v.index, Some(1));
358 } else {
359 panic!(
360 "Expected variable expression, got {:?}",
361 system.expressions[0]
362 );
363 }
364 }
365}