oxicuda_solver/dense/ode_pde/
pde.rs1use crate::error::{SolverError, SolverResult};
4
5use super::utils::{apply_bc_1d, solve_tridiagonal};
6
7#[derive(Debug, Clone)]
13pub struct Grid1D {
14 pub x_min: f64,
16 pub x_max: f64,
18 pub nx: usize,
20 pub dx: f64,
22}
23
24impl Grid1D {
25 pub fn new(x_min: f64, x_max: f64, nx: usize) -> Self {
27 let dx = if nx > 1 {
28 (x_max - x_min) / (nx - 1) as f64
29 } else {
30 0.0
31 };
32 Self {
33 x_min,
34 x_max,
35 nx,
36 dx,
37 }
38 }
39
40 pub fn point(&self, i: usize) -> f64 {
42 self.x_min + i as f64 * self.dx
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct Grid2D {
49 pub x_min: f64,
51 pub x_max: f64,
53 pub y_min: f64,
55 pub y_max: f64,
57 pub nx: usize,
59 pub ny: usize,
61 pub dx: f64,
63 pub dy: f64,
65}
66
67impl Grid2D {
68 pub fn new(x_min: f64, x_max: f64, nx: usize, y_min: f64, y_max: f64, ny: usize) -> Self {
70 let dx = if nx > 1 {
71 (x_max - x_min) / (nx - 1) as f64
72 } else {
73 0.0
74 };
75 let dy = if ny > 1 {
76 (y_max - y_min) / (ny - 1) as f64
77 } else {
78 0.0
79 };
80 Self {
81 x_min,
82 x_max,
83 y_min,
84 y_max,
85 nx,
86 ny,
87 dx,
88 dy,
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy)]
95pub enum BoundaryCondition {
96 Dirichlet(f64),
98 Neumann(f64),
100 Periodic,
102}
103
104#[derive(Debug, Clone)]
106pub struct PdeConfig {
107 pub grid: Grid1D,
109 pub dt: f64,
111 pub num_steps: usize,
113 pub bc_left: BoundaryCondition,
115 pub bc_right: BoundaryCondition,
117}
118
119pub struct HeatEquation1D {
125 pub alpha: f64,
127}
128
129impl HeatEquation1D {
130 pub fn stability_limit(&self, dx: f64) -> f64 {
134 dx * dx / (2.0 * self.alpha)
135 }
136
137 pub fn solve_explicit(&self, u0: &[f64], config: &PdeConfig) -> SolverResult<Vec<Vec<f64>>> {
139 let nx = config.grid.nx;
140 if u0.len() != nx {
141 return Err(SolverError::DimensionMismatch(format!(
142 "heat_explicit: u0 length ({}) != nx ({nx})",
143 u0.len()
144 )));
145 }
146
147 let dx = config.grid.dx;
148 let dt = config.dt;
149 let r = self.alpha * dt / (dx * dx);
150
151 let mut u = u0.to_vec();
152 let mut results = vec![u.clone()];
153
154 for _ in 0..config.num_steps {
155 let mut u_new = u.clone();
156
157 for i in 1..nx - 1 {
159 u_new[i] = u[i] + r * (u[i + 1] - 2.0 * u[i] + u[i - 1]);
160 }
161
162 apply_bc_1d(&mut u_new, &config.bc_left, &config.bc_right, nx);
164
165 u = u_new;
166 results.push(u.clone());
167 }
168
169 Ok(results)
170 }
171
172 pub fn solve_implicit(&self, u0: &[f64], config: &PdeConfig) -> SolverResult<Vec<Vec<f64>>> {
177 let nx = config.grid.nx;
178 if u0.len() != nx {
179 return Err(SolverError::DimensionMismatch(format!(
180 "heat_implicit: u0 length ({}) != nx ({nx})",
181 u0.len()
182 )));
183 }
184 if nx < 3 {
185 return Err(SolverError::DimensionMismatch(
186 "heat_implicit: need at least 3 grid points".to_string(),
187 ));
188 }
189
190 let dx = config.grid.dx;
191 let dt = config.dt;
192 let r = self.alpha * dt / (dx * dx);
193
194 let mut u = u0.to_vec();
195 let mut results = vec![u.clone()];
196
197 let m = nx - 2;
199
200 for _ in 0..config.num_steps {
201 let mut rhs = vec![0.0; m];
203 for (i, rhs_i) in rhs.iter_mut().enumerate() {
204 let idx = i + 1; *rhs_i = u[idx] + 0.5 * r * (u[idx + 1] - 2.0 * u[idx] + u[idx - 1]);
206 }
207
208 match config.bc_left {
210 BoundaryCondition::Dirichlet(val) => {
211 rhs[0] += 0.5 * r * val;
212 }
213 BoundaryCondition::Neumann(_) | BoundaryCondition::Periodic => {}
214 }
215 match config.bc_right {
216 BoundaryCondition::Dirichlet(val) => {
217 if m > 0 {
218 rhs[m - 1] += 0.5 * r * val;
219 }
220 }
221 BoundaryCondition::Neumann(_) | BoundaryCondition::Periodic => {}
222 }
223
224 let sub = vec![-0.5 * r; m.saturating_sub(1)];
227 let main = vec![1.0 + r; m];
228 let sup = vec![-0.5 * r; m.saturating_sub(1)];
229
230 let interior = solve_tridiagonal(&sub, &main, &sup, &rhs)?;
231
232 let mut u_new = vec![0.0; nx];
234 u_new[1..(m + 1)].copy_from_slice(&interior[..m]);
235
236 apply_bc_1d(&mut u_new, &config.bc_left, &config.bc_right, nx);
237
238 u = u_new;
239 results.push(u.clone());
240 }
241
242 Ok(results)
243 }
244}
245
246pub struct WaveEquation1D {
248 pub c: f64,
250}
251
252impl WaveEquation1D {
253 pub fn courant_number(&self, dx: f64, dt: f64) -> f64 {
255 self.c * dt / dx
256 }
257
258 pub fn solve(&self, u0: &[f64], v0: &[f64], config: &PdeConfig) -> SolverResult<Vec<Vec<f64>>> {
263 let nx = config.grid.nx;
264 if u0.len() != nx || v0.len() != nx {
265 return Err(SolverError::DimensionMismatch(format!(
266 "wave_solve: u0/v0 length mismatch with nx ({nx})"
267 )));
268 }
269 if nx < 3 {
270 return Err(SolverError::DimensionMismatch(
271 "wave_solve: need at least 3 grid points".to_string(),
272 ));
273 }
274
275 let dx = config.grid.dx;
276 let dt = config.dt;
277 let cfl = self.c * dt / dx;
278 let cfl2 = cfl * cfl;
279
280 let mut u_prev = u0.to_vec();
282 let mut u_cur = vec![0.0; nx];
283
284 for i in 1..nx - 1 {
286 let d2u = (u0[i + 1] - 2.0 * u0[i] + u0[i - 1]) / (dx * dx);
287 u_cur[i] = u0[i] + dt * v0[i] + 0.5 * dt * dt * self.c * self.c * d2u;
288 }
289 apply_bc_1d(&mut u_cur, &config.bc_left, &config.bc_right, nx);
290
291 let mut results = vec![u_prev.clone(), u_cur.clone()];
292
293 for _ in 1..config.num_steps {
295 let mut u_next = vec![0.0; nx];
296 for i in 1..nx - 1 {
297 u_next[i] = 2.0 * u_cur[i] - u_prev[i]
298 + cfl2 * (u_cur[i + 1] - 2.0 * u_cur[i] + u_cur[i - 1]);
299 }
300 apply_bc_1d(&mut u_next, &config.bc_left, &config.bc_right, nx);
301
302 u_prev = u_cur;
303 u_cur = u_next;
304 results.push(u_cur.clone());
305 }
306
307 Ok(results)
308 }
309}
310
311pub struct Poisson1D;
313
314impl Poisson1D {
315 pub fn solve(&self, f: &[f64], config: &PdeConfig) -> SolverResult<Vec<f64>> {
319 let nx = config.grid.nx;
320 if f.len() != nx {
321 return Err(SolverError::DimensionMismatch(format!(
322 "poisson: f length ({}) != nx ({nx})",
323 f.len()
324 )));
325 }
326 if nx < 3 {
327 return Err(SolverError::DimensionMismatch(
328 "poisson: need at least 3 grid points".to_string(),
329 ));
330 }
331
332 let dx = config.grid.dx;
333 let dx2 = dx * dx;
334 let m = nx - 2; let sub = vec![-1.0; m.saturating_sub(1)];
338 let main = vec![2.0; m];
339 let sup = vec![-1.0; m.saturating_sub(1)];
340
341 let mut rhs = vec![0.0; m];
342 for i in 0..m {
343 rhs[i] = dx2 * f[i + 1];
344 }
345
346 match config.bc_left {
348 BoundaryCondition::Dirichlet(val) => {
349 rhs[0] += val;
350 }
351 BoundaryCondition::Neumann(val) => {
352 rhs[0] += -2.0 * dx * val; }
355 BoundaryCondition::Periodic => {}
356 }
357 match config.bc_right {
358 BoundaryCondition::Dirichlet(val) => {
359 if m > 0 {
360 rhs[m - 1] += val;
361 }
362 }
363 BoundaryCondition::Neumann(val) => {
364 if m > 0 {
365 rhs[m - 1] += 2.0 * dx * val;
366 }
367 }
368 BoundaryCondition::Periodic => {}
369 }
370
371 let interior = solve_tridiagonal(&sub, &main, &sup, &rhs)?;
372
373 let mut u = vec![0.0; nx];
375 u[1..(m + 1)].copy_from_slice(&interior[..m]);
376 apply_bc_1d(&mut u, &config.bc_left, &config.bc_right, nx);
377
378 Ok(u)
379 }
380}
381
382pub struct AdvectionEquation1D {
384 pub a: f64,
386}
387
388impl AdvectionEquation1D {
389 pub fn solve_upwind(&self, u0: &[f64], config: &PdeConfig) -> SolverResult<Vec<Vec<f64>>> {
391 let nx = config.grid.nx;
392 if u0.len() != nx {
393 return Err(SolverError::DimensionMismatch(format!(
394 "advection_upwind: u0 length ({}) != nx ({nx})",
395 u0.len()
396 )));
397 }
398
399 let dx = config.grid.dx;
400 let dt = config.dt;
401 let cfl = self.a * dt / dx;
402
403 let mut u = u0.to_vec();
404 let mut results = vec![u.clone()];
405
406 for _ in 0..config.num_steps {
407 let mut u_new = u.clone();
408
409 for i in 1..nx - 1 {
410 if self.a >= 0.0 {
411 u_new[i] = u[i] - cfl * (u[i] - u[i - 1]);
413 } else {
414 u_new[i] = u[i] - cfl * (u[i + 1] - u[i]);
416 }
417 }
418
419 apply_bc_1d(&mut u_new, &config.bc_left, &config.bc_right, nx);
420 u = u_new;
421 results.push(u.clone());
422 }
423
424 Ok(results)
425 }
426
427 pub fn solve_lax_wendroff(
429 &self,
430 u0: &[f64],
431 config: &PdeConfig,
432 ) -> SolverResult<Vec<Vec<f64>>> {
433 let nx = config.grid.nx;
434 if u0.len() != nx {
435 return Err(SolverError::DimensionMismatch(format!(
436 "advection_lw: u0 length ({}) != nx ({nx})",
437 u0.len()
438 )));
439 }
440
441 let dx = config.grid.dx;
442 let dt = config.dt;
443 let cfl = self.a * dt / dx;
444 let cfl2 = cfl * cfl;
445
446 let mut u = u0.to_vec();
447 let mut results = vec![u.clone()];
448
449 for _ in 0..config.num_steps {
450 let mut u_new = u.clone();
451
452 for i in 1..nx - 1 {
453 u_new[i] = u[i] - 0.5 * cfl * (u[i + 1] - u[i - 1])
454 + 0.5 * cfl2 * (u[i + 1] - 2.0 * u[i] + u[i - 1]);
455 }
456
457 apply_bc_1d(&mut u_new, &config.bc_left, &config.bc_right, nx);
458 u = u_new;
459 results.push(u.clone());
460 }
461
462 Ok(results)
463 }
464}