differential_equations/methods/apc/
apcv4.rs

1//! Adams-Predictor-Corrector 4th Order Variable Step Size Method
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, cubic_hermite_interpolate},
6    linalg::norm,
7    methods::{Adaptive, Ordinary, h_init::InitialStepSize},
8    ode::{ODE, OrdinaryNumericalMethod},
9    stats::Evals,
10    status::Status,
11    tolerance::Tolerance,
12    traits::{Real, State},
13    utils::{constrain_step_size, validate_step_size_parameters},
14};
15
16use super::AdamsPredictorCorrector;
17
18impl<T: Real, Y: State<T>> AdamsPredictorCorrector<Ordinary, Adaptive, T, Y, 4> {
19    ///// Adams-Predictor-Corrector 4th Order Variable Step Size Method.
20    ///
21    /// The Adams-Predictor-Corrector method is an explicit method that
22    /// uses the previous states to predict the next state. This implementation
23    /// uses a variable step size to maintain a desired accuracy.
24    /// It is recommended to start with a small step size so that tolerance
25    /// can be quickly met and the algorithm can adjust the step size accordingly.
26    ///
27    /// The First 3 steps are calculated using
28    /// the Runge-Kutta method of order 4(5) and then the Adams-Predictor-Corrector
29    /// method is used to calculate the remaining steps until the final time./ Create a Adams-Predictor-Corrector 4th Order Variable Step Size Method instance.
30    ///
31    /// # Example
32    ///
33    /// ```
34    /// use differential_equations::prelude::*;
35    /// use nalgebra::{SVector, vector};
36    ///
37    /// struct HarmonicOscillator {
38    ///     k: f64,
39    /// }
40    ///
41    /// impl ODE<f64, SVector<f64, 2>> for HarmonicOscillator {
42    ///     fn diff(&self, _t: f64, y: &SVector<f64, 2>, dydt: &mut SVector<f64, 2>) {
43    ///         dydt[0] = y[1];
44    ///         dydt[1] = -self.k * y[0];
45    ///     }
46    /// }
47    /// let mut apcv4 = AdamsPredictorCorrector::v4();
48    /// let t0 = 0.0;
49    /// let tf = 10.0;
50    /// let y0 = vector![1.0, 0.0];
51    /// let system = HarmonicOscillator { k: 1.0 };
52    /// let results = ODEProblem::new(system, t0, tf, y0).solve(&mut apcv4).unwrap();
53    /// let expected = vector![-0.83907153, 0.54402111];
54    /// assert!((results.y.last().unwrap()[0] - expected[0]).abs() < 1e-6);
55    /// assert!((results.y.last().unwrap()[1] - expected[1]).abs() < 1e-6);
56    /// ```
57    ///
58    ///
59    /// ## Warning
60    ///
61    /// This method is not suitable for stiff problems and can results in
62    /// extremely small step sizes and long computation times.```
63    pub fn v4() -> Self {
64        Self::default()
65    }
66}
67
68// Implement OrdinaryNumericalMethod Trait for APCV4
69impl<T: Real, Y: State<T>> OrdinaryNumericalMethod<T, Y>
70    for AdamsPredictorCorrector<Ordinary, Adaptive, T, Y, 4>
71{
72    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
73    where
74        F: ODE<T, Y>,
75    {
76        let mut evals = Evals::new();
77
78        self.tf = tf;
79
80        // If h0 is zero, calculate initial step size
81        if self.h0 == T::zero() {
82            // Only use adaptive step size calculation if the method supports it
83            let tol = Tolerance::Scalar(self.tol);
84            self.h0 = InitialStepSize::<Ordinary>::compute(
85                ode, t0, tf, y0, 4, &tol, &tol, self.h_min, self.h_max, &mut evals,
86            );
87            evals.function += 2;
88        }
89
90        // Check that the initial step size is set
91        match validate_step_size_parameters::<T, Y>(self.h0, T::zero(), T::infinity(), t0, tf) {
92            Ok(h0) => self.h = h0,
93            Err(status) => return Err(status),
94        }
95
96        // Initialize state
97        self.t = t0;
98        self.y = *y0;
99        self.t_prev[0] = t0;
100        self.y_prev[0] = *y0;
101
102        // Previous saved steps
103        self.t_old = t0;
104        self.y_old = *y0;
105
106        // Perform the first 3 steps using Runge-Kutta 4 method
107        let two = T::from_f64(2.0).unwrap();
108        let six = T::from_f64(6.0).unwrap();
109        for i in 1..=3 {
110            // Compute k1, k2, k3, k4 of Runge-Kutta 4
111            ode.diff(self.t, &self.y, &mut self.k[0]);
112            ode.diff(
113                self.t + self.h / two,
114                &(self.y + self.k[0] * (self.h / two)),
115                &mut self.k[1],
116            );
117            ode.diff(
118                self.t + self.h / two,
119                &(self.y + self.k[1] * (self.h / two)),
120                &mut self.k[2],
121            );
122            ode.diff(
123                self.t + self.h,
124                &(self.y + self.k[2] * self.h),
125                &mut self.k[3],
126            );
127
128            // Update State
129            self.y += (self.k[0] + self.k[1] * two + self.k[2] * two + self.k[3]) * (self.h / six);
130            self.t += self.h;
131            self.t_prev[i] = self.t;
132            self.y_prev[i] = self.y;
133            evals.function += 4; // 4 evaluations per Runge-Kutta step
134
135            if i == 1 {
136                self.dydt = self.k[0];
137                self.dydt_old = self.k[0];
138            }
139        }
140
141        self.status = Status::Initialized;
142        Ok(evals)
143    }
144
145    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
146    where
147        F: ODE<T, Y>,
148    {
149        let mut evals = Evals::new();
150
151        // Check if Max Steps Reached
152        if self.steps >= self.max_steps {
153            self.status = Status::Error(Error::MaxSteps {
154                t: self.t,
155                y: self.y,
156            });
157            return Err(Error::MaxSteps {
158                t: self.t,
159                y: self.y,
160            });
161        }
162        self.steps += 1;
163
164        // If Step size changed and it takes us to the final time perform a Runge-Kutta 4 step to finish
165        if self.h != self.t_prev[0] - self.t_prev[1] && self.t + self.h == self.tf {
166            let two = T::from_f64(2.0).unwrap();
167            let six = T::from_f64(6.0).unwrap();
168
169            // Perform a Runge-Kutta 4 step to finish.
170            ode.diff(self.t, &self.y, &mut self.k[0]);
171            ode.diff(
172                self.t + self.h / two,
173                &(self.y + self.k[0] * (self.h / two)),
174                &mut self.k[1],
175            );
176            ode.diff(
177                self.t + self.h / two,
178                &(self.y + self.k[1] * (self.h / two)),
179                &mut self.k[2],
180            );
181            ode.diff(
182                self.t + self.h,
183                &(self.y + self.k[2] * self.h),
184                &mut self.k[3],
185            );
186            evals.function += 4; // 4 evaluations per Runge-Kutta step
187
188            // Update State
189            self.y += (self.k[0] + self.k[1] * two + self.k[2] * two + self.k[3]) * (self.h / six);
190            self.t += self.h;
191            return Ok(evals);
192        }
193
194        // Compute derivatives for history
195        ode.diff(self.t_prev[3], &self.y_prev[3], &mut self.k[0]);
196        ode.diff(self.t_prev[2], &self.y_prev[2], &mut self.k[1]);
197        ode.diff(self.t_prev[1], &self.y_prev[1], &mut self.k[2]);
198        ode.diff(self.t_prev[0], &self.y_prev[0], &mut self.k[3]);
199
200        let predictor = self.y_prev[3]
201            + (self.k[0] * T::from_f64(55.0).unwrap() - self.k[1] * T::from_f64(59.0).unwrap()
202                + self.k[2] * T::from_f64(37.0).unwrap()
203                - self.k[3] * T::from_f64(9.0).unwrap())
204                * self.h
205                / T::from_f64(24.0).unwrap();
206
207        // Corrector step:
208        ode.diff(self.t + self.h, &predictor, &mut self.k[3]);
209        let corrector = self.y_prev[3]
210            + (self.k[3] * T::from_f64(9.0).unwrap() + self.k[0] * T::from_f64(19.0).unwrap()
211                - self.k[1] * T::from_f64(5.0).unwrap()
212                + self.k[2] * T::from_f64(1.0).unwrap())
213                * self.h
214                / T::from_f64(24.0).unwrap();
215
216        // Track number of evaluations
217        evals.function += 5;
218
219        // Calculate sigma for step size adjustment
220        let sigma = T::from_f64(19.0).unwrap() * norm(corrector - predictor)
221            / (T::from_f64(270.0).unwrap() * self.h.abs());
222
223        // Check if Step meets tolerance
224        if sigma <= self.tol {
225            // Update Previous step states
226            self.t_old = self.t;
227            self.y_old = self.y;
228            self.dydt_old = self.dydt;
229
230            // Update state
231            self.t += self.h;
232            self.y = corrector;
233
234            // Check if previous step rejected
235            if let Status::RejectedStep = self.status {
236                self.status = Status::Solving;
237            }
238
239            // Adjust Step Size if needed
240            let two = T::from_f64(2.0).unwrap();
241            let four = T::from_f64(4.0).unwrap();
242            let q = (self.tol / (two * sigma)).powf(T::from_f64(0.25).unwrap());
243            self.h = if q > four { four * self.h } else { q * self.h };
244
245            // Bound Step Size
246            let tf_t_abs = (self.tf - self.t).abs();
247            let four_div = tf_t_abs / four;
248            let h_max_effective = if self.h_max < four_div {
249                self.h_max
250            } else {
251                four_div
252            };
253
254            self.h = constrain_step_size(self.h, self.h_min, h_max_effective);
255
256            // Calculate Previous Steps with new step size
257            self.t_prev[0] = self.t;
258            self.y_prev[0] = self.y;
259            let two = T::from_f64(2.0).unwrap();
260            let six = T::from_f64(6.0).unwrap();
261            for i in 1..=3 {
262                // Compute k1, k2, k3, k4 of Runge-Kutta 4
263                ode.diff(self.t, &self.y, &mut self.k[0]);
264                ode.diff(
265                    self.t + self.h / two,
266                    &(self.y + self.k[0] * (self.h / two)),
267                    &mut self.k[1],
268                );
269                ode.diff(
270                    self.t + self.h / two,
271                    &(self.y + self.k[1] * (self.h / two)),
272                    &mut self.k[2],
273                );
274                ode.diff(
275                    self.t + self.h,
276                    &(self.y + self.k[2] * self.h),
277                    &mut self.k[3],
278                );
279
280                // Update State
281                self.y +=
282                    (self.k[0] + self.k[1] * two + self.k[2] * two + self.k[3]) * (self.h / six);
283                self.t += self.h;
284                self.t_prev[i] = self.t;
285                self.y_prev[i] = self.y;
286                self.evals += 4; // 4 evaluations per Runge-Kutta step
287
288                if i == 1 {
289                    self.dydt = self.k[0];
290                }
291            }
292        } else {
293            // Step Rejected
294            self.status = Status::RejectedStep;
295
296            // Adjust Step Size
297            let two = T::from_f64(2.0).unwrap();
298            let tenth = T::from_f64(0.1).unwrap();
299            let q = (self.tol / (two * sigma)).powf(T::from_f64(0.25).unwrap());
300            self.h = if q < tenth {
301                tenth * self.h
302            } else {
303                q * self.h
304            };
305
306            // Calculate Previous Steps with new step size
307            self.t_prev[0] = self.t;
308            self.y_prev[0] = self.y;
309            let two = T::from_f64(2.0).unwrap();
310            let six = T::from_f64(6.0).unwrap();
311            for i in 1..=3 {
312                // Compute k1, k2, k3, k4 of Runge-Kutta 4
313                ode.diff(self.t, &self.y, &mut self.k[0]);
314                ode.diff(
315                    self.t + self.h / two,
316                    &(self.y + self.k[0] * (self.h / two)),
317                    &mut self.k[1],
318                );
319                ode.diff(
320                    self.t + self.h / two,
321                    &(self.y + self.k[1] * (self.h / two)),
322                    &mut self.k[2],
323                );
324                ode.diff(
325                    self.t + self.h,
326                    &(self.y + self.k[2] * self.h),
327                    &mut self.k[3],
328                );
329
330                // Update State
331                self.y +=
332                    (self.k[0] + self.k[1] * two + self.k[2] * two + self.k[3]) * (self.h / six);
333                self.t += self.h;
334                self.t_prev[i] = self.t;
335                self.y_prev[i] = self.y;
336                self.evals += 4; // 4 evaluations per Runge-Kutta step
337            }
338        }
339        Ok(evals)
340    }
341
342    fn t(&self) -> T {
343        self.t
344    }
345
346    fn y(&self) -> &Y {
347        &self.y
348    }
349
350    fn t_prev(&self) -> T {
351        self.t_old
352    }
353
354    fn y_prev(&self) -> &Y {
355        &self.y_old
356    }
357
358    fn h(&self) -> T {
359        // OrdinaryNumericalMethod repeats step size 4 times for each step
360        // so the ODEProblem inquiring is looking for what the next
361        // state will be thus the step size is multiplied by 4
362        self.h * T::from_f64(4.0).unwrap()
363    }
364
365    fn set_h(&mut self, h: T) {
366        self.h = h;
367    }
368
369    fn status(&self) -> &Status<T, Y> {
370        &self.status
371    }
372
373    fn set_status(&mut self, status: Status<T, Y>) {
374        self.status = status;
375    }
376}
377
378// Implement the Interpolation trait for APCV4
379impl<T: Real, Y: State<T>> Interpolation<T, Y>
380    for AdamsPredictorCorrector<Ordinary, Adaptive, T, Y, 4>
381{
382    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
383        // Check if t is within the range of the solver
384        if t_interp < self.t_old || t_interp > self.t {
385            return Err(Error::OutOfBounds {
386                t_interp,
387                t_prev: self.t_old,
388                t_curr: self.t,
389            });
390        }
391
392        // Calculate the interpolated value using cubic Hermite interpolation
393        let y_interp = cubic_hermite_interpolate(
394            self.t_old,
395            self.t,
396            &self.y_old,
397            &self.y,
398            &self.dydt_old,
399            &self.dydt,
400            t_interp,
401        );
402
403        Ok(y_interp)
404    }
405}