differential_equations/ode/methods/runge_kutta/explicit/adaptive_step.rs
1//! Adaptive step size Runge-Kutta methods without integrated dense output via cubic Hermite interpolation.
2
3/// Macro to create an adaptive Runge-Kutta solver with embedded error estimation
4/// and interpolation vs cubic Hermite interpolation.
5///
6/// # Arguments
7///
8/// * `name`: Name of the solver struct to create
9/// * `a`: Matrix of coefficients for intermediate stages
10/// * `b`: 2D array where first row is higher order weights, second row is lower order weights
11/// * `c`: Time offsets for each stage
12/// * `order`: Order of accuracy of the method
13/// * `stages`: Number of stages in the method
14///
15/// # Example
16///
17/// ```
18/// use differential_equations::adaptive_runge_kutta_method;
19///
20/// // Define RKF45 method
21/// adaptive_runge_kutta_method!(
22/// /// Runge-Kutta-Fehlberg 4(5) adaptive step size method
23/// name: RKF,
24/// a: [
25/// [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
26/// [1.0/4.0, 0.0, 0.0, 0.0, 0.0, 0.0],
27/// [3.0/32.0, 9.0/32.0, 0.0, 0.0, 0.0, 0.0],
28/// [1932.0/2197.0, -7200.0/2197.0, 7296.0/2197.0, 0.0, 0.0, 0.0],
29/// [439.0/216.0, -8.0, 3680.0/513.0, -845.0/4104.0, 0.0, 0.0],
30/// [-8.0/27.0, 2.0, -3544.0/2565.0, 1859.0/4104.0, -11.0/40.0, 0.0]
31/// ],
32/// b: [
33/// [16.0/135.0, 0.0, 6656.0/12825.0, 28561.0/56430.0, -9.0/50.0, 2.0/55.0], // 5th order
34/// [25.0/216.0, 0.0, 1408.0/2565.0, 2197.0/4104.0, -1.0/5.0, 0.0] // 4th order
35/// ],
36/// c: [0.0, 1.0/4.0, 3.0/8.0, 12.0/13.0, 1.0, 1.0/2.0],
37/// order: 5,
38/// stages: 6
39/// );
40/// ```
41///
42/// # Note on Butcher Tableaus
43///
44/// The `a` matrix is typically a lower triangular matrix with zeros on the diagonal.
45/// when creating the `a` matrix for implementation simplicity it is generated as a
46/// 2D array with zeros in the upper triangular portion of the matrix. The array size
47/// is known at compile time and it is a O(1) operation to access the desired elements.
48/// When computing the Runge-Kutta stages only the elements in the lower triangular portion
49/// of the matrix and unnessary multiplication by zero is avoided. The Rust compiler is also
50/// likely to optimize the array out instead of memory addresses directly.
51///
52/// The `b` matrix is a 2D array where the first row is the higher order weights and the
53/// second row is the lower order weights. This is used for embedded error estimation.
54///
55#[macro_export]
56macro_rules! adaptive_runge_kutta_method {
57 (
58 $(#[$attr:meta])*
59 name: $name:ident,
60 a: $a:expr,
61 b: $b:expr,
62 c: $c:expr,
63 order: $order:expr,
64 stages: $stages:expr
65 $(,)? // Optional trailing comma
66 ) => {
67 $(#[$attr])*
68 #[doc = "\n\n"]
69 #[doc = "This adaptive solver was automatically generated using the `adaptive_runge_kutta_method` macro."]
70 pub struct $name<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> {
71 // Initial Step Size
72 pub h0: T,
73
74 // Current Step Size
75 h: T,
76
77 // Current State
78 t: T,
79 y: V,
80 dydt: V,
81
82 // Previous State
83 t_prev: T,
84 y_prev: V,
85 dydt_prev: V,
86
87 // Stage values (fixed size array of Vs)
88 k: [V; $stages],
89
90 // Constants from Butcher tableau (fixed size arrays)
91 a: [[T; $stages]; $stages],
92 b_higher: [T; $stages],
93 b_lower: [T; $stages],
94 c: [T; $stages],
95
96 // Settings
97 pub rtol: T,
98 pub atol: T,
99 pub h_max: T,
100 pub h_min: T,
101 pub max_steps: usize,
102 pub max_rejects: usize,
103 pub safety_factor: T,
104 pub min_scale: T,
105 pub max_scale: T,
106
107 // Iteration tracking
108 reject: bool,
109 n_stiff: usize,
110 steps: usize, // Number of steps taken
111
112 // Status
113 status: $crate::Status<T, V, D>,
114 }
115
116 impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> Default for $name<T, V, D> {
117 fn default() -> Self {
118 // Initialize k vectors with zeros
119 let k: [V; $stages] = [V::zeros(); $stages];
120
121 // Convert Butcher tableau values to type T
122 let a_t: [[T; $stages]; $stages] = $a.map(|row| row.map(|x| T::from_f64(x).unwrap()));
123
124 // Handle the 2D array for b, where first row is higher order and second row is lower order
125 let b_higher: [T; $stages] = $b[0].map(|x| T::from_f64(x).unwrap());
126 let b_lower: [T; $stages] = $b[1].map(|x| T::from_f64(x).unwrap());
127
128 let c_t: [T; $stages] = $c.map(|x| T::from_f64(x).unwrap());
129
130 $name {
131 h0: T::from_f64(0.0).unwrap(),
132 h: T::from_f64(0.0).unwrap(),
133 t: T::from_f64(0.0).unwrap(),
134 y: V::zeros(),
135 dydt: V::zeros(),
136 t_prev: T::from_f64(0.0).unwrap(),
137 y_prev: V::zeros(),
138 dydt_prev: V::zeros(),
139 k,
140 a: a_t,
141 b_higher, // Higher order (b)
142 b_lower, // Lower order (b_hat)
143 c: c_t,
144 rtol: T::from_f64(1.0e-6).unwrap(),
145 atol: T::from_f64(1.0e-6).unwrap(),
146 h_max: T::infinity(),
147 h_min: T::from_f64(0.0).unwrap(),
148 max_steps: 10000,
149 max_rejects: 100,
150 safety_factor: T::from_f64(0.9).unwrap(),
151 min_scale: T::from_f64(0.2).unwrap(),
152 max_scale: T::from_f64(10.0).unwrap(),
153 reject: false,
154 n_stiff: 0,
155 steps: 0,
156 status: $crate::Status::Uninitialized,
157 }
158 }
159 }
160
161 impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> $crate::ode::ODENumericalMethod<T, V, D> for $name<T, V, D> {
162 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y: &V) -> Result<$crate::alias::Evals, $crate::Error<T, V>>
163 where
164 F: $crate::ode::ODE<T, V, D>,
165 {
166 let mut evals = $crate::alias::Evals::new();
167
168 // If h0 is zero calculate h0
169 if self.h0 == T::zero() {
170 self.h0 = $crate::ode::methods::h_init(ode, t0, tf, y, $order, self.rtol, self.atol, self.h_min, self.h_max);
171 }
172 evals.fcn += 1;
173
174 // Check bounds
175 match $crate::utils::validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
176 Ok(h0) => self.h = h0,
177 Err(status) => return Err(status),
178 }
179
180 // Initialize Statistics
181 self.reject = false;
182 self.n_stiff = 0;
183
184 // Initialize State
185 self.t = t0;
186 self.y = y.clone();
187 ode.diff(t0, y, &mut self.dydt);
188
189 // Initialize previous state
190 self.t_prev = t0;
191 self.y_prev = y.clone();
192 self.dydt_prev = self.dydt;
193
194 // Initialize Status
195 self.status = $crate::Status::Initialized;
196
197 Ok(evals)
198 }
199
200 fn step<F>(&mut self, ode: &F) -> Result<$crate::alias::Evals, $crate::Error<T, V>>
201 where
202 F: $crate::ode::ODE<T, V, D>,
203 {
204 let mut evals = $crate::alias::Evals::new();
205
206 // Make sure step size isn't too small
207 if self.h.abs() < T::default_epsilon() {
208 self.status = $crate::Status::Error($crate::Error::StepSize {
209 t: self.t,
210 y: self.y
211 });
212 return Err($crate::Error::StepSize {
213 t: self.t,
214 y: self.y
215 });
216 }
217
218 // Check if max steps has been reached
219 if self.steps >= self.max_steps {
220 self.status = $crate::Status::Error($crate::Error::MaxSteps {
221 t: self.t,
222 y: self.y
223 });
224 return Err($crate::Error::MaxSteps {
225 t: self.t,
226 y: self.y
227 });
228 }
229 self.steps += 1;
230
231 // Compute stages
232 ode.diff(self.t, &self.y, &mut self.k[0]);
233
234 for i in 1..$stages {
235 let mut y_stage = self.y;
236
237 for j in 0..i {
238 y_stage += self.k[j] * (self.a[i][j] * self.h);
239 }
240
241 ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
242 }
243
244 // Compute higher order solution
245 let mut y_high = self.y;
246 for i in 0..$stages {
247 y_high += self.k[i] * (self.b_higher[i] * self.h);
248 }
249
250 // Compute lower order solution for error estimation
251 let mut y_low = self.y;
252 for i in 0..$stages {
253 y_low += self.k[i] * (self.b_lower[i] * self.h);
254 }
255
256 // Compute error estimate
257 let err = y_high - y_low;
258
259 // Calculate error norm
260 // Using WRMS (weighted root mean square) norm
261 let mut err_norm: T = T::zero();
262
263 // Iterate through state elements
264 for n in 0..self.y.len() {
265 let tol = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
266 err_norm = err_norm.max((err.get(n) / tol).abs());
267 };
268
269 // Determine if step is accepted
270 if err_norm <= T::one() {
271 // Log previous state
272 self.t_prev = self.t;
273 self.y_prev = self.y;
274 self.dydt_prev = self.dydt;
275
276 if self.reject {
277 // Not rejected this time
278 self.n_stiff = 0;
279 self.reject = false;
280 self.status = $crate::Status::Solving;
281 }
282
283 // Update state with the higher-order solution
284 self.t += self.h;
285 self.y = y_high;
286 ode.diff(self.t, &self.y, &mut self.dydt);
287
288 // Update statistics
289 evals.fcn += $stages + 1;
290 } else {
291 // Step rejected
292 self.reject = true;
293
294 evals.fcn += $stages;
295 self.status = $crate::Status::RejectedStep;
296 self.n_stiff += 1;
297
298 // Check for stiffness
299 if self.n_stiff >= self.max_rejects {
300 self.status = $crate::Status::Error($crate::Error::Stiffness {
301 t: self.t, y: self.y
302 });
303 return Err($crate::Error::Stiffness {
304 t: self.t, y: self.y
305 });
306 }
307 }
308
309 // Calculate new step size
310 let order = T::from_usize($order).unwrap();
311 let err_order = T::one() / order;
312
313 // Standard step size controller formula
314 let scale = self.safety_factor * err_norm.powf(-err_order);
315
316 // Apply constraints to step size changes
317 let scale = scale.max(self.min_scale).min(self.max_scale);
318
319 // Update step size
320 self.h *= scale;
321
322 // Ensure step size is within bounds
323 self.h = $crate::utils::constrain_step_size(self.h, self.h_min, self.h_max);
324 Ok(evals)
325 }
326
327 fn t(&self) -> T {
328 self.t
329 }
330
331 fn y(&self) -> &V {
332 &self.y
333 }
334
335 fn t_prev(&self) -> T {
336 self.t_prev
337 }
338
339 fn y_prev(&self) -> &V {
340 &self.y_prev
341 }
342
343 fn h(&self) -> T {
344 self.h
345 }
346
347 fn set_h(&mut self, h: T) {
348 self.h = h;
349 }
350
351 fn status(&self) -> &$crate::Status<T, V, D> {
352 &self.status
353 }
354
355 fn set_status(&mut self, status: $crate::Status<T, V, D>) {
356 self.status = status;
357 }
358 }
359
360 impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> $crate::interpolate::Interpolation<T, V> for $name<T, V, D> {
361 fn interpolate(&mut self, t_interp: T) -> Result<V, $crate::Error<T, V>> {
362 // Check if t is within bounds
363 if t_interp < self.t_prev || t_interp > self.t {
364 return Err($crate::Error::OutOfBounds {
365 t_interp,
366 t_prev: self.t_prev,
367 t_curr: self.t
368 });
369 }
370
371 // Compute the interpolated value using cubic Hermite interpolation
372 let y_interp = $crate::interpolate::cubic_hermite_interpolate(self.t_prev, self.t, &self.y_prev, &self.y, &self.dydt_prev, &self.dydt, t_interp);
373
374 Ok(y_interp)
375 }
376 }
377
378 impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> $name<T, V, D> {
379 /// Create a new solver with the specified initial step size
380 pub fn new() -> Self {
381 Self {
382 ..Default::default()
383 }
384 }
385
386 /// Set initial step size
387 pub fn h0(mut self, h0: T) -> Self {
388 self.h0 = h0;
389 self
390 }
391
392 /// Set the relative tolerance for error control
393 pub fn rtol(mut self, rtol: T) -> Self {
394 self.rtol = rtol;
395 self
396 }
397
398 /// Set the absolute tolerance for error control
399 pub fn atol(mut self, atol: T) -> Self {
400 self.atol = atol;
401 self
402 }
403
404 /// Set the minimum allowed step size
405 pub fn h_min(mut self, h_min: T) -> Self {
406 self.h_min = h_min;
407 self
408 }
409
410 /// Set the maximum allowed step size
411 pub fn h_max(mut self, h_max: T) -> Self {
412 self.h_max = h_max;
413 self
414 }
415
416 /// Set the maximum number of steps allowed
417 pub fn max_steps(mut self, max_steps: usize) -> Self {
418 self.max_steps = max_steps;
419 self
420 }
421
422 /// Set the maximum number of consecutive rejected steps before declaring stiffness
423 pub fn max_rejects(mut self, max_rejects: usize) -> Self {
424 self.max_rejects = max_rejects;
425 self
426 }
427
428 /// Set the safety factor for step size control (default: 0.9)
429 pub fn safety_factor(mut self, safety_factor: T) -> Self {
430 self.safety_factor = safety_factor;
431 self
432 }
433
434 /// Set the minimum scale factor for step size changes (default: 0.2)
435 pub fn min_scale(mut self, min_scale: T) -> Self {
436 self.min_scale = min_scale;
437 self
438 }
439
440 /// Set the maximum scale factor for step size changes (default: 10.0)
441 pub fn max_scale(mut self, max_scale: T) -> Self {
442 self.max_scale = max_scale;
443 self
444 }
445
446 /// Get the order of the method
447 pub fn order(&self) -> usize {
448 $order
449 }
450
451 /// Get the number of stages in the method
452 pub fn stages(&self) -> usize {
453 $stages
454 }
455 }
456 };
457}
458
459adaptive_runge_kutta_method!(
460 /// Runge-Kutta-Fehlberg 4(5) adaptive method
461 /// This method uses six function evaluations to calculate a fifth-order accurate
462 /// solution, with an embedded fourth-order method for error estimation.
463 /// The RKF45 method is one of the most widely used adaptive step size methods due to
464 /// its excellent balance of efficiency and accuracy.
465 ///
466 /// The Butcher Tableau is as follows:
467 /// ```text
468 /// 0 |
469 /// 1/4 | 1/4
470 /// 3/8 | 3/32 9/32
471 /// 12/13 | 1932/2197 -7200/2197 7296/2197
472 /// 1 | 439/216 -8 3680/513 -845/4104
473 /// 1/2 | -8/27 2 -3544/2565 1859/4104 -11/40
474 /// -----------------------------------------------------------------------
475 /// | 16/135 0 6656/12825 28561/56430 -9/50 2/55 (5th order)
476 /// | 25/216 0 1408/2565 2197/4104 -1/5 0 (4th order)
477 /// ```
478 ///
479 /// Reference: [Wikipedia](https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method#CITEREFFehlberg1969)
480 name: RKF,
481 a: [
482 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
483 [1.0/4.0, 0.0, 0.0, 0.0, 0.0, 0.0],
484 [3.0/32.0, 9.0/32.0, 0.0, 0.0, 0.0, 0.0],
485 [1932.0/2197.0, -7200.0/2197.0, 7296.0/2197.0, 0.0, 0.0, 0.0],
486 [439.0/216.0, -8.0, 3680.0/513.0, -845.0/4104.0, 0.0, 0.0],
487 [-8.0/27.0, 2.0, -3544.0/2565.0, 1859.0/4104.0, -11.0/40.0, 0.0]
488 ],
489 b: [
490 [16.0/135.0, 0.0, 6656.0/12825.0, 28561.0/56430.0, -9.0/50.0, 2.0/55.0], // 5th order
491 [25.0/216.0, 0.0, 1408.0/2565.0, 2197.0/4104.0, -1.0/5.0, 0.0] // 4th order
492 ],
493 c: [0.0, 1.0/4.0, 3.0/8.0, 12.0/13.0, 1.0, 1.0/2.0],
494 order: 5,
495 stages: 6
496);
497
498adaptive_runge_kutta_method!(
499 /// Cash-Karp 4(5) adaptive method
500 /// This method uses six function evaluations to calculate a fifth-order accurate
501 /// solution, with an embedded fourth-order method for error estimation.
502 /// The Cash-Karp method is a variant of the Runge-Kutta-Fehlberg method that uses
503 /// different coefficients to achieve a more efficient and accurate solution.
504 ///
505 /// The Butcher Tableau is as follows:
506 /// ```text
507 /// 0 |
508 /// 1/5 | 1/5
509 /// 3/10 | 3/40 9/40
510 /// 3/5 | 3/10 -9/10 6/5
511 /// 1 | -11/54 5/2 -70/27 35/27
512 /// 7/8 | 1631/55296 175/512 575/13824 44275/110592 253/4096
513 /// ------------------------------------------------------------------------------------
514 /// | 37/378 0 250/621 125/594 0 512/1771 (5th order)
515 /// | 2825/27648 0 18575/48384 13525/55296 277/14336 1/4 (4th order)
516 /// ```
517 ///
518 /// Reference: [Wikipedia](https://en.wikipedia.org/wiki/Cash%E2%80%93Karp_method)
519 name: CashKarp,
520 a: [
521 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
522 [1.0/5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
523 [3.0/40.0, 9.0/40.0, 0.0, 0.0, 0.0, 0.0],
524 [3.0/10.0, -9.0/10.0, 6.0/5.0, 0.0, 0.0, 0.0],
525 [-11.0/54.0, 5.0/2.0, -70.0/27.0, 35.0/27.0, 0.0, 0.0],
526 [1631.0/55296.0, 175.0/512.0, 575.0/13824.0, 44275.0/110592.0, 253.0/4096.0, 0.0]
527 ],
528 b: [
529 [37.0/378.0, 0.0, 250.0/621.0, 125.0/594.0, 0.0, 512.0/1771.0], // 5th order
530 [2825.0/27648.0, 0.0, 18575.0/48384.0, 13525.0/55296.0, 277.0/14336.0, 1.0/4.0] // 4th order
531 ],
532 c: [0.0, 1.0/5.0, 3.0/10.0, 3.0/5.0, 1.0, 7.0/8.0],
533 order: 5,
534 stages: 6
535);