Skip to main content

differential_equations/methods/erk/dormandprince/
ordinary.rs

1//! Dormand-Prince Runge-Kutta methods for ODEs
2use crate::{
3    error::Error,
4    interpolate::Interpolation,
5    methods::{DormandPrince, ExplicitRungeKutta, Ordinary, h_init::InitialStepSize},
6    ode::{ODE, OrdinaryNumericalMethod},
7    stats::Evals,
8    status::Status,
9    traits::{Real, State},
10    utils::{constrain_step_size, validate_step_size_parameters},
11};
12
13impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
14    OrdinaryNumericalMethod<T, Y> for ExplicitRungeKutta<Ordinary, DormandPrince, T, Y, O, S, I>
15{
16    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
17    where
18        F: ODE<T, Y> + ?Sized,
19    {
20        let mut evals = Evals::new();
21
22        // If h0 is zero, calculate initial step size
23        if self.h0 == T::zero() {
24            // Use adaptive step size calculation for Dormand-Prince methods
25            self.h0 = InitialStepSize::<Ordinary>::compute(
26                ode, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max,
27                &mut evals,
28            );
29        }
30
31        // Check bounds
32        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
33            Ok(h0) => self.h = (self.filter)(h0),
34            Err(status) => return Err(status),
35        }
36
37        // Initialize Statistics
38        self.stiffness_counter = 0;
39
40        // Initialize State
41        self.t = t0;
42        self.y = y0.clone();
43        self.dydt = y0.zeros_like();
44        self.y_prev = y0.clone();
45        self.dydt_prev = y0.zeros_like();
46        self.k = core::array::from_fn(|_| y0.zeros_like());
47        self.cont = core::array::from_fn(|_| y0.zeros_like());
48        ode.diff(self.t, &self.y, &mut self.k[0]);
49        self.dydt = self.k[0].clone();
50        evals.function += 1;
51
52        // Initialize previous state
53        self.t_prev = self.t;
54        self.y_prev = self.y.clone();
55        self.dydt_prev = self.dydt.clone();
56
57        // Initialize Status
58        self.status = Status::Initialized;
59
60        Ok(evals)
61    }
62
63    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
64    where
65        F: ODE<T, Y> + ?Sized,
66    {
67        let mut evals = Evals::new();
68
69        // Check if step-size is becoming too small
70        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
71            self.status = Status::Error(Error::StepSize {
72                t: self.t,
73                y: self.y.clone(),
74            });
75            return Err(Error::StepSize {
76                t: self.t,
77                y: self.y.clone(),
78            });
79        }
80
81        // Check max steps
82        if self.steps >= self.max_steps {
83            self.status = Status::Error(Error::MaxSteps {
84                t: self.t,
85                y: self.y.clone(),
86            });
87            return Err(Error::MaxSteps {
88                t: self.t,
89                y: self.y.clone(),
90            });
91        }
92        self.steps += 1;
93
94        // Compute stages
95        let mut y_stage = self.y.zeros_like();
96        for i in 1..self.stages {
97            y_stage = self.y.clone();
98
99            for j in 0..i {
100                y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
101            }
102
103            ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
104        }
105
106        // The last stage will be used for stiffness detection
107        let ysti = y_stage.clone();
108
109        // Calculate the line segment for the new y value
110        let mut yseg = self.y.zeros_like();
111        for i in 0..self.stages {
112            yseg.add_scaled(self.b[i], &self.k[i]);
113        }
114
115        // Calculate the new y value using the line segment
116        let y_new = self.y.plus_scaled(self.h, &yseg);
117
118        // Evaluate derivative at new point for error estimation
119        let t_new = self.t + self.h;
120
121        // Number of function evaluations
122        evals.function += self.stages - 1; // We already have k[0]
123
124        // Error estimation
125        let er = self.er.unwrap();
126        let n = self.y.len();
127        let mut err2 = T::zero();
128        let mut err_state = self.y.zeros_like();
129        for (j, coefficient) in er.iter().enumerate().take(self.stages) {
130            err_state.add_scaled(*coefficient, &self.k[j]);
131        }
132        let mut err = self
133            .y
134            .error_norm(&y_new, &err_state, &self.atol, &self.rtol);
135
136        if let Some(bh) = &self.bh {
137            let mut err2_state = yseg.clone();
138            for (j, coefficient) in bh.iter().enumerate().take(self.stages) {
139                err2_state.add_scaled(-*coefficient, &self.k[j]);
140            }
141            err2 = self
142                .y
143                .error_norm(&y_new, &err2_state, &self.atol, &self.rtol);
144        }
145        let mut deno = err + T::from_f64(0.01).unwrap() * err2;
146        if deno <= T::zero() {
147            deno = T::one();
148        }
149        err = self.h.abs() * err * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
150
151        // Step size scale factor
152        let order = T::from_usize(self.order).unwrap();
153        let error_exponent = T::one() / order;
154        let mut scale = self.safety_factor * err.powf(-error_exponent);
155
156        // Clamp scale factor to prevent extreme step size changes
157        scale = scale.max(self.min_scale).min(self.max_scale);
158
159        // Determine if step is accepted
160        if err <= T::one() {
161            // Calculate the new derivative at the new point
162            ode.diff(t_new, &y_new, &mut self.dydt);
163            evals.function += 1;
164
165            // stiffness detection
166            let n_stiff_threshold = 100;
167            if self.steps.is_multiple_of(n_stiff_threshold) {
168                let stdnum = yseg.diff_norm_squared(&self.k[S - 1]);
169                let stden = self.dydt.diff_norm_squared(&ysti);
170
171                if stden > T::zero() {
172                    let h_lamb = self.h * (stdnum / stden).sqrt();
173                    if h_lamb > T::from_f64(6.1).unwrap() {
174                        self.non_stiffness_counter = 0;
175                        self.stiffness_counter += 1;
176                        if self.stiffness_counter == 15 {
177                            // Early Exit Stiffness Detected
178                            self.status = Status::Error(Error::Stiffness {
179                                t: self.t,
180                                y: self.y.clone(),
181                            });
182                            return Err(Error::Stiffness {
183                                t: self.t,
184                                y: self.y.clone(),
185                            });
186                        }
187                    }
188                } else {
189                    self.non_stiffness_counter += 1;
190                    if self.non_stiffness_counter == 6 {
191                        self.stiffness_counter = 0;
192                    }
193                }
194            }
195
196            // Preparation for dense output / interpolation
197            self.cont[0] = self.y.clone();
198            let ydiff = y_new.minus(&self.y);
199            self.cont[1] = ydiff.clone();
200            let mut bspl = ydiff.zeros_like();
201            bspl.add_scaled(self.h, &self.k[0]);
202            bspl.add_scaled(-T::one(), &ydiff);
203            self.cont[2] = bspl.clone();
204            let mut cont3 = ydiff;
205            cont3.add_scaled(-self.h, &self.dydt);
206            cont3.add_scaled(-T::one(), &bspl);
207            self.cont[3] = cont3;
208
209            // If method has dense output stages, compute them
210            if let Some(bi) = &self.bi {
211                // Compute extra stages for dense output
212                if I > S {
213                    // First dense output coefficient, k{i=order+1}, is the derivative at the new point
214                    self.k[self.stages] = self.dydt.clone();
215
216                    for i in S + 1..I {
217                        let mut y_stage = self.y.clone();
218                        for j in 0..i {
219                            y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
220                        }
221
222                        ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
223                        evals.function += 1;
224                    }
225                }
226
227                // Compute dense output coefficients
228                for i in 4..self.order {
229                    self.cont[i].fill(T::zero());
230                    for j in 0..self.dense_stages {
231                        self.cont[i].add_scaled(bi[i][j], &self.k[j]);
232                    }
233                    self.cont[i].scale_by(self.h);
234                }
235            }
236
237            // For interpolation
238            self.t_prev = self.t;
239            self.y_prev = self.y.clone();
240            self.dydt_prev = self.k[0].clone();
241            self.h_prev = self.h;
242
243            // Update the state with new values
244            self.t = t_new;
245            self.y = y_new;
246            self.k[0] = self.dydt.clone();
247
248            // Check if previous step is rejected
249            if let Status::RejectedStep = self.status {
250                self.status = Status::Solving;
251
252                // Limit step size growth to avoid oscillations between accepted and rejected steps
253                scale = scale.min(T::one());
254            }
255        } else {
256            // Step Rejected
257            self.status = Status::RejectedStep;
258        }
259
260        // Update step size
261        self.h *= scale;
262
263        // Ensure step size is within bounds
264        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
265
266        // Apply step size filter
267        self.h = (self.filter)(self.h);
268
269        Ok(evals)
270    }
271
272    fn t(&self) -> T {
273        self.t
274    }
275    fn y(&self) -> &Y {
276        &self.y
277    }
278    fn t_prev(&self) -> T {
279        self.t_prev
280    }
281    fn y_prev(&self) -> &Y {
282        &self.y_prev
283    }
284    fn h(&self) -> T {
285        self.h
286    }
287    fn set_h(&mut self, h: T) {
288        self.h = (self.filter)(h);
289    }
290    fn status(&self) -> &Status<T, Y> {
291        &self.status
292    }
293    fn set_status(&mut self, status: Status<T, Y>) {
294        self.status = status;
295    }
296}
297
298impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
299    for ExplicitRungeKutta<Ordinary, DormandPrince, T, Y, O, S, I>
300{
301    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
302        // Check if interpolation is out of bounds
303        if t_interp < self.t_prev || t_interp > self.t {
304            return Err(Error::OutOfBounds {
305                t_interp,
306                t_prev: self.t_prev,
307                t_curr: self.t,
308            });
309        }
310
311        // Evaluate the interpolation polynomial at the requested time
312        let s = (t_interp - self.t_prev) / self.h_prev;
313        let s1 = T::one() - s;
314
315        // Functional implementation of: cont[0] + (cont[1] + (cont[2] + (cont[3] + conpar*s1)*s)*s1)*s
316        let ilast = self.cont.len() - 1;
317        let poly = (1..ilast)
318            .rev()
319            .fold(self.cont[ilast].clone(), |mut acc, i| {
320                let factor = if i >= 4 {
321                    // For the higher-order part (conpar), alternate s and s1 based on index parity
322                    if (ilast - i) % 2 == 1 { s1 } else { s }
323                } else {
324                    // For the main polynomial part, pattern is [s1, s, s1] for indices [3, 2, 1]
325                    if i % 2 == 1 { s1 } else { s }
326                };
327                acc.scale_by(factor);
328                acc.add_scaled(T::one(), &self.cont[i]);
329                acc
330            });
331
332        // Final multiplication by s for the outermost level
333        let y_interp = self.cont[0].plus_scaled(s, &poly);
334
335        Ok(y_interp)
336    }
337}