1use crate::Float;
2
3use conv::{ConvAsUtil, ConvUtil, RoundToZero};
4use enum_dispatch::enum_dispatch;
5use ndarray::{Array1, ArrayView1};
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8use std::fmt::Debug;
9use thiserror::Error;
10
11#[enum_dispatch]
13pub trait GridTrait<T>: Clone + Debug + Send + Sync
14where
15 T: Copy,
16{
17 fn get_borders(&self) -> ArrayView1<'_, T>;
19
20 fn cell_count(&self) -> usize {
22 self.get_borders().len() - 1
23 }
24
25 fn get_start(&self) -> T {
27 self.get_borders()[0]
28 }
29
30 fn get_end(&self) -> T {
32 self.get_borders()[self.cell_count()]
33 }
34
35 fn idx(&self, x: T) -> CellIndex;
39}
40
41#[enum_dispatch(GridTrait<T>)]
43#[derive(Clone, Debug)]
44#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
45#[non_exhaustive]
46pub enum Grid<T>
47where
48 T: Float,
49{
50 Array(ArrayGrid<T>),
51 Linear(LinearGrid<T>),
52 Lg(LgGrid<T>),
53}
54
55impl<T> Grid<T>
56where
57 T: Float,
58{
59 pub fn array(borders: Array1<T>) -> Result<Self, ArrayGridError> {
60 ArrayGrid::new(borders).map(Into::into)
61 }
62
63 pub fn linear(start: T, end: T, n: usize) -> Self {
64 LinearGrid::new(start, end, n).into()
65 }
66
67 pub fn log_from_start_end(start: T, end: T, n: usize) -> Self {
68 LgGrid::from_start_end(start, end, n).into()
69 }
70
71 pub fn log_from_lg_start_end(lg_start: T, lg_end: T, n: usize) -> Self {
72 LgGrid::from_lg_start_end(lg_start, lg_end, n).into()
73 }
74}
75
76#[derive(Error, Debug)]
78pub enum ArrayGridError {
79 #[error("given grid is empty")]
80 ArrayIsEmpty,
81 #[error("given grid is not ascending")]
82 ArrayIsNotAscending,
83}
84
85#[derive(Clone, Debug)]
89#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
90pub struct ArrayGrid<T> {
91 borders: Array1<T>,
92}
93
94impl<T> ArrayGrid<T>
95where
96 T: Float,
97{
98 pub fn new(borders: Array1<T>) -> Result<Self, ArrayGridError> {
102 if borders.is_empty() {
103 return Err(ArrayGridError::ArrayIsEmpty);
104 }
105 if !crate::util::is_sorted(borders.as_slice().unwrap()) {
106 return Err(ArrayGridError::ArrayIsNotAscending);
107 }
108 Ok(Self { borders })
109 }
110}
111
112impl<T> GridTrait<T> for ArrayGrid<T>
113where
114 T: Float,
115{
116 #[inline]
117 fn get_borders(&self) -> ArrayView1<'_, T> {
118 self.borders.view()
119 }
120
121 fn idx(&self, x: T) -> CellIndex {
122 let i = self
123 .borders
124 .as_slice()
125 .unwrap()
126 .partition_point(|&b| b <= x);
127 match i {
128 0 => CellIndex::LowerMin,
129 _ if i == self.borders.len() => CellIndex::GreaterMax,
130 _ => CellIndex::Value(i - 1),
131 }
132 }
133}
134
135#[derive(Clone, Debug)]
139#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
140pub struct LinearGrid<T> {
141 start: T,
142 end: T,
143 n: usize,
144 cell_size: T,
145 borders: Array1<T>,
146}
147
148impl<T> LinearGrid<T>
149where
150 T: Float,
151{
152 pub fn new(start: T, end: T, n: usize) -> Self {
158 assert!(end > start);
159 let cell_size = (end - start) / n.value_as::<T>().unwrap();
160 let borders = Array1::linspace(start, end, n + 1);
161 Self {
162 start,
163 end,
164 n,
165 cell_size,
166 borders,
167 }
168 }
169
170 #[inline]
172 pub fn get_cell_size(&self) -> T {
173 self.cell_size
174 }
175}
176
177impl<T> GridTrait<T> for LinearGrid<T>
178where
179 T: Float,
180{
181 #[inline]
182 fn get_borders(&self) -> ArrayView1<'_, T> {
183 self.borders.view()
184 }
185
186 #[inline]
187 fn cell_count(&self) -> usize {
188 self.n
189 }
190
191 #[inline]
192 fn get_start(&self) -> T {
193 self.start
194 }
195
196 #[inline]
197 fn get_end(&self) -> T {
198 self.end
199 }
200
201 fn idx(&self, x: T) -> CellIndex {
202 if x < self.start {
203 return CellIndex::LowerMin;
204 }
205 if x >= self.end {
206 return CellIndex::GreaterMax;
207 }
208 let i = ((x - self.start) / self.cell_size)
209 .approx_by::<RoundToZero>()
210 .unwrap();
211 if i < self.n {
212 CellIndex::Value(i)
213 } else {
214 CellIndex::Value(self.n - 1)
216 }
217 }
218}
219
220#[derive(Clone, Debug)]
224#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
225pub struct LgGrid<T> {
226 start: T,
227 end: T,
228 lg_start: T,
229 lg_end: T,
230 n: usize,
231 cell_lg_size: T,
232 borders: Array1<T>,
233}
234
235impl<T> LgGrid<T>
236where
237 T: Float,
238{
239 pub fn from_start_end(start: T, end: T, n: usize) -> Self {
245 assert!(end > start);
246 assert!(start.is_positive());
247 let lg_start = start.log10();
248 let lg_end = end.log10();
249 let cell_lg_size = (lg_end - lg_start) / n.value_as::<T>().unwrap();
250 let mut borders = Array1::logspace(T::ten(), lg_start, lg_end, n + 1);
251 borders[0] = start;
252 borders[n] = end;
253 Self {
254 start,
255 end,
256 lg_start,
257 lg_end,
258 n,
259 cell_lg_size,
260 borders,
261 }
262 }
263
264 pub fn from_lg_start_end(lg_start: T, lg_end: T, n: usize) -> Self {
271 Self::from_start_end(T::powf(T::ten(), lg_start), T::powf(T::ten(), lg_end), n)
272 }
273
274 #[inline]
276 pub fn get_cell_lg_size(&self) -> T {
277 self.cell_lg_size
278 }
279
280 #[inline]
282 pub fn get_lg_start(&self) -> T {
283 self.lg_start
284 }
285
286 #[inline]
288 pub fn get_lg_end(&self) -> T {
289 self.lg_end
290 }
291}
292
293impl<T> GridTrait<T> for LgGrid<T>
294where
295 T: Float,
296{
297 #[inline]
298 fn get_borders(&self) -> ArrayView1<'_, T> {
299 self.borders.view()
300 }
301
302 #[inline]
303 fn cell_count(&self) -> usize {
304 self.n
305 }
306
307 #[inline]
308 fn get_start(&self) -> T {
309 self.start
310 }
311
312 #[inline]
313 fn get_end(&self) -> T {
314 self.end
315 }
316
317 fn idx(&self, x: T) -> CellIndex {
318 if x < self.start {
319 return CellIndex::LowerMin;
320 }
321 if x >= self.end {
322 return CellIndex::GreaterMax;
323 }
324 let i = ((x.log10() - self.lg_start) / self.cell_lg_size)
325 .approx_by::<RoundToZero>()
326 .unwrap();
327 if i < self.n {
328 CellIndex::Value(i)
329 } else {
330 CellIndex::Value(self.n - 1)
332 }
333 }
334}
335
336pub enum CellIndex {
338 LowerMin,
340 GreaterMax,
342 Value(usize),
344}