Skip to main content

differential_equations/methods/apc/
apcf4.rs

1//! Adams-Predictor-Corrector 4th Order Fixed Step Size Method.
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, cubic_hermite_interpolate},
6    methods::{Fixed, Ordinary},
7    ode::{ODE, OrdinaryNumericalMethod},
8    stats::Evals,
9    status::Status,
10    traits::{Real, State},
11    utils::validate_step_size_parameters,
12};
13
14use super::AdamsPredictorCorrector;
15
16impl<T: Real, Y: State<T>> AdamsPredictorCorrector<Ordinary, Fixed, T, Y, 4> {
17    /// Adams-Predictor-Corrector 4th Order Fixed Step Size Method.
18    ///
19    /// The Adams-Predictor-Corrector method is an explicit method that
20    /// uses the previous states to predict the next state.
21    ///
22    /// The First 3 steps, of fixed step size `h`, are calculated using
23    /// the Runge-Kutta method of order 4(5) and then the Adams-Predictor-Corrector
24    /// method is used to calculate the remaining steps until the final time.
25    ///
26    /// # Example
27    ///
28    /// ```rust
29    /// use differential_equations::prelude::*;
30    ///
31    /// struct HarmonicOscillator {
32    ///     k: f64,
33    /// }
34    ///
35    /// impl ODE<f64, [f64; 2]> for HarmonicOscillator {
36    ///     fn diff(&self, _t: f64, y: &[f64; 2], dydt: &mut [f64; 2]) {
37    ///         dydt[0] = y[1];
38    ///         dydt[1] = -self.k * y[0];
39    ///     }
40    /// }
41    /// let apcf4 = AdamsPredictorCorrector::f4(0.01);
42    /// let t0 = 0.0;
43    /// let tf = 10.0;
44    /// let y0 = [1.0, 0.0];
45    /// let system = HarmonicOscillator { k: 1.0 };
46    /// let results = IVP::ode(&system, t0, tf, y0).method(apcf4).solve().unwrap();
47    /// let expected = [-0.83907153, 0.54402111];
48    /// assert!((results.y.last().unwrap()[0] - expected[0]).abs() < 1e-2);
49    /// assert!((results.y.last().unwrap()[1] - expected[1]).abs() < 1e-2);
50    /// ```
51    ///
52    /// # Settings
53    /// * `h` - Step Size
54    ///
55    pub fn f4(h: T) -> Self {
56        Self {
57            h,
58            ..Default::default()
59        }
60    }
61}
62
63// Implement OrdinaryNumericalMethod Trait for APCF4
64impl<T: Real, Y: State<T>> OrdinaryNumericalMethod<T, Y>
65    for AdamsPredictorCorrector<Ordinary, Fixed, T, Y, 4>
66{
67    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
68    where
69        F: ODE<T, Y> + ?Sized,
70    {
71        let mut evals = Evals::new();
72
73        // Check Bounds
74        match validate_step_size_parameters::<T, Y>(self.h, T::zero(), T::infinity(), t0, tf) {
75            Ok(h) => self.h = h,
76            Err(e) => return Err(e),
77        }
78
79        // Initialize state
80        self.t = t0;
81        self.y = y0.clone();
82        self.dydt = y0.zeros_like();
83        self.dydt_old = y0.zeros_like();
84        self.y_old = y0.clone();
85        self.y_prev = core::array::from_fn(|_| y0.zeros_like());
86        self.k = core::array::from_fn(|_| y0.zeros_like());
87        self.t_prev[0] = t0;
88        self.y_prev[0] = y0.clone();
89
90        // Old state for interpolation
91        self.t_old = self.t;
92        self.y_old = self.y.clone();
93
94        let two = T::from_f64(2.0).unwrap();
95        let six = T::from_f64(6.0).unwrap();
96        for i in 1..=3 {
97            // Compute k1, k2, k3, k4 of Runge-Kutta 4
98            ode.diff(self.t, &self.y, &mut self.k[0]);
99            ode.diff(
100                self.t + self.h / two,
101                &self.y.plus_scaled(self.h / two, &self.k[0]),
102                &mut self.k[1],
103            );
104            ode.diff(
105                self.t + self.h / two,
106                &self.y.plus_scaled(self.h / two, &self.k[1]),
107                &mut self.k[2],
108            );
109            ode.diff(
110                self.t + self.h,
111                &self.y.plus_scaled(self.h, &self.k[2]),
112                &mut self.k[3],
113            );
114
115            // Update State
116            self.y.add_scaled(self.h / six, &self.k[0]);
117            self.y.add_scaled(two * self.h / six, &self.k[1]);
118            self.y.add_scaled(two * self.h / six, &self.k[2]);
119            self.y.add_scaled(self.h / six, &self.k[3]);
120            self.t += self.h;
121            self.t_prev[i] = self.t;
122            self.y_prev[i] = self.y.clone();
123            evals.function += 4; // 4 evaluations per Runge-Kutta step
124
125            if i == 1 {
126                self.dydt = self.k[0].clone();
127                self.dydt_old = self.dydt.clone();
128            }
129        }
130
131        self.status = Status::Initialized;
132        Ok(evals)
133    }
134
135    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
136    where
137        F: ODE<T, Y> + ?Sized,
138    {
139        let mut evals = Evals::new();
140
141        // state for interpolation
142        self.t_old = self.t;
143        self.y_old = self.y.clone();
144        self.dydt_old = self.dydt.clone();
145
146        // Compute derivatives for history
147        ode.diff(self.t_prev[3], &self.y_prev[3], &mut self.k[0]);
148        ode.diff(self.t_prev[2], &self.y_prev[2], &mut self.k[1]);
149        ode.diff(self.t_prev[1], &self.y_prev[1], &mut self.k[2]);
150        ode.diff(self.t_prev[0], &self.y_prev[0], &mut self.k[3]);
151
152        let predictor = self.y_prev[3].plus_linear_combination(&[
153            (
154                &self.k[0],
155                self.h * T::from_f64(55.0).unwrap() / T::from_f64(24.0).unwrap(),
156            ),
157            (
158                &self.k[1],
159                -self.h * T::from_f64(59.0).unwrap() / T::from_f64(24.0).unwrap(),
160            ),
161            (
162                &self.k[2],
163                self.h * T::from_f64(37.0).unwrap() / T::from_f64(24.0).unwrap(),
164            ),
165            (
166                &self.k[3],
167                -self.h * T::from_f64(9.0).unwrap() / T::from_f64(24.0).unwrap(),
168            ),
169        ]);
170
171        // Corrector step:
172        ode.diff(self.t + self.h, &predictor, &mut self.k[3]);
173        let corrector = self.y_prev[3].plus_linear_combination(&[
174            (
175                &self.k[3],
176                self.h * T::from_f64(9.0).unwrap() / T::from_f64(24.0).unwrap(),
177            ),
178            (
179                &self.k[0],
180                self.h * T::from_f64(19.0).unwrap() / T::from_f64(24.0).unwrap(),
181            ),
182            (
183                &self.k[1],
184                -self.h * T::from_f64(5.0).unwrap() / T::from_f64(24.0).unwrap(),
185            ),
186            (&self.k[2], self.h / T::from_f64(24.0).unwrap()),
187        ]);
188
189        // Update state
190        self.t += self.h;
191        self.y = corrector;
192        ode.diff(self.t, &self.y, &mut self.dydt);
193        evals.function += 6; // 6 evaluations for predictor-corrector step
194
195        // Shift history: drop the oldest and add the new state at the end.
196        self.t_prev.copy_within(1..4, 0);
197        self.y_prev.rotate_left(1);
198        self.t_prev[3] = self.t;
199        self.y_prev[3] = self.y.clone();
200        Ok(evals)
201    }
202
203    fn t(&self) -> T {
204        self.t
205    }
206
207    fn y(&self) -> &Y {
208        &self.y
209    }
210
211    fn t_prev(&self) -> T {
212        self.t_old
213    }
214
215    fn y_prev(&self) -> &Y {
216        &self.y_old
217    }
218
219    fn h(&self) -> T {
220        self.h
221    }
222
223    fn set_h(&mut self, h: T) {
224        self.h = h;
225    }
226
227    fn status(&self) -> &Status<T, Y> {
228        &self.status
229    }
230
231    fn set_status(&mut self, status: Status<T, Y>) {
232        self.status = status;
233    }
234}
235
236impl<T: Real, Y: State<T>> Interpolation<T, Y>
237    for AdamsPredictorCorrector<Ordinary, Fixed, T, Y, 4>
238{
239    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
240        // Check if t is within bounds
241        if t_interp < self.t_prev[0] || t_interp > self.t {
242            return Err(Error::OutOfBounds {
243                t_interp,
244                t_prev: self.t_prev[0],
245                t_curr: self.t,
246            });
247        }
248
249        // Calculate the interpolation using cubic hermite interpolation
250        let y_interp = cubic_hermite_interpolate(
251            self.t_old,
252            self.t,
253            &self.y_old,
254            &self.y,
255            &self.dydt_old,
256            &self.dydt,
257            t_interp,
258        );
259
260        Ok(y_interp)
261    }
262}