differential_equations/ode/methods/adams/
apcf4.rs

1//! Adams-Predictor-Corrector 4th Order Fixed Step Size Method.
2
3use super::*;
4
5///
6/// Adams-Predictor-Corrector 4th Order Fixed Step Size Method.
7///
8/// The Adams-Predictor-Corrector method is an explicit method that
9/// uses the previous states to predict the next state.
10///
11/// The First 3 steps, of fixed step size `h`, are calculated using
12/// the Runge-Kutta method of order 4(5) and then the Adams-Predictor-Corrector
13/// method is used to calculate the remaining steps tell the final time.
14///
15/// # Example
16///
17/// ```
18/// use differential_equations::prelude::*;
19/// use differential_equations::ode::methods::adams::APCF4;
20/// use nalgebra::{SVector, vector};
21///
22/// struct HarmonicOscillator {
23///     k: f64,
24/// }
25///
26/// impl ODE<f64, SVector<f64, 2>> for HarmonicOscillator {
27///     fn diff(&self, _t: f64, y: &SVector<f64, 2>, dydt: &mut SVector<f64, 2>) {
28///         dydt[0] = y[1];
29///         dydt[1] = -self.k * y[0];
30///     }
31/// }
32/// let mut apcf4 = APCF4::new(0.01);
33/// let t0 = 0.0;
34/// let tf = 10.0;
35/// let y0 = vector![1.0, 0.0];
36/// let system = HarmonicOscillator { k: 1.0 };
37/// let results = ODEProblem::new(system, t0, tf, y0).solve(&mut apcf4).unwrap();
38/// let expected = vector![-0.83907153, 0.54402111];
39/// assert!((results.y.last().unwrap()[0] - expected[0]).abs() < 1e-2);
40/// assert!((results.y.last().unwrap()[1] - expected[1]).abs() < 1e-2);
41/// ```
42///
43/// # Settings
44/// * `h` - Step Size
45///
46pub struct APCF4<T: Real, V: State<T>, D: CallBackData> {
47    // Step Size
48    pub h: T,
49    // Current State
50    t: T,
51    y: V,
52    dydt: V,
53    // Previous State for Cubic Hermite Interpolation
54    t_old: T,
55    y_old: V,
56    dydt_old: V,
57    // Previous States for Predictor-Corrector
58    t_prev: [T; 4],
59    y_prev: [V; 4],
60    // Predictor Correct Derivatives
61    k1: V, // Also the current derivative
62    k2: V,
63    k3: V,
64    k4: V,
65    // Number of evaluations
66    pub evals: usize,
67    // Status
68    status: Status<T, V, D>,
69}
70
71// Implement ODENumericalMethod Trait for APCF4
72impl<T: Real, V: State<T>, D: CallBackData> ODENumericalMethod<T, V, D> for APCF4<T, V, D> {
73    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
74    where
75        F: ODE<T, V, D>,
76    {
77        let mut evals = Evals::new();
78
79        // Check Bounds
80        match validate_step_size_parameters::<T, V, D>(self.h, T::zero(), T::infinity(), t0, tf) {
81            Ok(h) => self.h = h,
82            Err(e) => return Err(e),
83        }
84
85        // Initialize state
86        self.t = t0;
87        self.y = *y0;
88        self.t_prev[0] = t0;
89        self.y_prev[0] = *y0;
90
91        // Old state for interpolation
92        self.t_old = self.t;
93        self.y_old = self.y;
94
95        let two = T::from_f64(2.0).unwrap();
96        let six = T::from_f64(6.0).unwrap();
97        for i in 1..=3 {
98            // Compute k1, k2, k3, k4 of Runge-Kutta 4
99            ode.diff(self.t, &self.y, &mut self.k1);
100            ode.diff(
101                self.t + self.h / two,
102                &(self.y + self.k1 * (self.h / two)),
103                &mut self.k2,
104            );
105            ode.diff(
106                self.t + self.h / two,
107                &(self.y + self.k2 * (self.h / two)),
108                &mut self.k3,
109            );
110            ode.diff(self.t + self.h, &(self.y + self.k3 * self.h), &mut self.k4);
111
112            // Update State
113            self.y += (self.k1 + self.k2 * two + self.k3 * two + self.k4) * (self.h / six);
114            self.t += self.h;
115            self.t_prev[i] = self.t;
116            self.y_prev[i] = self.y;
117            evals.fcn += 4; // 4 evaluations per Runge-Kutta step
118
119            if i == 1 {
120                self.dydt = self.k1;
121                self.dydt_old = self.dydt;
122            }
123        }
124
125        self.status = Status::Initialized;
126        Ok(evals)
127    }
128
129    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
130    where
131        F: ODE<T, V, D>,
132    {
133        let mut evals = Evals::new();
134
135        // state for interpolation
136        self.t_old = self.t;
137        self.y_old = self.y;
138        self.dydt_old = self.dydt;
139
140        // Compute derivatives for history
141        ode.diff(self.t_prev[3], &self.y_prev[3], &mut self.k1);
142        ode.diff(self.t_prev[2], &self.y_prev[2], &mut self.k2);
143        ode.diff(self.t_prev[1], &self.y_prev[1], &mut self.k3);
144        ode.diff(self.t_prev[0], &self.y_prev[0], &mut self.k4);
145
146        let predictor = self.y_prev[3]
147            + (self.k1 * T::from_f64(55.0).unwrap() - self.k2 * T::from_f64(59.0).unwrap()
148                + self.k3 * T::from_f64(37.0).unwrap()
149                - self.k4 * T::from_f64(9.0).unwrap())
150                * self.h
151                / T::from_f64(24.0).unwrap();
152
153        // Corrector step:
154        ode.diff(self.t + self.h, &predictor, &mut self.k4);
155        let corrector = self.y_prev[3]
156            + (self.k4 * T::from_f64(9.0).unwrap() + self.k1 * T::from_f64(19.0).unwrap()
157                - self.k2 * T::from_f64(5.0).unwrap()
158                + self.k3 * T::from_f64(1.0).unwrap())
159                * (self.h / T::from_f64(24.0).unwrap());
160
161        // Update state
162        self.t += self.h;
163        self.y = corrector;
164        ode.diff(self.t, &self.y, &mut self.dydt);
165        evals.fcn += 6; // 6 evaluations for predictor-corrector step
166
167        // Shift history: drop the oldest and add the new state at the end.
168        self.t_prev.copy_within(1..4, 0);
169        self.y_prev.copy_within(1..4, 0);
170        self.t_prev[3] = self.t;
171        self.y_prev[3] = self.y;
172        Ok(evals)
173    }
174
175    fn t(&self) -> T {
176        self.t
177    }
178
179    fn y(&self) -> &V {
180        &self.y
181    }
182
183    fn t_prev(&self) -> T {
184        self.t_old
185    }
186
187    fn y_prev(&self) -> &V {
188        &self.y_old
189    }
190
191    fn h(&self) -> T {
192        self.h
193    }
194
195    fn set_h(&mut self, h: T) {
196        self.h = h;
197    }
198
199    fn status(&self) -> &Status<T, V, D> {
200        &self.status
201    }
202
203    fn set_status(&mut self, status: Status<T, V, D>) {
204        self.status = status;
205    }
206}
207
208impl<T: Real, V: State<T>, D: CallBackData> Interpolation<T, V> for APCF4<T, V, D> {
209    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
210        // Check if t is within bounds
211        if t_interp < self.t_prev[0] || t_interp > self.t {
212            return Err(Error::OutOfBounds {
213                t_interp,
214                t_prev: self.t_prev[0],
215                t_curr: self.t,
216            });
217        }
218
219        // Calculate the interpolation using cubic hermite interpolation
220        let y_interp = cubic_hermite_interpolate(
221            self.t_old,
222            self.t,
223            &self.y_old,
224            &self.y,
225            &self.dydt_old,
226            &self.dydt,
227            t_interp,
228        );
229
230        Ok(y_interp)
231    }
232}
233
234impl<T: Real, V: State<T>, D: CallBackData> APCF4<T, V, D> {
235    pub fn new(h: T) -> Self {
236        APCF4 {
237            h,
238            ..Default::default()
239        }
240    }
241}
242
243impl<T: Real, V: State<T>, D: CallBackData> Default for APCF4<T, V, D> {
244    fn default() -> Self {
245        APCF4 {
246            h: T::zero(),
247            t: T::zero(),
248            y: V::zeros(),
249            dydt: V::zeros(),
250            t_prev: [T::zero(); 4],
251            y_prev: [V::zeros(), V::zeros(), V::zeros(), V::zeros()],
252            t_old: T::zero(),
253            y_old: V::zeros(),
254            dydt_old: V::zeros(),
255            k1: V::zeros(),
256            k2: V::zeros(),
257            k3: V::zeros(),
258            k4: V::zeros(),
259            evals: 0,
260            status: Status::Uninitialized,
261        }
262    }
263}