differential_equations/methods/erk/adaptive/
ordinary.rs

1//! Adaptive Runge-Kutta methods for ODEs
2
3use super::{ExplicitRungeKutta, Ordinary, Adaptive};
4use crate::{
5    Error, Status,
6    alias::Evals,
7    methods::h_init::InitialStepSize,
8    interpolate::{Interpolation, cubic_hermite_interpolate},
9    ode::{OrdinaryNumericalMethod, ODE},
10    traits::{CallBackData, Real, State},
11    utils::{constrain_step_size, validate_step_size_parameters},
12};
13
14impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> OrdinaryNumericalMethod<T, V, D> for ExplicitRungeKutta<Ordinary, Adaptive, T, V, D, O, S, I> {
15    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
16    where
17        F: ODE<T, V, D>,
18    {
19        let mut evals = Evals::new();
20
21        // If h0 is zero, calculate initial step size
22        if self.h0 == T::zero() {
23            // Only use adaptive step size calculation if the method supports it
24            self.h0 = InitialStepSize::<Ordinary>::compute(ode, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, &mut evals);
25            evals.fcn += 2;
26
27        }
28
29        // Check bounds
30        match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
31            Ok(h0) => self.h = h0,
32            Err(status) => return Err(status),
33        }
34
35        // Initialize Statistics
36        self.stiffness_counter = 0;
37
38        // Initialize State
39        self.t = t0;
40        self.y = *y0;
41        ode.diff(self.t, &self.y, &mut self.dydt);
42        evals.fcn += 1;
43
44        // Initialize previous state
45        self.t_prev = self.t;
46        self.y_prev = self.y;
47        self.dydt_prev = self.dydt;
48
49        // Initialize Status
50        self.status = Status::Initialized;
51
52        Ok(evals)
53    }
54
55    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
56    where
57        F: ODE<T, V, D>,
58    {
59        let mut evals = Evals::new();
60
61        // Check step size
62        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
63            self.status = Status::Error(Error::StepSize {
64                t: self.t, y: self.y
65            });
66            return Err(Error::StepSize {
67                t: self.t, y: self.y
68            });
69        }
70
71        // Check max steps
72        if self.steps >= self.max_steps {
73            self.status = Status::Error(Error::MaxSteps {
74                t: self.t, y: self.y
75            });
76            return Err(Error::MaxSteps {
77                t: self.t, y: self.y
78            });
79        }
80        self.steps += 1;
81
82        // Save k[0] as the current derivative
83        self.k[0] = self.dydt;
84
85        // Compute stages
86        for i in 1..self.stages {
87            let mut y_stage = self.y;
88
89            for j in 0..i {
90                y_stage += self.k[j] * (self.a[i][j] * self.h);
91            }
92
93            ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
94        }
95        evals.fcn += self.stages - 1; // We already have k[0]
96 
97        // For adaptive methods with error estimation
98        // Compute higher order solution
99        let mut y_high = self.y;
100        for i in 0..self.stages {
101            y_high += self.k[i] * (self.b[i] * self.h);
102        }
103
104        // Compute lower order solution for error estimation
105        let mut y_low = self.y;
106        if let Some(bh) = &self.bh {
107            for i in 0..self.stages {
108                y_low += self.k[i] * (bh[i] * self.h);
109            }
110        }
111
112        // Compute error estimate
113        let err = y_high - y_low;
114
115        // Calculate error norm
116        let mut err_norm: T = T::zero();
117
118        // Iterate through state elements
119        for n in 0..self.y.len() {
120            let tol = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
121            err_norm = err_norm.max((err.get(n) / tol).abs());
122        };
123
124        // Step size scale factor
125        let order = T::from_usize(self.order).unwrap();
126        let error_exponent = T::one() / order;
127        let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
128        
129        // Clamp scale factor to prevent extreme step size changes
130        scale = scale.max(self.min_scale).min(self.max_scale);
131
132        // Determine if step is accepted
133        if err_norm <= T::one() {
134            // Log previous state
135            self.t_prev = self.t;
136            self.y_prev = self.y;
137            self.dydt_prev = self.k[0];
138            self.h_prev = self.h;
139
140            if let Status::RejectedStep = self.status {
141                self.stiffness_counter = 0;
142                self.status = Status::Solving;
143
144                // Limit step size growth to avoid oscillations between accepted and rejected steps
145                scale = scale.min(T::one());
146            }
147
148            // If method has dense output stages, compute them
149            if self.bi.is_some() {
150                // Compute extra stages for dense output
151                for i in 0..(I - S) {
152                    let mut y_stage = self.y;
153                    for j in 0..self.stages + i {
154                        y_stage += self.k[j] * (self.a[self.stages + i][j] * self.h);
155                    }
156
157                    ode.diff(self.t + self.c[self.stages + i] * self.h, &y_stage, &mut self.k[self.stages + i]);
158                }
159                evals.fcn += I - S;
160            }
161
162            // Update state with the higher-order solution
163            self.t += self.h;
164            self.y = y_high;
165
166            // Compute the derivative for the next step
167            if self.fsal {
168                // If FSAL (First Same As Last) is enabled, we can reuse the last derivative
169                self.dydt = self.k[S - 1];
170            } else {
171                // Otherwise, compute the new derivative
172                ode.diff(self.t, &self.y, &mut self.dydt);
173                evals.fcn += 1;
174            }
175        } else {
176            // Step rejected
177            self.status = Status::RejectedStep;
178            self.stiffness_counter += 1;
179
180            // Check for stiffness
181            if self.stiffness_counter >= self.max_rejects {
182                self.status = Status::Error(Error::Stiffness {
183                    t: self.t, y: self.y
184                });
185                return Err(Error::Stiffness {
186                    t: self.t, y: self.y
187                });
188            }
189        }
190
191        // Update step size
192        self.h *= scale;
193
194        // Ensure step size is within bounds
195        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
196        
197        Ok(evals)
198    }
199
200    fn t(&self) -> T { self.t }
201    fn y(&self) -> &V { &self.y }
202    fn t_prev(&self) -> T { self.t_prev }
203    fn y_prev(&self) -> &V { &self.y_prev }
204    fn h(&self) -> T { self.h }
205    fn set_h(&mut self, h: T) { self.h = h; }
206    fn status(&self) -> &Status<T, V, D> { &self.status }
207    fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
208}
209
210impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Ordinary, Adaptive, T, V, D, O, S, I> {
211    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
212        // Check if t is within bounds
213        if t_interp < self.t_prev || t_interp > self.t {
214            return Err(Error::OutOfBounds {
215                t_interp,
216                t_prev: self.t_prev,
217                t_curr: self.t
218            });
219        }
220
221        // If method has dense output coefficients, use them
222        if self.bi.is_some() {
223            // Calculate the normalized distance within the step [0, 1]
224            let s = (t_interp - self.t_prev) / self.h_prev;
225            
226            // Get the interpolation coefficients
227            let bi = self.bi.as_ref().unwrap();
228
229            let mut cont = [T::zero(); I];
230            // Compute the interpolation coefficients using Horner's method
231            for i in 0..self.dense_stages {
232                // Start with the highest-order term
233                cont[i] = bi[i][self.order - 1];
234
235                // Apply Horner's method
236                for j in (0..self.order - 1).rev() {
237                    cont[i] = cont[i] * s + bi[i][j];
238                }
239
240                // Multiply by s
241                cont[i] *= s;
242            }
243
244            // Compute the interpolated value
245            let mut y_interp = self.y_prev;
246            for i in 0..I {
247                y_interp += self.k[i] * cont[i] * self.h_prev;
248            }
249
250            Ok(y_interp)
251        } else {
252            // Otherwise use cubic Hermite interpolation
253            let y_interp = cubic_hermite_interpolate(
254                self.t_prev, 
255                self.t, 
256                &self.y_prev, 
257                &self.y, 
258                &self.dydt_prev, 
259                &self.dydt, 
260                t_interp
261            );
262
263            Ok(y_interp)
264        }
265    }
266}