differential_equations/ode/ode.rs
1//! Defines system of differential equations for numerical solvers.
2//! The NumericalMethods use this trait to take a input system from the user and solve
3//! Includes a differential equation. Event handling is provided via the separate `Event` trait.
4
5use crate::{
6 linalg::Matrix,
7 traits::{DefaultState, Real, State},
8};
9
10/// ODE Trait for Differential Equations
11///
12/// ODE trait defines the differential equation dydt = f(t, y) for the solver.
13/// The differential equation is used to solve the ordinary differential equation.
14///
15/// # Impl
16/// * `diff` - Differential Equation dydt = f(t, y) in form f(t, &y, &mut dydt).
17/// * `jacobian` - Jacobian matrix J = df/dy for the system of equations.
18///
19/// Note that the jacobian function is optional and can be left out when implementing.
20pub trait ODE<T = f64, Y = DefaultState<T>>
21where
22 T: Real,
23 Y: State<T>,
24{
25 /// Differential Equation dydt = f(t, y)
26 ///
27 /// An ordinary differential equation (ODE) takes a independent variable
28 /// which in this case is 't' as it is typically time and a dependent variable
29 /// which is a vector of values 'y'. The ODE returns the derivative of the
30 /// dependent variable 'y' with respect to the independent variable 't' as
31 /// dydt = f(t, y).
32 ///
33 /// For efficiency and ergonomics the derivative is calculated from an argument
34 /// of a mutable reference to the derivative vector dydt. This allows for a
35 /// derivatives to be calculated in place which is more efficient as iterative
36 /// ODE solvers require the derivative to be calculated at each step without
37 /// regard to the previous value.
38 ///
39 /// # Arguments
40 /// * `t` - Independent variable point.
41 /// * `y` - Dependent variable point.
42 /// * `dydt` - Derivative point.
43 ///
44 fn diff(&self, t: T, y: &Y, dydt: &mut Y);
45
46 /// jacobian matrix J = df/dy
47 ///
48 /// The jacobian matrix is a matrix of partial derivatives of a vector-valued function.
49 /// It describes the local behavior of the system of equations and can be used to improve
50 /// the efficiency of certain solvers by providing information about the local behavior
51 /// of the system of equations.
52 ///
53 /// By default, this method uses a finite difference approximation.
54 /// Users can override this with an analytical implementation for better efficiency.
55 ///
56 /// # Arguments
57 /// * `t` - Independent variable grid point.
58 /// * `y` - Dependent variable vector.
59 /// * `j` - jacobian matrix. This matrix should be pre-sized by the caller to `dim x dim` where `dim = y.len()`.
60 ///
61 fn jacobian(&self, t: T, y: &Y, j: &mut Matrix<T>) {
62 // Default implementation using forward finite differences
63 let dim = y.len();
64 let mut y_perturbed = y.clone();
65 let mut f_perturbed = y.zeros_like();
66 let mut f_origin = y.zeros_like();
67
68 // Compute the unperturbed derivative
69 self.diff(t, y, &mut f_origin);
70
71 // Use sqrt of machine epsilon for finite differences
72 let eps = T::default_epsilon().sqrt();
73
74 // For each column of the jacobian
75 for j_col in 0..dim {
76 // Get the original value
77 let y_original_j = y.get_component(j_col);
78
79 // Calculate perturbation size (max of component magnitude or 1.0)
80 let perturbation = eps * y_original_j.abs().max(T::one());
81
82 // Perturb the component
83 y_perturbed.copy_from_state(y);
84 y_perturbed.set_component(j_col, y_original_j + perturbation);
85
86 // Evaluate function with perturbed value
87 self.diff(t, &y_perturbed, &mut f_perturbed);
88
89 // Compute finite difference approximation for this column
90 for i_row in 0..dim {
91 j[(i_row, j_col)] = (f_perturbed.get_component(i_row)
92 - f_origin.get_component(i_row))
93 / perturbation;
94 }
95 }
96 }
97}
98
99impl<EqType, T: Real, Y: State<T>> ODE<T, Y> for &EqType
100where
101 EqType: ODE<T, Y> + ?Sized,
102{
103 fn diff(&self, t: T, y: &Y, dydt: &mut Y) {
104 (*self).diff(t, y, dydt);
105 }
106
107 fn jacobian(&self, t: T, y: &Y, j: &mut Matrix<T>) {
108 (*self).jacobian(t, y, j);
109 }
110}