Skip to main content

numra_pde/
grid.rs

1//! Spatial grids for PDE discretization.
2//!
3//! Author: Moussa Leblouba
4//! Date: 9 February 2026
5//! Modified: 2 May 2026
6
7use numra_core::Scalar;
8
9/// 1D uniform or non-uniform grid.
10#[derive(Clone, Debug)]
11pub struct Grid1D<S: Scalar> {
12    /// Grid points
13    points: Vec<S>,
14    /// Grid spacings (for non-uniform grids)
15    spacings: Vec<S>,
16}
17
18impl<S: Scalar> Grid1D<S> {
19    /// Create a uniform grid with n points from x_min to x_max.
20    pub fn uniform(x_min: S, x_max: S, n: usize) -> Self {
21        assert!(n >= 2, "Grid must have at least 2 points");
22        let dx = (x_max - x_min) / S::from_usize(n - 1);
23
24        let points: Vec<S> = (0..n).map(|i| x_min + S::from_usize(i) * dx).collect();
25
26        let spacings = vec![dx; n - 1];
27
28        Self { points, spacings }
29    }
30
31    /// Create a grid from explicit points.
32    pub fn from_points(points: Vec<S>) -> Self {
33        assert!(points.len() >= 2, "Grid must have at least 2 points");
34
35        let spacings: Vec<S> = points.windows(2).map(|w| w[1] - w[0]).collect();
36
37        Self { points, spacings }
38    }
39
40    /// Create a grid with refinement near boundaries.
41    ///
42    /// Uses a tanh-based clustering.
43    pub fn clustered(x_min: S, x_max: S, n: usize, cluster_factor: S) -> Self {
44        assert!(n >= 2, "Grid must have at least 2 points");
45
46        let mut points = Vec::with_capacity(n);
47        let length = x_max - x_min;
48
49        for i in 0..n {
50            let xi = S::from_f64(-1.0) + S::from_f64(2.0) * S::from_usize(i) / S::from_usize(n - 1);
51            // tanh clustering
52            let clustered = (cluster_factor * xi).tanh() / cluster_factor.tanh();
53            let x = x_min + length * (S::ONE + clustered) / S::from_f64(2.0);
54            points.push(x);
55        }
56
57        let spacings: Vec<S> = points.windows(2).map(|w| w[1] - w[0]).collect();
58
59        Self { points, spacings }
60    }
61
62    /// Number of grid points.
63    pub fn len(&self) -> usize {
64        self.points.len()
65    }
66
67    /// Check if grid is empty.
68    pub fn is_empty(&self) -> bool {
69        self.points.is_empty()
70    }
71
72    /// Number of interior points (excluding boundaries).
73    pub fn n_interior(&self) -> usize {
74        self.points.len().saturating_sub(2)
75    }
76
77    /// Get all grid points.
78    pub fn points(&self) -> &[S] {
79        &self.points
80    }
81
82    /// Get interior points (excluding boundaries).
83    pub fn interior_points(&self) -> &[S] {
84        &self.points[1..self.points.len() - 1]
85    }
86
87    /// Get grid spacing at index i (between points i and i+1).
88    pub fn dx(&self, i: usize) -> S {
89        self.spacings[i]
90    }
91
92    /// Get uniform spacing (panics if non-uniform).
93    pub fn dx_uniform(&self) -> S {
94        let first = self.spacings[0];
95        debug_assert!(
96            self.spacings
97                .iter()
98                .all(|&dx| (dx - first).abs() < S::from_f64(1e-10) * first),
99            "Grid is not uniform"
100        );
101        first
102    }
103
104    /// Check if grid is uniform.
105    pub fn is_uniform(&self) -> bool {
106        if self.spacings.is_empty() {
107            return true;
108        }
109        let first = self.spacings[0];
110        let tol = S::from_f64(1e-10) * first;
111        self.spacings.iter().all(|&dx| (dx - first).abs() < tol)
112    }
113
114    /// Get domain bounds.
115    pub fn bounds(&self) -> (S, S) {
116        (*self.points.first().unwrap(), *self.points.last().unwrap())
117    }
118
119    /// Domain length.
120    pub fn length(&self) -> S {
121        let (a, b) = self.bounds();
122        b - a
123    }
124
125    /// Find index of point closest to x.
126    pub fn find_index(&self, x: S) -> usize {
127        let (lo, hi) = self.bounds();
128        if x <= lo {
129            return 0;
130        }
131        if x >= hi {
132            return self.len() - 1;
133        }
134
135        // Binary search for non-uniform grids
136        let mut left = 0;
137        let mut right = self.len() - 1;
138        while left < right - 1 {
139            let mid = (left + right) / 2;
140            if x < self.points[mid] {
141                right = mid;
142            } else {
143                left = mid;
144            }
145        }
146
147        // Return closest point
148        if (x - self.points[left]).abs() < (x - self.points[right]).abs() {
149            left
150        } else {
151            right
152        }
153    }
154}
155
156/// 2D grid (tensor product of 1D grids).
157#[derive(Clone, Debug)]
158pub struct Grid2D<S: Scalar> {
159    /// Grid in x direction
160    pub x_grid: Grid1D<S>,
161    /// Grid in y direction
162    pub y_grid: Grid1D<S>,
163}
164
165impl<S: Scalar> Grid2D<S> {
166    /// Create a uniform 2D grid.
167    pub fn uniform(x_min: S, x_max: S, nx: usize, y_min: S, y_max: S, ny: usize) -> Self {
168        Self {
169            x_grid: Grid1D::uniform(x_min, x_max, nx),
170            y_grid: Grid1D::uniform(y_min, y_max, ny),
171        }
172    }
173
174    /// Total number of grid points.
175    pub fn len(&self) -> usize {
176        self.x_grid.len() * self.y_grid.len()
177    }
178
179    /// Check if grid is empty.
180    pub fn is_empty(&self) -> bool {
181        self.x_grid.is_empty() || self.y_grid.is_empty()
182    }
183
184    /// Number of interior points.
185    pub fn n_interior(&self) -> usize {
186        self.x_grid.n_interior() * self.y_grid.n_interior()
187    }
188
189    /// Convert (i, j) grid index to linear index.
190    pub fn linear_index(&self, i: usize, j: usize) -> usize {
191        j * self.x_grid.len() + i
192    }
193
194    /// Convert linear index to (i, j) grid index.
195    pub fn grid_index(&self, idx: usize) -> (usize, usize) {
196        let i = idx % self.x_grid.len();
197        let j = idx / self.x_grid.len();
198        (i, j)
199    }
200
201    /// Get (x, y) coordinates at grid index (i, j).
202    pub fn point(&self, i: usize, j: usize) -> (S, S) {
203        (self.x_grid.points()[i], self.y_grid.points()[j])
204    }
205
206    /// Number of points in x direction.
207    pub fn nx(&self) -> usize {
208        self.x_grid.len()
209    }
210
211    /// Number of points in y direction.
212    pub fn ny(&self) -> usize {
213        self.y_grid.len()
214    }
215
216    /// Number of interior points in x direction.
217    pub fn nx_interior(&self) -> usize {
218        self.x_grid.n_interior()
219    }
220
221    /// Number of interior points in y direction.
222    pub fn ny_interior(&self) -> usize {
223        self.y_grid.n_interior()
224    }
225}
226
227/// 3D grid (tensor product of 1D grids).
228#[derive(Clone, Debug)]
229pub struct Grid3D<S: Scalar> {
230    /// Grid in x direction
231    pub x_grid: Grid1D<S>,
232    /// Grid in y direction
233    pub y_grid: Grid1D<S>,
234    /// Grid in z direction
235    pub z_grid: Grid1D<S>,
236}
237
238impl<S: Scalar> Grid3D<S> {
239    /// Create a uniform 3D grid.
240    #[allow(clippy::too_many_arguments)]
241    pub fn uniform(
242        x_min: S,
243        x_max: S,
244        nx: usize,
245        y_min: S,
246        y_max: S,
247        ny: usize,
248        z_min: S,
249        z_max: S,
250        nz: usize,
251    ) -> Self {
252        Self {
253            x_grid: Grid1D::uniform(x_min, x_max, nx),
254            y_grid: Grid1D::uniform(y_min, y_max, ny),
255            z_grid: Grid1D::uniform(z_min, z_max, nz),
256        }
257    }
258
259    /// Total number of grid points.
260    pub fn len(&self) -> usize {
261        self.x_grid.len() * self.y_grid.len() * self.z_grid.len()
262    }
263
264    /// Check if grid is empty.
265    pub fn is_empty(&self) -> bool {
266        self.x_grid.is_empty() || self.y_grid.is_empty() || self.z_grid.is_empty()
267    }
268
269    /// Number of interior points.
270    pub fn n_interior(&self) -> usize {
271        self.x_grid.n_interior() * self.y_grid.n_interior() * self.z_grid.n_interior()
272    }
273
274    /// Convert (i, j, k) grid index to linear index (column-major: x varies fastest).
275    pub fn linear_index(&self, i: usize, j: usize, k: usize) -> usize {
276        k * (self.x_grid.len() * self.y_grid.len()) + j * self.x_grid.len() + i
277    }
278
279    /// Convert linear index to (i, j, k) grid index.
280    pub fn grid_index(&self, idx: usize) -> (usize, usize, usize) {
281        let nx = self.x_grid.len();
282        let ny = self.y_grid.len();
283        let i = idx % nx;
284        let j = (idx / nx) % ny;
285        let k = idx / (nx * ny);
286        (i, j, k)
287    }
288
289    /// Get (x, y, z) coordinates at grid index (i, j, k).
290    pub fn point(&self, i: usize, j: usize, k: usize) -> (S, S, S) {
291        (
292            self.x_grid.points()[i],
293            self.y_grid.points()[j],
294            self.z_grid.points()[k],
295        )
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_uniform_grid() {
305        let grid = Grid1D::uniform(0.0, 1.0, 11);
306        assert_eq!(grid.len(), 11);
307        assert_eq!(grid.n_interior(), 9);
308        assert!(grid.is_uniform());
309        assert!((grid.dx_uniform() - 0.1).abs() < 1e-10);
310        assert!((grid.points()[5] - 0.5).abs() < 1e-10);
311    }
312
313    #[test]
314    fn test_grid_bounds() {
315        let grid = Grid1D::uniform(-1.0, 2.0, 31);
316        let (lo, hi) = grid.bounds();
317        assert!((lo - (-1.0)).abs() < 1e-10);
318        assert!((hi - 2.0).abs() < 1e-10);
319        assert!((grid.length() - 3.0).abs() < 1e-10);
320    }
321
322    #[test]
323    fn test_from_points() {
324        let points = vec![0.0, 0.1, 0.3, 0.6, 1.0];
325        let grid = Grid1D::from_points(points);
326        assert_eq!(grid.len(), 5);
327        assert!(!grid.is_uniform());
328        assert!((grid.dx(0) - 0.1).abs() < 1e-10);
329        assert!((grid.dx(1) - 0.2).abs() < 1e-10);
330    }
331
332    #[test]
333    fn test_clustered_grid() {
334        let grid = Grid1D::clustered(0.0, 1.0, 21, 2.0);
335        assert_eq!(grid.len(), 21);
336
337        // Should have smaller spacing near boundaries
338        let dx_boundary = grid.dx(0);
339        let dx_middle = grid.dx(9);
340        assert!(dx_boundary < dx_middle);
341    }
342
343    #[test]
344    fn test_find_index() {
345        let grid = Grid1D::uniform(0.0, 1.0, 11);
346        assert_eq!(grid.find_index(0.0), 0);
347        assert_eq!(grid.find_index(0.5), 5);
348        assert_eq!(grid.find_index(1.0), 10);
349        assert_eq!(grid.find_index(0.49), 5);
350    }
351
352    #[test]
353    fn test_grid_2d() {
354        let grid = Grid2D::uniform(0.0, 1.0, 11, 0.0, 2.0, 21);
355        assert_eq!(grid.len(), 11 * 21);
356        assert_eq!(grid.n_interior(), 9 * 19);
357
358        let (x, y) = grid.point(5, 10);
359        assert!((x - 0.5).abs() < 1e-10);
360        assert!((y - 1.0).abs() < 1e-10);
361    }
362
363    #[test]
364    fn test_grid_3d() {
365        let grid = Grid3D::uniform(0.0, 1.0, 5, 0.0, 2.0, 11, 0.0, 3.0, 7);
366        assert_eq!(grid.len(), 5 * 11 * 7);
367        assert_eq!(grid.n_interior(), 3 * 9 * 5);
368
369        let (x, y, z) = grid.point(2, 5, 3);
370        assert!((x - 0.5).abs() < 1e-10);
371        assert!((y - 1.0).abs() < 1e-10);
372        // z = 3 * 3/6 = 3 * 0.5 = ... no, z_min=0, z_max=3, nz=7, dz=0.5
373        // z[3] = 0 + 3*0.5 = 1.5
374        assert!((z - 1.5).abs() < 1e-10);
375    }
376
377    #[test]
378    fn test_grid_3d_indexing() {
379        let grid = Grid3D::uniform(0.0, 1.0, 4, 0.0, 1.0, 5, 0.0, 1.0, 6);
380        // Round-trip test
381        for k in 0..6 {
382            for j in 0..5 {
383                for i in 0..4 {
384                    let idx = grid.linear_index(i, j, k);
385                    let (ri, rj, rk) = grid.grid_index(idx);
386                    assert_eq!((ri, rj, rk), (i, j, k));
387                }
388            }
389        }
390    }
391}