differential_equations/methods/erk/adaptive/
ordinary.rs

1//! Adaptive Runge-Kutta methods for ODEs
2
3use super::{ExplicitRungeKutta, Ordinary, Adaptive};
4use crate::{
5    Error, Status,
6    alias::Evals,
7    methods::h_init::InitialStepSize,
8    interpolate::{Interpolation, cubic_hermite_interpolate},
9    ode::{OrdinaryNumericalMethod, ODE},
10    traits::{CallBackData, Real, State},
11    utils::{constrain_step_size, validate_step_size_parameters},
12};
13
14impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> OrdinaryNumericalMethod<T, V, D> for ExplicitRungeKutta<Ordinary, Adaptive, T, V, D, O, S, I> {
15    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
16    where
17        F: ODE<T, V, D>,
18    {
19        let mut evals = Evals::new();
20
21        // If h0 is zero, calculate initial step size
22        if self.h0 == T::zero() {
23            // Only use adaptive step size calculation if the method supports it
24            self.h0 = InitialStepSize::<Ordinary>::compute(ode, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, &mut evals);
25            evals.fcn += 2;
26
27        }
28
29        // Check bounds
30        match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
31            Ok(h0) => self.h = h0,
32            Err(status) => return Err(status),
33        }
34
35        // Initialize Statistics
36        self.stiffness_counter = 0;
37
38        // Initialize State
39        self.t = t0;
40        self.y = *y0;
41        ode.diff(self.t, &self.y, &mut self.dydt);
42        evals.fcn += 1;
43
44        // Initialize previous state
45        self.t_prev = self.t;
46        self.y_prev = self.y;
47        self.dydt_prev = self.dydt;
48
49        // Initialize Status
50        self.status = Status::Initialized;
51
52        Ok(evals)
53    }
54
55    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
56    where
57        F: ODE<T, V, D>,
58    {
59        let mut evals = Evals::new();
60
61        // Check step size
62        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
63            self.status = Status::Error(Error::StepSize {
64                t: self.t, y: self.y
65            });
66            return Err(Error::StepSize {
67                t: self.t, y: self.y
68            });
69        }
70
71        // Check max steps
72        if self.steps >= self.max_steps {
73            self.status = Status::Error(Error::MaxSteps {
74                t: self.t, y: self.y
75            });
76            return Err(Error::MaxSteps {
77                t: self.t, y: self.y
78            });
79        }
80        self.steps += 1;
81
82        // Save k[0] as the current derivative
83        self.k[0] = self.dydt;
84
85        // Compute stages
86        for i in 1..self.stages {
87            let mut y_stage = self.y;
88
89            for j in 0..i {
90                y_stage += self.k[j] * (self.a[i][j] * self.h);
91            }
92
93            ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
94        }
95        evals.fcn += self.stages - 1; // We already have k[0]
96 
97        // For adaptive methods with error estimation
98        // Compute higher order solution
99        let mut y_high = self.y;
100        for i in 0..self.stages {
101            y_high += self.k[i] * (self.b[i] * self.h);
102        }
103
104        // Compute lower order solution for error estimation
105        let mut y_low = self.y;
106        if let Some(bh) = &self.bh {
107            for i in 0..self.stages {
108                y_low += self.k[i] * (bh[i] * self.h);
109            }
110        }
111
112        // Compute error estimate
113        let err = y_high - y_low;
114
115        // Calculate error norm
116        let mut err_norm: T = T::zero();
117
118        // Iterate through state elements
119        for n in 0..self.y.len() {
120            let tol = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
121            err_norm = err_norm.max((err.get(n) / tol).abs());
122        };
123
124        // Determine if step is accepted
125        if err_norm <= T::one() {
126            // Log previous state
127            self.t_prev = self.t;
128            self.y_prev = self.y;
129            self.dydt_prev = self.k[0];
130            self.h_prev = self.h;
131
132            if let Status::RejectedStep = self.status {
133                self.stiffness_counter = 0;
134                self.status = Status::Solving;
135            }
136
137            // If method has dense output stages, compute them
138            if self.bi.is_some() {
139                // Compute extra stages for dense output
140                for i in 0..(I - S) {
141                    let mut y_stage = self.y;
142                    for j in 0..self.stages + i {
143                        y_stage += self.k[j] * (self.a[self.stages + i][j] * self.h);
144                    }
145
146                    ode.diff(self.t + self.c[self.stages + i] * self.h, &y_stage, &mut self.k[self.stages + i]);
147                }
148                evals.fcn += I - S;
149            }
150
151            // Update state with the higher-order solution
152            self.t += self.h;
153            self.y = y_high;
154
155            // Compute the derivative for the next step
156            if self.fsal {
157                // If FSAL (First Same As Last) is enabled, we can reuse the last derivative
158                self.dydt = self.k[S - 1];
159            } else {
160                // Otherwise, compute the new derivative
161                ode.diff(self.t, &self.y, &mut self.dydt);
162                evals.fcn += 1;
163            }
164        } else {
165            // Step rejected
166            self.status = Status::RejectedStep;
167            self.stiffness_counter += 1;
168
169            // Check for stiffness
170            if self.stiffness_counter >= self.max_rejects {
171                self.status = Status::Error(Error::Stiffness {
172                    t: self.t, y: self.y
173                });
174                return Err(Error::Stiffness {
175                    t: self.t, y: self.y
176                });
177            }
178        }
179
180        // Calculate new step size for adaptive methods
181        let order = T::from_usize(self.order).unwrap();
182        let err_order = T::one() / order;
183
184        // Step size controller
185        let scale = self.safety_factor * err_norm.powf(-err_order);
186        let scale = scale.max(self.min_scale).min(self.max_scale);
187        self.h *= scale;
188
189        // Ensure step size is within bounds
190        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
191        
192        Ok(evals)
193    }
194
195    fn t(&self) -> T { self.t }
196    fn y(&self) -> &V { &self.y }
197    fn t_prev(&self) -> T { self.t_prev }
198    fn y_prev(&self) -> &V { &self.y_prev }
199    fn h(&self) -> T { self.h }
200    fn set_h(&mut self, h: T) { self.h = h; }
201    fn status(&self) -> &Status<T, V, D> { &self.status }
202    fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
203}
204
205impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Ordinary, Adaptive, T, V, D, O, S, I> {
206    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
207        // Check if t is within bounds
208        if t_interp < self.t_prev || t_interp > self.t {
209            return Err(Error::OutOfBounds {
210                t_interp,
211                t_prev: self.t_prev,
212                t_curr: self.t
213            });
214        }
215
216        // If method has dense output coefficients, use them
217        if self.bi.is_some() {
218            // Calculate the normalized distance within the step [0, 1]
219            let s = (t_interp - self.t_prev) / self.h_prev;
220            
221            // Get the interpolation coefficients
222            let bi = self.bi.as_ref().unwrap();
223
224            let mut cont = [T::zero(); I];
225            // Compute the interpolation coefficients using Horner's method
226            for i in 0..self.dense_stages {
227                // Start with the highest-order term
228                cont[i] = bi[i][self.order - 1];
229
230                // Apply Horner's method
231                for j in (0..self.order - 1).rev() {
232                    cont[i] = cont[i] * s + bi[i][j];
233                }
234
235                // Multiply by s
236                cont[i] *= s;
237            }
238
239            // Compute the interpolated value
240            let mut y_interp = self.y_prev;
241            for i in 0..I {
242                y_interp += self.k[i] * cont[i] * self.h_prev;
243            }
244
245            Ok(y_interp)
246        } else {
247            // Otherwise use cubic Hermite interpolation
248            let y_interp = cubic_hermite_interpolate(
249                self.t_prev, 
250                self.t, 
251                &self.y_prev, 
252                &self.y, 
253                &self.dydt_prev, 
254                &self.dydt, 
255                t_interp
256            );
257
258            Ok(y_interp)
259        }
260    }
261}