Skip to main content

differential_equations/methods/apc/
apcv4.rs

1//! Adams-Predictor-Corrector 4th Order Variable Step Size Method
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, cubic_hermite_interpolate},
6    methods::{Adaptive, Ordinary, h_init::InitialStepSize},
7    ode::{ODE, OrdinaryNumericalMethod},
8    stats::Evals,
9    status::Status,
10    tolerance::Tolerance,
11    traits::{Real, State},
12    utils::{constrain_step_size, validate_step_size_parameters},
13};
14
15use super::AdamsPredictorCorrector;
16
17impl<T: Real, Y: State<T>> AdamsPredictorCorrector<Ordinary, Adaptive, T, Y, 4> {
18    ///// Adams-Predictor-Corrector 4th Order Variable Step Size Method.
19    ///
20    /// The Adams-Predictor-Corrector method is an explicit method that
21    /// uses the previous states to predict the next state. This implementation
22    /// uses a variable step size to maintain a desired accuracy.
23    /// It is recommended to start with a small step size so that tolerance
24    /// can be quickly met and the algorithm can adjust the step size accordingly.
25    ///
26    /// The First 3 steps are calculated using
27    /// the Runge-Kutta method of order 4(5) and then the Adams-Predictor-Corrector
28    /// method is used to calculate the remaining steps until the final time./ Create a Adams-Predictor-Corrector 4th Order Variable Step Size Method instance.
29    ///
30    /// # Example
31    ///
32    /// ```rust
33    /// use differential_equations::prelude::*;
34    ///
35    /// struct HarmonicOscillator {
36    ///     k: f64,
37    /// }
38    ///
39    /// impl ODE<f64, [f64; 2]> for HarmonicOscillator {
40    ///     fn diff(&self, _t: f64, y: &[f64; 2], dydt: &mut [f64; 2]) {
41    ///         dydt[0] = y[1];
42    ///         dydt[1] = -self.k * y[0];
43    ///     }
44    /// }
45    /// let apcv4 = AdamsPredictorCorrector::v4();
46    /// let t0 = 0.0;
47    /// let tf = 10.0;
48    /// let y0 = [1.0, 0.0];
49    /// let system = HarmonicOscillator { k: 1.0 };
50    /// let results = IVP::ode(&system, t0, tf, y0).method(apcv4).solve().unwrap();
51    /// let expected = [-0.83907153, 0.54402111];
52    /// assert!((results.y.last().unwrap()[0] - expected[0]).abs() < 1e-6);
53    /// assert!((results.y.last().unwrap()[1] - expected[1]).abs() < 1e-6);
54    /// ```
55    ///
56    ///
57    /// ## Warning
58    ///
59    /// This method is not suitable for stiff problems and can results in
60    /// extremely small step sizes and long computation times.```
61    pub fn v4() -> Self {
62        Self::default()
63    }
64
65    fn rk4_step<F>(ode: &F, t: &mut T, y: &mut Y, h: T, k: &mut [Y; 4]) -> usize
66    where
67        F: ODE<T, Y> + ?Sized,
68    {
69        let two = T::from_f64(2.0).unwrap();
70        let six = T::from_f64(6.0).unwrap();
71
72        ode.diff(*t, y, &mut k[0]);
73        ode.diff(*t + h / two, &y.plus_scaled(h / two, &k[0]), &mut k[1]);
74        ode.diff(*t + h / two, &y.plus_scaled(h / two, &k[1]), &mut k[2]);
75        ode.diff(*t + h, &y.plus_scaled(h, &k[2]), &mut k[3]);
76
77        y.add_scaled(h / six, &k[0]);
78        y.add_scaled(two * h / six, &k[1]);
79        y.add_scaled(two * h / six, &k[2]);
80        y.add_scaled(h / six, &k[3]);
81        *t += h;
82        4
83    }
84}
85
86// Implement OrdinaryNumericalMethod Trait for APCV4
87impl<T: Real, Y: State<T>> OrdinaryNumericalMethod<T, Y>
88    for AdamsPredictorCorrector<Ordinary, Adaptive, T, Y, 4>
89{
90    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
91    where
92        F: ODE<T, Y> + ?Sized,
93    {
94        let mut evals = Evals::new();
95
96        self.tf = tf;
97
98        // If h0 is zero, calculate initial step size
99        if self.h0 == T::zero() {
100            // Only use adaptive step size calculation if the method supports it
101            let tol = Tolerance::Scalar(self.tol);
102            self.h0 = InitialStepSize::<Ordinary>::compute(
103                ode, t0, tf, y0, 4, &tol, &tol, self.h_min, self.h_max, &mut evals,
104            );
105            evals.function += 2;
106        }
107
108        // Check that the initial step size is set
109        match validate_step_size_parameters::<T, Y>(self.h0, T::zero(), T::infinity(), t0, tf) {
110            Ok(h0) => self.h = (self.filter)(h0),
111            Err(status) => return Err(status),
112        }
113
114        // Initialize state
115        self.t = t0;
116        self.y = y0.clone();
117        self.dydt = y0.zeros_like();
118        self.dydt_old = y0.zeros_like();
119        self.y_old = y0.clone();
120        self.y_prev = core::array::from_fn(|_| y0.zeros_like());
121        self.k = core::array::from_fn(|_| y0.zeros_like());
122        self.t_prev[0] = t0;
123        self.y_prev[0] = y0.clone();
124
125        // Previous saved steps
126        self.t_old = t0;
127        self.y_old = y0.clone();
128
129        // Perform the first 3 steps using Runge-Kutta 4 method
130        for i in 1..=3 {
131            evals.function += Self::rk4_step(ode, &mut self.t, &mut self.y, self.h, &mut self.k);
132            self.t_prev[i] = self.t;
133            self.y_prev[i] = self.y.clone();
134
135            if i == 1 {
136                self.dydt = self.k[0].clone();
137                self.dydt_old = self.k[0].clone();
138            }
139        }
140
141        self.status = Status::Initialized;
142        Ok(evals)
143    }
144
145    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
146    where
147        F: ODE<T, Y> + ?Sized,
148    {
149        let mut evals = Evals::new();
150
151        // Check if Max Steps Reached
152        if self.steps >= self.max_steps {
153            self.status = Status::Error(Error::MaxSteps {
154                t: self.t,
155                y: self.y.clone(),
156            });
157            return Err(Error::MaxSteps {
158                t: self.t,
159                y: self.y.clone(),
160            });
161        }
162        self.steps += 1;
163
164        // If Step size changed and it takes us to the final time perform a Runge-Kutta 4 step to finish
165        if self.h != self.t_prev[0] - self.t_prev[1] && self.t + self.h == self.tf {
166            evals.function += Self::rk4_step(ode, &mut self.t, &mut self.y, self.h, &mut self.k);
167            return Ok(evals);
168        }
169
170        // Compute derivatives for history
171        ode.diff(self.t_prev[3], &self.y_prev[3], &mut self.k[0]);
172        ode.diff(self.t_prev[2], &self.y_prev[2], &mut self.k[1]);
173        ode.diff(self.t_prev[1], &self.y_prev[1], &mut self.k[2]);
174        ode.diff(self.t_prev[0], &self.y_prev[0], &mut self.k[3]);
175
176        let predictor = self.y_prev[3].plus_linear_combination(&[
177            (
178                &self.k[0],
179                self.h * T::from_f64(55.0).unwrap() / T::from_f64(24.0).unwrap(),
180            ),
181            (
182                &self.k[1],
183                -self.h * T::from_f64(59.0).unwrap() / T::from_f64(24.0).unwrap(),
184            ),
185            (
186                &self.k[2],
187                self.h * T::from_f64(37.0).unwrap() / T::from_f64(24.0).unwrap(),
188            ),
189            (
190                &self.k[3],
191                -self.h * T::from_f64(9.0).unwrap() / T::from_f64(24.0).unwrap(),
192            ),
193        ]);
194
195        // Corrector step:
196        ode.diff(self.t + self.h, &predictor, &mut self.k[3]);
197        let corrector = self.y_prev[3].plus_linear_combination(&[
198            (
199                &self.k[3],
200                self.h * T::from_f64(9.0).unwrap() / T::from_f64(24.0).unwrap(),
201            ),
202            (
203                &self.k[0],
204                self.h * T::from_f64(19.0).unwrap() / T::from_f64(24.0).unwrap(),
205            ),
206            (
207                &self.k[1],
208                -self.h * T::from_f64(5.0).unwrap() / T::from_f64(24.0).unwrap(),
209            ),
210            (&self.k[2], self.h / T::from_f64(24.0).unwrap()),
211        ]);
212
213        // Track number of evaluations
214        evals.function += 5;
215
216        // Calculate sigma for step size adjustment
217        let sigma = T::from_f64(19.0).unwrap() * corrector.diff_norm_squared(&predictor).sqrt()
218            / (T::from_f64(270.0).unwrap() * self.h.abs());
219
220        // Check if Step meets tolerance
221        if sigma <= self.tol {
222            // Update Previous step states
223            self.t_old = self.t;
224            self.y_old = self.y.clone();
225            self.dydt_old = self.dydt.clone();
226
227            // Update state
228            self.t += self.h;
229            self.y = corrector;
230
231            // Check if previous step rejected
232            if let Status::RejectedStep = self.status {
233                self.status = Status::Solving;
234            }
235
236            // Adjust Step Size if needed
237            let two = T::from_f64(2.0).unwrap();
238            let four = T::from_f64(4.0).unwrap();
239            let q = (self.tol / (two * sigma)).powf(T::from_f64(0.25).unwrap());
240            self.h = if q > four { four * self.h } else { q * self.h };
241
242            // Bound Step Size
243            let tf_t_abs = (self.tf - self.t).abs();
244            let four_div = tf_t_abs / four;
245            let h_max_effective = if self.h_max < four_div {
246                self.h_max
247            } else {
248                four_div
249            };
250
251            self.h = constrain_step_size(self.h, self.h_min, h_max_effective);
252            self.h = (self.filter)(self.h);
253
254            // Calculate Previous Steps with new step size
255            self.t_prev[0] = self.t;
256            self.y_prev[0] = self.y.clone();
257            for i in 1..=3 {
258                self.evals += Self::rk4_step(ode, &mut self.t, &mut self.y, self.h, &mut self.k);
259                self.t_prev[i] = self.t;
260                self.y_prev[i] = self.y.clone();
261
262                if i == 1 {
263                    self.dydt = self.k[0].clone();
264                }
265            }
266        } else {
267            // Step Rejected
268            self.status = Status::RejectedStep;
269
270            // Adjust Step Size
271            let two = T::from_f64(2.0).unwrap();
272            let tenth = T::from_f64(0.1).unwrap();
273            let q = (self.tol / (two * sigma)).powf(T::from_f64(0.25).unwrap());
274            self.h = if q < tenth {
275                tenth * self.h
276            } else {
277                q * self.h
278            };
279            self.h = (self.filter)(self.h);
280
281            // Calculate Previous Steps with new step size
282            self.t_prev[0] = self.t;
283            self.y_prev[0] = self.y.clone();
284            for i in 1..=3 {
285                self.evals += Self::rk4_step(ode, &mut self.t, &mut self.y, self.h, &mut self.k);
286                self.t_prev[i] = self.t;
287                self.y_prev[i] = self.y.clone();
288            }
289        }
290        Ok(evals)
291    }
292
293    fn t(&self) -> T {
294        self.t
295    }
296
297    fn y(&self) -> &Y {
298        &self.y
299    }
300
301    fn t_prev(&self) -> T {
302        self.t_old
303    }
304
305    fn y_prev(&self) -> &Y {
306        &self.y_old
307    }
308
309    fn h(&self) -> T {
310        // OrdinaryNumericalMethod repeats step size 4 times for each step
311        // so the IVP driver is looking for what the next
312        // state will be thus the step size is multiplied by 4
313        self.h * T::from_f64(4.0).unwrap()
314    }
315
316    fn set_h(&mut self, h: T) {
317        self.h = (self.filter)(h);
318    }
319
320    fn status(&self) -> &Status<T, Y> {
321        &self.status
322    }
323
324    fn set_status(&mut self, status: Status<T, Y>) {
325        self.status = status;
326    }
327}
328
329// Implement the Interpolation trait for APCV4
330impl<T: Real, Y: State<T>> Interpolation<T, Y>
331    for AdamsPredictorCorrector<Ordinary, Adaptive, T, Y, 4>
332{
333    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
334        // Check if t is within the range of the solver
335        if t_interp < self.t_old || t_interp > self.t {
336            return Err(Error::OutOfBounds {
337                t_interp,
338                t_prev: self.t_old,
339                t_curr: self.t,
340            });
341        }
342
343        // Calculate the interpolated value using cubic Hermite interpolation
344        let y_interp = cubic_hermite_interpolate(
345            self.t_old,
346            self.t,
347            &self.y_old,
348            &self.y,
349            &self.dydt_old,
350            &self.dydt,
351            t_interp,
352        );
353
354        Ok(y_interp)
355    }
356}