differential_equations/methods/erk/adaptive/
ordinary.rs

1//! Adaptive Runge-Kutta methods for ODEs
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, cubic_hermite_interpolate},
6    methods::{Adaptive, ExplicitRungeKutta, Ordinary, h_init::InitialStepSize},
7    ode::{ODE, OrdinaryNumericalMethod},
8    stats::Evals,
9    status::Status,
10    traits::{CallBackData, Real, State},
11    utils::{constrain_step_size, validate_step_size_parameters},
12};
13
14impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
15    OrdinaryNumericalMethod<T, Y, D> for ExplicitRungeKutta<Ordinary, Adaptive, T, Y, D, O, S, I>
16{
17    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
18    where
19        F: ODE<T, Y, D>,
20    {
21        let mut evals = Evals::new();
22
23        // If h0 is zero, calculate initial step size
24        if self.h0 == T::zero() {
25            // Only use adaptive step size calculation if the method supports it
26            self.h0 = InitialStepSize::<Ordinary>::compute(
27                ode, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max,
28                &mut evals,
29            );
30            evals.function += 2;
31        }
32
33        // Check bounds
34        match validate_step_size_parameters::<T, Y, D>(self.h0, self.h_min, self.h_max, t0, tf) {
35            Ok(h0) => self.h = h0,
36            Err(status) => return Err(status),
37        }
38
39        // Initialize Statistics
40        self.stiffness_counter = 0;
41
42        // Initialize State
43        self.t = t0;
44        self.y = *y0;
45        ode.diff(self.t, &self.y, &mut self.dydt);
46        evals.function += 1;
47
48        // Initialize previous state
49        self.t_prev = self.t;
50        self.y_prev = self.y;
51        self.dydt_prev = self.dydt;
52
53        // Initialize Status
54        self.status = Status::Initialized;
55
56        Ok(evals)
57    }
58
59    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
60    where
61        F: ODE<T, Y, D>,
62    {
63        let mut evals = Evals::new();
64
65        // Check step size
66        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
67            self.status = Status::Error(Error::StepSize {
68                t: self.t,
69                y: self.y,
70            });
71            return Err(Error::StepSize {
72                t: self.t,
73                y: self.y,
74            });
75        }
76
77        // Check max steps
78        if self.steps >= self.max_steps {
79            self.status = Status::Error(Error::MaxSteps {
80                t: self.t,
81                y: self.y,
82            });
83            return Err(Error::MaxSteps {
84                t: self.t,
85                y: self.y,
86            });
87        }
88        self.steps += 1;
89
90        // Save k[0] as the current derivative
91        self.k[0] = self.dydt;
92
93        // Compute stages
94        for i in 1..self.stages {
95            let mut y_stage = self.y;
96
97            for j in 0..i {
98                y_stage += self.k[j] * (self.a[i][j] * self.h);
99            }
100
101            ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
102        }
103        evals.function += self.stages - 1; // We already have k[0]
104
105        // For adaptive methods with error estimation
106        // Compute higher order solution
107        let mut y_high = self.y;
108        for i in 0..self.stages {
109            y_high += self.k[i] * (self.b[i] * self.h);
110        }
111
112        // Compute lower order solution for error estimation
113        let mut y_low = self.y;
114        let bh = &self.bh.unwrap();
115        for i in 0..self.stages {
116            y_low += self.k[i] * (bh[i] * self.h);
117        }
118
119        // Compute error estimate
120        let err = y_high - y_low;
121
122        // Calculate error norm
123        let mut err_norm: T = T::zero();
124
125        // Iterate through state elements
126        for n in 0..self.y.len() {
127            let tol = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
128            err_norm = err_norm.max((err.get(n) / tol).abs());
129        }
130
131        // Step size scale factor
132        let order = T::from_usize(self.order).unwrap();
133        let error_exponent = T::one() / order;
134        let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
135
136        // Clamp scale factor to prevent extreme step size changes
137        scale = scale.max(self.min_scale).min(self.max_scale);
138
139        // Determine if step is accepted
140        if err_norm <= T::one() {
141            // Log previous state
142            self.t_prev = self.t;
143            self.y_prev = self.y;
144            self.dydt_prev = self.k[0];
145            self.h_prev = self.h;
146
147            if let Status::RejectedStep = self.status {
148                self.stiffness_counter = 0;
149                self.status = Status::Solving;
150
151                // Limit step size growth to avoid oscillations between accepted and rejected steps
152                scale = scale.min(T::one());
153            }
154
155            // If method has dense output stages, compute them
156            if self.bi.is_some() {
157                // Compute extra stages for dense output
158                for i in 0..(I - S) {
159                    let mut y_stage = self.y;
160                    for j in 0..self.stages + i {
161                        y_stage += self.k[j] * (self.a[self.stages + i][j] * self.h);
162                    }
163
164                    ode.diff(
165                        self.t + self.c[self.stages + i] * self.h,
166                        &y_stage,
167                        &mut self.k[self.stages + i],
168                    );
169                }
170                evals.function += I - S;
171            }
172
173            // Update state with the higher-order solution
174            self.t += self.h;
175            self.y = y_high;
176
177            // Compute the derivative for the next step
178            if self.fsal {
179                // If FSAL (First Same As Last) is enabled, we can reuse the last derivative
180                self.dydt = self.k[S - 1];
181            } else {
182                // Otherwise, compute the new derivative
183                ode.diff(self.t, &self.y, &mut self.dydt);
184                evals.function += 1;
185            }
186        } else {
187            // Step rejected
188            self.status = Status::RejectedStep;
189            self.stiffness_counter += 1;
190
191            // Check for stiffness
192            if self.stiffness_counter >= self.max_rejects {
193                self.status = Status::Error(Error::Stiffness {
194                    t: self.t,
195                    y: self.y,
196                });
197                return Err(Error::Stiffness {
198                    t: self.t,
199                    y: self.y,
200                });
201            }
202        }
203
204        // Update step size
205        self.h *= scale;
206
207        // Ensure step size is within bounds
208        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
209
210        Ok(evals)
211    }
212
213    fn t(&self) -> T {
214        self.t
215    }
216    fn y(&self) -> &Y {
217        &self.y
218    }
219    fn t_prev(&self) -> T {
220        self.t_prev
221    }
222    fn y_prev(&self) -> &Y {
223        &self.y_prev
224    }
225    fn h(&self) -> T {
226        self.h
227    }
228    fn set_h(&mut self, h: T) {
229        self.h = h;
230    }
231    fn status(&self) -> &Status<T, Y, D> {
232        &self.status
233    }
234    fn set_status(&mut self, status: Status<T, Y, D>) {
235        self.status = status;
236    }
237}
238
239impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
240    Interpolation<T, Y> for ExplicitRungeKutta<Ordinary, Adaptive, T, Y, D, O, S, I>
241{
242    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
243        // Check if t is within bounds
244        if t_interp < self.t_prev || t_interp > self.t {
245            return Err(Error::OutOfBounds {
246                t_interp,
247                t_prev: self.t_prev,
248                t_curr: self.t,
249            });
250        }
251
252        // If method has dense output coefficients, use them
253        if self.bi.is_some() {
254            // Calculate the normalized distance within the step [0, 1]
255            let s = (t_interp - self.t_prev) / self.h_prev;
256
257            // Get the interpolation coefficients
258            let bi = self.bi.as_ref().unwrap();
259
260            let mut cont = [T::zero(); I];
261            // Compute the interpolation coefficients using Horner's method
262            for i in 0..self.dense_stages {
263                // Start with the highest-order term
264                cont[i] = bi[i][self.order - 1];
265
266                // Apply Horner's method
267                for j in (0..self.order - 1).rev() {
268                    cont[i] = cont[i] * s + bi[i][j];
269                }
270
271                // Multiply by s
272                cont[i] *= s;
273            }
274
275            // Compute the interpolated value
276            let mut y_interp = self.y_prev;
277            for i in 0..I {
278                y_interp += self.k[i] * cont[i] * self.h_prev;
279            }
280
281            Ok(y_interp)
282        } else {
283            // Otherwise use cubic Hermite interpolation
284            let y_interp = cubic_hermite_interpolate(
285                self.t_prev,
286                self.t,
287                &self.y_prev,
288                &self.y,
289                &self.dydt_prev,
290                &self.dydt,
291                t_interp,
292            );
293
294            Ok(y_interp)
295        }
296    }
297}