differential_equations/methods/apc/
apcf4.rs

1//! Adams-Predictor-Corrector 4th Order Fixed Step Size Method.
2
3use super::AdamsPredictorCorrector;
4use crate::{
5    Error, Status,
6    alias::Evals,
7    interpolate::{Interpolation, cubic_hermite_interpolate},
8    ode::{OrdinaryNumericalMethod, ODE},
9    traits::{CallBackData, Real, State},
10    utils::{validate_step_size_parameters},
11    methods::{Ordinary, Fixed},
12};
13
14impl<T: Real, V: State<T>, D: CallBackData> AdamsPredictorCorrector<Ordinary, Fixed, T, V, D, 4> {
15    /// Adams-Predictor-Corrector 4th Order Fixed Step Size Method.
16    ///
17    /// The Adams-Predictor-Corrector method is an explicit method that
18    /// uses the previous states to predict the next state.
19    ///
20    /// The First 3 steps, of fixed step size `h`, are calculated using
21    /// the Runge-Kutta method of order 4(5) and then the Adams-Predictor-Corrector
22    /// method is used to calculate the remaining steps until the final time.
23    ///
24    /// # Example
25    ///
26    /// ```
27    /// use differential_equations::prelude::*;
28    /// use nalgebra::{SVector, vector};
29    ///
30    /// struct HarmonicOscillator {
31    ///     k: f64,
32    /// }
33    ///
34    /// impl ODE<f64, SVector<f64, 2>> for HarmonicOscillator {
35    ///     fn diff(&self, _t: f64, y: &SVector<f64, 2>, dydt: &mut SVector<f64, 2>) {
36    ///         dydt[0] = y[1];
37    ///         dydt[1] = -self.k * y[0];
38    ///     }
39    /// }
40    /// let mut apcf4 = AdamsPredictorCorrector::f4(0.01);
41    /// let t0 = 0.0;
42    /// let tf = 10.0;
43    /// let y0 = vector![1.0, 0.0];
44    /// let system = HarmonicOscillator { k: 1.0 };
45    /// let results = ODEProblem::new(system, t0, tf, y0).solve(&mut apcf4).unwrap();
46    /// let expected = vector![-0.83907153, 0.54402111];
47    /// assert!((results.y.last().unwrap()[0] - expected[0]).abs() < 1e-2);
48    /// assert!((results.y.last().unwrap()[1] - expected[1]).abs() < 1e-2);
49    /// ```
50    ///
51    /// # Settings
52    /// * `h` - Step Size
53    ///
54    pub fn f4(h: T) -> Self {
55        Self {
56            h,
57            ..Default::default()
58        }
59    }
60}
61
62// Implement OrdinaryNumericalMethod Trait for APCF4
63impl<T: Real, V: State<T>, D: CallBackData> OrdinaryNumericalMethod<T, V, D> for AdamsPredictorCorrector<Ordinary, Fixed, T, V, D, 4> {
64    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
65    where
66        F: ODE<T, V, D>,
67    {
68        let mut evals = Evals::new();
69
70        // Check Bounds
71        match validate_step_size_parameters::<T, V, D>(self.h, T::zero(), T::infinity(), t0, tf) {
72            Ok(h) => self.h = h,
73            Err(e) => return Err(e),
74        }
75
76        // Initialize state
77        self.t = t0;
78        self.y = *y0;
79        self.t_prev[0] = t0;
80        self.y_prev[0] = *y0;
81
82        // Old state for interpolation
83        self.t_old = self.t;
84        self.y_old = self.y;
85
86        let two = T::from_f64(2.0).unwrap();
87        let six = T::from_f64(6.0).unwrap();
88        for i in 1..=3 {
89            // Compute k1, k2, k3, k4 of Runge-Kutta 4
90            ode.diff(self.t, &self.y, &mut self.k[0]);
91            ode.diff(
92                self.t + self.h / two,
93                &(self.y + self.k[0] * (self.h / two)),
94                &mut self.k[1],
95            );
96            ode.diff(
97                self.t + self.h / two,
98                &(self.y + self.k[1] * (self.h / two)),
99                &mut self.k[2],
100            );
101            ode.diff(self.t + self.h, &(self.y + self.k[2] * self.h), &mut self.k[3]);
102
103            // Update State
104            self.y += (self.k[0] + self.k[1] * two + self.k[2] * two + self.k[3]) * (self.h / six);
105            self.t += self.h;
106            self.t_prev[i] = self.t;
107            self.y_prev[i] = self.y;
108            evals.fcn += 4; // 4 evaluations per Runge-Kutta step
109
110            if i == 1 {
111                self.dydt = self.k[0];
112                self.dydt_old = self.dydt;
113            }
114        }
115
116        self.status = Status::Initialized;
117        Ok(evals)
118    }
119
120    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
121    where
122        F: ODE<T, V, D>,
123    {
124        let mut evals = Evals::new();
125
126        // state for interpolation
127        self.t_old = self.t;
128        self.y_old = self.y;
129        self.dydt_old = self.dydt;
130
131        // Compute derivatives for history
132        ode.diff(self.t_prev[3], &self.y_prev[3], &mut self.k[0]);
133        ode.diff(self.t_prev[2], &self.y_prev[2], &mut self.k[1]);
134        ode.diff(self.t_prev[1], &self.y_prev[1], &mut self.k[2]);
135        ode.diff(self.t_prev[0], &self.y_prev[0], &mut self.k[3]);
136
137        let predictor = self.y_prev[3]
138            + (self.k[0] * T::from_f64(55.0).unwrap() - self.k[1] * T::from_f64(59.0).unwrap()
139                + self.k[2] * T::from_f64(37.0).unwrap()
140                - self.k[3] * T::from_f64(9.0).unwrap())
141                * self.h
142                / T::from_f64(24.0).unwrap();
143
144        // Corrector step:
145        ode.diff(self.t + self.h, &predictor, &mut self.k[3]);
146        let corrector = self.y_prev[3]
147            + (self.k[3] * T::from_f64(9.0).unwrap() + self.k[0] * T::from_f64(19.0).unwrap()
148                - self.k[1] * T::from_f64(5.0).unwrap()
149                + self.k[2] * T::from_f64(1.0).unwrap())
150                * (self.h / T::from_f64(24.0).unwrap());
151
152        // Update state
153        self.t += self.h;
154        self.y = corrector;
155        ode.diff(self.t, &self.y, &mut self.dydt);
156        evals.fcn += 6; // 6 evaluations for predictor-corrector step
157
158        // Shift history: drop the oldest and add the new state at the end.
159        self.t_prev.copy_within(1..4, 0);
160        self.y_prev.copy_within(1..4, 0);
161        self.t_prev[3] = self.t;
162        self.y_prev[3] = self.y;
163        Ok(evals)
164    }
165
166    fn t(&self) -> T {
167        self.t
168    }
169
170    fn y(&self) -> &V {
171        &self.y
172    }
173
174    fn t_prev(&self) -> T {
175        self.t_old
176    }
177
178    fn y_prev(&self) -> &V {
179        &self.y_old
180    }
181
182    fn h(&self) -> T {
183        self.h
184    }
185
186    fn set_h(&mut self, h: T) {
187        self.h = h;
188    }
189
190    fn status(&self) -> &Status<T, V, D> {
191        &self.status
192    }
193
194    fn set_status(&mut self, status: Status<T, V, D>) {
195        self.status = status;
196    }
197}
198
199impl<T: Real, V: State<T>, D: CallBackData> Interpolation<T, V> for AdamsPredictorCorrector<Ordinary, Fixed, T, V, D, 4> {
200    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
201        // Check if t is within bounds
202        if t_interp < self.t_prev[0] || t_interp > self.t {
203            return Err(Error::OutOfBounds {
204                t_interp,
205                t_prev: self.t_prev[0],
206                t_curr: self.t,
207            });
208        }
209
210        // Calculate the interpolation using cubic hermite interpolation
211        let y_interp = cubic_hermite_interpolate(
212            self.t_old,
213            self.t,
214            &self.y_old,
215            &self.y,
216            &self.dydt_old,
217            &self.dydt,
218            t_interp,
219        );
220
221        Ok(y_interp)
222    }
223}