differential_equations/ode/methods/runge_kutta/explicit/
fixed_step.rs

1//! Fixed-step Runge-Kutta methods for solving ordinary differential equations.
2
3/// Macro to create a Runge-Kutta solver from a Butcher tableau with fixed-size arrays
4///
5/// # Arguments
6///
7/// * `name`: Name of the solver struct to create
8/// * `doc`: Documentation string for the solver
9/// * `a`: Matrix of coefficients for intermediate stages
10/// * `b`: Weights for final summation
11/// * `c`: Time offsets for each stage
12/// * `order`: Order of accuracy of the method
13/// * `stages`: Number of stages in the method
14///
15/// # Example
16///
17/// ```
18/// use differential_equations::runge_kutta_method;
19///
20/// // Define classical RK4 method
21/// runge_kutta_method!(
22///     /// Classical 4th Order Runge-Kutta Method
23///     name: RK4,
24///     a: [[0.0, 0.0, 0.0, 0.0],
25///         [0.5, 0.0, 0.0, 0.0],
26///         [0.0, 0.5, 0.0, 0.0],
27///         [0.0, 0.0, 1.0, 0.0]],
28///     b: [1.0/6.0, 2.0/6.0, 2.0/6.0, 1.0/6.0],
29///     c: [0.0, 0.5, 0.5, 1.0],
30///     order: 4,
31///     stages: 4
32/// );
33/// ```
34///
35/// # Note on Butcher Tableaus
36///
37/// The `a` matrix is typically a lower triangular matrix with zeros on the diagonal.
38/// when creating the `a` matrix for implementation simplicity it is generated as a
39/// 2D array with zeros in the upper triangular portion of the matrix. The array size
40/// is known at compile time and it is a O(1) operation to access the desired elements.
41/// When computing the Runge-Kutta stages only the elements in the lower triangular portion
42/// of the matrix and unnessary multiplication by zero is avoided. The Rust compiler is also
43/// likely to optimize the array out instead of memory addresses directly.
44///
45#[macro_export]
46macro_rules! runge_kutta_method {
47    (
48        $(#[$attr:meta])*
49        name: $name:ident,
50        a: $a:expr,
51        b: $b:expr,
52        c: $c:expr,
53        order: $order:expr,
54        stages: $stages:expr
55        $(,)? // Optional trailing comma
56    ) => {
57
58
59        $(#[$attr])*
60        #[doc = "\n\n"]
61        #[doc = "This solver was automatically generated using the `runge_kutta_method` macro."]
62        pub struct $name<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> {
63            // Step Size
64            pub h: T,
65
66            // Current State
67            t: T,
68            y: V,
69
70            // Previous State
71            t_prev: T,
72            y_prev: V,
73            dydt_prev: V,
74
75            // Stage values (fixed size arrays of Vectors)
76            k: [V; $stages],
77
78            // Constants from Butcher tableau (fixed size arrays)
79            a: [[T; $stages]; $stages],
80            b: [T; $stages],
81            c: [T; $stages],
82
83            // Status
84            status: $crate::Status<T, V, D>,
85        }
86
87        impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> Default for $name<T, V, D> {
88            fn default() -> Self {
89                // Convert Butcher tableau values to type T
90                let a_t: [[T; $stages]; $stages] = $a.map(|row| row.map(|x| T::from_f64(x).unwrap()));
91                let b_t: [T; $stages] = $b.map(|x| T::from_f64(x).unwrap());
92                let c_t: [T; $stages] = $c.map(|x| T::from_f64(x).unwrap());
93
94                $name {
95                    h: T::from_f64(0.1).unwrap(),
96                    t: T::from_f64(0.0).unwrap(),
97                    y: V::zeros(),
98                    t_prev: T::from_f64(0.0).unwrap(),
99                    y_prev: V::zeros(),
100                    dydt_prev: V::zeros(),
101                    // Initialize k vectors with zeros
102                    k: [V::zeros(); $stages],
103                    // Use the converted Butcher tableau
104                    a: a_t,
105                    b: b_t,
106                    c: c_t,
107                    status: $crate::Status::Uninitialized,
108                }
109            }
110        }
111
112        impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> $crate::ode::ODENumericalMethod<T, V, D> for $name<T, V, D> {
113            fn init<F>(&mut self, ode: &F, t0: T, tf: T, y: &V) -> Result<$crate::alias::Evals, $crate::Error<T, V>>
114            where
115                F: $crate::ode::ODE<T, V, D>
116            {
117                let mut evals = $crate::alias::Evals::new();
118
119                // Check Bounds
120                match $crate::utils::validate_step_size_parameters::<T, V, D>(self.h, T::zero(), T::infinity(), t0, tf) {
121                    Ok(_) => {},
122                    Err(e) => return Err(e),
123                }
124
125                // Initialize State
126                self.t = t0;
127                self.y = y.clone();
128                ode.diff(t0, y, &mut self.k[0]);
129                evals.fcn += 1;
130
131                // Initialize previous state
132                self.t_prev = t0;
133                self.y_prev = y.clone();
134                self.dydt_prev = self.k[0];
135
136                // Initialize Status
137                self.status = $crate::Status::Initialized;
138
139                Ok(evals)
140            }
141
142            fn step<F>(&mut self, ode: &F) -> Result<$crate::alias::Evals, $crate::Error<T, V>>
143            where
144                F: $crate::ode::ODE<T, V, D>
145            {
146                let mut evals = $crate::alias::Evals::new();
147
148                // Log previous state
149                self.t_prev = self.t;
150                self.y_prev = self.y;
151                self.dydt_prev = self.k[0];
152
153                // Compute k_0 = f(t, y)
154                ode.diff(self.t, &self.y, &mut self.k[0]);
155
156                // Compute stage values
157                for i in 1..$stages {
158                    // Start with the original y value
159                    let mut stage_y = self.y;
160
161                    // Add contribution from previous stages
162                    for j in 0..i {
163                        stage_y += self.k[j] * (self.a[i][j] * self.h);
164                    }
165
166                    // Compute k_i = f(t + c_i*h, stage_y)
167                    ode.diff(self.t + self.c[i] * self.h, &stage_y, &mut self.k[i]);
168                }
169                evals.fcn += $stages;
170
171                // Compute the final update
172                let mut delta_y = V::zeros();
173                for i in 0..$stages {
174                    delta_y += self.k[i] * (self.b[i] * self.h);
175                }
176
177                // Update state
178                self.y += delta_y;
179                self.t += self.h;
180
181                Ok(evals)
182            }
183
184            fn t(&self) -> T {
185                self.t
186            }
187
188            fn y(&self) -> &V {
189                &self.y
190            }
191
192            fn t_prev(&self) -> T {
193                self.t_prev
194            }
195
196            fn y_prev(&self) -> &V {
197                &self.y_prev
198            }
199
200            fn h(&self) -> T {
201                self.h
202            }
203
204            fn set_h(&mut self, h: T) {
205                self.h = h;
206            }
207
208            fn status(&self) -> &$crate::Status<T, V, D> {
209                &self.status
210            }
211
212            fn set_status(&mut self, status: $crate::Status<T, V, D>) {
213                self.status = status;
214            }
215        }
216
217        impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> $crate::interpolate::Interpolation<T, V> for $name<T, V, D> {
218            fn interpolate(&mut self, t_interp: T) -> Result<V, $crate::Error<T, V>> {
219                // Check if t is within the bounds of the current step
220                if t_interp < self.t_prev || t_interp > self.t {
221                    return Err($crate::Error::OutOfBounds {
222                        t_interp,
223                        t_prev: self.t_prev,
224                        t_curr: self.t });
225                }
226
227                let y_interp = $crate::interpolate::cubic_hermite_interpolate(self.t_prev, self.t, &self.y_prev, &self.y, &self.dydt_prev, &self.k[0], t_interp);
228
229                Ok(y_interp)
230            }
231        }
232
233        impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> $name<T, V, D> {
234            /// Create a new solver with the specified step size
235            ///
236            /// # Arguments
237            /// * `h` - Step size
238            ///
239            /// # Returns
240            /// * A new solver instance
241            pub fn new(h: T) -> Self {
242                $name {
243                    h,
244                    ..Default::default()
245                }
246            }
247
248            /// Get the order of accuracy of this method
249            pub fn order(&self) -> usize {
250                $order
251            }
252
253            /// Get the number of stages in this method
254            pub fn stages(&self) -> usize {
255                $stages
256            }
257        }
258    };
259}
260
261runge_kutta_method!(
262    /// Euler's Method (1st Order Runge-Kutta) for solving ordinary differential equations.
263    ///
264    /// Euler's method is the simplest form of Runge-Kutta methods, and is a first-order method also known as RK1.
265    ///
266    /// The Butcher Tableau is as follows:
267    /// ```text
268    /// 0 | 0
269    /// -----
270    ///   | 1
271    /// ```
272    ///
273    /// Reference: [Wikipedia](https://en.wikipedia.org/wiki/Euler_method)
274    name: Euler,
275    a: [[0.0]],
276    b: [1.0],
277    c: [0.0],
278    order: 1,
279    stages: 1
280);
281
282runge_kutta_method!(
283    /// Midpoint Method (2nd Order Runge-Kutta) for solving ordinary differential equations.
284    ///
285    /// The Butcher Tableau is as follows:
286    /// ```text
287    /// 0   |
288    /// 1/2 | 1/2
289    /// ------------
290    ///     | 0   1
291    /// ```
292    ///
293    /// Reference: [Wikipedia](https://en.wikipedia.org/wiki/Midpoint_method)
294    name: Midpoint,
295    a: [[0.0, 0.0],
296        [0.5, 0.0]],
297    b: [0.0, 1.0],
298    c: [0.0, 0.5],
299    order: 2,
300    stages: 2
301);
302
303runge_kutta_method!(
304    /// Heun's Method (2nd Order Runge-Kutta) for solving ordinary differential equations.
305    ///
306    /// The Butcher Tableau is as follows:
307    /// ```text
308    /// 0   |
309    /// 1   | 1
310    /// ------------
311    ///     | 1/2 1/2
312    /// ```
313    ///
314    /// Reference: [Wikipedia](https://en.wikipedia.org/wiki/Heun%27s_method)
315    name: Heun,
316    a: [[0.0, 0.0],
317        [1.0, 0.0]],
318    b: [0.5, 0.5],
319    c: [0.0, 1.0],
320    order: 2,
321    stages: 2
322);
323
324runge_kutta_method!(
325    /// Ralston's Method (2nd Order Runge-Kutta) for solving ordinary differential equations.
326    ///
327    /// The Butcher Tableau is as follows:
328    /// ```text
329    /// 0   |
330    /// 2/3 | 2/3
331    /// ------------
332    ///     | 1/4 3/4
333    /// ```
334    ///
335    /// Reference: [Wikipedia](https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods#Second-order_methods_with_two_stages)
336    name: Ralston,
337    a: [[0.0, 0.0],
338        [2.0/3.0, 0.0]],
339    b: [1.0/4.0, 3.0/4.0],
340    c: [0.0, 2.0/3.0],
341    order: 2,
342    stages: 2
343);
344
345runge_kutta_method!(
346    /// Classic Runge-Kutta 4 method for solving ordinary differential equations.
347    ///
348    /// The Butcher Tableau is as follows:
349    /// ```text
350    /// 0   |
351    /// 0.5 | 0.5
352    /// 0.5 | 0   0.5
353    /// 1   | 0   0   1
354    /// ---------------------
355    ///    | 1/6 1/3 1/3 1/6
356    /// ```
357    ///
358    /// Reference: [Wikipedia](https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods#Examples)
359    name: RK4,
360    a: [[0.0, 0.0, 0.0, 0.0],
361        [0.5, 0.0, 0.0, 0.0],
362        [0.0, 0.5, 0.0, 0.0],
363        [0.0, 0.0, 1.0, 0.0]],
364    b: [1.0/6.0, 1.0/3.0, 1.0/3.0, 1.0/6.0],
365    c: [0.0, 0.5, 0.5, 1.0],
366    order: 4,
367    stages: 4
368);
369
370runge_kutta_method!(
371    /// Three-Eighths Rule (4th Order Runge-Kutta) for solving ordinary differential equations.
372    /// The primary advantage this method has is that almost all of the error coefficients
373    /// are smaller than in the popular method, but it requires slightly more FLOPs
374    /// (floating-point operations) per time step.
375    ///
376    /// The Butcher Tableau is as follows:
377    /// ```text
378    /// 0   |
379    /// 1/3 | 1/3
380    /// 2/3 | -1/3 1
381    /// 1   | 1   -1   1
382    /// ---------------------
383    ///   | 1/8 3/8 3/8 1/8
384    /// ```
385    ///
386    /// Reference: [Wikipedia](https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods#Examples)
387    ///
388    name: ThreeEights,
389    a: [[0.0, 0.0, 0.0, 0.0],
390        [1.0/3.0, 0.0, 0.0, 0.0],
391        [-1.0/3.0, 1.0, 0.0, 0.0],
392        [1.0, -1.0, 1.0, 0.0]],
393    b: [1.0/8.0, 3.0/8.0, 3.0/8.0, 1.0/8.0],
394    c: [0.0, 1.0/3.0, 2.0/3.0, 1.0],
395    order: 4,
396    stages: 4
397);