differential_equations/ode/methods/adams/
apcv4.rs

1//! Adams-Predictor-Corrector 4th Order Variable Step Size Method
2
3use super::*;
4use crate::{linalg::norm, ode::methods::h_init};
5
6///
7/// Adams-Predictor-Corrector 4th Order Variable Step Size Method.
8///
9/// The Adams-Predictor-Corrector method is an explicit method that
10/// uses the previous states to predict the next state. This implementation
11/// uses a variable step size to maintain a desired accuracy.
12/// It is recommended to start with a small step size so that tolerance
13/// can be quickly met and the algorithm can adjust the step size accordingly.
14///
15/// The First 3 steps are calculated using
16/// the Runge-Kutta method of order 4(5) and then the Adams-Predictor-Corrector
17/// method is used to calculate the remaining steps until the final time.
18///
19/// # Example
20///
21/// ```
22/// use differential_equations::prelude::*;
23/// use differential_equations::ode::methods::adams::APCV4;
24/// use nalgebra::{SVector, vector};
25///
26/// struct HarmonicOscillator {
27///     k: f64,
28/// }
29///
30/// impl ODE<f64, SVector<f64, 2>> for HarmonicOscillator {
31///     fn diff(&self, _t: f64, y: &SVector<f64, 2>, dydt: &mut SVector<f64, 2>) {
32///         dydt[0] = y[1];
33///         dydt[1] = -self.k * y[0];
34///     }
35/// }
36/// let mut apcv4 = APCV4::new();
37/// let t0 = 0.0;
38/// let tf = 10.0;
39/// let y0 = vector![1.0, 0.0];
40/// let system = HarmonicOscillator { k: 1.0 };
41/// let results = ODEProblem::new(system, t0, tf, y0).solve(&mut apcv4).unwrap();
42/// let expected = vector![-0.83907153, 0.54402111];
43/// assert!((results.y.last().unwrap()[0] - expected[0]).abs() < 1e-6);
44/// assert!((results.y.last().unwrap()[1] - expected[1]).abs() < 1e-6);
45/// ```
46///
47/// # Settings
48/// * `tol` - Tolerance
49/// * `h_max` - Maximum step size
50/// * `max_steps` - Maximum number of steps
51///
52/// ## Warning
53///
54/// This method is not suitable for stiff problems and can results in
55/// extremely small step sizes and long computation times.
56///
57pub struct APCV4<T: Real, V: State<T>, D: CallBackData> {
58    // Initial Step Size
59    pub h0: T,
60
61    // Current Step Size
62    h: T,
63
64    // Current State
65    t: T,
66    y: V,
67    dydt: V,
68
69    // Final Time
70    tf: T,
71
72    // Previous States
73    t_prev: [T; 4],
74    y_prev: [V; 4],
75
76    // Previous step states
77    t_old: T,
78    y_old: V,
79    dydt_old: V,
80
81    // Predictor Correct Derivatives
82    k1: V,
83    k2: V,
84    k3: V,
85    k4: V,
86
87    // Statistic Tracking
88    evals: usize,
89    steps: usize,
90
91    // Status
92    status: Status<T, V, D>,
93
94    // Settings
95    pub tol: T,
96    pub h_max: T,
97    pub h_min: T,
98    pub max_steps: usize,
99}
100
101// Implement ODENumericalMethod Trait for APCV4
102impl<T: Real, V: State<T>, D: CallBackData> ODENumericalMethod<T, V, D> for APCV4<T, V, D> {
103    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
104    where
105        F: ODE<T, V, D>,
106    {
107        let mut evals = Evals::new();
108
109        self.tf = tf;
110
111        // Initialize initial step size if it is zero
112        if self.h0 == T::zero() {
113            self.h0 = h_init(
114                ode, t0, tf, y0, 4, self.tol, self.tol, self.h_min, self.h_max,
115            );
116        }
117
118        // Check that the initial step size is set
119        match validate_step_size_parameters::<T, V, D>(self.h0, T::zero(), T::infinity(), t0, tf) {
120            Ok(h0) => self.h = h0,
121            Err(status) => return Err(status),
122        }
123
124        // Initialize state
125        self.t = t0;
126        self.y = *y0;
127        self.t_prev[0] = t0;
128        self.y_prev[0] = *y0;
129
130        // Previous saved steps
131        self.t_old = t0;
132        self.y_old = *y0;
133
134        // Perform the first 3 steps using Runge-Kutta 4 method
135        let two = T::from_f64(2.0).unwrap();
136        let six = T::from_f64(6.0).unwrap();
137        for i in 1..=3 {
138            // Compute k1, k2, k3, k4 of Runge-Kutta 4
139            ode.diff(self.t, &self.y, &mut self.k1);
140            ode.diff(
141                self.t + self.h / two,
142                &(self.y + self.k1 * (self.h / two)),
143                &mut self.k2,
144            );
145            ode.diff(
146                self.t + self.h / two,
147                &(self.y + self.k2 * (self.h / two)),
148                &mut self.k3,
149            );
150            ode.diff(self.t + self.h, &(self.y + self.k3 * self.h), &mut self.k4);
151
152            // Update State
153            self.y += (self.k1 + self.k2 * two + self.k3 * two + self.k4) * (self.h / six);
154            self.t += self.h;
155            self.t_prev[i] = self.t;
156            self.y_prev[i] = self.y;
157            evals.fcn += 4; // 4 evaluations per Runge-Kutta step
158
159            if i == 1 {
160                self.dydt = self.k1;
161                self.dydt_old = self.k1;
162            }
163        }
164
165        self.status = Status::Initialized;
166        Ok(evals)
167    }
168
169    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
170    where
171        F: ODE<T, V, D>,
172    {
173        let mut evals = Evals::new();
174
175        // Check if Max Steps Reached
176        if self.steps >= self.max_steps {
177            self.status = Status::Error(Error::MaxSteps {
178                t: self.t,
179                y: self.y,
180            });
181            return Err(Error::MaxSteps {
182                t: self.t,
183                y: self.y,
184            });
185        }
186        self.steps += 1;
187
188        // If Step size changed and it takes us to the final time perform a Runge-Kutta 4 step to finish
189        if self.h != self.t_prev[0] - self.t_prev[1] && self.t + self.h == self.tf {
190            let two = T::from_f64(2.0).unwrap();
191            let six = T::from_f64(6.0).unwrap();
192
193            // Perform a Runge-Kutta 4 step to finish.
194            ode.diff(self.t, &self.y, &mut self.k1);
195            ode.diff(
196                self.t + self.h / two,
197                &(self.y + self.k1 * (self.h / two)),
198                &mut self.k2,
199            );
200            ode.diff(
201                self.t + self.h / two,
202                &(self.y + self.k2 * (self.h / two)),
203                &mut self.k3,
204            );
205            ode.diff(self.t + self.h, &(self.y + self.k3 * self.h), &mut self.k4);
206            evals.fcn += 4; // 4 evaluations per Runge-Kutta step
207
208            // Update State
209            self.y += (self.k1 + self.k2 * two + self.k3 * two + self.k4) * (self.h / six);
210            self.t += self.h;
211            return Ok(evals);
212        }
213
214        // Compute derivatives for history
215        ode.diff(self.t_prev[3], &self.y_prev[3], &mut self.k1);
216        ode.diff(self.t_prev[2], &self.y_prev[2], &mut self.k2);
217        ode.diff(self.t_prev[1], &self.y_prev[1], &mut self.k3);
218        ode.diff(self.t_prev[0], &self.y_prev[0], &mut self.k4);
219
220        let predictor = self.y_prev[3]
221            + (self.k1 * T::from_f64(55.0).unwrap() - self.k2 * T::from_f64(59.0).unwrap()
222                + self.k3 * T::from_f64(37.0).unwrap()
223                - self.k4 * T::from_f64(9.0).unwrap())
224                * self.h
225                / T::from_f64(24.0).unwrap();
226
227        // Corrector step:
228        ode.diff(self.t + self.h, &predictor, &mut self.k4);
229        let corrector = self.y_prev[3]
230            + (self.k4 * T::from_f64(9.0).unwrap() + self.k1 * T::from_f64(19.0).unwrap()
231                - self.k2 * T::from_f64(5.0).unwrap()
232                + self.k3 * T::from_f64(1.0).unwrap())
233                * self.h
234                / T::from_f64(24.0).unwrap();
235
236        // Track number of evaluations
237        evals.fcn += 5;
238
239        // Calculate sigma for step size adjustment
240        let sigma = T::from_f64(19.0).unwrap() * norm(corrector - predictor)
241            / (T::from_f64(270.0).unwrap() * self.h.abs());
242
243        // Check if Step meets tolerance
244        if sigma <= self.tol {
245            // Update Previous step states
246            self.t_old = self.t;
247            self.y_old = self.y;
248            self.dydt_old = self.dydt;
249
250            // Update state
251            self.t += self.h;
252            self.y = corrector;
253
254            // Check if previous step rejected
255            if let Status::RejectedStep = self.status {
256                self.status = Status::Solving;
257            }
258
259            // Adjust Step Size if needed
260            let two = T::from_f64(2.0).unwrap();
261            let four = T::from_f64(4.0).unwrap();
262            let q = (self.tol / (two * sigma)).powf(T::from_f64(0.25).unwrap());
263            self.h = if q > four { four * self.h } else { q * self.h };
264
265            // Bound Step Size
266            let tf_t_abs = (self.tf - self.t).abs();
267            let four_div = tf_t_abs / four;
268            let h_max_effective = if self.h_max < four_div {
269                self.h_max
270            } else {
271                four_div
272            };
273
274            self.h = constrain_step_size(self.h, self.h_min, h_max_effective);
275
276            // Calculate Previous Steps with new step size
277            self.t_prev[0] = self.t;
278            self.y_prev[0] = self.y;
279            let two = T::from_f64(2.0).unwrap();
280            let six = T::from_f64(6.0).unwrap();
281            for i in 1..=3 {
282                // Compute k1, k2, k3, k4 of Runge-Kutta 4
283                ode.diff(self.t, &self.y, &mut self.k1);
284                ode.diff(
285                    self.t + self.h / two,
286                    &(self.y + self.k1 * (self.h / two)),
287                    &mut self.k2,
288                );
289                ode.diff(
290                    self.t + self.h / two,
291                    &(self.y + self.k2 * (self.h / two)),
292                    &mut self.k3,
293                );
294                ode.diff(self.t + self.h, &(self.y + self.k3 * self.h), &mut self.k4);
295
296                // Update State
297                self.y += (self.k1 + self.k2 * two + self.k3 * two + self.k4) * (self.h / six);
298                self.t += self.h;
299                self.t_prev[i] = self.t;
300                self.y_prev[i] = self.y;
301                self.evals += 4; // 4 evaluations per Runge-Kutta step
302
303                if i == 1 {
304                    self.dydt = self.k1;
305                }
306            }
307        } else {
308            // Step Rejected
309            self.status = Status::RejectedStep;
310
311            // Adjust Step Size
312            let two = T::from_f64(2.0).unwrap();
313            let tenth = T::from_f64(0.1).unwrap();
314            let q = (self.tol / (two * sigma)).powf(T::from_f64(0.25).unwrap());
315            self.h = if q < tenth {
316                tenth * self.h
317            } else {
318                q * self.h
319            };
320
321            // Calculate Previous Steps with new step size
322            self.t_prev[0] = self.t;
323            self.y_prev[0] = self.y;
324            let two = T::from_f64(2.0).unwrap();
325            let six = T::from_f64(6.0).unwrap();
326            for i in 1..=3 {
327                // Compute k1, k2, k3, k4 of Runge-Kutta 4
328                ode.diff(self.t, &self.y, &mut self.k1);
329                ode.diff(
330                    self.t + self.h / two,
331                    &(self.y + self.k1 * (self.h / two)),
332                    &mut self.k2,
333                );
334                ode.diff(
335                    self.t + self.h / two,
336                    &(self.y + self.k2 * (self.h / two)),
337                    &mut self.k3,
338                );
339                ode.diff(self.t + self.h, &(self.y + self.k3 * self.h), &mut self.k4);
340
341                // Update State
342                self.y += (self.k1 + self.k2 * two + self.k3 * two + self.k4) * (self.h / six);
343                self.t += self.h;
344                self.t_prev[i] = self.t;
345                self.y_prev[i] = self.y;
346                self.evals += 4; // 4 evaluations per Runge-Kutta step
347            }
348        }
349        Ok(evals)
350    }
351
352    fn t(&self) -> T {
353        self.t
354    }
355
356    fn y(&self) -> &V {
357        &self.y
358    }
359
360    fn t_prev(&self) -> T {
361        self.t_old
362    }
363
364    fn y_prev(&self) -> &V {
365        &self.y_old
366    }
367
368    fn h(&self) -> T {
369        // ODENumericalMethod repeats step size 4 times for each step
370        // so the ODEProblem inquiring is looking for what the next
371        // state will be thus the step size is multiplied by 4
372        self.h * T::from_f64(4.0).unwrap()
373    }
374
375    fn set_h(&mut self, h: T) {
376        self.h = h;
377    }
378
379    fn status(&self) -> &Status<T, V, D> {
380        &self.status
381    }
382
383    fn set_status(&mut self, status: Status<T, V, D>) {
384        self.status = status;
385    }
386}
387
388// Implement the Interpolation trait for APCV4
389impl<T: Real, V: State<T>, D: CallBackData> Interpolation<T, V> for APCV4<T, V, D> {
390    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
391        // Check if t is within the range of the solver
392        if t_interp < self.t_old || t_interp > self.t {
393            return Err(Error::OutOfBounds {
394                t_interp,
395                t_prev: self.t_old,
396                t_curr: self.t,
397            });
398        }
399
400        // Calculate the interpolated value using cubic Hermite interpolation
401        let y_interp = cubic_hermite_interpolate(
402            self.t_old,
403            self.t,
404            &self.y_old,
405            &self.y,
406            &self.dydt_old,
407            &self.dydt,
408            t_interp,
409        );
410
411        Ok(y_interp)
412    }
413}
414
415// Initialize APCV4 with default values and chainable settings
416impl<T: Real, V: State<T>, D: CallBackData> APCV4<T, V, D> {
417    pub fn new() -> Self {
418        APCV4 {
419            ..Default::default()
420        }
421    }
422
423    pub fn h0(mut self, h0: T) -> Self {
424        self.h0 = h0;
425        self
426    }
427
428    pub fn tol(mut self, tol: T) -> Self {
429        self.tol = tol;
430        self
431    }
432
433    pub fn h_min(mut self, h_min: T) -> Self {
434        self.h_min = h_min;
435        self
436    }
437
438    pub fn h_max(mut self, h_max: T) -> Self {
439        self.h_max = h_max;
440        self
441    }
442
443    pub fn max_steps(mut self, max_steps: usize) -> Self {
444        self.max_steps = max_steps;
445        self
446    }
447}
448
449impl<T: Real, V: State<T>, D: CallBackData> Default for APCV4<T, V, D> {
450    fn default() -> Self {
451        APCV4 {
452            h0: T::zero(),
453            h: T::zero(),
454            t: T::zero(),
455            y: V::zeros(),
456            dydt: V::zeros(),
457            t_prev: [T::zero(); 4],
458            y_prev: [V::zeros(), V::zeros(), V::zeros(), V::zeros()],
459            t_old: T::zero(),
460            y_old: V::zeros(),
461            dydt_old: V::zeros(),
462            k1: V::zeros(),
463            k2: V::zeros(),
464            k3: V::zeros(),
465            k4: V::zeros(),
466            tf: T::zero(),
467            evals: 0,
468            steps: 0,
469            status: Status::Uninitialized,
470            tol: T::from_f64(1.0e-6).unwrap(),
471            h_max: T::infinity(),
472            h_min: T::zero(),
473            max_steps: 1_000_000,
474        }
475    }
476}