differential_equations/methods/erk/fixed/
ordinary.rs

1//! Fixed Runge-Kutta methods for ODEs
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, cubic_hermite_interpolate},
6    methods::{ExplicitRungeKutta, Fixed, Ordinary},
7    ode::{ODE, OrdinaryNumericalMethod},
8    stats::Evals,
9    status::Status,
10    traits::{Real, State},
11    utils::validate_step_size_parameters,
12};
13
14impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
15    OrdinaryNumericalMethod<T, Y> for ExplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
16{
17    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
18    where
19        F: ODE<T, Y>,
20    {
21        let mut evals = Evals::new();
22
23        // If h0 is zero, calculate initial step size for fixed-step methods
24        if self.h0 == T::zero() {
25            // Simple default step size for fixed-step methods
26            let duration = (tf - t0).abs();
27            let default_steps = T::from_usize(100).unwrap();
28            self.h0 = duration / default_steps;
29        }
30
31        // Check bounds
32        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
33            Ok(h0) => self.h = h0,
34            Err(status) => return Err(status),
35        } // Initialize Statistics
36
37        // Initialize State
38        self.t = t0;
39        self.y = *y0;
40        ode.diff(self.t, &self.y, &mut self.dydt);
41        evals.function += 1;
42
43        // Initialize previous state
44        self.t_prev = self.t;
45        self.y_prev = self.y;
46        self.dydt_prev = self.dydt;
47
48        // Initialize Status
49        self.status = Status::Initialized;
50
51        Ok(evals)
52    }
53
54    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
55    where
56        F: ODE<T, Y>,
57    {
58        let mut evals = Evals::new();
59
60        // Check max steps
61        if self.steps >= self.max_steps {
62            self.status = Status::Error(Error::MaxSteps {
63                t: self.t,
64                y: self.y,
65            });
66            return Err(Error::MaxSteps {
67                t: self.t,
68                y: self.y,
69            });
70        }
71        self.steps += 1;
72
73        // Save k[0] as the current derivative
74        self.k[0] = self.dydt;
75
76        // Compute stages
77        for i in 1..self.stages {
78            let mut y_stage = self.y;
79
80            for j in 0..i {
81                y_stage += self.k[j] * (self.a[i][j] * self.h);
82            }
83
84            ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
85        }
86        evals.function += self.stages - 1; // We already have k[0]
87
88        // Store current state before update for interpolation
89        self.t_prev = self.t;
90        self.y_prev = self.y;
91        self.dydt_prev = self.k[0];
92
93        // Compute solution
94        let mut y_next = self.y;
95        for i in 0..self.stages {
96            y_next += self.k[i] * (self.b[i] * self.h);
97        }
98
99        // If method has dense output stages, compute them
100        if self.bi.is_some() {
101            // Compute extra stages for dense output
102            for i in 0..(I - S) {
103                let mut y_stage = self.y;
104                for j in 0..self.stages + i {
105                    y_stage += self.k[j] * (self.a[self.stages + i][j] * self.h);
106                }
107
108                ode.diff(
109                    self.t + self.c[self.stages + i] * self.h,
110                    &y_stage,
111                    &mut self.k[self.stages + i],
112                );
113            }
114            evals.function += I - S;
115        }
116
117        // Update state
118        self.t += self.h;
119        self.y = y_next;
120
121        // Calculate new derivative for next step
122        if self.fsal {
123            // If FSAL (First Same As Last) is enabled, we can reuse the last derivative
124            self.dydt = self.k[S - 1];
125        } else {
126            // Otherwise, compute the new derivative
127            ode.diff(self.t, &self.y, &mut self.dydt);
128            evals.function += 1;
129        }
130
131        self.status = Status::Solving;
132        Ok(evals)
133    }
134
135    fn t(&self) -> T {
136        self.t
137    }
138    fn y(&self) -> &Y {
139        &self.y
140    }
141    fn t_prev(&self) -> T {
142        self.t_prev
143    }
144    fn y_prev(&self) -> &Y {
145        &self.y_prev
146    }
147    fn h(&self) -> T {
148        self.h
149    }
150    fn set_h(&mut self, h: T) {
151        self.h = h;
152    }
153    fn status(&self) -> &Status<T, Y> {
154        &self.status
155    }
156    fn set_status(&mut self, status: Status<T, Y>) {
157        self.status = status;
158    }
159}
160
161impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
162    for ExplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
163{
164    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
165        // Check if t is within bounds
166        if t_interp < self.t_prev || t_interp > self.t {
167            return Err(Error::OutOfBounds {
168                t_interp,
169                t_prev: self.t_prev,
170                t_curr: self.t,
171            });
172        }
173
174        // If method has dense output coefficients, use them
175        if self.bi.is_some() {
176            // Calculate the normalized distance within the step [0, 1]
177            let s = (t_interp - self.t_prev) / self.h_prev;
178
179            // Get the interpolation coefficients
180            let bi = self.bi.as_ref().unwrap();
181
182            let mut cont = [T::zero(); I];
183            // Compute the interpolation coefficients using Horner's method
184            for i in 0..self.dense_stages {
185                // Start with the highest-order term
186                cont[i] = bi[i][self.order - 1];
187
188                // Apply Horner's method
189                for j in (0..self.order - 1).rev() {
190                    cont[i] = cont[i] * s + bi[i][j];
191                }
192
193                // Multiply by s
194                cont[i] *= s;
195            }
196
197            // Compute the interpolated value
198            let mut y_interp = self.y_prev;
199            for i in 0..I {
200                y_interp += self.k[i] * cont[i] * self.h_prev;
201            }
202
203            Ok(y_interp)
204        } else {
205            // Otherwise use cubic Hermite interpolation
206            let y_interp = cubic_hermite_interpolate(
207                self.t_prev,
208                self.t,
209                &self.y_prev,
210                &self.y,
211                &self.dydt_prev,
212                &self.dydt,
213                t_interp,
214            );
215
216            Ok(y_interp)
217        }
218    }
219}