1use crate::Array2;
2use core::fmt;
3
4#[derive(Debug, Clone)]
5pub enum GridError {
6 InvalidSpacing {
7 dx: f64,
8 },
9 InvalidExtent {
10 axis: &'static str,
11 value: f64,
12 },
13 ResolutionOverflow {
14 axis: &'static str,
15 value: f64,
16 dx: f64,
17 },
18 InvalidMaskShape {
19 expected: (usize, usize),
20 actual: (usize, usize),
21 },
22}
23
24impl PartialEq for GridError {
25 fn eq(&self, other: &Self) -> bool {
26 match (self, other) {
27 (Self::InvalidSpacing { dx: a }, Self::InvalidSpacing { dx: b }) => {
28 (a.is_nan() && b.is_nan()) || a == b
29 }
30 (
31 Self::InvalidExtent {
32 axis: axis_a,
33 value: value_a,
34 },
35 Self::InvalidExtent {
36 axis: axis_b,
37 value: value_b,
38 },
39 ) => axis_a == axis_b && ((value_a.is_nan() && value_b.is_nan()) || value_a == value_b),
40 (
41 Self::ResolutionOverflow {
42 axis: axis_a,
43 value: value_a,
44 dx: dx_a,
45 },
46 Self::ResolutionOverflow {
47 axis: axis_b,
48 value: value_b,
49 dx: dx_b,
50 },
51 ) => {
52 axis_a == axis_b
53 && ((value_a.is_nan() && value_b.is_nan()) || value_a == value_b)
54 && ((dx_a.is_nan() && dx_b.is_nan()) || dx_a == dx_b)
55 }
56 (
57 Self::InvalidMaskShape {
58 expected: expected_a,
59 actual: actual_a,
60 },
61 Self::InvalidMaskShape {
62 expected: expected_b,
63 actual: actual_b,
64 },
65 ) => expected_a == expected_b && actual_a == actual_b,
66 _ => false,
67 }
68 }
69}
70
71impl fmt::Display for GridError {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 match self {
74 Self::InvalidSpacing { dx } => {
75 write!(f, "grid spacing dx must be finite and positive, got {dx}")
76 }
77 Self::InvalidExtent { axis, value } => {
78 write!(
79 f,
80 "grid extent {axis} must be finite and non-negative, got {value}"
81 )
82 }
83 Self::ResolutionOverflow { axis, value, dx } => write!(
84 f,
85 "grid extent {axis}={value} with dx={dx} produces too many cells"
86 ),
87 Self::InvalidMaskShape { expected, actual } => write!(
88 f,
89 "mask shape must match grid shape {:?}, got {:?}",
90 expected, actual
91 ),
92 }
93 }
94}
95
96impl std::error::Error for GridError {}
97
98#[derive(Debug)]
100pub struct Grid2D {
101 lx: f64,
102 ly: f64,
103 dx: f64,
104 nx: usize,
105 ny: usize,
106}
107
108#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
109#[cfg_attr(feature = "serde", serde(tag = "type", rename_all = "snake_case"))]
110#[derive(Debug, Clone)]
111pub enum BoundaryGeometry {
112 Circular {
113 r_outer: f64,
114 r_hole: f64,
115 },
116 Rectangular,
117 #[cfg_attr(feature = "serde", serde(skip))]
118 Mask(Array2<bool>),
119}
120
121impl Grid2D {
122 pub fn new(lx: f64, ly: f64, dx: f64) -> Result<Self, GridError> {
123 if !dx.is_finite() || dx <= 0.0 {
124 return Err(GridError::InvalidSpacing { dx });
125 }
126 validate_extent("lx", lx)?;
127 validate_extent("ly", ly)?;
128 let nx = cells_for_extent("lx", lx, dx)?;
129 let ny = cells_for_extent("ly", ly, dx)?;
130 Ok(Self { lx, ly, dx, nx, ny })
131 }
132
133 pub fn nx(&self) -> usize {
134 self.nx
135 }
136
137 pub fn ny(&self) -> usize {
138 self.ny
139 }
140
141 pub fn dx(&self) -> f64 {
142 self.dx
143 }
144
145 pub fn lx(&self) -> f64 {
146 self.lx
147 }
148
149 pub fn ly(&self) -> f64 {
150 self.ly
151 }
152
153 pub fn coords(&self) -> (Array2<f64>, Array2<f64>) {
155 let cx = (self.nx / 2) as f64 * self.dx;
156 let cy = (self.ny / 2) as f64 * self.dx;
157 let mut x = Array2::zeros((self.nx, self.ny));
158 let mut y = Array2::zeros((self.nx, self.ny));
159 for i in 0..self.nx {
160 for j in 0..self.ny {
161 x[[i, j]] = i as f64 * self.dx - cx;
162 y[[i, j]] = j as f64 * self.dx - cy;
163 }
164 }
165 (x, y)
166 }
167
168 pub fn radius_map(&self) -> Array2<f64> {
170 let (x, y) = self.coords();
171 let mut r = Array2::zeros((self.nx, self.ny));
172 for i in 0..self.nx {
173 for j in 0..self.ny {
174 r[[i, j]] = (x[[i, j]].powi(2) + y[[i, j]].powi(2)).sqrt();
175 }
176 }
177 r
178 }
179
180 pub fn interior_mask(&self, geom: &BoundaryGeometry) -> Result<Array2<bool>, GridError> {
182 match geom {
183 BoundaryGeometry::Circular { r_outer, r_hole } => {
184 let r = self.radius_map();
185 let mut mask = Array2::from_elem((self.nx, self.ny), false);
186 for i in 0..self.nx {
187 for j in 0..self.ny {
188 let rv = r[[i, j]];
189 mask[[i, j]] = rv < *r_outer && rv > *r_hole;
190 }
191 }
192 Ok(mask)
193 }
194 BoundaryGeometry::Mask(mask) => {
195 let actual = (mask.nrows(), mask.ncols());
196 let expected = (self.nx, self.ny);
197 if actual != expected {
198 return Err(GridError::InvalidMaskShape { expected, actual });
199 }
200 Ok(mask.clone())
201 }
202 BoundaryGeometry::Rectangular => {
203 let mut mask = Array2::from_elem((self.nx, self.ny), true);
204 let x_border = self.nx.saturating_sub(2);
205 let y_border = self.ny.saturating_sub(2);
206 for i in 0..self.nx {
207 for j in 0..self.ny {
208 if i < 2 || i >= x_border || j < 2 || j >= y_border {
209 mask[[i, j]] = false;
210 }
211 }
212 }
213 Ok(mask)
214 }
215 }
216 }
217}
218
219fn validate_extent(axis: &'static str, value: f64) -> Result<(), GridError> {
220 if !value.is_finite() || value < 0.0 {
221 return Err(GridError::InvalidExtent { axis, value });
222 }
223 Ok(())
224}
225
226fn cells_for_extent(axis: &'static str, value: f64, dx: f64) -> Result<usize, GridError> {
227 let cells = (value / dx).round();
228 if !cells.is_finite() || cells < 0.0 || cells > (usize::MAX - 1) as f64 {
229 return Err(GridError::ResolutionOverflow { axis, value, dx });
230 }
231 Ok(cells as usize + 1)
232}