differential_equations/ode/methods/runge_kutta/explicit/
adaptive_step.rs

1//! Adaptive step size Runge-Kutta methods without integrated dense output via cubic Hermite interpolation.
2
3/// Macro to create an adaptive Runge-Kutta solver with embedded error estimation
4/// and interpolation vs cubic Hermite interpolation.
5///
6/// # Arguments
7///
8/// * `name`: Name of the solver struct to create
9/// * `a`: Matrix of coefficients for intermediate stages
10/// * `b`: 2D array where first row is higher order weights, second row is lower order weights
11/// * `c`: Time offsets for each stage
12/// * `order`: Order of accuracy of the method
13/// * `stages`: Number of stages in the method
14///
15/// # Example
16///
17/// ```
18/// use differential_equations::adaptive_runge_kutta_method;
19///
20/// // Define RKF45 method
21/// adaptive_runge_kutta_method!(
22///     /// Runge-Kutta-Fehlberg 4(5) adaptive step size method
23///     name: RKF,
24///     a: [
25///         [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
26///         [1.0/4.0, 0.0, 0.0, 0.0, 0.0, 0.0],
27///         [3.0/32.0, 9.0/32.0, 0.0, 0.0, 0.0, 0.0],
28///         [1932.0/2197.0, -7200.0/2197.0, 7296.0/2197.0, 0.0, 0.0, 0.0],
29///         [439.0/216.0, -8.0, 3680.0/513.0, -845.0/4104.0, 0.0, 0.0],
30///         [-8.0/27.0, 2.0, -3544.0/2565.0, 1859.0/4104.0, -11.0/40.0, 0.0]
31///     ],
32///     b: [
33///         [16.0/135.0, 0.0, 6656.0/12825.0, 28561.0/56430.0, -9.0/50.0, 2.0/55.0], // 5th order
34///         [25.0/216.0, 0.0, 1408.0/2565.0, 2197.0/4104.0, -1.0/5.0, 0.0]           // 4th order
35///     ],
36///     c: [0.0, 1.0/4.0, 3.0/8.0, 12.0/13.0, 1.0, 1.0/2.0],
37///     order: 5,
38///     stages: 6
39/// );
40/// ```
41///
42/// # Note on Butcher Tableaus
43///
44/// The `a` matrix is typically a lower triangular matrix with zeros on the diagonal.
45/// when creating the `a` matrix for implementation simplicity it is generated as a
46/// 2D array with zeros in the upper triangular portion of the matrix. The array size
47/// is known at compile time and it is a O(1) operation to access the desired elements.
48/// When computing the Runge-Kutta stages only the elements in the lower triangular portion
49/// of the matrix and unnessary multiplication by zero is avoided. The Rust compiler is also
50/// likely to optimize the array out instead of memory addresses directly.
51///
52/// The `b` matrix is a 2D array where the first row is the higher order weights and the
53/// second row is the lower order weights. This is used for embedded error estimation.
54///
55#[macro_export]
56macro_rules! adaptive_runge_kutta_method {
57    (
58        $(#[$attr:meta])*
59        name: $name:ident,
60        a: $a:expr,
61        b: $b:expr,
62        c: $c:expr,
63        order: $order:expr,
64        stages: $stages:expr
65        $(,)? // Optional trailing comma
66    ) => {
67        $(#[$attr])*
68        #[doc = "\n\n"]
69        #[doc = "This adaptive solver was automatically generated using the `adaptive_runge_kutta_method` macro."]
70        pub struct $name<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> {
71            // Initial Step Size
72            pub h0: T,
73
74            // Current Step Size
75            h: T,
76
77            // Current State
78            t: T,
79            y: V,
80            dydt: V,
81
82            // Previous State
83            t_prev: T,
84            y_prev: V,
85            dydt_prev: V,
86
87            // Stage values (fixed size array of Vs)
88            k: [V; $stages],
89
90            // Constants from Butcher tableau (fixed size arrays)
91            a: [[T; $stages]; $stages],
92            b_higher: [T; $stages],
93            b_lower: [T; $stages],
94            c: [T; $stages],
95
96            // Settings
97            pub rtol: T,
98            pub atol: T,
99            pub h_max: T,
100            pub h_min: T,
101            pub max_steps: usize,
102            pub max_rejects: usize,
103            pub safety_factor: T,
104            pub min_scale: T,
105            pub max_scale: T,
106
107            // Iteration tracking
108            reject: bool,
109            n_stiff: usize,
110            steps: usize, // Number of steps taken
111
112            // Status
113            status: $crate::Status<T, V, D>,
114        }
115
116        impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> Default for $name<T, V, D> {
117            fn default() -> Self {
118                // Initialize k vectors with zeros
119                let k: [V; $stages] = [V::zeros(); $stages];
120
121                // Convert Butcher tableau values to type T
122                let a_t: [[T; $stages]; $stages] = $a.map(|row| row.map(|x| T::from_f64(x).unwrap()));
123
124                // Handle the 2D array for b, where first row is higher order and second row is lower order
125                let b_higher: [T; $stages] = $b[0].map(|x| T::from_f64(x).unwrap());
126                let b_lower: [T; $stages] = $b[1].map(|x| T::from_f64(x).unwrap());
127
128                let c_t: [T; $stages] = $c.map(|x| T::from_f64(x).unwrap());
129
130                $name {
131                    h0: T::from_f64(0.0).unwrap(),
132                    h: T::from_f64(0.0).unwrap(),
133                    t: T::from_f64(0.0).unwrap(),
134                    y: V::zeros(),
135                    dydt: V::zeros(),
136                    t_prev: T::from_f64(0.0).unwrap(),
137                    y_prev: V::zeros(),
138                    dydt_prev: V::zeros(),
139                    k,
140                    a: a_t,
141                    b_higher, // Higher order (b)
142                    b_lower,  // Lower order (b_hat)
143                    c: c_t,
144                    rtol: T::from_f64(1.0e-6).unwrap(),
145                    atol: T::from_f64(1.0e-6).unwrap(),
146                    h_max: T::infinity(),
147                    h_min: T::from_f64(0.0).unwrap(),
148                    max_steps: 10000,
149                    max_rejects: 100,
150                    safety_factor: T::from_f64(0.9).unwrap(),
151                    min_scale: T::from_f64(0.2).unwrap(),
152                    max_scale: T::from_f64(10.0).unwrap(),
153                    reject: false,
154                    n_stiff: 0,
155                    steps: 0,
156                    status: $crate::Status::Uninitialized,
157                }
158            }
159        }
160
161        impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> $crate::ode::ODENumericalMethod<T, V, D> for $name<T, V, D> {
162            fn init<F>(&mut self, ode: &F, t0: T, tf: T, y: &V) -> Result<$crate::alias::Evals, $crate::Error<T, V>>
163            where
164                F: $crate::ode::ODE<T, V, D>,
165            {
166                let mut evals = $crate::alias::Evals::new();
167
168                // If h0 is zero calculate h0
169                if self.h0 == T::zero() {
170                    self.h0 = $crate::ode::methods::h_init(ode, t0, tf, y, $order, self.rtol, self.atol, self.h_min, self.h_max);
171                }
172                evals.fcn += 1;
173
174                // Check bounds
175                match $crate::utils::validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
176                    Ok(h0) => self.h = h0,
177                    Err(status) => return Err(status),
178                }
179
180                // Initialize Statistics
181                self.reject = false;
182                self.n_stiff = 0;
183
184                // Initialize State
185                self.t = t0;
186                self.y = y.clone();
187                ode.diff(t0, y, &mut self.dydt);
188
189                // Initialize previous state
190                self.t_prev = t0;
191                self.y_prev = y.clone();
192                self.dydt_prev = self.dydt;
193
194                // Initialize Status
195                self.status = $crate::Status::Initialized;
196
197                Ok(evals)
198            }
199
200            fn step<F>(&mut self, ode: &F) -> Result<$crate::alias::Evals, $crate::Error<T, V>>
201            where
202                F: $crate::ode::ODE<T, V, D>,
203            {
204                let mut evals = $crate::alias::Evals::new();
205
206                // Make sure step size isn't too small
207                if self.h.abs() < T::default_epsilon() {
208                    self.status = $crate::Status::Error($crate::Error::StepSize {
209                        t: self.t,
210                        y: self.y
211                    });
212                    return Err($crate::Error::StepSize {
213                        t: self.t,
214                        y: self.y
215                    });
216                }
217
218                // Check if max steps has been reached
219                if self.steps >= self.max_steps {
220                    self.status = $crate::Status::Error($crate::Error::MaxSteps {
221                        t: self.t,
222                        y: self.y
223                    });
224                    return Err($crate::Error::MaxSteps {
225                        t: self.t,
226                        y: self.y
227                    });
228                }
229                self.steps += 1;
230
231                // Compute stages
232                ode.diff(self.t, &self.y, &mut self.k[0]);
233
234                for i in 1..$stages {
235                    let mut y_stage = self.y;
236
237                    for j in 0..i {
238                        y_stage += self.k[j] * (self.a[i][j] * self.h);
239                    }
240
241                    ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
242                }
243
244                // Compute higher order solution
245                let mut y_high = self.y;
246                for i in 0..$stages {
247                    y_high += self.k[i] * (self.b_higher[i] * self.h);
248                }
249
250                // Compute lower order solution for error estimation
251                let mut y_low = self.y;
252                for i in 0..$stages {
253                    y_low += self.k[i] * (self.b_lower[i] * self.h);
254                }
255
256                // Compute error estimate
257                let err = y_high - y_low;
258
259                // Calculate error norm
260                // Using WRMS (weighted root mean square) norm
261                let mut err_norm: T = T::zero();
262
263                // Iterate through state elements
264                for n in 0..self.y.len() {
265                    let tol = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
266                    err_norm = err_norm.max((err.get(n) / tol).abs());
267                };
268
269                // Determine if step is accepted
270                if err_norm <= T::one() {
271                    // Log previous state
272                    self.t_prev = self.t;
273                    self.y_prev = self.y;
274                    self.dydt_prev = self.dydt;
275
276                    if self.reject {
277                        // Not rejected this time
278                        self.n_stiff = 0;
279                        self.reject = false;
280                        self.status = $crate::Status::Solving;
281                    }
282
283                    // Update state with the higher-order solution
284                    self.t += self.h;
285                    self.y = y_high;
286                    ode.diff(self.t, &self.y, &mut self.dydt);
287
288                    // Update statistics
289                    evals.fcn += $stages + 1;
290                } else {
291                    // Step rejected
292                    self.reject = true;
293
294                    evals.fcn += $stages;
295                    self.status = $crate::Status::RejectedStep;
296                    self.n_stiff += 1;
297
298                    // Check for stiffness
299                    if self.n_stiff >= self.max_rejects {
300                        self.status = $crate::Status::Error($crate::Error::Stiffness {
301                            t: self.t, y: self.y
302                        });
303                        return Err($crate::Error::Stiffness {
304                            t: self.t, y: self.y
305                        });
306                    }
307                }
308
309                // Calculate new step size
310                let order = T::from_usize($order).unwrap();
311                let err_order = T::one() / order;
312
313                // Standard step size controller formula
314                let scale = self.safety_factor * err_norm.powf(-err_order);
315
316                // Apply constraints to step size changes
317                let scale = scale.max(self.min_scale).min(self.max_scale);
318
319                // Update step size
320                self.h *= scale;
321
322                // Ensure step size is within bounds
323                self.h = $crate::utils::constrain_step_size(self.h, self.h_min, self.h_max);
324                Ok(evals)
325            }
326
327            fn t(&self) -> T {
328                self.t
329            }
330
331            fn y(&self) -> &V {
332                &self.y
333            }
334
335            fn t_prev(&self) -> T {
336                self.t_prev
337            }
338
339            fn y_prev(&self) -> &V {
340                &self.y_prev
341            }
342
343            fn h(&self) -> T {
344                self.h
345            }
346
347            fn set_h(&mut self, h: T) {
348                self.h = h;
349            }
350
351            fn status(&self) -> &$crate::Status<T, V, D> {
352                &self.status
353            }
354
355            fn set_status(&mut self, status: $crate::Status<T, V, D>) {
356                self.status = status;
357            }
358        }
359
360        impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> $crate::interpolate::Interpolation<T, V> for $name<T, V, D> {
361            fn interpolate(&mut self, t_interp: T) -> Result<V, $crate::Error<T, V>> {
362                // Check if t is within bounds
363                if t_interp < self.t_prev || t_interp > self.t {
364                    return Err($crate::Error::OutOfBounds {
365                        t_interp,
366                        t_prev: self.t_prev,
367                        t_curr: self.t
368                    });
369                }
370
371                // Compute the interpolated value using cubic Hermite interpolation
372                let y_interp = $crate::interpolate::cubic_hermite_interpolate(self.t_prev, self.t, &self.y_prev, &self.y, &self.dydt_prev, &self.dydt, t_interp);
373
374                Ok(y_interp)
375            }
376        }
377
378        impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> $name<T, V, D> {
379            /// Create a new solver with the specified initial step size
380            pub fn new() -> Self {
381                Self {
382                    ..Default::default()
383                }
384            }
385
386            /// Set initial step size
387            pub fn h0(mut self, h0: T) -> Self {
388                self.h0 = h0;
389                self
390            }
391
392            /// Set the relative tolerance for error control
393            pub fn rtol(mut self, rtol: T) -> Self {
394                self.rtol = rtol;
395                self
396            }
397
398            /// Set the absolute tolerance for error control
399            pub fn atol(mut self, atol: T) -> Self {
400                self.atol = atol;
401                self
402            }
403
404            /// Set the minimum allowed step size
405            pub fn h_min(mut self, h_min: T) -> Self {
406                self.h_min = h_min;
407                self
408            }
409
410            /// Set the maximum allowed step size
411            pub fn h_max(mut self, h_max: T) -> Self {
412                self.h_max = h_max;
413                self
414            }
415
416            /// Set the maximum number of steps allowed
417            pub fn max_steps(mut self, max_steps: usize) -> Self {
418                self.max_steps = max_steps;
419                self
420            }
421
422            /// Set the maximum number of consecutive rejected steps before declaring stiffness
423            pub fn max_rejects(mut self, max_rejects: usize) -> Self {
424                self.max_rejects = max_rejects;
425                self
426            }
427
428            /// Set the safety factor for step size control (default: 0.9)
429            pub fn safety_factor(mut self, safety_factor: T) -> Self {
430                self.safety_factor = safety_factor;
431                self
432            }
433
434            /// Set the minimum scale factor for step size changes (default: 0.2)
435            pub fn min_scale(mut self, min_scale: T) -> Self {
436                self.min_scale = min_scale;
437                self
438            }
439
440            /// Set the maximum scale factor for step size changes (default: 10.0)
441            pub fn max_scale(mut self, max_scale: T) -> Self {
442                self.max_scale = max_scale;
443                self
444            }
445
446            /// Get the order of the method
447            pub fn order(&self) -> usize {
448                $order
449            }
450
451            /// Get the number of stages in the method
452            pub fn stages(&self) -> usize {
453                $stages
454            }
455        }
456    };
457}
458
459adaptive_runge_kutta_method!(
460    /// Runge-Kutta-Fehlberg 4(5) adaptive method
461    /// This method uses six function evaluations to calculate a fifth-order accurate
462    /// solution, with an embedded fourth-order method for error estimation.
463    /// The RKF45 method is one of the most widely used adaptive step size methods due to
464    /// its excellent balance of efficiency and accuracy.
465    ///
466    /// The Butcher Tableau is as follows:
467    /// ```text
468    /// 0      |
469    /// 1/4    | 1/4
470    /// 3/8    | 3/32         9/32
471    /// 12/13  | 1932/2197    -7200/2197  7296/2197
472    /// 1      | 439/216      -8          3680/513    -845/4104
473    /// 1/2    | -8/27        2           -3544/2565  1859/4104   -11/40
474    /// -----------------------------------------------------------------------
475    ///        | 16/135       0           6656/12825  28561/56430 -9/50       2/55    (5th order)
476    ///        | 25/216       0           1408/2565   2197/4104   -1/5        0       (4th order)
477    /// ```
478    ///
479    /// Reference: [Wikipedia](https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method#CITEREFFehlberg1969)
480    name: RKF,
481    a: [
482        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
483        [1.0/4.0, 0.0, 0.0, 0.0, 0.0, 0.0],
484        [3.0/32.0, 9.0/32.0, 0.0, 0.0, 0.0, 0.0],
485        [1932.0/2197.0, -7200.0/2197.0, 7296.0/2197.0, 0.0, 0.0, 0.0],
486        [439.0/216.0, -8.0, 3680.0/513.0, -845.0/4104.0, 0.0, 0.0],
487        [-8.0/27.0, 2.0, -3544.0/2565.0, 1859.0/4104.0, -11.0/40.0, 0.0]
488    ],
489    b: [
490        [16.0/135.0, 0.0, 6656.0/12825.0, 28561.0/56430.0, -9.0/50.0, 2.0/55.0], // 5th order
491        [25.0/216.0, 0.0, 1408.0/2565.0, 2197.0/4104.0, -1.0/5.0, 0.0]           // 4th order
492    ],
493    c: [0.0, 1.0/4.0, 3.0/8.0, 12.0/13.0, 1.0, 1.0/2.0],
494    order: 5,
495    stages: 6
496);
497
498adaptive_runge_kutta_method!(
499    /// Cash-Karp 4(5) adaptive method
500    /// This method uses six function evaluations to calculate a fifth-order accurate
501    /// solution, with an embedded fourth-order method for error estimation.
502    /// The Cash-Karp method is a variant of the Runge-Kutta-Fehlberg method that uses
503    /// different coefficients to achieve a more efficient and accurate solution.
504    ///
505    /// The Butcher Tableau is as follows:
506    /// ```text
507    /// 0      |
508    /// 1/5    | 1/5
509    /// 3/10   | 3/40         9/40
510    /// 3/5    | 3/10         -9/10       6/5
511    /// 1      | -11/54       5/2         -70/27      35/27
512    /// 7/8    | 1631/55296   175/512     575/13824   44275/110592 253/4096
513    /// ------------------------------------------------------------------------------------
514    ///        | 37/378       0           250/621     125/594     0           512/1771  (5th order)
515    ///        | 2825/27648   0           18575/48384 13525/55296 277/14336   1/4       (4th order)
516    /// ```
517    ///
518    /// Reference: [Wikipedia](https://en.wikipedia.org/wiki/Cash%E2%80%93Karp_method)
519    name: CashKarp,
520    a: [
521        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
522        [1.0/5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
523        [3.0/40.0, 9.0/40.0, 0.0, 0.0, 0.0, 0.0],
524        [3.0/10.0, -9.0/10.0, 6.0/5.0, 0.0, 0.0, 0.0],
525        [-11.0/54.0, 5.0/2.0, -70.0/27.0, 35.0/27.0, 0.0, 0.0],
526        [1631.0/55296.0, 175.0/512.0, 575.0/13824.0, 44275.0/110592.0, 253.0/4096.0, 0.0]
527    ],
528    b: [
529        [37.0/378.0, 0.0, 250.0/621.0, 125.0/594.0, 0.0, 512.0/1771.0], // 5th order
530        [2825.0/27648.0, 0.0, 18575.0/48384.0, 13525.0/55296.0, 277.0/14336.0, 1.0/4.0] // 4th order
531    ],
532    c: [0.0, 1.0/5.0, 3.0/10.0, 3.0/5.0, 1.0, 7.0/8.0],
533    order: 5,
534    stages: 6
535);