Skip to main content

neco_gridfield/
grid.rs

1use crate::Array2;
2use core::fmt;
3
4#[derive(Debug, Clone)]
5pub enum GridError {
6    InvalidSpacing {
7        dx: f64,
8    },
9    InvalidExtent {
10        axis: &'static str,
11        value: f64,
12    },
13    ResolutionOverflow {
14        axis: &'static str,
15        value: f64,
16        dx: f64,
17    },
18    InvalidMaskShape {
19        expected: (usize, usize),
20        actual: (usize, usize),
21    },
22}
23
24impl PartialEq for GridError {
25    fn eq(&self, other: &Self) -> bool {
26        match (self, other) {
27            (Self::InvalidSpacing { dx: a }, Self::InvalidSpacing { dx: b }) => {
28                (a.is_nan() && b.is_nan()) || a == b
29            }
30            (
31                Self::InvalidExtent {
32                    axis: axis_a,
33                    value: value_a,
34                },
35                Self::InvalidExtent {
36                    axis: axis_b,
37                    value: value_b,
38                },
39            ) => axis_a == axis_b && ((value_a.is_nan() && value_b.is_nan()) || value_a == value_b),
40            (
41                Self::ResolutionOverflow {
42                    axis: axis_a,
43                    value: value_a,
44                    dx: dx_a,
45                },
46                Self::ResolutionOverflow {
47                    axis: axis_b,
48                    value: value_b,
49                    dx: dx_b,
50                },
51            ) => {
52                axis_a == axis_b
53                    && ((value_a.is_nan() && value_b.is_nan()) || value_a == value_b)
54                    && ((dx_a.is_nan() && dx_b.is_nan()) || dx_a == dx_b)
55            }
56            (
57                Self::InvalidMaskShape {
58                    expected: expected_a,
59                    actual: actual_a,
60                },
61                Self::InvalidMaskShape {
62                    expected: expected_b,
63                    actual: actual_b,
64                },
65            ) => expected_a == expected_b && actual_a == actual_b,
66            _ => false,
67        }
68    }
69}
70
71impl fmt::Display for GridError {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        match self {
74            Self::InvalidSpacing { dx } => {
75                write!(f, "grid spacing dx must be finite and positive, got {dx}")
76            }
77            Self::InvalidExtent { axis, value } => {
78                write!(
79                    f,
80                    "grid extent {axis} must be finite and non-negative, got {value}"
81                )
82            }
83            Self::ResolutionOverflow { axis, value, dx } => write!(
84                f,
85                "grid extent {axis}={value} with dx={dx} produces too many cells"
86            ),
87            Self::InvalidMaskShape { expected, actual } => write!(
88                f,
89                "mask shape must match grid shape {:?}, got {:?}",
90                expected, actual
91            ),
92        }
93    }
94}
95
96impl std::error::Error for GridError {}
97
98/// 2D uniform square grid.
99#[derive(Debug)]
100pub struct Grid2D {
101    lx: f64,
102    ly: f64,
103    dx: f64,
104    nx: usize,
105    ny: usize,
106}
107
108#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
109#[cfg_attr(feature = "serde", serde(tag = "type", rename_all = "snake_case"))]
110#[derive(Debug, Clone)]
111pub enum BoundaryGeometry {
112    Circular {
113        r_outer: f64,
114        r_hole: f64,
115    },
116    Rectangular,
117    #[cfg_attr(feature = "serde", serde(skip))]
118    Mask(Array2<bool>),
119}
120
121impl Grid2D {
122    pub fn new(lx: f64, ly: f64, dx: f64) -> Result<Self, GridError> {
123        if !dx.is_finite() || dx <= 0.0 {
124            return Err(GridError::InvalidSpacing { dx });
125        }
126        validate_extent("lx", lx)?;
127        validate_extent("ly", ly)?;
128        let nx = cells_for_extent("lx", lx, dx)?;
129        let ny = cells_for_extent("ly", ly, dx)?;
130        Ok(Self { lx, ly, dx, nx, ny })
131    }
132
133    pub fn nx(&self) -> usize {
134        self.nx
135    }
136
137    pub fn ny(&self) -> usize {
138        self.ny
139    }
140
141    pub fn dx(&self) -> f64 {
142        self.dx
143    }
144
145    pub fn lx(&self) -> f64 {
146        self.lx
147    }
148
149    pub fn ly(&self) -> f64 {
150        self.ly
151    }
152
153    /// Coordinate arrays (X, Y) with origin at grid center.
154    pub fn coords(&self) -> (Array2<f64>, Array2<f64>) {
155        let cx = (self.nx / 2) as f64 * self.dx;
156        let cy = (self.ny / 2) as f64 * self.dx;
157        let mut x = Array2::zeros((self.nx, self.ny));
158        let mut y = Array2::zeros((self.nx, self.ny));
159        for i in 0..self.nx {
160            for j in 0..self.ny {
161                x[[i, j]] = i as f64 * self.dx - cx;
162                y[[i, j]] = j as f64 * self.dx - cy;
163            }
164        }
165        (x, y)
166    }
167
168    /// Radius map from grid center.
169    pub fn radius_map(&self) -> Array2<f64> {
170        let (x, y) = self.coords();
171        let mut r = Array2::zeros((self.nx, self.ny));
172        for i in 0..self.nx {
173            for j in 0..self.ny {
174                r[[i, j]] = (x[[i, j]].powi(2) + y[[i, j]].powi(2)).sqrt();
175            }
176        }
177        r
178    }
179
180    /// Interior mask (true = active computational cell).
181    pub fn interior_mask(&self, geom: &BoundaryGeometry) -> Result<Array2<bool>, GridError> {
182        match geom {
183            BoundaryGeometry::Circular { r_outer, r_hole } => {
184                let r = self.radius_map();
185                let mut mask = Array2::from_elem((self.nx, self.ny), false);
186                for i in 0..self.nx {
187                    for j in 0..self.ny {
188                        let rv = r[[i, j]];
189                        mask[[i, j]] = rv < *r_outer && rv > *r_hole;
190                    }
191                }
192                Ok(mask)
193            }
194            BoundaryGeometry::Mask(mask) => {
195                let actual = (mask.nrows(), mask.ncols());
196                let expected = (self.nx, self.ny);
197                if actual != expected {
198                    return Err(GridError::InvalidMaskShape { expected, actual });
199                }
200                Ok(mask.clone())
201            }
202            BoundaryGeometry::Rectangular => {
203                let mut mask = Array2::from_elem((self.nx, self.ny), true);
204                let x_border = self.nx.saturating_sub(2);
205                let y_border = self.ny.saturating_sub(2);
206                for i in 0..self.nx {
207                    for j in 0..self.ny {
208                        if i < 2 || i >= x_border || j < 2 || j >= y_border {
209                            mask[[i, j]] = false;
210                        }
211                    }
212                }
213                Ok(mask)
214            }
215        }
216    }
217}
218
219fn validate_extent(axis: &'static str, value: f64) -> Result<(), GridError> {
220    if !value.is_finite() || value < 0.0 {
221        return Err(GridError::InvalidExtent { axis, value });
222    }
223    Ok(())
224}
225
226fn cells_for_extent(axis: &'static str, value: f64, dx: f64) -> Result<usize, GridError> {
227    let cells = (value / dx).round();
228    if !cells.is_finite() || cells < 0.0 || cells > (usize::MAX - 1) as f64 {
229        return Err(GridError::ResolutionOverflow { axis, value, dx });
230    }
231    Ok(cells as usize + 1)
232}