oxicuda_solver/dense/ode_pde/
explicit.rs1use crate::error::{SolverError, SolverResult};
4
5use super::types::{OdeConfig, OdeSolution, OdeSystem};
6use super::utils::validate_ode_inputs;
7
8pub struct EulerSolver;
14
15impl EulerSolver {
16 pub fn solve(
18 system: &dyn OdeSystem,
19 y0: &[f64],
20 config: &OdeConfig,
21 ) -> SolverResult<OdeSolution> {
22 let n = system.dim();
23 validate_ode_inputs(n, y0, config)?;
24
25 let mut t = config.t_start;
26 let dt = config.dt;
27 let mut y = y0.to_vec();
28 let mut k = vec![0.0; n];
29
30 let mut times = vec![t];
31 let mut states = vec![y.clone()];
32 let mut num_steps = 0_usize;
33 let mut num_rhs = 0_usize;
34
35 while t < config.t_end - dt * 1e-10 && num_steps < config.max_steps {
36 let h = dt.min(config.t_end - t);
37 system.rhs(t, &y, &mut k)?;
38 num_rhs += 1;
39
40 for i in 0..n {
41 y[i] += h * k[i];
42 }
43 t += h;
44 num_steps += 1;
45
46 times.push(t);
47 states.push(y.clone());
48 }
49
50 Ok(OdeSolution {
51 times,
52 states,
53 num_steps,
54 num_rejected: 0,
55 num_rhs_evals: num_rhs,
56 })
57 }
58}
59
60pub struct Rk4Solver;
66
67impl Rk4Solver {
68 pub fn solve(
70 system: &dyn OdeSystem,
71 y0: &[f64],
72 config: &OdeConfig,
73 ) -> SolverResult<OdeSolution> {
74 let n = system.dim();
75 validate_ode_inputs(n, y0, config)?;
76
77 let mut t = config.t_start;
78 let dt = config.dt;
79 let mut y = y0.to_vec();
80
81 let mut k1 = vec![0.0; n];
82 let mut k2 = vec![0.0; n];
83 let mut k3 = vec![0.0; n];
84 let mut k4 = vec![0.0; n];
85 let mut tmp = vec![0.0; n];
86
87 let mut times = vec![t];
88 let mut states = vec![y.clone()];
89 let mut num_steps = 0_usize;
90 let mut num_rhs = 0_usize;
91
92 while t < config.t_end - dt * 1e-10 && num_steps < config.max_steps {
93 let h = dt.min(config.t_end - t);
94
95 system.rhs(t, &y, &mut k1)?;
97 num_rhs += 1;
98
99 for i in 0..n {
101 tmp[i] = y[i] + 0.5 * h * k1[i];
102 }
103 system.rhs(t + 0.5 * h, &tmp, &mut k2)?;
104 num_rhs += 1;
105
106 for i in 0..n {
108 tmp[i] = y[i] + 0.5 * h * k2[i];
109 }
110 system.rhs(t + 0.5 * h, &tmp, &mut k3)?;
111 num_rhs += 1;
112
113 for i in 0..n {
115 tmp[i] = y[i] + h * k3[i];
116 }
117 system.rhs(t + h, &tmp, &mut k4)?;
118 num_rhs += 1;
119
120 for i in 0..n {
122 y[i] += h / 6.0 * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]);
123 }
124 t += h;
125 num_steps += 1;
126
127 times.push(t);
128 states.push(y.clone());
129 }
130
131 Ok(OdeSolution {
132 times,
133 states,
134 num_steps,
135 num_rejected: 0,
136 num_rhs_evals: num_rhs,
137 })
138 }
139}
140
141pub struct Rk45Solver;
147
148impl Rk45Solver {
149 const A21: f64 = 1.0 / 5.0;
151 const A31: f64 = 3.0 / 40.0;
152 const A32: f64 = 9.0 / 40.0;
153 const A41: f64 = 44.0 / 45.0;
154 const A42: f64 = -56.0 / 15.0;
155 const A43: f64 = 32.0 / 9.0;
156 const A51: f64 = 19372.0 / 6561.0;
157 const A52: f64 = -25360.0 / 2187.0;
158 const A53: f64 = 64448.0 / 6561.0;
159 const A54: f64 = -212.0 / 729.0;
160 const A61: f64 = 9017.0 / 3168.0;
161 const A62: f64 = -355.0 / 33.0;
162 const A63: f64 = 46732.0 / 5247.0;
163 const A64: f64 = 49.0 / 176.0;
164 const A65: f64 = -5103.0 / 18656.0;
165
166 const B1: f64 = 35.0 / 384.0;
168 const B3: f64 = 500.0 / 1113.0;
170 const B4: f64 = 125.0 / 192.0;
171 const B5: f64 = -2187.0 / 6784.0;
172 const B6: f64 = 11.0 / 84.0;
173
174 const E1: f64 = 71.0 / 57600.0;
176 const E3: f64 = -71.0 / 16695.0;
178 const E4: f64 = 71.0 / 1920.0;
179 const E5: f64 = -17253.0 / 339200.0;
180 const E6: f64 = 22.0 / 525.0;
181 const E7: f64 = -1.0 / 40.0;
182
183 pub fn solve(
185 system: &dyn OdeSystem,
186 y0: &[f64],
187 config: &OdeConfig,
188 ) -> SolverResult<OdeSolution> {
189 let n = system.dim();
190 validate_ode_inputs(n, y0, config)?;
191
192 let mut t = config.t_start;
193 let mut h = config.dt;
194 let mut y = y0.to_vec();
195
196 let mut k1 = vec![0.0; n];
197 let mut k2 = vec![0.0; n];
198 let mut k3 = vec![0.0; n];
199 let mut k4 = vec![0.0; n];
200 let mut k5 = vec![0.0; n];
201 let mut k6 = vec![0.0; n];
202 let mut k7 = vec![0.0; n];
203 let mut tmp = vec![0.0; n];
204 let mut y_new = vec![0.0; n];
205
206 let mut times = vec![t];
207 let mut states = vec![y.clone()];
208 let mut num_steps = 0_usize;
209 let mut num_rejected = 0_usize;
210 let mut num_rhs = 0_usize;
211
212 let safety = 0.9;
214 let min_factor = 0.2;
215 let max_factor = 5.0;
216
217 system.rhs(t, &y, &mut k1)?;
218 num_rhs += 1;
219
220 while t < config.t_end - 1e-14 * config.t_end.abs().max(1.0)
221 && num_steps + num_rejected < config.max_steps
222 {
223 h = h.min(config.t_end - t);
224
225 for i in 0..n {
227 tmp[i] = y[i] + h * Self::A21 * k1[i];
228 }
229 system.rhs(t + h / 5.0, &tmp, &mut k2)?;
230
231 for i in 0..n {
233 tmp[i] = y[i] + h * (Self::A31 * k1[i] + Self::A32 * k2[i]);
234 }
235 system.rhs(t + 3.0 / 10.0 * h, &tmp, &mut k3)?;
236
237 for i in 0..n {
239 tmp[i] = y[i] + h * (Self::A41 * k1[i] + Self::A42 * k2[i] + Self::A43 * k3[i]);
240 }
241 system.rhs(t + 4.0 / 5.0 * h, &tmp, &mut k4)?;
242
243 for i in 0..n {
245 tmp[i] = y[i]
246 + h * (Self::A51 * k1[i]
247 + Self::A52 * k2[i]
248 + Self::A53 * k3[i]
249 + Self::A54 * k4[i]);
250 }
251 system.rhs(t + 8.0 / 9.0 * h, &tmp, &mut k5)?;
252
253 for i in 0..n {
255 tmp[i] = y[i]
256 + h * (Self::A61 * k1[i]
257 + Self::A62 * k2[i]
258 + Self::A63 * k3[i]
259 + Self::A64 * k4[i]
260 + Self::A65 * k5[i]);
261 }
262 system.rhs(t + h, &tmp, &mut k6)?;
263
264 num_rhs += 5;
265
266 for i in 0..n {
268 y_new[i] = y[i]
269 + h * (Self::B1 * k1[i]
270 + Self::B3 * k3[i]
271 + Self::B4 * k4[i]
272 + Self::B5 * k5[i]
273 + Self::B6 * k6[i]);
274 }
275
276 system.rhs(t + h, &y_new, &mut k7)?;
279 num_rhs += 1;
280
281 let mut err_norm = 0.0;
282 for i in 0..n {
283 let err_i = h
284 * (Self::E1 * k1[i]
285 + Self::E3 * k3[i]
286 + Self::E4 * k4[i]
287 + Self::E5 * k5[i]
288 + Self::E6 * k6[i]
289 + Self::E7 * k7[i]);
290 let scale = config.atol + config.rtol * y_new[i].abs().max(y[i].abs());
291 err_norm += (err_i / scale).powi(2);
292 }
293 err_norm = (err_norm / n as f64).sqrt();
294
295 if err_norm <= 1.0 {
296 t += h;
298 y.copy_from_slice(&y_new);
299 num_steps += 1;
300
301 times.push(t);
302 states.push(y.clone());
303
304 k1.copy_from_slice(&k7);
306
307 let factor = if err_norm > 1e-15 {
309 (safety / err_norm.powf(0.2)).clamp(min_factor, max_factor)
310 } else {
311 max_factor
312 };
313 h *= factor;
314 } else {
315 num_rejected += 1;
317 let factor = (safety / err_norm.powf(0.2)).clamp(min_factor, 1.0);
318 h *= factor;
319 }
320 }
321
322 if num_steps + num_rejected >= config.max_steps && t < config.t_end - 1e-10 {
323 return Err(SolverError::ConvergenceFailure {
324 iterations: config.max_steps as u32,
325 residual: (config.t_end - t).abs(),
326 });
327 }
328
329 Ok(OdeSolution {
330 times,
331 states,
332 num_steps,
333 num_rejected,
334 num_rhs_evals: num_rhs,
335 })
336 }
337}