differential_equations/methods/apc/
apcf4.rs

1//! Adams-Predictor-Corrector 4th Order Fixed Step Size Method.
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, cubic_hermite_interpolate},
6    methods::{Fixed, Ordinary},
7    ode::{ODE, OrdinaryNumericalMethod},
8    stats::Evals,
9    status::Status,
10    traits::{Real, State},
11    utils::validate_step_size_parameters,
12};
13
14use super::AdamsPredictorCorrector;
15
16impl<T: Real, Y: State<T>> AdamsPredictorCorrector<Ordinary, Fixed, T, Y, 4> {
17    /// Adams-Predictor-Corrector 4th Order Fixed Step Size Method.
18    ///
19    /// The Adams-Predictor-Corrector method is an explicit method that
20    /// uses the previous states to predict the next state.
21    ///
22    /// The First 3 steps, of fixed step size `h`, are calculated using
23    /// the Runge-Kutta method of order 4(5) and then the Adams-Predictor-Corrector
24    /// method is used to calculate the remaining steps until the final time.
25    ///
26    /// # Example
27    ///
28    /// ```
29    /// use differential_equations::prelude::*;
30    /// use nalgebra::{SVector, vector};
31    ///
32    /// struct HarmonicOscillator {
33    ///     k: f64,
34    /// }
35    ///
36    /// impl ODE<f64, SVector<f64, 2>> for HarmonicOscillator {
37    ///     fn diff(&self, _t: f64, y: &SVector<f64, 2>, dydt: &mut SVector<f64, 2>) {
38    ///         dydt[0] = y[1];
39    ///         dydt[1] = -self.k * y[0];
40    ///     }
41    /// }
42    /// let mut apcf4 = AdamsPredictorCorrector::f4(0.01);
43    /// let t0 = 0.0;
44    /// let tf = 10.0;
45    /// let y0 = vector![1.0, 0.0];
46    /// let system = HarmonicOscillator { k: 1.0 };
47    /// let results = ODEProblem::new(system, t0, tf, y0).solve(&mut apcf4).unwrap();
48    /// let expected = vector![-0.83907153, 0.54402111];
49    /// assert!((results.y.last().unwrap()[0] - expected[0]).abs() < 1e-2);
50    /// assert!((results.y.last().unwrap()[1] - expected[1]).abs() < 1e-2);
51    /// ```
52    ///
53    /// # Settings
54    /// * `h` - Step Size
55    ///
56    pub fn f4(h: T) -> Self {
57        Self {
58            h,
59            ..Default::default()
60        }
61    }
62}
63
64// Implement OrdinaryNumericalMethod Trait for APCF4
65impl<T: Real, Y: State<T>> OrdinaryNumericalMethod<T, Y>
66    for AdamsPredictorCorrector<Ordinary, Fixed, T, Y, 4>
67{
68    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
69    where
70        F: ODE<T, Y>,
71    {
72        let mut evals = Evals::new();
73
74        // Check Bounds
75        match validate_step_size_parameters::<T, Y>(self.h, T::zero(), T::infinity(), t0, tf) {
76            Ok(h) => self.h = h,
77            Err(e) => return Err(e),
78        }
79
80        // Initialize state
81        self.t = t0;
82        self.y = *y0;
83        self.t_prev[0] = t0;
84        self.y_prev[0] = *y0;
85
86        // Old state for interpolation
87        self.t_old = self.t;
88        self.y_old = self.y;
89
90        let two = T::from_f64(2.0).unwrap();
91        let six = T::from_f64(6.0).unwrap();
92        for i in 1..=3 {
93            // Compute k1, k2, k3, k4 of Runge-Kutta 4
94            ode.diff(self.t, &self.y, &mut self.k[0]);
95            ode.diff(
96                self.t + self.h / two,
97                &(self.y + self.k[0] * (self.h / two)),
98                &mut self.k[1],
99            );
100            ode.diff(
101                self.t + self.h / two,
102                &(self.y + self.k[1] * (self.h / two)),
103                &mut self.k[2],
104            );
105            ode.diff(
106                self.t + self.h,
107                &(self.y + self.k[2] * self.h),
108                &mut self.k[3],
109            );
110
111            // Update State
112            self.y += (self.k[0] + self.k[1] * two + self.k[2] * two + self.k[3]) * (self.h / six);
113            self.t += self.h;
114            self.t_prev[i] = self.t;
115            self.y_prev[i] = self.y;
116            evals.function += 4; // 4 evaluations per Runge-Kutta step
117
118            if i == 1 {
119                self.dydt = self.k[0];
120                self.dydt_old = self.dydt;
121            }
122        }
123
124        self.status = Status::Initialized;
125        Ok(evals)
126    }
127
128    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
129    where
130        F: ODE<T, Y>,
131    {
132        let mut evals = Evals::new();
133
134        // state for interpolation
135        self.t_old = self.t;
136        self.y_old = self.y;
137        self.dydt_old = self.dydt;
138
139        // Compute derivatives for history
140        ode.diff(self.t_prev[3], &self.y_prev[3], &mut self.k[0]);
141        ode.diff(self.t_prev[2], &self.y_prev[2], &mut self.k[1]);
142        ode.diff(self.t_prev[1], &self.y_prev[1], &mut self.k[2]);
143        ode.diff(self.t_prev[0], &self.y_prev[0], &mut self.k[3]);
144
145        let predictor = self.y_prev[3]
146            + (self.k[0] * T::from_f64(55.0).unwrap() - self.k[1] * T::from_f64(59.0).unwrap()
147                + self.k[2] * T::from_f64(37.0).unwrap()
148                - self.k[3] * T::from_f64(9.0).unwrap())
149                * self.h
150                / T::from_f64(24.0).unwrap();
151
152        // Corrector step:
153        ode.diff(self.t + self.h, &predictor, &mut self.k[3]);
154        let corrector = self.y_prev[3]
155            + (self.k[3] * T::from_f64(9.0).unwrap() + self.k[0] * T::from_f64(19.0).unwrap()
156                - self.k[1] * T::from_f64(5.0).unwrap()
157                + self.k[2] * T::from_f64(1.0).unwrap())
158                * (self.h / T::from_f64(24.0).unwrap());
159
160        // Update state
161        self.t += self.h;
162        self.y = corrector;
163        ode.diff(self.t, &self.y, &mut self.dydt);
164        evals.function += 6; // 6 evaluations for predictor-corrector step
165
166        // Shift history: drop the oldest and add the new state at the end.
167        self.t_prev.copy_within(1..4, 0);
168        self.y_prev.copy_within(1..4, 0);
169        self.t_prev[3] = self.t;
170        self.y_prev[3] = self.y;
171        Ok(evals)
172    }
173
174    fn t(&self) -> T {
175        self.t
176    }
177
178    fn y(&self) -> &Y {
179        &self.y
180    }
181
182    fn t_prev(&self) -> T {
183        self.t_old
184    }
185
186    fn y_prev(&self) -> &Y {
187        &self.y_old
188    }
189
190    fn h(&self) -> T {
191        self.h
192    }
193
194    fn set_h(&mut self, h: T) {
195        self.h = h;
196    }
197
198    fn status(&self) -> &Status<T, Y> {
199        &self.status
200    }
201
202    fn set_status(&mut self, status: Status<T, Y>) {
203        self.status = status;
204    }
205}
206
207impl<T: Real, Y: State<T>> Interpolation<T, Y>
208    for AdamsPredictorCorrector<Ordinary, Fixed, T, Y, 4>
209{
210    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
211        // Check if t is within bounds
212        if t_interp < self.t_prev[0] || t_interp > self.t {
213            return Err(Error::OutOfBounds {
214                t_interp,
215                t_prev: self.t_prev[0],
216                t_curr: self.t,
217            });
218        }
219
220        // Calculate the interpolation using cubic hermite interpolation
221        let y_interp = cubic_hermite_interpolate(
222            self.t_old,
223            self.t,
224            &self.y_old,
225            &self.y,
226            &self.dydt_old,
227            &self.dydt,
228            t_interp,
229        );
230
231        Ok(y_interp)
232    }
233}