Skip to main content

oxicuda_solver/dense/ode_pde/
pde.rs

1//! PDE types and solvers: Heat, Wave, Poisson, Advection (1-D).
2
3use crate::error::{SolverError, SolverResult};
4
5use super::utils::{apply_bc_1d, solve_tridiagonal};
6
7// =========================================================================
8// PDE types
9// =========================================================================
10
11/// One-dimensional uniform grid.
12#[derive(Debug, Clone)]
13pub struct Grid1D {
14    /// Left boundary.
15    pub x_min: f64,
16    /// Right boundary.
17    pub x_max: f64,
18    /// Number of grid points.
19    pub nx: usize,
20    /// Grid spacing (computed).
21    pub dx: f64,
22}
23
24impl Grid1D {
25    /// Create a uniform 1-D grid with `nx` points from `x_min` to `x_max`.
26    pub fn new(x_min: f64, x_max: f64, nx: usize) -> Self {
27        let dx = if nx > 1 {
28            (x_max - x_min) / (nx - 1) as f64
29        } else {
30            0.0
31        };
32        Self {
33            x_min,
34            x_max,
35            nx,
36            dx,
37        }
38    }
39
40    /// Return the coordinate of grid point `i`.
41    pub fn point(&self, i: usize) -> f64 {
42        self.x_min + i as f64 * self.dx
43    }
44}
45
46/// Two-dimensional uniform grid.
47#[derive(Debug, Clone)]
48pub struct Grid2D {
49    /// X-direction range.
50    pub x_min: f64,
51    /// X-direction range.
52    pub x_max: f64,
53    /// Y-direction range.
54    pub y_min: f64,
55    /// Y-direction range.
56    pub y_max: f64,
57    /// Number of grid points in x.
58    pub nx: usize,
59    /// Number of grid points in y.
60    pub ny: usize,
61    /// Spacing in x.
62    pub dx: f64,
63    /// Spacing in y.
64    pub dy: f64,
65}
66
67impl Grid2D {
68    /// Create a uniform 2-D grid.
69    pub fn new(x_min: f64, x_max: f64, nx: usize, y_min: f64, y_max: f64, ny: usize) -> Self {
70        let dx = if nx > 1 {
71            (x_max - x_min) / (nx - 1) as f64
72        } else {
73            0.0
74        };
75        let dy = if ny > 1 {
76            (y_max - y_min) / (ny - 1) as f64
77        } else {
78            0.0
79        };
80        Self {
81            x_min,
82            x_max,
83            y_min,
84            y_max,
85            nx,
86            ny,
87            dx,
88            dy,
89        }
90    }
91}
92
93/// Boundary condition type.
94#[derive(Debug, Clone, Copy)]
95pub enum BoundaryCondition {
96    /// Fixed value at the boundary.
97    Dirichlet(f64),
98    /// Fixed derivative at the boundary.
99    Neumann(f64),
100    /// Periodic boundary (left = right).
101    Periodic,
102}
103
104/// Configuration for a 1-D PDE solve.
105#[derive(Debug, Clone)]
106pub struct PdeConfig {
107    /// Spatial grid.
108    pub grid: Grid1D,
109    /// Time step.
110    pub dt: f64,
111    /// Number of time steps.
112    pub num_steps: usize,
113    /// Left boundary condition.
114    pub bc_left: BoundaryCondition,
115    /// Right boundary condition.
116    pub bc_right: BoundaryCondition,
117}
118
119// =========================================================================
120// PDE solvers
121// =========================================================================
122
123/// 1-D heat equation solver: du/dt = alpha * d²u/dx².
124pub struct HeatEquation1D {
125    /// Thermal diffusivity.
126    pub alpha: f64,
127}
128
129impl HeatEquation1D {
130    /// Maximum stable time step for the explicit (FTCS) scheme.
131    ///
132    /// For stability we need dt <= dx² / (2 * alpha).
133    pub fn stability_limit(&self, dx: f64) -> f64 {
134        dx * dx / (2.0 * self.alpha)
135    }
136
137    /// Solve using Forward-Time Central-Space (FTCS) explicit scheme.
138    pub fn solve_explicit(&self, u0: &[f64], config: &PdeConfig) -> SolverResult<Vec<Vec<f64>>> {
139        let nx = config.grid.nx;
140        if u0.len() != nx {
141            return Err(SolverError::DimensionMismatch(format!(
142                "heat_explicit: u0 length ({}) != nx ({nx})",
143                u0.len()
144            )));
145        }
146
147        let dx = config.grid.dx;
148        let dt = config.dt;
149        let r = self.alpha * dt / (dx * dx);
150
151        let mut u = u0.to_vec();
152        let mut results = vec![u.clone()];
153
154        for _ in 0..config.num_steps {
155            let mut u_new = u.clone();
156
157            // Interior points
158            for i in 1..nx - 1 {
159                u_new[i] = u[i] + r * (u[i + 1] - 2.0 * u[i] + u[i - 1]);
160            }
161
162            // Boundary conditions
163            apply_bc_1d(&mut u_new, &config.bc_left, &config.bc_right, nx);
164
165            u = u_new;
166            results.push(u.clone());
167        }
168
169        Ok(results)
170    }
171
172    /// Solve using Crank-Nicolson (implicit) scheme.
173    ///
174    /// Unconditionally stable, second-order in both time and space.
175    /// Reduces to a tridiagonal system at each time step.
176    pub fn solve_implicit(&self, u0: &[f64], config: &PdeConfig) -> SolverResult<Vec<Vec<f64>>> {
177        let nx = config.grid.nx;
178        if u0.len() != nx {
179            return Err(SolverError::DimensionMismatch(format!(
180                "heat_implicit: u0 length ({}) != nx ({nx})",
181                u0.len()
182            )));
183        }
184        if nx < 3 {
185            return Err(SolverError::DimensionMismatch(
186                "heat_implicit: need at least 3 grid points".to_string(),
187            ));
188        }
189
190        let dx = config.grid.dx;
191        let dt = config.dt;
192        let r = self.alpha * dt / (dx * dx);
193
194        let mut u = u0.to_vec();
195        let mut results = vec![u.clone()];
196
197        // Interior system size
198        let m = nx - 2;
199
200        for _ in 0..config.num_steps {
201            // Build RHS from explicit half: (I + r/2 * A) * u_interior
202            let mut rhs = vec![0.0; m];
203            for (i, rhs_i) in rhs.iter_mut().enumerate() {
204                let idx = i + 1; // grid index
205                *rhs_i = u[idx] + 0.5 * r * (u[idx + 1] - 2.0 * u[idx] + u[idx - 1]);
206            }
207
208            // Add boundary contributions
209            match config.bc_left {
210                BoundaryCondition::Dirichlet(val) => {
211                    rhs[0] += 0.5 * r * val;
212                }
213                BoundaryCondition::Neumann(_) | BoundaryCondition::Periodic => {}
214            }
215            match config.bc_right {
216                BoundaryCondition::Dirichlet(val) => {
217                    if m > 0 {
218                        rhs[m - 1] += 0.5 * r * val;
219                    }
220                }
221                BoundaryCondition::Neumann(_) | BoundaryCondition::Periodic => {}
222            }
223
224            // Tridiagonal system: (I - r/2 * A) * u_new = rhs
225            // sub-diag: -r/2, main: 1+r, super-diag: -r/2
226            let sub = vec![-0.5 * r; m.saturating_sub(1)];
227            let main = vec![1.0 + r; m];
228            let sup = vec![-0.5 * r; m.saturating_sub(1)];
229
230            let interior = solve_tridiagonal(&sub, &main, &sup, &rhs)?;
231
232            // Assemble full solution
233            let mut u_new = vec![0.0; nx];
234            u_new[1..(m + 1)].copy_from_slice(&interior[..m]);
235
236            apply_bc_1d(&mut u_new, &config.bc_left, &config.bc_right, nx);
237
238            u = u_new;
239            results.push(u.clone());
240        }
241
242        Ok(results)
243    }
244}
245
246/// 1-D wave equation solver: d²u/dt² = c² * d²u/dx².
247pub struct WaveEquation1D {
248    /// Wave speed.
249    pub c: f64,
250}
251
252impl WaveEquation1D {
253    /// Compute the Courant number: c * dt / dx.
254    pub fn courant_number(&self, dx: f64, dt: f64) -> f64 {
255        self.c * dt / dx
256    }
257
258    /// Solve using the leapfrog / Störmer-Verlet scheme.
259    ///
260    /// `u0` is the initial displacement, `v0` the initial velocity.
261    /// Stability requires Courant number <= 1.
262    pub fn solve(&self, u0: &[f64], v0: &[f64], config: &PdeConfig) -> SolverResult<Vec<Vec<f64>>> {
263        let nx = config.grid.nx;
264        if u0.len() != nx || v0.len() != nx {
265            return Err(SolverError::DimensionMismatch(format!(
266                "wave_solve: u0/v0 length mismatch with nx ({nx})"
267            )));
268        }
269        if nx < 3 {
270            return Err(SolverError::DimensionMismatch(
271                "wave_solve: need at least 3 grid points".to_string(),
272            ));
273        }
274
275        let dx = config.grid.dx;
276        let dt = config.dt;
277        let cfl = self.c * dt / dx;
278        let cfl2 = cfl * cfl;
279
280        // u^{n-1} and u^{n}
281        let mut u_prev = u0.to_vec();
282        let mut u_cur = vec![0.0; nx];
283
284        // First step uses Taylor expansion: u^1 = u^0 + dt*v0 + 0.5*dt²*c²*u''
285        for i in 1..nx - 1 {
286            let d2u = (u0[i + 1] - 2.0 * u0[i] + u0[i - 1]) / (dx * dx);
287            u_cur[i] = u0[i] + dt * v0[i] + 0.5 * dt * dt * self.c * self.c * d2u;
288        }
289        apply_bc_1d(&mut u_cur, &config.bc_left, &config.bc_right, nx);
290
291        let mut results = vec![u_prev.clone(), u_cur.clone()];
292
293        // Leapfrog: u^{n+1} = 2*u^n - u^{n-1} + cfl² * (u_{i+1} - 2*u_i + u_{i-1})
294        for _ in 1..config.num_steps {
295            let mut u_next = vec![0.0; nx];
296            for i in 1..nx - 1 {
297                u_next[i] = 2.0 * u_cur[i] - u_prev[i]
298                    + cfl2 * (u_cur[i + 1] - 2.0 * u_cur[i] + u_cur[i - 1]);
299            }
300            apply_bc_1d(&mut u_next, &config.bc_left, &config.bc_right, nx);
301
302            u_prev = u_cur;
303            u_cur = u_next;
304            results.push(u_cur.clone());
305        }
306
307        Ok(results)
308    }
309}
310
311/// 1-D Poisson equation solver: -u'' = f, with boundary conditions.
312pub struct Poisson1D;
313
314impl Poisson1D {
315    /// Solve -u'' = f on the grid with specified boundary conditions.
316    ///
317    /// Uses a tridiagonal direct solve (Thomas algorithm).
318    pub fn solve(&self, f: &[f64], config: &PdeConfig) -> SolverResult<Vec<f64>> {
319        let nx = config.grid.nx;
320        if f.len() != nx {
321            return Err(SolverError::DimensionMismatch(format!(
322                "poisson: f length ({}) != nx ({nx})",
323                f.len()
324            )));
325        }
326        if nx < 3 {
327            return Err(SolverError::DimensionMismatch(
328                "poisson: need at least 3 grid points".to_string(),
329            ));
330        }
331
332        let dx = config.grid.dx;
333        let dx2 = dx * dx;
334        let m = nx - 2; // interior points
335
336        // Build tridiagonal system: -u_{i-1} + 2*u_i - u_{i+1} = dx²*f_i
337        let sub = vec![-1.0; m.saturating_sub(1)];
338        let main = vec![2.0; m];
339        let sup = vec![-1.0; m.saturating_sub(1)];
340
341        let mut rhs = vec![0.0; m];
342        for i in 0..m {
343            rhs[i] = dx2 * f[i + 1];
344        }
345
346        // Add boundary contributions
347        match config.bc_left {
348            BoundaryCondition::Dirichlet(val) => {
349                rhs[0] += val;
350            }
351            BoundaryCondition::Neumann(val) => {
352                // Ghost point approach: u_{-1} = u_1 - 2*dx*val
353                rhs[0] += -2.0 * dx * val; // approximate
354            }
355            BoundaryCondition::Periodic => {}
356        }
357        match config.bc_right {
358            BoundaryCondition::Dirichlet(val) => {
359                if m > 0 {
360                    rhs[m - 1] += val;
361                }
362            }
363            BoundaryCondition::Neumann(val) => {
364                if m > 0 {
365                    rhs[m - 1] += 2.0 * dx * val;
366                }
367            }
368            BoundaryCondition::Periodic => {}
369        }
370
371        let interior = solve_tridiagonal(&sub, &main, &sup, &rhs)?;
372
373        // Assemble full solution
374        let mut u = vec![0.0; nx];
375        u[1..(m + 1)].copy_from_slice(&interior[..m]);
376        apply_bc_1d(&mut u, &config.bc_left, &config.bc_right, nx);
377
378        Ok(u)
379    }
380}
381
382/// 1-D advection equation solver: du/dt + a * du/dx = 0.
383pub struct AdvectionEquation1D {
384    /// Advection velocity.
385    pub a: f64,
386}
387
388impl AdvectionEquation1D {
389    /// Solve using the first-order upwind scheme.
390    pub fn solve_upwind(&self, u0: &[f64], config: &PdeConfig) -> SolverResult<Vec<Vec<f64>>> {
391        let nx = config.grid.nx;
392        if u0.len() != nx {
393            return Err(SolverError::DimensionMismatch(format!(
394                "advection_upwind: u0 length ({}) != nx ({nx})",
395                u0.len()
396            )));
397        }
398
399        let dx = config.grid.dx;
400        let dt = config.dt;
401        let cfl = self.a * dt / dx;
402
403        let mut u = u0.to_vec();
404        let mut results = vec![u.clone()];
405
406        for _ in 0..config.num_steps {
407            let mut u_new = u.clone();
408
409            for i in 1..nx - 1 {
410                if self.a >= 0.0 {
411                    // Upwind from left
412                    u_new[i] = u[i] - cfl * (u[i] - u[i - 1]);
413                } else {
414                    // Upwind from right
415                    u_new[i] = u[i] - cfl * (u[i + 1] - u[i]);
416                }
417            }
418
419            apply_bc_1d(&mut u_new, &config.bc_left, &config.bc_right, nx);
420            u = u_new;
421            results.push(u.clone());
422        }
423
424        Ok(results)
425    }
426
427    /// Solve using the Lax-Wendroff scheme (second-order).
428    pub fn solve_lax_wendroff(
429        &self,
430        u0: &[f64],
431        config: &PdeConfig,
432    ) -> SolverResult<Vec<Vec<f64>>> {
433        let nx = config.grid.nx;
434        if u0.len() != nx {
435            return Err(SolverError::DimensionMismatch(format!(
436                "advection_lw: u0 length ({}) != nx ({nx})",
437                u0.len()
438            )));
439        }
440
441        let dx = config.grid.dx;
442        let dt = config.dt;
443        let cfl = self.a * dt / dx;
444        let cfl2 = cfl * cfl;
445
446        let mut u = u0.to_vec();
447        let mut results = vec![u.clone()];
448
449        for _ in 0..config.num_steps {
450            let mut u_new = u.clone();
451
452            for i in 1..nx - 1 {
453                u_new[i] = u[i] - 0.5 * cfl * (u[i + 1] - u[i - 1])
454                    + 0.5 * cfl2 * (u[i + 1] - 2.0 * u[i] + u[i - 1]);
455            }
456
457            apply_bc_1d(&mut u_new, &config.bc_left, &config.bc_right, nx);
458            u = u_new;
459            results.push(u.clone());
460        }
461
462        Ok(results)
463    }
464}