differential_equations/methods/erk/fixed/
ordinary.rs

1//! Fixed Runge-Kutta methods for ODEs
2
3use super::{ExplicitRungeKutta, Ordinary, Fixed};
4use crate::{
5    Error, Status,
6    alias::Evals,
7    interpolate::{Interpolation, cubic_hermite_interpolate},
8    ode::{OrdinaryNumericalMethod, ODE},
9    traits::{CallBackData, Real, State},
10    utils::validate_step_size_parameters,
11};
12
13impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> OrdinaryNumericalMethod<T, V, D> for ExplicitRungeKutta<Ordinary, Fixed, T, V, D, O, S, I> {    
14    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
15    where
16        F: ODE<T, V, D>,
17    {
18        let mut evals = Evals::new();
19
20        // If h0 is zero, calculate initial step size for fixed-step methods
21        if self.h0 == T::zero() {
22            // Simple default step size for fixed-step methods
23            let duration = (tf - t0).abs();
24            let default_steps = T::from_usize(100).unwrap();
25            self.h0 = duration / default_steps;
26        }
27
28        // Check bounds
29        match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
30            Ok(h0) => self.h = h0,
31            Err(status) => return Err(status),
32        }        // Initialize Statistics
33
34        // Initialize State
35        self.t = t0;
36        self.y = *y0;
37        ode.diff(self.t, &self.y, &mut self.dydt);
38        evals.fcn += 1;
39
40        // Initialize previous state
41        self.t_prev = self.t;
42        self.y_prev = self.y;
43        self.dydt_prev = self.dydt;
44
45        // Initialize Status
46        self.status = Status::Initialized;
47
48        Ok(evals)
49    }
50
51    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
52    where
53        F: ODE<T, V, D>,
54    {
55        let mut evals = Evals::new();
56
57        // Check max steps
58        if self.steps >= self.max_steps {
59            self.status = Status::Error(Error::MaxSteps {
60                t: self.t, y: self.y
61            });
62            return Err(Error::MaxSteps {
63                t: self.t, y: self.y
64            });
65        }
66        self.steps += 1;
67
68        // Save k[0] as the current derivative
69        self.k[0] = self.dydt;
70
71        // Compute stages
72        for i in 1..self.stages {
73            let mut y_stage = self.y;
74
75            for j in 0..i {
76                y_stage += self.k[j] * (self.a[i][j] * self.h);
77            }
78
79            ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
80        }
81        evals.fcn += self.stages - 1; // We already have k[0]
82
83        // Store current state before update for interpolation
84        self.t_prev = self.t;
85        self.y_prev = self.y;
86        self.dydt_prev = self.k[0];
87
88        // Compute solution
89        let mut y_next = self.y;
90        for i in 0..self.stages {
91            y_next += self.k[i] * (self.b[i] * self.h);
92        }
93
94        // If method has dense output stages, compute them
95        if self.bi.is_some() {
96            // Compute extra stages for dense output
97            for i in 0..(I - S) {
98                let mut y_stage = self.y;
99                for j in 0..self.stages + i {
100                    y_stage += self.k[j] * (self.a[self.stages + i][j] * self.h);
101                }
102
103                ode.diff(self.t + self.c[self.stages + i] * self.h, &y_stage, &mut self.k[self.stages + i]);
104            }
105            evals.fcn += I - S;
106        }
107
108        // Update state
109        self.t += self.h;
110        self.y = y_next;
111        
112        // Calculate new derivative for next step
113        if self.fsal {
114            // If FSAL (First Same As Last) is enabled, we can reuse the last derivative
115            self.dydt = self.k[S - 1];
116        } else {
117            // Otherwise, compute the new derivative
118            ode.diff(self.t, &self.y, &mut self.dydt);
119            evals.fcn += 1;
120        }
121        
122        self.status = Status::Solving;        
123        Ok(evals)
124    }
125
126    fn t(&self) -> T { self.t }
127    fn y(&self) -> &V { &self.y }
128    fn t_prev(&self) -> T { self.t_prev }
129    fn y_prev(&self) -> &V { &self.y_prev }
130    fn h(&self) -> T { self.h }
131    fn set_h(&mut self, h: T) { self.h = h; }
132    fn status(&self) -> &Status<T, V, D> { &self.status }
133    fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
134}
135
136impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Ordinary, Fixed, T, V, D, O, S, I> {
137    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
138        // Check if t is within bounds
139        if t_interp < self.t_prev || t_interp > self.t {
140            return Err(Error::OutOfBounds {
141                t_interp,
142                t_prev: self.t_prev,
143                t_curr: self.t
144            });
145        }       
146        
147        // If method has dense output coefficients, use them
148        if self.bi.is_some() {
149            // Calculate the normalized distance within the step [0, 1]
150            let s = (t_interp - self.t_prev) / self.h_prev;
151            
152            // Get the interpolation coefficients
153            let bi = self.bi.as_ref().unwrap();
154
155            let mut cont = [T::zero(); I];
156            // Compute the interpolation coefficients using Horner's method
157            for i in 0..self.dense_stages {
158                // Start with the highest-order term
159                cont[i] = bi[i][self.order - 1];
160
161                // Apply Horner's method
162                for j in (0..self.order - 1).rev() {
163                    cont[i] = cont[i] * s + bi[i][j];
164                }
165
166                // Multiply by s
167                cont[i] *= s;
168            }
169
170            // Compute the interpolated value
171            let mut y_interp = self.y_prev;
172            for i in 0..I {
173                y_interp += self.k[i] * cont[i] * self.h_prev;
174            }
175
176            Ok(y_interp)
177        } else {
178            // Otherwise use cubic Hermite interpolation
179            let y_interp = cubic_hermite_interpolate(
180                self.t_prev, 
181                self.t, 
182                &self.y_prev, 
183                &self.y, 
184                &self.dydt_prev, 
185                &self.dydt, 
186                t_interp
187            );
188
189            Ok(y_interp)
190        }
191
192    }
193}