Skip to main content

differential_equations/methods/erk/fixed/
ordinary.rs

1//! Fixed Runge-Kutta methods for ODEs
2use crate::{
3    error::Error,
4    interpolate::{Interpolation, cubic_hermite_interpolate},
5    methods::{ExplicitRungeKutta, Fixed, Ordinary},
6    ode::{ODE, OrdinaryNumericalMethod},
7    stats::Evals,
8    status::Status,
9    traits::{Real, State},
10    utils::validate_step_size_parameters,
11};
12
13impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
14    OrdinaryNumericalMethod<T, Y> for ExplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
15{
16    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
17    where
18        F: ODE<T, Y> + ?Sized,
19    {
20        let mut evals = Evals::new();
21
22        // If h0 is zero, calculate initial step size for fixed-step methods
23        if self.h0 == T::zero() {
24            // Simple default step size for fixed-step methods
25            let duration = (tf - t0).abs();
26            let default_steps = T::from_usize(100).unwrap();
27            self.h0 = duration / default_steps;
28        }
29
30        // Check bounds
31        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
32            Ok(h0) => self.h = h0,
33            Err(status) => return Err(status),
34        } // Initialize Statistics
35
36        // Initialize State
37        self.t = t0;
38        self.y = y0.clone();
39        self.dydt = y0.zeros_like();
40        self.y_prev = y0.clone();
41        self.dydt_prev = y0.zeros_like();
42        self.k = core::array::from_fn(|_| y0.zeros_like());
43        self.cont = core::array::from_fn(|_| y0.zeros_like());
44        ode.diff(self.t, &self.y, &mut self.dydt);
45        evals.function += 1;
46
47        // Initialize previous state
48        self.t_prev = self.t;
49        self.y_prev = self.y.clone();
50        self.dydt_prev = self.dydt.clone();
51
52        // Initialize Status
53        self.status = Status::Initialized;
54
55        Ok(evals)
56    }
57
58    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
59    where
60        F: ODE<T, Y> + ?Sized,
61    {
62        let mut evals = Evals::new();
63
64        // Check max steps
65        if self.steps >= self.max_steps {
66            self.status = Status::Error(Error::MaxSteps {
67                t: self.t,
68                y: self.y.clone(),
69            });
70            return Err(Error::MaxSteps {
71                t: self.t,
72                y: self.y.clone(),
73            });
74        }
75        self.steps += 1;
76
77        // Save k[0] as the current derivative
78        self.k[0] = self.dydt.clone();
79
80        // Compute stages
81        for i in 1..self.stages {
82            let mut y_stage = self.y.clone();
83
84            for j in 0..i {
85                y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
86            }
87
88            ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
89        }
90        evals.function += self.stages - 1; // We already have k[0]
91
92        // Store current state before update for interpolation
93        self.t_prev = self.t;
94        self.y_prev = self.y.clone();
95        self.dydt_prev = self.k[0].clone();
96        self.h_prev = self.h;
97
98        // Compute solution
99        let mut y_next = self.y.clone();
100        for i in 0..self.stages {
101            y_next.add_scaled(self.b[i] * self.h, &self.k[i]);
102        }
103
104        // If method has dense output stages, compute them
105        if self.bi.is_some() {
106            // Compute extra stages for dense output
107            for i in 0..(I - S) {
108                let mut y_stage = self.y.clone();
109                for j in 0..self.stages + i {
110                    y_stage.add_scaled(self.a[self.stages + i][j] * self.h, &self.k[j]);
111                }
112
113                ode.diff(
114                    self.t + self.c[self.stages + i] * self.h,
115                    &y_stage,
116                    &mut self.k[self.stages + i],
117                );
118            }
119            evals.function += I - S;
120        }
121
122        // Update state
123        self.t += self.h;
124        self.y = y_next;
125
126        // Calculate new derivative for next step
127        if self.fsal {
128            // If FSAL (First Same As Last) is enabled, we can reuse the last derivative
129            self.dydt = self.k[S - 1].clone();
130        } else {
131            // Otherwise, compute the new derivative
132            ode.diff(self.t, &self.y, &mut self.dydt);
133            evals.function += 1;
134        }
135
136        self.status = Status::Solving;
137        Ok(evals)
138    }
139
140    fn t(&self) -> T {
141        self.t
142    }
143    fn y(&self) -> &Y {
144        &self.y
145    }
146    fn t_prev(&self) -> T {
147        self.t_prev
148    }
149    fn y_prev(&self) -> &Y {
150        &self.y_prev
151    }
152    fn h(&self) -> T {
153        self.h
154    }
155    fn set_h(&mut self, h: T) {
156        self.h = h;
157    }
158    fn status(&self) -> &Status<T, Y> {
159        &self.status
160    }
161    fn set_status(&mut self, status: Status<T, Y>) {
162        self.status = status;
163    }
164}
165
166impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
167    for ExplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
168{
169    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
170        // Check if t is within bounds
171        if t_interp < self.t_prev || t_interp > self.t {
172            return Err(Error::OutOfBounds {
173                t_interp,
174                t_prev: self.t_prev,
175                t_curr: self.t,
176            });
177        }
178
179        // If method has dense output coefficients, use them
180        if let Some(bi) = self.bi.as_ref() {
181            // Calculate the normalized distance within the step [0, 1]
182            let s = (t_interp - self.t_prev) / self.h_prev;
183
184            let mut cont = [T::zero(); I];
185            // Compute the interpolation coefficients using Horner's method
186            for i in 0..self.dense_stages {
187                // Start with the highest-order term
188                cont[i] = bi[i][self.order - 1];
189
190                // Apply Horner's method
191                for j in (0..self.order - 1).rev() {
192                    cont[i] = cont[i] * s + bi[i][j];
193                }
194
195                // Multiply by s
196                cont[i] *= s;
197            }
198
199            // Compute the interpolated value
200            let mut y_interp = self.y_prev.clone();
201            for i in 0..I {
202                y_interp.add_scaled(cont[i] * self.h_prev, &self.k[i]);
203            }
204
205            Ok(y_interp)
206        } else {
207            // Otherwise use cubic Hermite interpolation
208            let y_interp = cubic_hermite_interpolate(
209                self.t_prev,
210                self.t,
211                &self.y_prev,
212                &self.y,
213                &self.dydt_prev,
214                &self.dydt,
215                t_interp,
216            );
217
218            Ok(y_interp)
219        }
220    }
221}