Skip to main content

differential_equations/ode/
ode.rs

1//! Defines system of differential equations for numerical solvers.
2//! The NumericalMethods use this trait to take a input system from the user and solve
3//! Includes a differential equation. Event handling is provided via the separate `Event` trait.
4
5use crate::{
6    linalg::Matrix,
7    traits::{DefaultState, Real, State},
8};
9
10/// ODE Trait for Differential Equations
11///
12/// ODE trait defines the differential equation dydt = f(t, y) for the solver.
13/// The differential equation is used to solve the ordinary differential equation.
14///
15/// # Impl
16/// * `diff`    - Differential Equation dydt = f(t, y) in form f(t, &y, &mut dydt).
17/// * `jacobian` - Jacobian matrix J = df/dy for the system of equations.
18///
19/// Note that the jacobian function is optional and can be left out when implementing.
20pub trait ODE<T = f64, Y = DefaultState<T>>
21where
22    T: Real,
23    Y: State<T>,
24{
25    /// Differential Equation dydt = f(t, y)
26    ///
27    /// An ordinary differential equation (ODE) takes a independent variable
28    /// which in this case is 't' as it is typically time and a dependent variable
29    /// which is a vector of values 'y'. The ODE returns the derivative of the
30    /// dependent variable 'y' with respect to the independent variable 't' as
31    /// dydt = f(t, y).
32    ///
33    /// For efficiency and ergonomics the derivative is calculated from an argument
34    /// of a mutable reference to the derivative vector dydt. This allows for a
35    /// derivatives to be calculated in place which is more efficient as iterative
36    /// ODE solvers require the derivative to be calculated at each step without
37    /// regard to the previous value.
38    ///
39    /// # Arguments
40    /// * `t`    - Independent variable point.
41    /// * `y`    - Dependent variable point.
42    /// * `dydt` - Derivative point.
43    ///
44    fn diff(&self, t: T, y: &Y, dydt: &mut Y);
45
46    /// jacobian matrix J = df/dy
47    ///
48    /// The jacobian matrix is a matrix of partial derivatives of a vector-valued function.
49    /// It describes the local behavior of the system of equations and can be used to improve
50    /// the efficiency of certain solvers by providing information about the local behavior
51    /// of the system of equations.
52    ///
53    /// By default, this method uses a finite difference approximation.
54    /// Users can override this with an analytical implementation for better efficiency.
55    ///
56    /// # Arguments
57    /// * `t` - Independent variable grid point.
58    /// * `y` - Dependent variable vector.
59    /// * `j` - jacobian matrix. This matrix should be pre-sized by the caller to `dim x dim` where `dim = y.len()`.
60    ///
61    fn jacobian(&self, t: T, y: &Y, j: &mut Matrix<T>) {
62        // Default implementation using forward finite differences
63        let dim = y.len();
64        let mut y_perturbed = y.clone();
65        let mut f_perturbed = y.zeros_like();
66        let mut f_origin = y.zeros_like();
67
68        // Compute the unperturbed derivative
69        self.diff(t, y, &mut f_origin);
70
71        // Use sqrt of machine epsilon for finite differences
72        let eps = T::default_epsilon().sqrt();
73
74        // For each column of the jacobian
75        for j_col in 0..dim {
76            // Get the original value
77            let y_original_j = y.get_component(j_col);
78
79            // Calculate perturbation size (max of component magnitude or 1.0)
80            let perturbation = eps * y_original_j.abs().max(T::one());
81
82            // Perturb the component
83            y_perturbed.copy_from_state(y);
84            y_perturbed.set_component(j_col, y_original_j + perturbation);
85
86            // Evaluate function with perturbed value
87            self.diff(t, &y_perturbed, &mut f_perturbed);
88
89            // Compute finite difference approximation for this column
90            for i_row in 0..dim {
91                j[(i_row, j_col)] = (f_perturbed.get_component(i_row)
92                    - f_origin.get_component(i_row))
93                    / perturbation;
94            }
95        }
96    }
97}
98
99impl<EqType, T: Real, Y: State<T>> ODE<T, Y> for &EqType
100where
101    EqType: ODE<T, Y> + ?Sized,
102{
103    fn diff(&self, t: T, y: &Y, dydt: &mut Y) {
104        (*self).diff(t, y, dydt);
105    }
106
107    fn jacobian(&self, t: T, y: &Y, j: &mut Matrix<T>) {
108        (*self).jacobian(t, y, j);
109    }
110}