differential_equations/methods/erk/dormandprince/
ordinary.rs

1//! Dormand-Prince Runge-Kutta methods for ODEs
2
3use super::{ExplicitRungeKutta, Ordinary, DormandPrince};
4use crate::{
5    Error, Status,
6    alias::Evals,
7    methods::h_init::InitialStepSize,
8    interpolate::Interpolation,
9    ode::{OrdinaryNumericalMethod, ODE},
10    traits::{CallBackData, Real, State},
11    utils::{constrain_step_size, validate_step_size_parameters},
12};
13
14impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> OrdinaryNumericalMethod<T, V, D> for ExplicitRungeKutta<Ordinary, DormandPrince, T, V, D, O, S, I> {
15    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
16    where
17        F: ODE<T, V, D>,
18    {
19        let mut evals = Evals::new();
20
21        // If h0 is zero, calculate initial step size
22        if self.h0 == T::zero() {
23            // Use adaptive step size calculation for Dormand-Prince methods
24            self.h0 = InitialStepSize::<Ordinary>::compute(ode, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, &mut evals);
25        }
26
27        // Check bounds
28        match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
29            Ok(h0) => self.h = h0,
30            Err(status) => return Err(status),
31        }
32
33        // Initialize Statistics
34        self.stiffness_counter = 0;
35
36        // Initialize State
37        self.t = t0;
38        self.y = *y0;
39        ode.diff(self.t, &self.y, &mut self.k[0]);
40        self.dydt = self.k[0];
41        evals.fcn += 1;
42
43        // Initialize previous state
44        self.t_prev = self.t;
45        self.y_prev = self.y;
46        self.dydt_prev = self.dydt;
47
48        // Initialize Status
49        self.status = Status::Initialized;
50
51        Ok(evals)
52    }
53
54    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
55    where
56        F: ODE<T, V, D>,
57    {
58        let mut evals = Evals::new();
59
60        // Check if step-size is becoming too small
61        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
62            self.status = Status::Error(Error::StepSize {
63                t: self.t, y: self.y
64            });
65            return Err(Error::StepSize {
66                t: self.t, y: self.y
67            });
68        }
69
70        // Check max steps
71        if self.steps >= self.max_steps {
72            self.status = Status::Error(Error::MaxSteps {
73                t: self.t, y: self.y
74            });
75            return Err(Error::MaxSteps {
76                t: self.t, y: self.y
77            });
78        }
79        self.steps += 1;
80
81        // Compute stages
82        let mut y_stage = V::zeros();
83        for i in 1..self.stages {
84            y_stage = V::zeros();
85
86            for j in 0..i {
87                y_stage += self.k[j] * self.a[i][j];
88            }
89            y_stage = self.y + y_stage * self.h;
90
91            ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);        
92        }
93
94        // The last stage will be used for stiffness detection
95        let ysti = y_stage;
96
97        // Calculate the line segment for the new y value
98        let mut yseg = V::zeros();
99        for i in 0..self.stages {
100            yseg += self.k[i] * self.b[i];
101        }
102
103        // Calculate the new y value using the line segment
104        let y_new = self.y + yseg * self.h;
105
106        // Evaluate derivative at new point for error estimation
107        let t_new = self.t + self.h;
108
109        // Number of function evaluations
110        evals.fcn += self.stages - 1; // We already have k[0]
111
112        // Error estimation
113        let er = self.er.unwrap();
114        let n = self.y.len();
115        let mut err = T::zero();
116        let mut err2 = T::zero();
117        let mut erri;
118        for i in 0..n {
119            // Calculate the error scale
120            let sk = self.atol + self.rtol * self.y.get(i).abs().max(y_new.get(i).abs());
121
122            // Primary error term
123            erri = T::zero();
124            for j in 0..self.stages {
125                erri += er[j] * self.k[j].get(i);
126            }
127            err += (erri / sk).powi(2);
128
129            // Optional secondary error term
130            if let Some(bh) = &self.bh {
131                erri = yseg.get(i);
132                for j in 0..self.stages {
133                    erri -= bh[j] * self.k[j].get(i);
134                }
135                err2 += (erri / sk).powi(2);
136            }
137        }
138        let mut deno = err + T::from_f64(0.01).unwrap() * err2;
139        if deno <= T::zero() {
140            deno = T::one();
141        }
142        err = self.h.abs() * err * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
143
144        // Step size scale factor
145        let order = T::from_usize(self.order).unwrap();
146        let error_exponent = T::one() / order;
147        let mut scale = self.safety_factor * err.powf(-error_exponent);
148        
149        // Clamp scale factor to prevent extreme step size changes
150        scale = scale.max(self.min_scale).min(self.max_scale);
151
152        // Determine if step is accepted
153        if err <= T::one() {
154            // Calculate the new derivative at the new point
155            ode.diff(t_new, &y_new, &mut self.dydt);
156            evals.fcn += 1;
157
158            // stiffness detection
159            let n_stiff_threshold = 100;
160            if self.steps % n_stiff_threshold == 0 {
161                let mut stdnum = T::zero();
162                let mut stden = T::zero();
163                let sqr = yseg - self.k[S-1];
164                for i in 0..sqr.len() {
165                    stdnum += sqr.get(i).powi(2);
166                }
167                let sqr = self.dydt - ysti;
168                for i in 0..sqr.len() {
169                    stden += sqr.get(i).powi(2);
170                }
171
172                if stden > T::zero() {
173                    let h_lamb = self.h * (stdnum / stden).sqrt();
174                    if h_lamb > T::from_f64(6.1).unwrap() {
175                        self.non_stiffness_counter = 0;
176                        self.stiffness_counter += 1;
177                        if self.stiffness_counter == 15 {
178                            // Early Exit Stiffness Detected
179                            self.status = Status::Error(Error::Stiffness {
180                                t: self.t,
181                                y: self.y,
182                            });
183                            return Err(Error::Stiffness {
184                                t: self.t,
185                                y: self.y,
186                            });
187                        }
188                    }
189                } else {
190                    self.non_stiffness_counter += 1;
191                    if self.non_stiffness_counter == 6 {
192                        self.stiffness_counter = 0;
193                    }
194                }
195            }
196
197            // Preparation for dense output / interpolation
198            self.cont[0] = self.y;
199            let ydiff = y_new - self.y;
200            self.cont[1] = ydiff;
201            let bspl = self.k[0] * self.h - ydiff;
202            self.cont[2] = bspl;
203            self.cont[3] = ydiff - self.dydt * self.h - bspl;
204
205            // If method has dense output stages, compute them
206            if let Some(bi) = &self.bi {
207                // Compute extra stages for dense output
208                if I > S {
209                    // First dense output coefficient, k{i=order+1}, is the derivative at the new point
210                    self.k[self.stages] = self.dydt;
211
212                    for i in S+1..I {
213                        let mut y_stage = V::zeros();
214                        for j in 0..i {
215                            y_stage += self.k[j] * self.a[i][j];
216                        }
217                        y_stage = self.y + y_stage * self.h;
218
219                        ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
220                        evals.fcn += 1;
221                    }
222                }
223
224                // Compute dense output coefficients
225                for i in 4..self.order {
226                    self.cont[i] = V::zeros();
227                    for j in 0..self.dense_stages {
228                        self.cont[i] += self.k[j] * bi[i][j];
229                    }
230                    self.cont[i] = self.cont[i] * self.h;
231                }
232            }
233
234            // For interpolation
235            self.t_prev = self.t;
236            self.y_prev = self.y;
237            self.dydt_prev = self.k[0];
238            self.h_prev = self.h;
239
240            // Update the state with new values
241            self.t = t_new;
242            self.y = y_new;
243            self.k[0] = self.dydt;
244
245            // Check if previous step is rejected
246            if let Status::RejectedStep = self.status {
247                self.status = Status::Solving;
248
249                // Limit step size growth to avoid oscillations between accepted and rejected steps
250                scale = scale.min(T::one());
251            }
252        } else {
253            // Step Rejected
254            self.status = Status::RejectedStep;
255        }
256
257        // Update step size
258        self.h *= scale;
259
260        // Ensure step size is within bounds
261        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
262        
263        Ok(evals)
264    }
265
266    fn t(&self) -> T { self.t }
267    fn y(&self) -> &V { &self.y }
268    fn t_prev(&self) -> T { self.t_prev }
269    fn y_prev(&self) -> &V { &self.y_prev }
270    fn h(&self) -> T { self.h }
271    fn set_h(&mut self, h: T) { self.h = h; }
272    fn status(&self) -> &Status<T, V, D> { &self.status }
273    fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
274}
275
276impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Ordinary, DormandPrince, T, V, D, O, S, I> {
277    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {        
278        // Check if interpolation is out of bounds
279        if t_interp < self.t_prev || t_interp > self.t {
280            return Err(Error::OutOfBounds {
281                t_interp,
282                t_prev: self.t_prev,
283                t_curr: self.t,
284            });
285        }        
286        
287        // Evaluate the interpolation polynomial at the requested time
288        let s = (t_interp - self.t_prev) / self.h_prev;
289        let s1 = T::one() - s;        
290        
291        // Functional implementation of: cont[0] + (cont[1] + (cont[2] + (cont[3] + conpar*s1)*s)*s1)*s
292        let ilast = self.cont.len() - 1;
293        let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {            
294            let factor = if i >= 4 {
295                // For the higher-order part (conpar), alternate s and s1 based on index parity
296                if (ilast - i) % 2 == 1 { s1 } else { s }
297            } else {
298                // For the main polynomial part, pattern is [s1, s, s1] for indices [3, 2, 1]
299                if i % 2 == 1 { s1 } else { s }
300            };
301            acc * factor + self.cont[i]
302        });
303        
304        // Final multiplication by s for the outermost level
305        let y_interp = self.cont[0] + poly * s;
306
307        Ok(y_interp)
308    }
309}