differential_equations/methods/apc/
apcv4.rs

1//! Adams-Predictor-Corrector 4th Order Variable Step Size Method
2
3use super::AdamsPredictorCorrector;
4use crate::{
5    Error, Status,
6    alias::Evals,
7    interpolate::{Interpolation, cubic_hermite_interpolate},
8    linalg::norm,
9    ode::{OrdinaryNumericalMethod, ODE},
10    traits::{CallBackData, Real, State},
11    utils::{constrain_step_size, validate_step_size_parameters},
12    methods::{Ordinary, Adaptive, h_init::InitialStepSize},
13};
14
15impl<T: Real, V: State<T>, D: CallBackData> AdamsPredictorCorrector<Ordinary, Adaptive, T, V, D, 4> {
16    ///// Adams-Predictor-Corrector 4th Order Variable Step Size Method.
17    ///
18    /// The Adams-Predictor-Corrector method is an explicit method that
19    /// uses the previous states to predict the next state. This implementation
20    /// uses a variable step size to maintain a desired accuracy.
21    /// It is recommended to start with a small step size so that tolerance
22    /// can be quickly met and the algorithm can adjust the step size accordingly.
23    ///
24    /// The First 3 steps are calculated using
25    /// the Runge-Kutta method of order 4(5) and then the Adams-Predictor-Corrector
26    /// method is used to calculate the remaining steps until the final time./ Create a Adams-Predictor-Corrector 4th Order Variable Step Size Method instance.
27    /// 
28    /// # Example
29    ///
30    /// ```
31    /// use differential_equations::prelude::*;
32    /// use nalgebra::{SVector, vector};
33    ///
34    /// struct HarmonicOscillator {
35    ///     k: f64,
36    /// }
37    ///
38    /// impl ODE<f64, SVector<f64, 2>> for HarmonicOscillator {
39    ///     fn diff(&self, _t: f64, y: &SVector<f64, 2>, dydt: &mut SVector<f64, 2>) {
40    ///         dydt[0] = y[1];
41    ///         dydt[1] = -self.k * y[0];
42    ///     }
43    /// }
44    /// let mut apcv4 = AdamsPredictorCorrector::v4();
45    /// let t0 = 0.0;
46    /// let tf = 10.0;
47    /// let y0 = vector![1.0, 0.0];
48    /// let system = HarmonicOscillator { k: 1.0 };
49    /// let results = ODEProblem::new(system, t0, tf, y0).solve(&mut apcv4).unwrap();
50    /// let expected = vector![-0.83907153, 0.54402111];
51    /// assert!((results.y.last().unwrap()[0] - expected[0]).abs() < 1e-6);
52    /// assert!((results.y.last().unwrap()[1] - expected[1]).abs() < 1e-6);
53    /// ```
54    /// 
55    ///
56    /// ## Warning
57    ///
58    /// This method is not suitable for stiff problems and can results in
59    /// extremely small step sizes and long computation times.```
60    pub fn v4() -> Self {
61        Self::default()
62    }
63}
64
65// Implement OrdinaryNumericalMethod Trait for APCV4
66impl<T: Real, V: State<T>, D: CallBackData> OrdinaryNumericalMethod<T, V, D> for AdamsPredictorCorrector<Ordinary, Adaptive, T, V, D, 4> {
67    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
68    where
69        F: ODE<T, V, D>,
70    {
71        let mut evals = Evals::new();
72
73        self.tf = tf;
74        
75        // If h0 is zero, calculate initial step size
76        if self.h0 == T::zero() {
77            // Only use adaptive step size calculation if the method supports it
78            self.h0 = InitialStepSize::<Ordinary>::compute(ode, t0, tf, y0, 4, self.tol, self.tol, self.h_min, self.h_max, &mut evals);
79            evals.fcn += 2;
80
81        }
82
83        // Check that the initial step size is set
84        match validate_step_size_parameters::<T, V, D>(self.h0, T::zero(), T::infinity(), t0, tf) {
85            Ok(h0) => self.h = h0,
86            Err(status) => return Err(status),
87        }
88
89        // Initialize state
90        self.t = t0;
91        self.y = *y0;
92        self.t_prev[0] = t0;
93        self.y_prev[0] = *y0;
94
95        // Previous saved steps
96        self.t_old = t0;
97        self.y_old = *y0;
98
99        // Perform the first 3 steps using Runge-Kutta 4 method
100        let two = T::from_f64(2.0).unwrap();
101        let six = T::from_f64(6.0).unwrap();
102        for i in 1..=3 {
103            // Compute k1, k2, k3, k4 of Runge-Kutta 4
104            ode.diff(self.t, &self.y, &mut self.k[0]);
105            ode.diff(
106                self.t + self.h / two,
107                &(self.y + self.k[0] * (self.h / two)),
108                &mut self.k[1],
109            );
110            ode.diff(
111                self.t + self.h / two,
112                &(self.y + self.k[1] * (self.h / two)),
113                &mut self.k[2],
114            );
115            ode.diff(self.t + self.h, &(self.y + self.k[2] * self.h), &mut self.k[3]);
116
117            // Update State
118            self.y += (self.k[0] + self.k[1] * two + self.k[2] * two + self.k[3]) * (self.h / six);
119            self.t += self.h;
120            self.t_prev[i] = self.t;
121            self.y_prev[i] = self.y;
122            evals.fcn += 4; // 4 evaluations per Runge-Kutta step
123
124            if i == 1 {
125                self.dydt = self.k[0];
126                self.dydt_old = self.k[0];
127            }
128        }
129
130        self.status = Status::Initialized;
131        Ok(evals)
132    }
133
134    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
135    where
136        F: ODE<T, V, D>,
137    {
138        let mut evals = Evals::new();
139
140        // Check if Max Steps Reached
141        if self.steps >= self.max_steps {
142            self.status = Status::Error(Error::MaxSteps {
143                t: self.t,
144                y: self.y,
145            });
146            return Err(Error::MaxSteps {
147                t: self.t,
148                y: self.y,
149            });
150        }
151        self.steps += 1;
152
153        // If Step size changed and it takes us to the final time perform a Runge-Kutta 4 step to finish
154        if self.h != self.t_prev[0] - self.t_prev[1] && self.t + self.h == self.tf {
155            let two = T::from_f64(2.0).unwrap();
156            let six = T::from_f64(6.0).unwrap();
157
158            // Perform a Runge-Kutta 4 step to finish.
159            ode.diff(self.t, &self.y, &mut self.k[0]);
160            ode.diff(
161                self.t + self.h / two,
162                &(self.y + self.k[0] * (self.h / two)),
163                &mut self.k[1],
164            );
165            ode.diff(
166                self.t + self.h / two,
167                &(self.y + self.k[1] * (self.h / two)),
168                &mut self.k[2],
169            );
170            ode.diff(self.t + self.h, &(self.y + self.k[2] * self.h), &mut self.k[3]);
171            evals.fcn += 4; // 4 evaluations per Runge-Kutta step
172
173            // Update State
174            self.y += (self.k[0] + self.k[1] * two + self.k[2] * two + self.k[3]) * (self.h / six);
175            self.t += self.h;
176            return Ok(evals);
177        }
178
179        // Compute derivatives for history
180        ode.diff(self.t_prev[3], &self.y_prev[3], &mut self.k[0]);
181        ode.diff(self.t_prev[2], &self.y_prev[2], &mut self.k[1]);
182        ode.diff(self.t_prev[1], &self.y_prev[1], &mut self.k[2]);
183        ode.diff(self.t_prev[0], &self.y_prev[0], &mut self.k[3]);
184
185        let predictor = self.y_prev[3]
186            + (self.k[0] * T::from_f64(55.0).unwrap() - self.k[1] * T::from_f64(59.0).unwrap()
187                + self.k[2] * T::from_f64(37.0).unwrap()
188                - self.k[3] * T::from_f64(9.0).unwrap())
189                * self.h
190                / T::from_f64(24.0).unwrap();
191
192        // Corrector step:
193        ode.diff(self.t + self.h, &predictor, &mut self.k[3]);
194        let corrector = self.y_prev[3]
195            + (self.k[3] * T::from_f64(9.0).unwrap() + self.k[0] * T::from_f64(19.0).unwrap()
196                - self.k[1] * T::from_f64(5.0).unwrap()
197                + self.k[2] * T::from_f64(1.0).unwrap())
198                * self.h
199                / T::from_f64(24.0).unwrap();
200
201        // Track number of evaluations
202        evals.fcn += 5;
203
204        // Calculate sigma for step size adjustment
205        let sigma = T::from_f64(19.0).unwrap() * norm(corrector - predictor)
206            / (T::from_f64(270.0).unwrap() * self.h.abs());
207
208        // Check if Step meets tolerance
209        if sigma <= self.tol {
210            // Update Previous step states
211            self.t_old = self.t;
212            self.y_old = self.y;
213            self.dydt_old = self.dydt;
214
215            // Update state
216            self.t += self.h;
217            self.y = corrector;
218
219            // Check if previous step rejected
220            if let Status::RejectedStep = self.status {
221                self.status = Status::Solving;
222            }
223
224            // Adjust Step Size if needed
225            let two = T::from_f64(2.0).unwrap();
226            let four = T::from_f64(4.0).unwrap();
227            let q = (self.tol / (two * sigma)).powf(T::from_f64(0.25).unwrap());
228            self.h = if q > four { four * self.h } else { q * self.h };
229
230            // Bound Step Size
231            let tf_t_abs = (self.tf - self.t).abs();
232            let four_div = tf_t_abs / four;
233            let h_max_effective = if self.h_max < four_div {
234                self.h_max
235            } else {
236                four_div
237            };
238
239            self.h = constrain_step_size(self.h, self.h_min, h_max_effective);
240
241            // Calculate Previous Steps with new step size
242            self.t_prev[0] = self.t;
243            self.y_prev[0] = self.y;
244            let two = T::from_f64(2.0).unwrap();
245            let six = T::from_f64(6.0).unwrap();
246            for i in 1..=3 {
247                // Compute k1, k2, k3, k4 of Runge-Kutta 4
248                ode.diff(self.t, &self.y, &mut self.k[0]);
249                ode.diff(
250                    self.t + self.h / two,
251                    &(self.y + self.k[0] * (self.h / two)),
252                    &mut self.k[1],
253                );
254                ode.diff(
255                    self.t + self.h / two,
256                    &(self.y + self.k[1] * (self.h / two)),
257                    &mut self.k[2],
258                );
259                ode.diff(self.t + self.h, &(self.y + self.k[2] * self.h), &mut self.k[3]);
260
261                // Update State
262                self.y += (self.k[0] + self.k[1] * two + self.k[2] * two + self.k[3]) * (self.h / six);
263                self.t += self.h;
264                self.t_prev[i] = self.t;
265                self.y_prev[i] = self.y;
266                self.evals += 4; // 4 evaluations per Runge-Kutta step
267
268                if i == 1 {
269                    self.dydt = self.k[0];
270                }
271            }
272        } else {
273            // Step Rejected
274            self.status = Status::RejectedStep;
275
276            // Adjust Step Size
277            let two = T::from_f64(2.0).unwrap();
278            let tenth = T::from_f64(0.1).unwrap();
279            let q = (self.tol / (two * sigma)).powf(T::from_f64(0.25).unwrap());
280            self.h = if q < tenth {
281                tenth * self.h
282            } else {
283                q * self.h
284            };
285
286            // Calculate Previous Steps with new step size
287            self.t_prev[0] = self.t;
288            self.y_prev[0] = self.y;
289            let two = T::from_f64(2.0).unwrap();
290            let six = T::from_f64(6.0).unwrap();
291            for i in 1..=3 {
292                // Compute k1, k2, k3, k4 of Runge-Kutta 4
293                ode.diff(self.t, &self.y, &mut self.k[0]);
294                ode.diff(
295                    self.t + self.h / two,
296                    &(self.y + self.k[0] * (self.h / two)),
297                    &mut self.k[1],
298                );
299                ode.diff(
300                    self.t + self.h / two,
301                    &(self.y + self.k[1] * (self.h / two)),
302                    &mut self.k[2],
303                );
304                ode.diff(self.t + self.h, &(self.y + self.k[2] * self.h), &mut self.k[3]);
305
306                // Update State
307                self.y += (self.k[0] + self.k[1] * two + self.k[2] * two + self.k[3]) * (self.h / six);
308                self.t += self.h;
309                self.t_prev[i] = self.t;
310                self.y_prev[i] = self.y;
311                self.evals += 4; // 4 evaluations per Runge-Kutta step
312            }
313        }
314        Ok(evals)
315    }
316
317    fn t(&self) -> T {
318        self.t
319    }
320
321    fn y(&self) -> &V {
322        &self.y
323    }
324
325    fn t_prev(&self) -> T {
326        self.t_old
327    }
328
329    fn y_prev(&self) -> &V {
330        &self.y_old
331    }
332
333    fn h(&self) -> T {
334        // OrdinaryNumericalMethod repeats step size 4 times for each step
335        // so the ODEProblem inquiring is looking for what the next
336        // state will be thus the step size is multiplied by 4
337        self.h * T::from_f64(4.0).unwrap()
338    }
339
340    fn set_h(&mut self, h: T) {
341        self.h = h;
342    }
343
344    fn status(&self) -> &Status<T, V, D> {
345        &self.status
346    }
347
348    fn set_status(&mut self, status: Status<T, V, D>) {
349        self.status = status;
350    }
351}
352
353// Implement the Interpolation trait for APCV4
354impl<T: Real, V: State<T>, D: CallBackData> Interpolation<T, V> for AdamsPredictorCorrector<Ordinary, Adaptive, T, V, D, 4> {
355    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
356        // Check if t is within the range of the solver
357        if t_interp < self.t_old || t_interp > self.t {
358            return Err(Error::OutOfBounds {
359                t_interp,
360                t_prev: self.t_old,
361                t_curr: self.t,
362            });
363        }
364
365        // Calculate the interpolated value using cubic Hermite interpolation
366        let y_interp = cubic_hermite_interpolate(
367            self.t_old,
368            self.t,
369            &self.y_old,
370            &self.y,
371            &self.dydt_old,
372            &self.dydt,
373            t_interp,
374        );
375
376        Ok(y_interp)
377    }
378}