differential_equations/ode/ode.rs
1//! Defines system of differential equations for numerical solvers.
2//! The NumericalMethods use this trait to take a input system from the user and solve
3//! Includes a differential equation and optional event function to interupt solver
4//! given a condition or event.
5
6use crate::{
7 ControlFlag,
8 traits::{CallBackData, Real, State},
9};
10use nalgebra::DMatrix;
11
12/// ODE Trait for Differential Equations
13///
14/// ODE trait defines the differential equation dydt = f(t, y) for the solver.
15/// The differential equation is used to solve the ordinary differential equation.
16/// The trait also includes a solout function to interupt the solver when a condition
17/// is met or event occurs.
18///
19/// # Impl
20/// * `diff` - Differential Equation dydt = f(t, y) in form f(t, &y, &mut dydt).
21/// * `event` - Solout function to interupt solver when condition is met or event occurs.
22///
23/// Note that the solout function is optional and can be left out when implementing
24/// in which can it will be set to return false by default.
25///
26#[allow(unused_variables)]
27pub trait ODE<T = f64, V = f64, D = String>
28where
29 T: Real,
30 V: State<T>,
31 D: CallBackData,
32{
33 /// Differential Equation dydt = f(t, y)
34 ///
35 /// An ordinary differential equation (ODE) takes a independent variable
36 /// which in this case is 't' as it is typically time and a dependent variable
37 /// which is a vector of values 'y'. The ODE returns the derivative of the
38 /// dependent variable 'y' with respect to the independent variable 't' as
39 /// dydt = f(t, y).
40 ///
41 /// For efficiency and ergonomics the derivative is calculated from an argument
42 /// of a mutable reference to the derivative vector dydt. This allows for a
43 /// derivatives to be calculated in place which is more efficient as iterative
44 /// ODE solvers require the derivative to be calculated at each step without
45 /// regard to the previous value.
46 ///
47 /// # Arguments
48 /// * `t` - Independent variable point.
49 /// * `y` - Dependent variable point.
50 /// * `dydt` - Derivative point.
51 ///
52 fn diff(&self, t: T, y: &V, dydt: &mut V);
53
54 /// Event function to detect significant conditions during integration.
55 ///
56 /// Called after each step to detect events like threshold crossings,
57 /// singularities, or other mathematically/physically significant conditions.
58 /// Can be used to terminate integration when conditions occur.
59 ///
60 /// # Arguments
61 /// * `t` - Current independent variable point.
62 /// * `y` - Current dependent variable point.
63 ///
64 /// # Returns
65 /// * `ControlFlag` - Command to continue or stop solver.
66 ///
67 fn event(&self, t: T, y: &V) -> ControlFlag<T, V, D> {
68 ControlFlag::Continue
69 }
70
71 /// Jacobian matrix J = df/dy
72 ///
73 /// The Jacobian matrix is a matrix of partial derivatives of a vector-valued function.
74 /// It describes the local behavior of the system of equations and can be used to improve
75 /// the efficiency of certain solvers by providing information about the local behavior
76 /// of the system of equations.
77 ///
78 /// By default, this method uses a finite difference approximation.
79 /// Users can override this with an analytical implementation for better efficiency.
80 ///
81 /// # Arguments
82 /// * `t` - Independent variable grid point.
83 /// * `y` - Dependent variable vector.
84 /// * `j` - Jacobian matrix. This matrix should be pre-sized by the caller to `dim x dim` where `dim = y.len()`.
85 ///
86 fn jacobian(&self, t: T, y: &V, j: &mut DMatrix<T>) {
87 // Default implementation using forward finite differences
88 let dim = y.len();
89 let mut y_perturbed = y.clone();
90 let mut f_perturbed = V::zeros();
91 let mut f_origin = V::zeros();
92
93 // Compute the unperturbed derivative
94 self.diff(t, y, &mut f_origin);
95
96 // Use sqrt of machine epsilon for finite differences
97 let eps = T::default_epsilon().sqrt();
98
99 // For each column of the Jacobian
100 for j_col in 0..dim {
101 // Get the original value
102 let y_original_j = y.get(j_col);
103
104 // Calculate perturbation size (max of component magnitude or 1.0)
105 let perturbation = eps * y_original_j.abs().max(T::one());
106
107 // Perturb the component
108 y_perturbed.set(j_col, y_original_j + perturbation);
109
110 // Evaluate function with perturbed value
111 self.diff(t, &y_perturbed, &mut f_perturbed);
112
113 // Restore original value
114 y_perturbed.set(j_col, y_original_j);
115
116 // Compute finite difference approximation for this column
117 for i_row in 0..dim {
118 j[(i_row, j_col)] = (f_perturbed.get(i_row) - f_origin.get(i_row)) / perturbation;
119 }
120 }
121 }
122}