differential_equations/ode/ode.rs
1//! Defines system of differential equations for numerical solvers.
2//! The NumericalMethods use this trait to take a input system from the user and solve
3//! Includes a differential equation and optional event function to interupt solver
4//! given a condition or event.
5
6use crate::{
7 linalg::Matrix,
8 traits::{Real, State},
9};
10
11/// ODE Trait for Differential Equations
12///
13/// ODE trait defines the differential equation dydt = f(t, y) for the solver.
14/// The differential equation is used to solve the ordinary differential equation.
15/// The trait also includes a solout function to interupt the solver when a condition
16/// is met or event occurs.
17///
18/// # Impl
19/// * `diff` - Differential Equation dydt = f(t, y) in form f(t, &y, &mut dydt).
20/// * `event` - Event function to interupt solver when condition is met or event occurs.
21/// * `jacobian` - Jacobian matrix J = df/dy for the system of equations.
22///
23/// Note that the event and jacobian functions are optional and can be left out when implementing.
24///
25#[allow(unused_variables)]
26pub trait ODE<T = f64, Y = f64>
27where
28 T: Real,
29 Y: State<T>,
30{
31 /// Differential Equation dydt = f(t, y)
32 ///
33 /// An ordinary differential equation (ODE) takes a independent variable
34 /// which in this case is 't' as it is typically time and a dependent variable
35 /// which is a vector of values 'y'. The ODE returns the derivative of the
36 /// dependent variable 'y' with respect to the independent variable 't' as
37 /// dydt = f(t, y).
38 ///
39 /// For efficiency and ergonomics the derivative is calculated from an argument
40 /// of a mutable reference to the derivative vector dydt. This allows for a
41 /// derivatives to be calculated in place which is more efficient as iterative
42 /// ODE solvers require the derivative to be calculated at each step without
43 /// regard to the previous value.
44 ///
45 /// # Arguments
46 /// * `t` - Independent variable point.
47 /// * `y` - Dependent variable point.
48 /// * `dydt` - Derivative point.
49 ///
50 fn diff(&self, t: T, y: &Y, dydt: &mut Y);
51
52 /// jacobian matrix J = df/dy
53 ///
54 /// The jacobian matrix is a matrix of partial derivatives of a vector-valued function.
55 /// It describes the local behavior of the system of equations and can be used to improve
56 /// the efficiency of certain solvers by providing information about the local behavior
57 /// of the system of equations.
58 ///
59 /// By default, this method uses a finite difference approximation.
60 /// Users can override this with an analytical implementation for better efficiency.
61 ///
62 /// # Arguments
63 /// * `t` - Independent variable grid point.
64 /// * `y` - Dependent variable vector.
65 /// * `j` - jacobian matrix. This matrix should be pre-sized by the caller to `dim x dim` where `dim = y.len()`.
66 ///
67 fn jacobian(&self, t: T, y: &Y, j: &mut Matrix<T>) {
68 // Default implementation using forward finite differences
69 let dim = y.len();
70 let mut y_perturbed = *y;
71 let mut f_perturbed = Y::zeros();
72 let mut f_origin = Y::zeros();
73
74 // Compute the unperturbed derivative
75 self.diff(t, y, &mut f_origin);
76
77 // Use sqrt of machine epsilon for finite differences
78 let eps = T::default_epsilon().sqrt();
79
80 // For each column of the jacobian
81 for j_col in 0..dim {
82 // Get the original value
83 let y_original_j = y.get(j_col);
84
85 // Calculate perturbation size (max of component magnitude or 1.0)
86 let perturbation = eps * y_original_j.abs().max(T::one());
87
88 // Perturb the component
89 y_perturbed.set(j_col, y_original_j + perturbation);
90
91 // Evaluate function with perturbed value
92 self.diff(t, &y_perturbed, &mut f_perturbed);
93
94 // Restore original value
95 y_perturbed.set(j_col, y_original_j);
96
97 // Compute finite difference approximation for this column
98 for i_row in 0..dim {
99 j[(i_row, j_col)] = (f_perturbed.get(i_row) - f_origin.get(i_row)) / perturbation;
100 }
101 }
102 }
103}