differential_equations/methods/erk/dormandprince/
ordinary.rs

1//! Dormand-Prince Runge-Kutta methods for ODEs
2
3use crate::{
4    error::Error,
5    interpolate::Interpolation,
6    methods::{DormandPrince, ExplicitRungeKutta, Ordinary, h_init::InitialStepSize},
7    ode::{ODE, OrdinaryNumericalMethod},
8    stats::Evals,
9    status::Status,
10    traits::{Real, State},
11    utils::{constrain_step_size, validate_step_size_parameters},
12};
13
14impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
15    OrdinaryNumericalMethod<T, Y> for ExplicitRungeKutta<Ordinary, DormandPrince, T, Y, O, S, I>
16{
17    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
18    where
19        F: ODE<T, Y>,
20    {
21        let mut evals = Evals::new();
22
23        // If h0 is zero, calculate initial step size
24        if self.h0 == T::zero() {
25            // Use adaptive step size calculation for Dormand-Prince methods
26            self.h0 = InitialStepSize::<Ordinary>::compute(
27                ode, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max,
28                &mut evals,
29            );
30        }
31
32        // Check bounds
33        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
34            Ok(h0) => self.h = h0,
35            Err(status) => return Err(status),
36        }
37
38        // Initialize Statistics
39        self.stiffness_counter = 0;
40
41        // Initialize State
42        self.t = t0;
43        self.y = *y0;
44        ode.diff(self.t, &self.y, &mut self.k[0]);
45        self.dydt = self.k[0];
46        evals.function += 1;
47
48        // Initialize previous state
49        self.t_prev = self.t;
50        self.y_prev = self.y;
51        self.dydt_prev = self.dydt;
52
53        // Initialize Status
54        self.status = Status::Initialized;
55
56        Ok(evals)
57    }
58
59    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
60    where
61        F: ODE<T, Y>,
62    {
63        let mut evals = Evals::new();
64
65        // Check if step-size is becoming too small
66        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
67            self.status = Status::Error(Error::StepSize {
68                t: self.t,
69                y: self.y,
70            });
71            return Err(Error::StepSize {
72                t: self.t,
73                y: self.y,
74            });
75        }
76
77        // Check max steps
78        if self.steps >= self.max_steps {
79            self.status = Status::Error(Error::MaxSteps {
80                t: self.t,
81                y: self.y,
82            });
83            return Err(Error::MaxSteps {
84                t: self.t,
85                y: self.y,
86            });
87        }
88        self.steps += 1;
89
90        // Compute stages
91        let mut y_stage = Y::zeros();
92        for i in 1..self.stages {
93            y_stage = Y::zeros();
94
95            for j in 0..i {
96                y_stage += self.k[j] * self.a[i][j];
97            }
98            y_stage = self.y + y_stage * self.h;
99
100            ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
101        }
102
103        // The last stage will be used for stiffness detection
104        let ysti = y_stage;
105
106        // Calculate the line segment for the new y value
107        let mut yseg = Y::zeros();
108        for i in 0..self.stages {
109            yseg += self.k[i] * self.b[i];
110        }
111
112        // Calculate the new y value using the line segment
113        let y_new = self.y + yseg * self.h;
114
115        // Evaluate derivative at new point for error estimation
116        let t_new = self.t + self.h;
117
118        // Number of function evaluations
119        evals.function += self.stages - 1; // We already have k[0]
120
121        // Error estimation
122        let er = self.er.unwrap();
123        let n = self.y.len();
124        let mut err = T::zero();
125        let mut err2 = T::zero();
126        let mut erri;
127        for i in 0..n {
128            // Calculate the error scale
129            let sk = self.atol[i] + self.rtol[i] * self.y.get(i).abs().max(y_new.get(i).abs());
130
131            // Primary error term
132            erri = T::zero();
133            for j in 0..self.stages {
134                erri += er[j] * self.k[j].get(i);
135            }
136            err += (erri / sk).powi(2);
137
138            // Optional secondary error term
139            if let Some(bh) = &self.bh {
140                erri = yseg.get(i);
141                for j in 0..self.stages {
142                    erri -= bh[j] * self.k[j].get(i);
143                }
144                err2 += (erri / sk).powi(2);
145            }
146        }
147        let mut deno = err + T::from_f64(0.01).unwrap() * err2;
148        if deno <= T::zero() {
149            deno = T::one();
150        }
151        err = self.h.abs() * err * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
152
153        // Step size scale factor
154        let order = T::from_usize(self.order).unwrap();
155        let error_exponent = T::one() / order;
156        let mut scale = self.safety_factor * err.powf(-error_exponent);
157
158        // Clamp scale factor to prevent extreme step size changes
159        scale = scale.max(self.min_scale).min(self.max_scale);
160
161        // Determine if step is accepted
162        if err <= T::one() {
163            // Calculate the new derivative at the new point
164            ode.diff(t_new, &y_new, &mut self.dydt);
165            evals.function += 1;
166
167            // stiffness detection
168            let n_stiff_threshold = 100;
169            if self.steps % n_stiff_threshold == 0 {
170                let mut stdnum = T::zero();
171                let mut stden = T::zero();
172                let sqr = yseg - self.k[S - 1];
173                for i in 0..sqr.len() {
174                    stdnum += sqr.get(i).powi(2);
175                }
176                let sqr = self.dydt - ysti;
177                for i in 0..sqr.len() {
178                    stden += sqr.get(i).powi(2);
179                }
180
181                if stden > T::zero() {
182                    let h_lamb = self.h * (stdnum / stden).sqrt();
183                    if h_lamb > T::from_f64(6.1).unwrap() {
184                        self.non_stiffness_counter = 0;
185                        self.stiffness_counter += 1;
186                        if self.stiffness_counter == 15 {
187                            // Early Exit Stiffness Detected
188                            self.status = Status::Error(Error::Stiffness {
189                                t: self.t,
190                                y: self.y,
191                            });
192                            return Err(Error::Stiffness {
193                                t: self.t,
194                                y: self.y,
195                            });
196                        }
197                    }
198                } else {
199                    self.non_stiffness_counter += 1;
200                    if self.non_stiffness_counter == 6 {
201                        self.stiffness_counter = 0;
202                    }
203                }
204            }
205
206            // Preparation for dense output / interpolation
207            self.cont[0] = self.y;
208            let ydiff = y_new - self.y;
209            self.cont[1] = ydiff;
210            let bspl = self.k[0] * self.h - ydiff;
211            self.cont[2] = bspl;
212            self.cont[3] = ydiff - self.dydt * self.h - bspl;
213
214            // If method has dense output stages, compute them
215            if let Some(bi) = &self.bi {
216                // Compute extra stages for dense output
217                if I > S {
218                    // First dense output coefficient, k{i=order+1}, is the derivative at the new point
219                    self.k[self.stages] = self.dydt;
220
221                    for i in S + 1..I {
222                        let mut y_stage = Y::zeros();
223                        for j in 0..i {
224                            y_stage += self.k[j] * self.a[i][j];
225                        }
226                        y_stage = self.y + y_stage * self.h;
227
228                        ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
229                        evals.function += 1;
230                    }
231                }
232
233                // Compute dense output coefficients
234                for i in 4..self.order {
235                    self.cont[i] = Y::zeros();
236                    for j in 0..self.dense_stages {
237                        self.cont[i] += self.k[j] * bi[i][j];
238                    }
239                    self.cont[i] = self.cont[i] * self.h;
240                }
241            }
242
243            // For interpolation
244            self.t_prev = self.t;
245            self.y_prev = self.y;
246            self.dydt_prev = self.k[0];
247            self.h_prev = self.h;
248
249            // Update the state with new values
250            self.t = t_new;
251            self.y = y_new;
252            self.k[0] = self.dydt;
253
254            // Check if previous step is rejected
255            if let Status::RejectedStep = self.status {
256                self.status = Status::Solving;
257
258                // Limit step size growth to avoid oscillations between accepted and rejected steps
259                scale = scale.min(T::one());
260            }
261        } else {
262            // Step Rejected
263            self.status = Status::RejectedStep;
264        }
265
266        // Update step size
267        self.h *= scale;
268
269        // Ensure step size is within bounds
270        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
271
272        Ok(evals)
273    }
274
275    fn t(&self) -> T {
276        self.t
277    }
278    fn y(&self) -> &Y {
279        &self.y
280    }
281    fn t_prev(&self) -> T {
282        self.t_prev
283    }
284    fn y_prev(&self) -> &Y {
285        &self.y_prev
286    }
287    fn h(&self) -> T {
288        self.h
289    }
290    fn set_h(&mut self, h: T) {
291        self.h = h;
292    }
293    fn status(&self) -> &Status<T, Y> {
294        &self.status
295    }
296    fn set_status(&mut self, status: Status<T, Y>) {
297        self.status = status;
298    }
299}
300
301impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
302    for ExplicitRungeKutta<Ordinary, DormandPrince, T, Y, O, S, I>
303{
304    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
305        // Check if interpolation is out of bounds
306        if t_interp < self.t_prev || t_interp > self.t {
307            return Err(Error::OutOfBounds {
308                t_interp,
309                t_prev: self.t_prev,
310                t_curr: self.t,
311            });
312        }
313
314        // Evaluate the interpolation polynomial at the requested time
315        let s = (t_interp - self.t_prev) / self.h_prev;
316        let s1 = T::one() - s;
317
318        // Functional implementation of: cont[0] + (cont[1] + (cont[2] + (cont[3] + conpar*s1)*s)*s1)*s
319        let ilast = self.cont.len() - 1;
320        let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {
321            let factor = if i >= 4 {
322                // For the higher-order part (conpar), alternate s and s1 based on index parity
323                if (ilast - i) % 2 == 1 { s1 } else { s }
324            } else {
325                // For the main polynomial part, pattern is [s1, s, s1] for indices [3, 2, 1]
326                if i % 2 == 1 { s1 } else { s }
327            };
328            acc * factor + self.cont[i]
329        });
330
331        // Final multiplication by s for the outermost level
332        let y_interp = self.cont[0] + poly * s;
333
334        Ok(y_interp)
335    }
336}