Skip to main content

proof_engine/solver/
pde.rs

1//! PDE solvers — finite difference methods for heat, wave, and Laplace equations.
2
3use super::boundary::{BoundaryCondition, BoundaryType};
4
5/// A 2D scalar field on a regular grid.
6#[derive(Debug, Clone)]
7pub struct ScalarField2D {
8    pub data: Vec<f64>,
9    pub width: usize,
10    pub height: usize,
11    pub dx: f64,
12    pub dy: f64,
13}
14
15impl ScalarField2D {
16    pub fn new(width: usize, height: usize, dx: f64, dy: f64) -> Self {
17        Self { data: vec![0.0; width * height], width, height, dx, dy }
18    }
19
20    pub fn get(&self, x: usize, y: usize) -> f64 { self.data[y * self.width + x] }
21    pub fn set(&mut self, x: usize, y: usize, val: f64) { self.data[y * self.width + x] = val; }
22
23    pub fn fill(&mut self, val: f64) { self.data.fill(val); }
24    pub fn max_value(&self) -> f64 { self.data.iter().copied().fold(f64::MIN, f64::max) }
25    pub fn min_value(&self) -> f64 { self.data.iter().copied().fold(f64::MAX, f64::min) }
26}
27
28/// PDE method selector.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum PdeMethod {
31    ExplicitEuler,
32    ImplicitEuler,
33    CrankNicolson,
34}
35
36/// PDE type.
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum PdeType {
39    Heat,
40    Wave,
41    Laplace,
42}
43
44/// PDE solver for 2D scalar fields.
45pub struct PdeSolver {
46    pub pde_type: PdeType,
47    pub method: PdeMethod,
48    pub dt: f64,
49    pub diffusivity: f64,  // α for heat, c² for wave
50    prev: Option<ScalarField2D>,
51}
52
53impl PdeSolver {
54    pub fn heat(dt: f64, alpha: f64) -> Self {
55        Self { pde_type: PdeType::Heat, method: PdeMethod::ExplicitEuler, dt, diffusivity: alpha, prev: None }
56    }
57
58    pub fn wave(dt: f64, c: f64) -> Self {
59        Self { pde_type: PdeType::Wave, method: PdeMethod::ExplicitEuler, dt, diffusivity: c * c, prev: None }
60    }
61
62    pub fn laplace() -> Self {
63        Self { pde_type: PdeType::Laplace, method: PdeMethod::ExplicitEuler, dt: 1.0, diffusivity: 1.0, prev: None }
64    }
65
66    /// Step the PDE forward one time step.
67    pub fn step(&mut self, field: &mut ScalarField2D, bc: &BoundaryCondition) {
68        match self.pde_type {
69            PdeType::Heat => self.heat_step(field, bc),
70            PdeType::Wave => self.wave_step(field, bc),
71            PdeType::Laplace => self.laplace_step(field, bc),
72        }
73    }
74
75    /// Run n iterations.
76    pub fn solve(&mut self, field: &mut ScalarField2D, bc: &BoundaryCondition, steps: u32) {
77        for _ in 0..steps {
78            self.step(field, bc);
79        }
80    }
81
82    fn heat_step(&self, field: &mut ScalarField2D, bc: &BoundaryCondition) {
83        let w = field.width;
84        let h = field.height;
85        let dx2 = field.dx * field.dx;
86        let dy2 = field.dy * field.dy;
87        let r_x = self.diffusivity * self.dt / dx2;
88        let r_y = self.diffusivity * self.dt / dy2;
89
90        let old = field.data.clone();
91
92        for y in 1..h - 1 {
93            for x in 1..w - 1 {
94                let idx = y * w + x;
95                let laplacian = (old[idx - 1] - 2.0 * old[idx] + old[idx + 1]) / dx2
96                              + (old[idx - w] - 2.0 * old[idx] + old[idx + w]) / dy2;
97                field.data[idx] = old[idx] + self.diffusivity * self.dt * laplacian;
98            }
99        }
100
101        bc.apply(field);
102    }
103
104    fn wave_step(&mut self, field: &mut ScalarField2D, bc: &BoundaryCondition) {
105        let w = field.width;
106        let h = field.height;
107        let dx2 = field.dx * field.dx;
108        let dy2 = field.dy * field.dy;
109        let c2 = self.diffusivity;
110        let dt2 = self.dt * self.dt;
111
112        let current = field.data.clone();
113        let prev_data = match &self.prev {
114            Some(p) => p.data.clone(),
115            None => current.clone(),
116        };
117
118        for y in 1..h - 1 {
119            for x in 1..w - 1 {
120                let idx = y * w + x;
121                let laplacian = (current[idx - 1] - 2.0 * current[idx] + current[idx + 1]) / dx2
122                              + (current[idx - w] - 2.0 * current[idx] + current[idx + w]) / dy2;
123                field.data[idx] = 2.0 * current[idx] - prev_data[idx] + c2 * dt2 * laplacian;
124            }
125        }
126
127        self.prev = Some(ScalarField2D { data: current, width: w, height: h, dx: field.dx, dy: field.dy });
128        bc.apply(field);
129    }
130
131    fn laplace_step(&self, field: &mut ScalarField2D, bc: &BoundaryCondition) {
132        // Jacobi iteration for Laplace equation (∇²u = 0)
133        let w = field.width;
134        let h = field.height;
135        let old = field.data.clone();
136
137        for y in 1..h - 1 {
138            for x in 1..w - 1 {
139                let idx = y * w + x;
140                field.data[idx] = 0.25 * (old[idx - 1] + old[idx + 1] + old[idx - w] + old[idx + w]);
141            }
142        }
143
144        bc.apply(field);
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn heat_diffuses() {
154        let mut field = ScalarField2D::new(20, 20, 0.1, 0.1);
155        // Hot spot in center
156        field.set(10, 10, 100.0);
157        let bc = BoundaryCondition::new(BoundaryType::Dirichlet(0.0));
158        let mut solver = PdeSolver::heat(0.001, 1.0);
159        solver.solve(&mut field, &bc, 100);
160        // Center should have decreased, neighbors increased
161        assert!(field.get(10, 10) < 100.0);
162        assert!(field.get(10, 11) > 0.0);
163    }
164
165    #[test]
166    fn laplace_converges() {
167        let mut field = ScalarField2D::new(10, 10, 1.0, 1.0);
168        // Set top boundary to 100
169        for x in 0..10 { field.set(x, 0, 100.0); }
170        let bc = BoundaryCondition::new(BoundaryType::Dirichlet(0.0));
171        let mut solver = PdeSolver::laplace();
172        solver.solve(&mut field, &bc, 500);
173        // Interior should be between 0 and 100
174        let mid = field.get(5, 5);
175        assert!(mid > 0.0 && mid < 100.0, "mid={mid}");
176    }
177}