differential_equations/methods/erk/dormandprince/
ordinary.rs

1//! Dormand-Prince Runge-Kutta methods for ODEs
2
3use crate::{
4    error::Error,
5    interpolate::Interpolation,
6    methods::{DormandPrince, ExplicitRungeKutta, Ordinary, h_init::InitialStepSize},
7    ode::{ODE, OrdinaryNumericalMethod},
8    stats::Evals,
9    status::Status,
10    traits::{CallBackData, Real, State},
11    utils::{constrain_step_size, validate_step_size_parameters},
12};
13
14impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
15    OrdinaryNumericalMethod<T, Y, D>
16    for ExplicitRungeKutta<Ordinary, DormandPrince, T, Y, D, O, S, I>
17{
18    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
19    where
20        F: ODE<T, Y, D>,
21    {
22        let mut evals = Evals::new();
23
24        // If h0 is zero, calculate initial step size
25        if self.h0 == T::zero() {
26            // Use adaptive step size calculation for Dormand-Prince methods
27            self.h0 = InitialStepSize::<Ordinary>::compute(
28                ode, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max,
29                &mut evals,
30            );
31        }
32
33        // Check bounds
34        match validate_step_size_parameters::<T, Y, D>(self.h0, self.h_min, self.h_max, t0, tf) {
35            Ok(h0) => self.h = h0,
36            Err(status) => return Err(status),
37        }
38
39        // Initialize Statistics
40        self.stiffness_counter = 0;
41
42        // Initialize State
43        self.t = t0;
44        self.y = *y0;
45        ode.diff(self.t, &self.y, &mut self.k[0]);
46        self.dydt = self.k[0];
47        evals.function += 1;
48
49        // Initialize previous state
50        self.t_prev = self.t;
51        self.y_prev = self.y;
52        self.dydt_prev = self.dydt;
53
54        // Initialize Status
55        self.status = Status::Initialized;
56
57        Ok(evals)
58    }
59
60    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
61    where
62        F: ODE<T, Y, D>,
63    {
64        let mut evals = Evals::new();
65
66        // Check if step-size is becoming too small
67        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
68            self.status = Status::Error(Error::StepSize {
69                t: self.t,
70                y: self.y,
71            });
72            return Err(Error::StepSize {
73                t: self.t,
74                y: self.y,
75            });
76        }
77
78        // Check max steps
79        if self.steps >= self.max_steps {
80            self.status = Status::Error(Error::MaxSteps {
81                t: self.t,
82                y: self.y,
83            });
84            return Err(Error::MaxSteps {
85                t: self.t,
86                y: self.y,
87            });
88        }
89        self.steps += 1;
90
91        // Compute stages
92        let mut y_stage = Y::zeros();
93        for i in 1..self.stages {
94            y_stage = Y::zeros();
95
96            for j in 0..i {
97                y_stage += self.k[j] * self.a[i][j];
98            }
99            y_stage = self.y + y_stage * self.h;
100
101            ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
102        }
103
104        // The last stage will be used for stiffness detection
105        let ysti = y_stage;
106
107        // Calculate the line segment for the new y value
108        let mut yseg = Y::zeros();
109        for i in 0..self.stages {
110            yseg += self.k[i] * self.b[i];
111        }
112
113        // Calculate the new y value using the line segment
114        let y_new = self.y + yseg * self.h;
115
116        // Evaluate derivative at new point for error estimation
117        let t_new = self.t + self.h;
118
119        // Number of function evaluations
120        evals.function += self.stages - 1; // We already have k[0]
121
122        // Error estimation
123        let er = self.er.unwrap();
124        let n = self.y.len();
125        let mut err = T::zero();
126        let mut err2 = T::zero();
127        let mut erri;
128        for i in 0..n {
129            // Calculate the error scale
130            let sk = self.atol + self.rtol * self.y.get(i).abs().max(y_new.get(i).abs());
131
132            // Primary error term
133            erri = T::zero();
134            for j in 0..self.stages {
135                erri += er[j] * self.k[j].get(i);
136            }
137            err += (erri / sk).powi(2);
138
139            // Optional secondary error term
140            if let Some(bh) = &self.bh {
141                erri = yseg.get(i);
142                for j in 0..self.stages {
143                    erri -= bh[j] * self.k[j].get(i);
144                }
145                err2 += (erri / sk).powi(2);
146            }
147        }
148        let mut deno = err + T::from_f64(0.01).unwrap() * err2;
149        if deno <= T::zero() {
150            deno = T::one();
151        }
152        err = self.h.abs() * err * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
153
154        // Step size scale factor
155        let order = T::from_usize(self.order).unwrap();
156        let error_exponent = T::one() / order;
157        let mut scale = self.safety_factor * err.powf(-error_exponent);
158
159        // Clamp scale factor to prevent extreme step size changes
160        scale = scale.max(self.min_scale).min(self.max_scale);
161
162        // Determine if step is accepted
163        if err <= T::one() {
164            // Calculate the new derivative at the new point
165            ode.diff(t_new, &y_new, &mut self.dydt);
166            evals.function += 1;
167
168            // stiffness detection
169            let n_stiff_threshold = 100;
170            if self.steps % n_stiff_threshold == 0 {
171                let mut stdnum = T::zero();
172                let mut stden = T::zero();
173                let sqr = yseg - self.k[S - 1];
174                for i in 0..sqr.len() {
175                    stdnum += sqr.get(i).powi(2);
176                }
177                let sqr = self.dydt - ysti;
178                for i in 0..sqr.len() {
179                    stden += sqr.get(i).powi(2);
180                }
181
182                if stden > T::zero() {
183                    let h_lamb = self.h * (stdnum / stden).sqrt();
184                    if h_lamb > T::from_f64(6.1).unwrap() {
185                        self.non_stiffness_counter = 0;
186                        self.stiffness_counter += 1;
187                        if self.stiffness_counter == 15 {
188                            // Early Exit Stiffness Detected
189                            self.status = Status::Error(Error::Stiffness {
190                                t: self.t,
191                                y: self.y,
192                            });
193                            return Err(Error::Stiffness {
194                                t: self.t,
195                                y: self.y,
196                            });
197                        }
198                    }
199                } else {
200                    self.non_stiffness_counter += 1;
201                    if self.non_stiffness_counter == 6 {
202                        self.stiffness_counter = 0;
203                    }
204                }
205            }
206
207            // Preparation for dense output / interpolation
208            self.cont[0] = self.y;
209            let ydiff = y_new - self.y;
210            self.cont[1] = ydiff;
211            let bspl = self.k[0] * self.h - ydiff;
212            self.cont[2] = bspl;
213            self.cont[3] = ydiff - self.dydt * self.h - bspl;
214
215            // If method has dense output stages, compute them
216            if let Some(bi) = &self.bi {
217                // Compute extra stages for dense output
218                if I > S {
219                    // First dense output coefficient, k{i=order+1}, is the derivative at the new point
220                    self.k[self.stages] = self.dydt;
221
222                    for i in S + 1..I {
223                        let mut y_stage = Y::zeros();
224                        for j in 0..i {
225                            y_stage += self.k[j] * self.a[i][j];
226                        }
227                        y_stage = self.y + y_stage * self.h;
228
229                        ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
230                        evals.function += 1;
231                    }
232                }
233
234                // Compute dense output coefficients
235                for i in 4..self.order {
236                    self.cont[i] = Y::zeros();
237                    for j in 0..self.dense_stages {
238                        self.cont[i] += self.k[j] * bi[i][j];
239                    }
240                    self.cont[i] = self.cont[i] * self.h;
241                }
242            }
243
244            // For interpolation
245            self.t_prev = self.t;
246            self.y_prev = self.y;
247            self.dydt_prev = self.k[0];
248            self.h_prev = self.h;
249
250            // Update the state with new values
251            self.t = t_new;
252            self.y = y_new;
253            self.k[0] = self.dydt;
254
255            // Check if previous step is rejected
256            if let Status::RejectedStep = self.status {
257                self.status = Status::Solving;
258
259                // Limit step size growth to avoid oscillations between accepted and rejected steps
260                scale = scale.min(T::one());
261            }
262        } else {
263            // Step Rejected
264            self.status = Status::RejectedStep;
265        }
266
267        // Update step size
268        self.h *= scale;
269
270        // Ensure step size is within bounds
271        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
272
273        Ok(evals)
274    }
275
276    fn t(&self) -> T {
277        self.t
278    }
279    fn y(&self) -> &Y {
280        &self.y
281    }
282    fn t_prev(&self) -> T {
283        self.t_prev
284    }
285    fn y_prev(&self) -> &Y {
286        &self.y_prev
287    }
288    fn h(&self) -> T {
289        self.h
290    }
291    fn set_h(&mut self, h: T) {
292        self.h = h;
293    }
294    fn status(&self) -> &Status<T, Y, D> {
295        &self.status
296    }
297    fn set_status(&mut self, status: Status<T, Y, D>) {
298        self.status = status;
299    }
300}
301
302impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
303    Interpolation<T, Y> for ExplicitRungeKutta<Ordinary, DormandPrince, T, Y, D, O, S, I>
304{
305    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
306        // Check if interpolation is out of bounds
307        if t_interp < self.t_prev || t_interp > self.t {
308            return Err(Error::OutOfBounds {
309                t_interp,
310                t_prev: self.t_prev,
311                t_curr: self.t,
312            });
313        }
314
315        // Evaluate the interpolation polynomial at the requested time
316        let s = (t_interp - self.t_prev) / self.h_prev;
317        let s1 = T::one() - s;
318
319        // Functional implementation of: cont[0] + (cont[1] + (cont[2] + (cont[3] + conpar*s1)*s)*s1)*s
320        let ilast = self.cont.len() - 1;
321        let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {
322            let factor = if i >= 4 {
323                // For the higher-order part (conpar), alternate s and s1 based on index parity
324                if (ilast - i) % 2 == 1 { s1 } else { s }
325            } else {
326                // For the main polynomial part, pattern is [s1, s, s1] for indices [3, 2, 1]
327                if i % 2 == 1 { s1 } else { s }
328            };
329            acc * factor + self.cont[i]
330        });
331
332        // Final multiplication by s for the outermost level
333        let y_interp = self.cont[0] + poly * s;
334
335        Ok(y_interp)
336    }
337}