light_curve_dmdt/
grid.rs

1use crate::Float;
2
3use conv::{ConvAsUtil, ConvUtil, RoundToZero};
4use enum_dispatch::enum_dispatch;
5use ndarray::{Array1, ArrayView1};
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8use std::fmt::Debug;
9use thiserror::Error;
10
11/// Grid trait for dm or dt axis
12#[enum_dispatch]
13pub trait GridTrait<T>: Clone + Debug + Send + Sync
14where
15    T: Copy,
16{
17    /// Cell borders coordinates, [cell_count()](GridTrait::cell_count) + 1 length [ArrayView1]
18    fn get_borders(&self) -> ArrayView1<'_, T>;
19
20    /// Number of cells
21    fn cell_count(&self) -> usize {
22        self.get_borders().len() - 1
23    }
24
25    /// Coordinate of the left border of the leftmost cell
26    fn get_start(&self) -> T {
27        self.get_borders()[0]
28    }
29
30    /// Coordinate of the right border of the rightmost cell
31    fn get_end(&self) -> T {
32        self.get_borders()[self.cell_count()]
33    }
34
35    /// Get index of the cell containing given value
36    ///
37    /// Note that cells include their left borders but doesn't include right borders
38    fn idx(&self, x: T) -> CellIndex;
39}
40
41/// Grid for dm or dt axis
42#[enum_dispatch(GridTrait<T>)]
43#[derive(Clone, Debug)]
44#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
45#[non_exhaustive]
46pub enum Grid<T>
47where
48    T: Float,
49{
50    Array(ArrayGrid<T>),
51    Linear(LinearGrid<T>),
52    Lg(LgGrid<T>),
53}
54
55impl<T> Grid<T>
56where
57    T: Float,
58{
59    pub fn array(borders: Array1<T>) -> Result<Self, ArrayGridError> {
60        ArrayGrid::new(borders).map(Into::into)
61    }
62
63    pub fn linear(start: T, end: T, n: usize) -> Self {
64        LinearGrid::new(start, end, n).into()
65    }
66
67    pub fn log_from_start_end(start: T, end: T, n: usize) -> Self {
68        LgGrid::from_start_end(start, end, n).into()
69    }
70
71    pub fn log_from_lg_start_end(lg_start: T, lg_end: T, n: usize) -> Self {
72        LgGrid::from_lg_start_end(lg_start, lg_end, n).into()
73    }
74}
75
76/// An error to be returned from grid constructors
77#[derive(Error, Debug)]
78pub enum ArrayGridError {
79    #[error("given grid is empty")]
80    ArrayIsEmpty,
81    #[error("given grid is not ascending")]
82    ArrayIsNotAscending,
83}
84
85/// Grid which cell borders are defined by an ascending array
86///
87/// Lookup time is O(lb n)
88#[derive(Clone, Debug)]
89#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
90pub struct ArrayGrid<T> {
91    borders: Array1<T>,
92}
93
94impl<T> ArrayGrid<T>
95where
96    T: Float,
97{
98    /// Wraps given array into [ArrayGrid] or return an error
99    ///
100    /// Note that array describes cell borders, not center or whatever else
101    pub fn new(borders: Array1<T>) -> Result<Self, ArrayGridError> {
102        if borders.is_empty() {
103            return Err(ArrayGridError::ArrayIsEmpty);
104        }
105        if !crate::util::is_sorted(borders.as_slice().unwrap()) {
106            return Err(ArrayGridError::ArrayIsNotAscending);
107        }
108        Ok(Self { borders })
109    }
110}
111
112impl<T> GridTrait<T> for ArrayGrid<T>
113where
114    T: Float,
115{
116    #[inline]
117    fn get_borders(&self) -> ArrayView1<'_, T> {
118        self.borders.view()
119    }
120
121    fn idx(&self, x: T) -> CellIndex {
122        let i = self
123            .borders
124            .as_slice()
125            .unwrap()
126            .partition_point(|&b| b <= x);
127        match i {
128            0 => CellIndex::LowerMin,
129            _ if i == self.borders.len() => CellIndex::GreaterMax,
130            _ => CellIndex::Value(i - 1),
131        }
132    }
133}
134
135/// Linear grid defined by its start, end and number of cells
136///
137/// Lookup time is O(1)
138#[derive(Clone, Debug)]
139#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
140pub struct LinearGrid<T> {
141    start: T,
142    end: T,
143    n: usize,
144    cell_size: T,
145    borders: Array1<T>,
146}
147
148impl<T> LinearGrid<T>
149where
150    T: Float,
151{
152    /// Create [LinearGrid] from borders and number of cells
153    ///
154    /// `start` is the left border of the leftmost cell, `end` is the right border of the rightmost
155    /// cell, `n` is the number of cells. This means that the number of borders is `n + 1`, `start`
156    /// border has zero index and `end` border has index `n`.
157    pub fn new(start: T, end: T, n: usize) -> Self {
158        assert!(end > start);
159        let cell_size = (end - start) / n.value_as::<T>().unwrap();
160        let borders = Array1::linspace(start, end, n + 1);
161        Self {
162            start,
163            end,
164            n,
165            cell_size,
166            borders,
167        }
168    }
169
170    /// Cell size
171    #[inline]
172    pub fn get_cell_size(&self) -> T {
173        self.cell_size
174    }
175}
176
177impl<T> GridTrait<T> for LinearGrid<T>
178where
179    T: Float,
180{
181    #[inline]
182    fn get_borders(&self) -> ArrayView1<'_, T> {
183        self.borders.view()
184    }
185
186    #[inline]
187    fn cell_count(&self) -> usize {
188        self.n
189    }
190
191    #[inline]
192    fn get_start(&self) -> T {
193        self.start
194    }
195
196    #[inline]
197    fn get_end(&self) -> T {
198        self.end
199    }
200
201    fn idx(&self, x: T) -> CellIndex {
202        if x < self.start {
203            return CellIndex::LowerMin;
204        }
205        if x >= self.end {
206            return CellIndex::GreaterMax;
207        }
208        let i = ((x - self.start) / self.cell_size)
209            .approx_by::<RoundToZero>()
210            .unwrap();
211        if i < self.n {
212            CellIndex::Value(i)
213        } else {
214            // x is a bit smaller self.end + float rounding
215            CellIndex::Value(self.n - 1)
216        }
217    }
218}
219
220/// Logarithmic grid defined by its start, end and number of cells
221///
222/// Lookup time is O(1)
223#[derive(Clone, Debug)]
224#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
225pub struct LgGrid<T> {
226    start: T,
227    end: T,
228    lg_start: T,
229    lg_end: T,
230    n: usize,
231    cell_lg_size: T,
232    borders: Array1<T>,
233}
234
235impl<T> LgGrid<T>
236where
237    T: Float,
238{
239    /// Create [LinearGrid] from borders and number of cells
240    ///
241    /// `start` is the left border of the leftmost cell, `end` is the right border of the rightmost
242    /// cell, `n` is the number of cells. This means that the number of borders is `n + 1`, `start`
243    /// border has zero index and `end` border has index `n`.
244    pub fn from_start_end(start: T, end: T, n: usize) -> Self {
245        assert!(end > start);
246        assert!(start.is_positive());
247        let lg_start = start.log10();
248        let lg_end = end.log10();
249        let cell_lg_size = (lg_end - lg_start) / n.value_as::<T>().unwrap();
250        let mut borders = Array1::logspace(T::ten(), lg_start, lg_end, n + 1);
251        borders[0] = start;
252        borders[n] = end;
253        Self {
254            start,
255            end,
256            lg_start,
257            lg_end,
258            n,
259            cell_lg_size,
260            borders,
261        }
262    }
263
264    /// Create [LinearGrid] from decimal logarithms of borders and number of cells
265    ///
266    /// `lg_start` is the decimal logarithm of the left border of the leftmost cell, `lg_end` is the
267    /// decimal logarithm of the right border of the rightmost cell, `n` is the number of cells.
268    /// This means that the number of borders is `n + 1`, `lg_start` border has zero index and
269    /// `lg_end` border has index `n`.
270    pub fn from_lg_start_end(lg_start: T, lg_end: T, n: usize) -> Self {
271        Self::from_start_end(T::powf(T::ten(), lg_start), T::powf(T::ten(), lg_end), n)
272    }
273
274    /// Logarithmic size of cell
275    #[inline]
276    pub fn get_cell_lg_size(&self) -> T {
277        self.cell_lg_size
278    }
279
280    /// Logarithm of the leftmost border
281    #[inline]
282    pub fn get_lg_start(&self) -> T {
283        self.lg_start
284    }
285
286    /// Logarithm of the rightmost border
287    #[inline]
288    pub fn get_lg_end(&self) -> T {
289        self.lg_end
290    }
291}
292
293impl<T> GridTrait<T> for LgGrid<T>
294where
295    T: Float,
296{
297    #[inline]
298    fn get_borders(&self) -> ArrayView1<'_, T> {
299        self.borders.view()
300    }
301
302    #[inline]
303    fn cell_count(&self) -> usize {
304        self.n
305    }
306
307    #[inline]
308    fn get_start(&self) -> T {
309        self.start
310    }
311
312    #[inline]
313    fn get_end(&self) -> T {
314        self.end
315    }
316
317    fn idx(&self, x: T) -> CellIndex {
318        if x < self.start {
319            return CellIndex::LowerMin;
320        }
321        if x >= self.end {
322            return CellIndex::GreaterMax;
323        }
324        let i = ((x.log10() - self.lg_start) / self.cell_lg_size)
325            .approx_by::<RoundToZero>()
326            .unwrap();
327        if i < self.n {
328            CellIndex::Value(i)
329        } else {
330            // x is a bit smaller self.end + float rounding
331            CellIndex::Value(self.n - 1)
332        }
333    }
334}
335
336/// Value to return from [GridTrait::idx]
337pub enum CellIndex {
338    /// Bellow the leftmost border
339    LowerMin,
340    /// Equal or greater the rightmost border
341    GreaterMax,
342    /// Cell index
343    Value(usize),
344}