Skip to main content

oxicuda_solver/dense/ode_pde/
mod.rs

1//! ODE and PDE solver kernels.
2//!
3//! Provides CPU-side implementations of numerical methods for ordinary and
4//! partial differential equations. In production these algorithms would generate
5//! PTX kernels for GPU execution; the CPU reference implementations here ensure
6//! correctness and serve as the algorithmic specification.
7//!
8//! ## ODE solvers
9//!
10//! - **Euler** — first-order forward Euler
11//! - **RK4** — classical fourth-order Runge-Kutta
12//! - **RK45** — Dormand-Prince 4(5) with adaptive step-size control
13//! - **Implicit Euler** — backward Euler for stiff systems (Newton iteration)
14//! - **BDF2** — second-order backward differentiation formula
15//!
16//! ## PDE solvers
17//!
18//! - **Heat equation (1-D)** — explicit FTCS and implicit Crank-Nicolson
19//! - **Wave equation (1-D)** — leapfrog / Störmer-Verlet
20//! - **Poisson equation (1-D)** — direct tridiagonal solve
21//! - **Advection equation (1-D)** — first-order upwind and Lax-Wendroff
22
23mod explicit;
24mod implicit;
25mod pde;
26mod types;
27mod utils;
28
29// Re-export public API
30pub use explicit::{EulerSolver, Rk4Solver, Rk45Solver};
31pub use implicit::{Bdf2Solver, ImplicitEulerSolver};
32pub use pde::{
33    AdvectionEquation1D, BoundaryCondition, Grid1D, Grid2D, HeatEquation1D, PdeConfig, Poisson1D,
34    WaveEquation1D,
35};
36pub use types::{OdeConfig, OdeMethod, OdeSolution, OdeSystem, StepResult};
37pub use utils::{numerical_jacobian, solve_tridiagonal};
38
39#[cfg(test)]
40mod tests {
41    use super::utils::apply_bc_1d;
42    use super::*;
43    use std::f64::consts::PI;
44
45    // --- Test ODE systems ---
46
47    /// Exponential decay: y' = -y, solution y(t) = y0 * exp(-t).
48    struct ExponentialDecay;
49
50    impl OdeSystem for ExponentialDecay {
51        fn rhs(&self, _t: f64, y: &[f64], dydt: &mut [f64]) -> crate::error::SolverResult<()> {
52            dydt[0] = -y[0];
53            Ok(())
54        }
55        fn dim(&self) -> usize {
56            1
57        }
58    }
59
60    /// Harmonic oscillator: y'' + y = 0, as system y' = v, v' = -y.
61    /// Solution: y(t) = cos(t), v(t) = -sin(t) with y(0)=1, v(0)=0.
62    struct HarmonicOscillator;
63
64    impl OdeSystem for HarmonicOscillator {
65        fn rhs(&self, _t: f64, y: &[f64], dydt: &mut [f64]) -> crate::error::SolverResult<()> {
66            dydt[0] = y[1]; // dy/dt = v
67            dydt[1] = -y[0]; // dv/dt = -y
68            Ok(())
69        }
70        fn dim(&self) -> usize {
71            2
72        }
73    }
74
75    /// Stiff system: y' = -1000*(y - sin(t)) + cos(t).
76    struct StiffSystem;
77
78    impl OdeSystem for StiffSystem {
79        fn rhs(&self, t: f64, y: &[f64], dydt: &mut [f64]) -> crate::error::SolverResult<()> {
80            dydt[0] = -1000.0 * (y[0] - t.sin()) + t.cos();
81            Ok(())
82        }
83        fn dim(&self) -> usize {
84            1
85        }
86    }
87
88    /// Van der Pol oscillator: y'' - mu*(1-y²)*y' + y = 0.
89    struct VanDerPol {
90        mu: f64,
91    }
92
93    impl OdeSystem for VanDerPol {
94        fn rhs(&self, _t: f64, y: &[f64], dydt: &mut [f64]) -> crate::error::SolverResult<()> {
95            dydt[0] = y[1];
96            dydt[1] = self.mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
97            Ok(())
98        }
99        fn dim(&self) -> usize {
100            2
101        }
102    }
103
104    // --- ODE tests ---
105
106    #[test]
107    fn euler_exponential_decay() {
108        let sys = ExponentialDecay;
109        let config = OdeConfig {
110            t_start: 0.0,
111            t_end: 1.0,
112            dt: 0.001,
113            method: OdeMethod::Euler,
114            ..OdeConfig::default()
115        };
116        let sol = EulerSolver::solve(&sys, &[1.0], &config);
117        assert!(sol.is_ok());
118        let sol = sol.ok().filter(|s| !s.states.is_empty());
119        assert!(sol.is_some());
120        let sol = sol.as_ref().and_then(|s| s.states.last());
121        assert!(sol.is_some());
122        let y_final = sol.map(|s| s[0]).unwrap_or(0.0);
123        let expected = (-1.0_f64).exp();
124        // Euler is only first-order, so ~0.001 accuracy with dt=0.001
125        assert!(
126            (y_final - expected).abs() < 0.01,
127            "Euler: y(1) = {y_final}, expected {expected}"
128        );
129    }
130
131    #[test]
132    fn rk4_harmonic_oscillator() {
133        let sys = HarmonicOscillator;
134        let config = OdeConfig {
135            t_start: 0.0,
136            t_end: 2.0 * PI,
137            dt: 0.01,
138            method: OdeMethod::Rk4,
139            ..OdeConfig::default()
140        };
141        let sol = Rk4Solver::solve(&sys, &[1.0, 0.0], &config);
142        assert!(sol.is_ok());
143        let sol = sol.ok().filter(|s| !s.states.is_empty());
144        assert!(sol.is_some());
145        let last = sol.as_ref().and_then(|s| s.states.last());
146        assert!(last.is_some());
147        let y = last.map(|s| s[0]).unwrap_or(0.0);
148        let v = last.map(|s| s[1]).unwrap_or(0.0);
149        // After one full period, should return to (1, 0)
150        assert!((y - 1.0).abs() < 1e-6, "RK4: y(2pi) = {y}, expected 1.0");
151        assert!(v.abs() < 1e-6, "RK4: v(2pi) = {v}, expected 0.0");
152    }
153
154    #[test]
155    fn rk45_adaptive_step() {
156        let sys = HarmonicOscillator;
157        let config = OdeConfig {
158            t_start: 0.0,
159            t_end: 2.0 * PI,
160            dt: 0.1,
161            rtol: 1e-8,
162            atol: 1e-10,
163            max_steps: 10_000,
164            method: OdeMethod::Rk45,
165        };
166        let sol = Rk45Solver::solve(&sys, &[1.0, 0.0], &config);
167        assert!(sol.is_ok());
168        let sol_data = sol.ok().filter(|s| !s.states.is_empty());
169        assert!(sol_data.is_some());
170        let sd = sol_data.as_ref();
171        let last = sd.and_then(|s| s.states.last());
172        let y = last.map(|s| s[0]).unwrap_or(0.0);
173        let v = last.map(|s| s[1]).unwrap_or(0.0);
174        assert!((y - 1.0).abs() < 1e-6, "RK45: y(2pi) = {y}, expected 1.0");
175        assert!(v.abs() < 1e-6, "RK45: v(2pi) = {v}, expected 0.0");
176        // Adaptive should take fewer steps than fixed RK4 with same accuracy
177        let num = sd.map(|s| s.num_steps).unwrap_or(0);
178        assert!(
179            num < 1000,
180            "RK45 should take fewer than 1000 steps, took {num}"
181        );
182    }
183
184    #[test]
185    fn implicit_euler_stiff() {
186        let sys = StiffSystem;
187        let config = OdeConfig {
188            t_start: 0.0,
189            t_end: 0.1,
190            dt: 0.01,
191            method: OdeMethod::ImplicitEuler,
192            ..OdeConfig::default()
193        };
194        // Start near the exact solution sin(0) = 0
195        let sol = ImplicitEulerSolver::solve(&sys, &[0.0], &config);
196        assert!(sol.is_ok());
197        let sol_data = sol.ok().filter(|s| !s.states.is_empty());
198        assert!(sol_data.is_some());
199        let last = sol_data.as_ref().and_then(|s| s.states.last());
200        let y = last.map(|s| s[0]).unwrap_or(f64::NAN);
201        let expected = 0.1_f64.sin();
202        // Implicit Euler is first-order, but it should handle the stiffness
203        assert!(
204            (y - expected).abs() < 0.05,
205            "ImplicitEuler: y(0.1) = {y}, expected {expected}"
206        );
207    }
208
209    #[test]
210    fn bdf2_van_der_pol() {
211        let sys = VanDerPol { mu: 1.0 };
212        let config = OdeConfig {
213            t_start: 0.0,
214            t_end: 1.0,
215            dt: 0.01,
216            method: OdeMethod::Bdf2,
217            ..OdeConfig::default()
218        };
219        let sol = Bdf2Solver::solve(&sys, &[2.0, 0.0], &config);
220        assert!(sol.is_ok());
221        let sol_data = sol.ok().filter(|s| !s.states.is_empty());
222        assert!(sol_data.is_some());
223        // Just verify it completed and states are finite
224        let sd = sol_data.as_ref();
225        let all_finite = sd
226            .map(|s| s.states.iter().all(|st| st.iter().all(|v| v.is_finite())))
227            .unwrap_or(false);
228        assert!(all_finite, "BDF2: all states should be finite");
229    }
230
231    #[test]
232    fn ode_convergence_order() {
233        // Test that RK4 achieves ~4th order convergence
234        let sys = ExponentialDecay;
235        let t_end = 1.0;
236        let exact = (-1.0_f64).exp();
237
238        let mut errors = Vec::new();
239        for &dt in &[0.1, 0.05, 0.025] {
240            let config = OdeConfig {
241                t_start: 0.0,
242                t_end,
243                dt,
244                method: OdeMethod::Rk4,
245                ..OdeConfig::default()
246            };
247            let sol = Rk4Solver::solve(&sys, &[1.0], &config);
248            let sol_data = sol.ok().filter(|s| !s.states.is_empty());
249            let y = sol_data
250                .as_ref()
251                .and_then(|s| s.states.last())
252                .map(|s| s[0])
253                .unwrap_or(0.0);
254            errors.push((y - exact).abs());
255        }
256
257        // When dt halves, error should decrease by ~16x for 4th order
258        if errors[0] > 1e-15 && errors[1] > 1e-15 {
259            let ratio = errors[0] / errors[1];
260            assert!(
261                ratio > 10.0,
262                "RK4 convergence ratio should be ~16, got {ratio}"
263            );
264        }
265    }
266
267    // --- PDE tests ---
268
269    #[test]
270    fn heat_explicit_gaussian() {
271        // Gaussian diffusion: u(x,0) = exp(-x²), alpha = 0.01
272        let alpha = 0.01;
273        let nx = 101;
274        let grid = Grid1D::new(-5.0, 5.0, nx);
275        let heat = HeatEquation1D { alpha };
276
277        let dt = 0.4 * heat.stability_limit(grid.dx); // below stability limit
278        let config = PdeConfig {
279            grid: grid.clone(),
280            dt,
281            num_steps: 10,
282            bc_left: BoundaryCondition::Dirichlet(0.0),
283            bc_right: BoundaryCondition::Dirichlet(0.0),
284        };
285
286        let u0: Vec<f64> = (0..nx)
287            .map(|i| {
288                let x = grid.point(i);
289                (-x * x).exp()
290            })
291            .collect();
292
293        let result = heat.solve_explicit(&u0, &config);
294        assert!(result.is_ok());
295        let data = result.ok().filter(|d| !d.is_empty());
296        assert!(data.is_some());
297        let last = data.as_ref().and_then(|d| d.last());
298        assert!(last.is_some());
299
300        // Solution should still be peaked at center but broader
301        let u_final = last.as_ref().map(|u| u.as_slice()).unwrap_or(&[]);
302        if u_final.len() == nx {
303            let mid = nx / 2;
304            // Peak should be lower than initial (diffusion spreads it)
305            assert!(u_final[mid] < u0[mid], "Heat diffusion should reduce peak");
306            // Solution should remain positive
307            assert!(
308                u_final.iter().all(|&v| v >= -1e-10),
309                "Heat solution should remain non-negative"
310            );
311        }
312    }
313
314    #[test]
315    fn heat_crank_nicolson_stability() {
316        // Crank-Nicolson should be stable even with large dt
317        let alpha = 1.0;
318        let nx = 51;
319        let grid = Grid1D::new(0.0, 1.0, nx);
320        let heat = HeatEquation1D { alpha };
321
322        // Use dt >> stability limit for explicit
323        let dt = 10.0 * heat.stability_limit(grid.dx);
324        let config = PdeConfig {
325            grid: grid.clone(),
326            dt,
327            num_steps: 20,
328            bc_left: BoundaryCondition::Dirichlet(0.0),
329            bc_right: BoundaryCondition::Dirichlet(0.0),
330        };
331
332        let u0: Vec<f64> = (0..nx)
333            .map(|i| {
334                let x = grid.point(i);
335                (PI * x).sin()
336            })
337            .collect();
338
339        let result = heat.solve_implicit(&u0, &config);
340        assert!(result.is_ok());
341        let data = result.ok().filter(|d| !d.is_empty());
342        assert!(data.is_some());
343        let last = data.as_ref().and_then(|d| d.last());
344
345        // Solution should be bounded (no blow-up)
346        let u_final = last.as_ref().map(|u| u.as_slice()).unwrap_or(&[]);
347        let max_val = u_final.iter().copied().fold(0.0_f64, f64::max);
348        assert!(
349            max_val < 2.0,
350            "Crank-Nicolson should be stable; max = {max_val}"
351        );
352    }
353
354    #[test]
355    fn wave_energy_conservation() {
356        let c = 1.0;
357        let nx = 101;
358        let grid = Grid1D::new(0.0, 1.0, nx);
359        let dx = grid.dx;
360        let wave = WaveEquation1D { c };
361
362        let dt = 0.5 * dx / c; // CFL < 1
363        let config = PdeConfig {
364            grid: grid.clone(),
365            dt,
366            num_steps: 50,
367            bc_left: BoundaryCondition::Dirichlet(0.0),
368            bc_right: BoundaryCondition::Dirichlet(0.0),
369        };
370
371        // Initial displacement: sin(pi*x), zero velocity
372        let u0: Vec<f64> = (0..nx)
373            .map(|i| {
374                let x = grid.point(i);
375                (PI * x).sin()
376            })
377            .collect();
378        let v0 = vec![0.0; nx];
379
380        let result = wave.solve(&u0, &v0, &config);
381        assert!(result.is_ok());
382        let data = result.ok().filter(|d| d.len() > 2);
383        assert!(data.is_some());
384
385        // Compute energy at first and last time step
386        let states = data.as_deref().unwrap_or(&[]);
387        if states.len() >= 3 {
388            let energy = |u: &[f64], u_prev: &[f64]| -> f64 {
389                let mut ke = 0.0;
390                let mut pe = 0.0;
391                for i in 1..nx - 1 {
392                    let v = (u[i] - u_prev[i]) / dt;
393                    ke += 0.5 * v * v * dx;
394                    let ux = (u[i + 1] - u[i - 1]) / (2.0 * dx);
395                    pe += 0.5 * c * c * ux * ux * dx;
396                }
397                ke + pe
398            };
399
400            let e_initial = energy(&states[1], &states[0]);
401            let e_final = energy(&states[states.len() - 1], &states[states.len() - 2]);
402
403            // Energy should be approximately conserved
404            if e_initial > 1e-10 {
405                let rel_change = (e_final - e_initial).abs() / e_initial;
406                assert!(
407                    rel_change < 0.1,
408                    "Wave energy changed by {:.2}%",
409                    rel_change * 100.0
410                );
411            }
412        }
413    }
414
415    #[test]
416    fn poisson_analytical() {
417        // Solve -u'' = 2 on [0,1] with u(0) = 0, u(1) = 0
418        // Exact: u(x) = x*(1-x)
419        let nx = 101;
420        let grid = Grid1D::new(0.0, 1.0, nx);
421        let poisson = Poisson1D;
422
423        let f_rhs: Vec<f64> = vec![2.0; nx];
424        let config = PdeConfig {
425            grid: grid.clone(),
426            dt: 0.0, // not used
427            num_steps: 0,
428            bc_left: BoundaryCondition::Dirichlet(0.0),
429            bc_right: BoundaryCondition::Dirichlet(0.0),
430        };
431
432        let result = poisson.solve(&f_rhs, &config);
433        assert!(result.is_ok());
434        let u = result.ok().unwrap_or_default();
435
436        let mut max_err = 0.0_f64;
437        for (i, u_val) in u.iter().enumerate().take(nx) {
438            let x = grid.point(i);
439            let exact = x * (1.0 - x);
440            max_err = max_err.max((u_val - exact).abs());
441        }
442
443        // Second-order scheme: error ~ O(dx²) ~ O(1e-4)
444        assert!(
445            max_err < 1e-3,
446            "Poisson max error = {max_err}, expected < 1e-3"
447        );
448    }
449
450    #[test]
451    fn advection_upwind() {
452        let a = 1.0;
453        let nx = 101;
454        let grid = Grid1D::new(0.0, 2.0, nx);
455        let adv = AdvectionEquation1D { a };
456
457        let dt = 0.5 * grid.dx / a.abs(); // CFL = 0.5
458        let config = PdeConfig {
459            grid: grid.clone(),
460            dt,
461            num_steps: 10,
462            bc_left: BoundaryCondition::Dirichlet(0.0),
463            bc_right: BoundaryCondition::Dirichlet(0.0),
464        };
465
466        // Step function initial condition
467        let u0: Vec<f64> = (0..nx)
468            .map(|i| {
469                let x = grid.point(i);
470                if (0.5..=1.0).contains(&x) { 1.0 } else { 0.0 }
471            })
472            .collect();
473
474        let result = adv.solve_upwind(&u0, &config);
475        assert!(result.is_ok());
476        let data = result.ok().filter(|d| !d.is_empty());
477        assert!(data.is_some());
478
479        // After advection, the pulse should have moved to the right
480        let last = data
481            .as_ref()
482            .and_then(|d| d.last())
483            .cloned()
484            .unwrap_or_default();
485        // Solution should remain bounded
486        let max_val = last.iter().copied().fold(0.0_f64, f64::max);
487        assert!(
488            max_val <= 1.0 + 1e-10,
489            "Upwind should be monotone, max = {max_val}"
490        );
491    }
492
493    #[test]
494    fn lax_wendroff_accuracy() {
495        let a = 1.0;
496        let nx = 201;
497        let grid = Grid1D::new(0.0, 4.0, nx);
498        let adv = AdvectionEquation1D { a };
499
500        let dt = 0.5 * grid.dx / a.abs();
501        let num_steps = 20;
502        let config = PdeConfig {
503            grid: grid.clone(),
504            dt,
505            num_steps,
506            bc_left: BoundaryCondition::Dirichlet(0.0),
507            bc_right: BoundaryCondition::Dirichlet(0.0),
508        };
509
510        // Smooth Gaussian initial condition
511        let u0: Vec<f64> = (0..nx)
512            .map(|i| {
513                let x = grid.point(i);
514                (-(x - 1.0).powi(2) / 0.1).exp()
515            })
516            .collect();
517
518        let result_lw = adv.solve_lax_wendroff(&u0, &config);
519        let result_up = adv.solve_upwind(&u0, &config);
520        assert!(result_lw.is_ok());
521        assert!(result_up.is_ok());
522
523        let lw_last = result_lw
524            .ok()
525            .and_then(|d| d.last().cloned())
526            .unwrap_or_default();
527        let up_last = result_up
528            .ok()
529            .and_then(|d| d.last().cloned())
530            .unwrap_or_default();
531
532        // Lax-Wendroff should preserve the pulse shape better (less diffusion)
533        let lw_max = lw_last.iter().copied().fold(0.0_f64, f64::max);
534        let up_max = up_last.iter().copied().fold(0.0_f64, f64::max);
535
536        assert!(
537            lw_max >= up_max - 1e-10,
538            "Lax-Wendroff should have less diffusion: LW max = {lw_max}, upwind max = {up_max}"
539        );
540    }
541
542    #[test]
543    fn boundary_condition_enforcement() {
544        let nx = 11;
545        let mut u = vec![1.0_f64; nx];
546
547        apply_bc_1d(
548            &mut u,
549            &BoundaryCondition::Dirichlet(0.0),
550            &BoundaryCondition::Dirichlet(2.0),
551            nx,
552        );
553        assert!((u[0] - 0.0_f64).abs() < 1e-15);
554        assert!((u[nx - 1] - 2.0_f64).abs() < 1e-15);
555
556        // Periodic
557        let mut u2 = vec![0.0_f64; nx];
558        u2[1] = 3.0;
559        u2[nx - 2] = 5.0;
560        apply_bc_1d(
561            &mut u2,
562            &BoundaryCondition::Periodic,
563            &BoundaryCondition::Periodic,
564            nx,
565        );
566        assert!((u2[0] - 5.0_f64).abs() < 1e-15);
567        assert!((u2[nx - 1] - 3.0_f64).abs() < 1e-15);
568    }
569
570    #[test]
571    fn stability_limit_calculation() {
572        let heat = HeatEquation1D { alpha: 0.5 };
573        let dx = 0.1;
574        let limit = heat.stability_limit(dx);
575        // dt <= dx² / (2*alpha) = 0.01 / 1.0 = 0.01
576        assert!((limit - 0.01).abs() < 1e-15, "stability limit = {limit}");
577    }
578
579    #[test]
580    fn grid_construction() {
581        let g = Grid1D::new(0.0, 1.0, 11);
582        assert_eq!(g.nx, 11);
583        assert!((g.dx - 0.1).abs() < 1e-15);
584        assert!((g.point(0) - 0.0).abs() < 1e-15);
585        assert!((g.point(5) - 0.5).abs() < 1e-15);
586        assert!((g.point(10) - 1.0).abs() < 1e-15);
587
588        let g2 = Grid2D::new(0.0, 1.0, 11, -1.0, 1.0, 21);
589        assert_eq!(g2.nx, 11);
590        assert_eq!(g2.ny, 21);
591        assert!((g2.dx - 0.1).abs() < 1e-15);
592        assert!((g2.dy - 0.1).abs() < 1e-15);
593    }
594
595    #[test]
596    fn numerical_jacobian_accuracy() {
597        // For y' = -y, the Jacobian is J = [-1]
598        let sys = ExponentialDecay;
599        let jac = numerical_jacobian(&sys, 0.0, &[1.0], 1e-8);
600        assert!(jac.is_ok());
601        let j = jac.ok().unwrap_or_default();
602        if !j.is_empty() && !j[0].is_empty() {
603            assert!(
604                (j[0][0] - (-1.0)).abs() < 1e-5,
605                "Jacobian J[0][0] = {}, expected -1.0",
606                j[0][0]
607            );
608        }
609
610        // Harmonic oscillator: J = [[0, 1], [-1, 0]]
611        let sys2 = HarmonicOscillator;
612        let jac2 = numerical_jacobian(&sys2, 0.0, &[1.0, 0.0], 1e-8);
613        assert!(jac2.is_ok());
614        let j2 = jac2.ok().unwrap_or_default();
615        if j2.len() >= 2 && j2[0].len() >= 2 {
616            assert!((j2[0][0]).abs() < 1e-5, "J[0][0] should be ~0");
617            assert!((j2[0][1] - 1.0).abs() < 1e-5, "J[0][1] should be ~1");
618            assert!((j2[1][0] - (-1.0)).abs() < 1e-5, "J[1][0] should be ~-1");
619            assert!((j2[1][1]).abs() < 1e-5, "J[1][1] should be ~0");
620        }
621    }
622}