differential_equations/ode/
ode.rs

1//! Defines system of differential equations for numerical solvers.
2//! The NumericalMethods use this trait to take a input system from the user and solve
3//! Includes a differential equation and optional event function to interupt solver
4//! given a condition or event.
5
6use crate::{
7    linalg::Matrix,
8    traits::{Real, State},
9};
10
11/// ODE Trait for Differential Equations
12///
13/// ODE trait defines the differential equation dydt = f(t, y) for the solver.
14/// The differential equation is used to solve the ordinary differential equation.
15/// The trait also includes a solout function to interupt the solver when a condition
16/// is met or event occurs.
17///
18/// # Impl
19/// * `diff`    - Differential Equation dydt = f(t, y) in form f(t, &y, &mut dydt).
20/// * `event`   - Event function to interupt solver when condition is met or event occurs.
21/// * `jacobian` - Jacobian matrix J = df/dy for the system of equations.
22///
23/// Note that the event and jacobian functions are optional and can be left out when implementing.
24///
25#[allow(unused_variables)]
26pub trait ODE<T = f64, Y = f64>
27where
28    T: Real,
29    Y: State<T>,
30{
31    /// Differential Equation dydt = f(t, y)
32    ///
33    /// An ordinary differential equation (ODE) takes a independent variable
34    /// which in this case is 't' as it is typically time and a dependent variable
35    /// which is a vector of values 'y'. The ODE returns the derivative of the
36    /// dependent variable 'y' with respect to the independent variable 't' as
37    /// dydt = f(t, y).
38    ///
39    /// For efficiency and ergonomics the derivative is calculated from an argument
40    /// of a mutable reference to the derivative vector dydt. This allows for a
41    /// derivatives to be calculated in place which is more efficient as iterative
42    /// ODE solvers require the derivative to be calculated at each step without
43    /// regard to the previous value.
44    ///
45    /// # Arguments
46    /// * `t`    - Independent variable point.
47    /// * `y`    - Dependent variable point.
48    /// * `dydt` - Derivative point.
49    ///
50    fn diff(&self, t: T, y: &Y, dydt: &mut Y);
51
52    /// jacobian matrix J = df/dy
53    ///
54    /// The jacobian matrix is a matrix of partial derivatives of a vector-valued function.
55    /// It describes the local behavior of the system of equations and can be used to improve
56    /// the efficiency of certain solvers by providing information about the local behavior
57    /// of the system of equations.
58    ///
59    /// By default, this method uses a finite difference approximation.
60    /// Users can override this with an analytical implementation for better efficiency.
61    ///
62    /// # Arguments
63    /// * `t` - Independent variable grid point.
64    /// * `y` - Dependent variable vector.
65    /// * `j` - jacobian matrix. This matrix should be pre-sized by the caller to `dim x dim` where `dim = y.len()`.
66    ///
67    fn jacobian(&self, t: T, y: &Y, j: &mut Matrix<T>) {
68        // Default implementation using forward finite differences
69        let dim = y.len();
70        let mut y_perturbed = *y;
71        let mut f_perturbed = Y::zeros();
72        let mut f_origin = Y::zeros();
73
74        // Compute the unperturbed derivative
75        self.diff(t, y, &mut f_origin);
76
77        // Use sqrt of machine epsilon for finite differences
78        let eps = T::default_epsilon().sqrt();
79
80        // For each column of the jacobian
81        for j_col in 0..dim {
82            // Get the original value
83            let y_original_j = y.get(j_col);
84
85            // Calculate perturbation size (max of component magnitude or 1.0)
86            let perturbation = eps * y_original_j.abs().max(T::one());
87
88            // Perturb the component
89            y_perturbed.set(j_col, y_original_j + perturbation);
90
91            // Evaluate function with perturbed value
92            self.diff(t, &y_perturbed, &mut f_perturbed);
93
94            // Restore original value
95            y_perturbed.set(j_col, y_original_j);
96
97            // Compute finite difference approximation for this column
98            for i_row in 0..dim {
99                j[(i_row, j_col)] = (f_perturbed.get(i_row) - f_origin.get(i_row)) / perturbation;
100            }
101        }
102    }
103}