Skip to main content

differential_equations/methods/erk/adaptive/
ordinary.rs

1//! Adaptive Runge-Kutta methods for ODEs
2use crate::{
3    error::Error,
4    interpolate::{Interpolation, cubic_hermite_interpolate},
5    methods::{Adaptive, ExplicitRungeKutta, Ordinary, h_init::InitialStepSize},
6    ode::{ODE, OrdinaryNumericalMethod},
7    stats::Evals,
8    status::Status,
9    traits::{Real, State},
10    utils::{constrain_step_size, validate_step_size_parameters},
11};
12
13impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
14    OrdinaryNumericalMethod<T, Y> for ExplicitRungeKutta<Ordinary, Adaptive, T, Y, O, S, I>
15{
16    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
17    where
18        F: ODE<T, Y> + ?Sized,
19    {
20        let mut evals = Evals::new();
21
22        // If h0 is zero, calculate initial step size
23        if self.h0 == T::zero() {
24            // Only use adaptive step size calculation if the method supports it
25            self.h0 = InitialStepSize::<Ordinary>::compute(
26                ode, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max,
27                &mut evals,
28            );
29            evals.function += 2;
30        }
31
32        // Check bounds
33        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
34            Ok(h0) => self.h = (self.filter)(h0),
35            Err(status) => return Err(status),
36        }
37
38        // Initialize Statistics
39        self.stiffness_counter = 0;
40
41        // Initialize State
42        self.t = t0;
43        self.y = y0.clone();
44        self.dydt = y0.zeros_like();
45        self.y_prev = y0.clone();
46        self.dydt_prev = y0.zeros_like();
47        self.k = core::array::from_fn(|_| y0.zeros_like());
48        self.cont = core::array::from_fn(|_| y0.zeros_like());
49        ode.diff(self.t, &self.y, &mut self.dydt);
50        evals.function += 1;
51
52        // Initialize previous state
53        self.t_prev = self.t;
54        self.y_prev = self.y.clone();
55        self.dydt_prev = self.dydt.clone();
56
57        // Initialize Status
58        self.status = Status::Initialized;
59
60        Ok(evals)
61    }
62
63    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
64    where
65        F: ODE<T, Y> + ?Sized,
66    {
67        let mut evals = Evals::new();
68
69        // Check step size
70        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
71            self.status = Status::Error(Error::StepSize {
72                t: self.t,
73                y: self.y.clone(),
74            });
75            return Err(Error::StepSize {
76                t: self.t,
77                y: self.y.clone(),
78            });
79        }
80
81        // Check max steps
82        if self.steps >= self.max_steps {
83            self.status = Status::Error(Error::MaxSteps {
84                t: self.t,
85                y: self.y.clone(),
86            });
87            return Err(Error::MaxSteps {
88                t: self.t,
89                y: self.y.clone(),
90            });
91        }
92        self.steps += 1;
93
94        // Save k[0] as the current derivative
95        self.k[0] = self.dydt.clone();
96
97        // Compute stages
98        for i in 1..self.stages {
99            let mut y_stage = self.y.clone();
100
101            for j in 0..i {
102                y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
103            }
104
105            ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
106        }
107        evals.function += self.stages - 1; // We already have k[0]
108
109        // For adaptive methods with error estimation
110        // Compute higher order solution
111        let mut y_high = self.y.clone();
112        for i in 0..self.stages {
113            y_high.add_scaled(self.b[i] * self.h, &self.k[i]);
114        }
115
116        // Compute lower order solution for error estimation
117        let mut y_low = self.y.clone();
118        let bh = &self.bh.unwrap();
119        for i in 0..self.stages {
120            y_low.add_scaled(bh[i] * self.h, &self.k[i]);
121        }
122
123        // Calculate error norm
124        let err = y_high.minus(&y_low);
125        let err_norm = self.y.error_norm_inf(&y_high, &err, &self.atol, &self.rtol);
126
127        // Step size scale factor
128        let order = T::from_usize(self.order).unwrap();
129        let error_exponent = T::one() / order;
130        let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
131
132        // Clamp scale factor to prevent extreme step size changes
133        scale = scale.max(self.min_scale).min(self.max_scale);
134
135        // Determine if step is accepted
136        if err_norm <= T::one() {
137            // Log previous state
138            self.t_prev = self.t;
139            self.y_prev = self.y.clone();
140            self.dydt_prev = self.k[0].clone();
141            self.h_prev = self.h;
142
143            if let Status::RejectedStep = self.status {
144                self.stiffness_counter = 0;
145                self.status = Status::Solving;
146
147                // Limit step size growth to avoid oscillations between accepted and rejected steps
148                scale = scale.min(T::one());
149            }
150
151            // If method has dense output stages, compute them
152            if self.bi.is_some() {
153                // Compute extra stages for dense output
154                for i in 0..(I - S) {
155                    let mut y_stage = self.y.clone();
156                    for j in 0..self.stages + i {
157                        y_stage.add_scaled(self.a[self.stages + i][j] * self.h, &self.k[j]);
158                    }
159
160                    ode.diff(
161                        self.t + self.c[self.stages + i] * self.h,
162                        &y_stage,
163                        &mut self.k[self.stages + i],
164                    );
165                }
166                evals.function += I - S;
167            }
168
169            // Update state with the higher-order solution
170            self.t += self.h;
171            self.y = y_high;
172
173            // Compute the derivative for the next step
174            if self.fsal {
175                // If FSAL (First Same As Last) is enabled, we can reuse the last derivative
176                self.dydt = self.k[S - 1].clone();
177            } else {
178                // Otherwise, compute the new derivative
179                ode.diff(self.t, &self.y, &mut self.dydt);
180                evals.function += 1;
181            }
182        } else {
183            // Step rejected
184            self.status = Status::RejectedStep;
185            self.stiffness_counter += 1;
186
187            // Check for stiffness
188            if self.stiffness_counter >= self.max_rejects {
189                self.status = Status::Error(Error::Stiffness {
190                    t: self.t,
191                    y: self.y.clone(),
192                });
193                return Err(Error::Stiffness {
194                    t: self.t,
195                    y: self.y.clone(),
196                });
197            }
198        }
199
200        // Update step size
201        self.h *= scale;
202
203        // Ensure step size is within bounds
204        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
205
206        // Apply step size filter
207        self.h = (self.filter)(self.h);
208
209        Ok(evals)
210    }
211
212    fn t(&self) -> T {
213        self.t
214    }
215    fn y(&self) -> &Y {
216        &self.y
217    }
218    fn t_prev(&self) -> T {
219        self.t_prev
220    }
221    fn y_prev(&self) -> &Y {
222        &self.y_prev
223    }
224    fn h(&self) -> T {
225        self.h
226    }
227    fn set_h(&mut self, h: T) {
228        self.h = (self.filter)(h);
229    }
230    fn status(&self) -> &Status<T, Y> {
231        &self.status
232    }
233    fn set_status(&mut self, status: Status<T, Y>) {
234        self.status = status;
235    }
236}
237
238impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
239    for ExplicitRungeKutta<Ordinary, Adaptive, T, Y, O, S, I>
240{
241    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
242        // Check if t is within bounds
243        if t_interp < self.t_prev || t_interp > self.t {
244            return Err(Error::OutOfBounds {
245                t_interp,
246                t_prev: self.t_prev,
247                t_curr: self.t,
248            });
249        }
250
251        // If method has dense output coefficients, use them
252        if let Some(bi) = self.bi.as_ref() {
253            // Calculate the normalized distance within the step [0, 1]
254            let s = (t_interp - self.t_prev) / self.h_prev;
255
256            let mut cont = [T::zero(); I];
257            // Compute the interpolation coefficients using Horner's method
258            for i in 0..self.dense_stages {
259                // Start with the highest-order term
260                cont[i] = bi[i][self.order - 1];
261
262                // Apply Horner's method
263                for j in (0..self.order - 1).rev() {
264                    cont[i] = cont[i] * s + bi[i][j];
265                }
266
267                // Multiply by s
268                cont[i] *= s;
269            }
270
271            // Compute the interpolated value
272            let mut y_interp = self.y_prev.clone();
273            for i in 0..I {
274                y_interp.add_scaled(cont[i] * self.h_prev, &self.k[i]);
275            }
276
277            Ok(y_interp)
278        } else {
279            // Otherwise use cubic Hermite interpolation
280            let y_interp = cubic_hermite_interpolate(
281                self.t_prev,
282                self.t,
283                &self.y_prev,
284                &self.y,
285                &self.dydt_prev,
286                &self.dydt,
287                t_interp,
288            );
289
290            Ok(y_interp)
291        }
292    }
293}