Skip to main content

oxicuda_solver/dense/ode_pde/
explicit.rs

1//! Explicit ODE solvers: Euler, RK4, RK45.
2
3use crate::error::{SolverError, SolverResult};
4
5use super::types::{OdeConfig, OdeSolution, OdeSystem};
6use super::utils::validate_ode_inputs;
7
8// =========================================================================
9// Forward Euler
10// =========================================================================
11
12/// Forward Euler solver.
13pub struct EulerSolver;
14
15impl EulerSolver {
16    /// Integrate the ODE system using the forward Euler method.
17    pub fn solve(
18        system: &dyn OdeSystem,
19        y0: &[f64],
20        config: &OdeConfig,
21    ) -> SolverResult<OdeSolution> {
22        let n = system.dim();
23        validate_ode_inputs(n, y0, config)?;
24
25        let mut t = config.t_start;
26        let dt = config.dt;
27        let mut y = y0.to_vec();
28        let mut k = vec![0.0; n];
29
30        let mut times = vec![t];
31        let mut states = vec![y.clone()];
32        let mut num_steps = 0_usize;
33        let mut num_rhs = 0_usize;
34
35        while t < config.t_end - dt * 1e-10 && num_steps < config.max_steps {
36            let h = dt.min(config.t_end - t);
37            system.rhs(t, &y, &mut k)?;
38            num_rhs += 1;
39
40            for i in 0..n {
41                y[i] += h * k[i];
42            }
43            t += h;
44            num_steps += 1;
45
46            times.push(t);
47            states.push(y.clone());
48        }
49
50        Ok(OdeSolution {
51            times,
52            states,
53            num_steps,
54            num_rejected: 0,
55            num_rhs_evals: num_rhs,
56        })
57    }
58}
59
60// =========================================================================
61// Classical RK4
62// =========================================================================
63
64/// Classical fourth-order Runge-Kutta solver.
65pub struct Rk4Solver;
66
67impl Rk4Solver {
68    /// Integrate the ODE system using classical RK4.
69    pub fn solve(
70        system: &dyn OdeSystem,
71        y0: &[f64],
72        config: &OdeConfig,
73    ) -> SolverResult<OdeSolution> {
74        let n = system.dim();
75        validate_ode_inputs(n, y0, config)?;
76
77        let mut t = config.t_start;
78        let dt = config.dt;
79        let mut y = y0.to_vec();
80
81        let mut k1 = vec![0.0; n];
82        let mut k2 = vec![0.0; n];
83        let mut k3 = vec![0.0; n];
84        let mut k4 = vec![0.0; n];
85        let mut tmp = vec![0.0; n];
86
87        let mut times = vec![t];
88        let mut states = vec![y.clone()];
89        let mut num_steps = 0_usize;
90        let mut num_rhs = 0_usize;
91
92        while t < config.t_end - dt * 1e-10 && num_steps < config.max_steps {
93            let h = dt.min(config.t_end - t);
94
95            // k1
96            system.rhs(t, &y, &mut k1)?;
97            num_rhs += 1;
98
99            // k2
100            for i in 0..n {
101                tmp[i] = y[i] + 0.5 * h * k1[i];
102            }
103            system.rhs(t + 0.5 * h, &tmp, &mut k2)?;
104            num_rhs += 1;
105
106            // k3
107            for i in 0..n {
108                tmp[i] = y[i] + 0.5 * h * k2[i];
109            }
110            system.rhs(t + 0.5 * h, &tmp, &mut k3)?;
111            num_rhs += 1;
112
113            // k4
114            for i in 0..n {
115                tmp[i] = y[i] + h * k3[i];
116            }
117            system.rhs(t + h, &tmp, &mut k4)?;
118            num_rhs += 1;
119
120            // Combine
121            for i in 0..n {
122                y[i] += h / 6.0 * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]);
123            }
124            t += h;
125            num_steps += 1;
126
127            times.push(t);
128            states.push(y.clone());
129        }
130
131        Ok(OdeSolution {
132            times,
133            states,
134            num_steps,
135            num_rejected: 0,
136            num_rhs_evals: num_rhs,
137        })
138    }
139}
140
141// =========================================================================
142// Dormand-Prince RK45 (adaptive)
143// =========================================================================
144
145/// Dormand-Prince 4(5) adaptive solver (RK45).
146pub struct Rk45Solver;
147
148impl Rk45Solver {
149    // Dormand-Prince Butcher tableau (DOPRI5) coefficients.
150    const A21: f64 = 1.0 / 5.0;
151    const A31: f64 = 3.0 / 40.0;
152    const A32: f64 = 9.0 / 40.0;
153    const A41: f64 = 44.0 / 45.0;
154    const A42: f64 = -56.0 / 15.0;
155    const A43: f64 = 32.0 / 9.0;
156    const A51: f64 = 19372.0 / 6561.0;
157    const A52: f64 = -25360.0 / 2187.0;
158    const A53: f64 = 64448.0 / 6561.0;
159    const A54: f64 = -212.0 / 729.0;
160    const A61: f64 = 9017.0 / 3168.0;
161    const A62: f64 = -355.0 / 33.0;
162    const A63: f64 = 46732.0 / 5247.0;
163    const A64: f64 = 49.0 / 176.0;
164    const A65: f64 = -5103.0 / 18656.0;
165
166    // 5th-order weights (for the solution)
167    const B1: f64 = 35.0 / 384.0;
168    // B2 = 0
169    const B3: f64 = 500.0 / 1113.0;
170    const B4: f64 = 125.0 / 192.0;
171    const B5: f64 = -2187.0 / 6784.0;
172    const B6: f64 = 11.0 / 84.0;
173
174    // 4th-order weights (for error estimation)
175    const E1: f64 = 71.0 / 57600.0;
176    // E2 = 0
177    const E3: f64 = -71.0 / 16695.0;
178    const E4: f64 = 71.0 / 1920.0;
179    const E5: f64 = -17253.0 / 339200.0;
180    const E6: f64 = 22.0 / 525.0;
181    const E7: f64 = -1.0 / 40.0;
182
183    /// Integrate with adaptive step-size control.
184    pub fn solve(
185        system: &dyn OdeSystem,
186        y0: &[f64],
187        config: &OdeConfig,
188    ) -> SolverResult<OdeSolution> {
189        let n = system.dim();
190        validate_ode_inputs(n, y0, config)?;
191
192        let mut t = config.t_start;
193        let mut h = config.dt;
194        let mut y = y0.to_vec();
195
196        let mut k1 = vec![0.0; n];
197        let mut k2 = vec![0.0; n];
198        let mut k3 = vec![0.0; n];
199        let mut k4 = vec![0.0; n];
200        let mut k5 = vec![0.0; n];
201        let mut k6 = vec![0.0; n];
202        let mut k7 = vec![0.0; n];
203        let mut tmp = vec![0.0; n];
204        let mut y_new = vec![0.0; n];
205
206        let mut times = vec![t];
207        let mut states = vec![y.clone()];
208        let mut num_steps = 0_usize;
209        let mut num_rejected = 0_usize;
210        let mut num_rhs = 0_usize;
211
212        // Safety factor and step size bounds
213        let safety = 0.9;
214        let min_factor = 0.2;
215        let max_factor = 5.0;
216
217        system.rhs(t, &y, &mut k1)?;
218        num_rhs += 1;
219
220        while t < config.t_end - 1e-14 * config.t_end.abs().max(1.0)
221            && num_steps + num_rejected < config.max_steps
222        {
223            h = h.min(config.t_end - t);
224
225            // Stage 2
226            for i in 0..n {
227                tmp[i] = y[i] + h * Self::A21 * k1[i];
228            }
229            system.rhs(t + h / 5.0, &tmp, &mut k2)?;
230
231            // Stage 3
232            for i in 0..n {
233                tmp[i] = y[i] + h * (Self::A31 * k1[i] + Self::A32 * k2[i]);
234            }
235            system.rhs(t + 3.0 / 10.0 * h, &tmp, &mut k3)?;
236
237            // Stage 4
238            for i in 0..n {
239                tmp[i] = y[i] + h * (Self::A41 * k1[i] + Self::A42 * k2[i] + Self::A43 * k3[i]);
240            }
241            system.rhs(t + 4.0 / 5.0 * h, &tmp, &mut k4)?;
242
243            // Stage 5
244            for i in 0..n {
245                tmp[i] = y[i]
246                    + h * (Self::A51 * k1[i]
247                        + Self::A52 * k2[i]
248                        + Self::A53 * k3[i]
249                        + Self::A54 * k4[i]);
250            }
251            system.rhs(t + 8.0 / 9.0 * h, &tmp, &mut k5)?;
252
253            // Stage 6
254            for i in 0..n {
255                tmp[i] = y[i]
256                    + h * (Self::A61 * k1[i]
257                        + Self::A62 * k2[i]
258                        + Self::A63 * k3[i]
259                        + Self::A64 * k4[i]
260                        + Self::A65 * k5[i]);
261            }
262            system.rhs(t + h, &tmp, &mut k6)?;
263
264            num_rhs += 5;
265
266            // 5th-order solution
267            for i in 0..n {
268                y_new[i] = y[i]
269                    + h * (Self::B1 * k1[i]
270                        + Self::B3 * k3[i]
271                        + Self::B4 * k4[i]
272                        + Self::B5 * k5[i]
273                        + Self::B6 * k6[i]);
274            }
275
276            // Error estimate (difference between 5th and 4th order)
277            // We need k7 for the error estimate
278            system.rhs(t + h, &y_new, &mut k7)?;
279            num_rhs += 1;
280
281            let mut err_norm = 0.0;
282            for i in 0..n {
283                let err_i = h
284                    * (Self::E1 * k1[i]
285                        + Self::E3 * k3[i]
286                        + Self::E4 * k4[i]
287                        + Self::E5 * k5[i]
288                        + Self::E6 * k6[i]
289                        + Self::E7 * k7[i]);
290                let scale = config.atol + config.rtol * y_new[i].abs().max(y[i].abs());
291                err_norm += (err_i / scale).powi(2);
292            }
293            err_norm = (err_norm / n as f64).sqrt();
294
295            if err_norm <= 1.0 {
296                // Accept step
297                t += h;
298                y.copy_from_slice(&y_new);
299                num_steps += 1;
300
301                times.push(t);
302                states.push(y.clone());
303
304                // FSAL: reuse k7 as k1 for next step
305                k1.copy_from_slice(&k7);
306
307                // Increase step size
308                let factor = if err_norm > 1e-15 {
309                    (safety / err_norm.powf(0.2)).clamp(min_factor, max_factor)
310                } else {
311                    max_factor
312                };
313                h *= factor;
314            } else {
315                // Reject step
316                num_rejected += 1;
317                let factor = (safety / err_norm.powf(0.2)).clamp(min_factor, 1.0);
318                h *= factor;
319            }
320        }
321
322        if num_steps + num_rejected >= config.max_steps && t < config.t_end - 1e-10 {
323            return Err(SolverError::ConvergenceFailure {
324                iterations: config.max_steps as u32,
325                residual: (config.t_end - t).abs(),
326            });
327        }
328
329        Ok(OdeSolution {
330            times,
331            states,
332            num_steps,
333            num_rejected,
334            num_rhs_evals: num_rhs,
335        })
336    }
337}