light_curve_dmdt/
dmdt.rs

1use crate::{CellIndex, ErfFloat, ErrorFunction, Float, Grid, GridTrait, LgGrid, LinearGrid};
2
3use itertools::Itertools;
4use ndarray::{Array1, Array2, s};
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7use std::fmt::Debug;
8
9/// dm–dt map plotter
10#[derive(Clone, Debug)]
11#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
12pub struct DmDt<T>
13where
14    T: Float,
15{
16    pub dt_grid: Grid<T>,
17    pub dm_grid: Grid<T>,
18}
19
20impl<T> DmDt<T>
21where
22    T: Float,
23{
24    /// Create new [DmDt]
25    pub fn from_grids<Gdt, Gdm>(dt_grid: Gdt, dm_grid: Gdm) -> Self
26    where
27        Gdt: Into<Grid<T>>,
28        Gdm: Into<Grid<T>>,
29    {
30        Self {
31            dt_grid: dt_grid.into(),
32            dm_grid: dm_grid.into(),
33        }
34    }
35
36    /// Create new [DmDt] with logarithmic dt grid and linear dm grid
37    ///
38    /// dt grid will have borders `[10^min_lgdt, 10^max_lgdt)`, dm grid will have borders
39    /// `[-max_abs_dm, max_abs_dm)`
40    pub fn from_lgdt_dm_limits(
41        min_lgdt: T,
42        max_lgdt: T,
43        lgdt_size: usize,
44        max_abs_dm: T,
45        dm_size: usize,
46    ) -> Self {
47        Self::from_grids(
48            LgGrid::from_lg_start_end(min_lgdt, max_lgdt, lgdt_size),
49            LinearGrid::new(-max_abs_dm, max_abs_dm, dm_size),
50        )
51    }
52
53    /// N dt by N dm
54    pub fn shape(&self) -> (usize, usize) {
55        (self.dt_grid.cell_count(), self.dm_grid.cell_count())
56    }
57
58    /// Represents each pair of (t, m) points as a unity value in dm-dt map
59    ///
60    /// `t` must be an ascending slice
61    pub fn points(&self, t: &[T], m: &[T]) -> Array2<u64> {
62        let mut a = Array2::zeros(self.shape());
63        for (i1, (&x1, &y1)) in t.iter().zip(m.iter()).enumerate() {
64            for (&x2, &y2) in t[i1 + 1..].iter().zip(m[i1 + 1..].iter()) {
65                let dt = x2 - x1;
66                let idx_dt = match self.dt_grid.idx(dt) {
67                    CellIndex::LowerMin => continue,
68                    CellIndex::GreaterMax => break,
69                    CellIndex::Value(idx_dt) => idx_dt,
70                };
71                let dm = y2 - y1;
72                let idx_dm = match self.dm_grid.idx(dm) {
73                    CellIndex::Value(idx_dm) => idx_dm,
74                    CellIndex::LowerMin | CellIndex::GreaterMax => continue,
75                };
76                a[(idx_dt, idx_dm)] += 1;
77            }
78        }
79        a
80    }
81
82    fn update_gausses_helper<Erf>(
83        &self,
84        a: &mut Array2<T>,
85        idx_dt: usize,
86        y1: T,
87        y2: T,
88        d1: T,
89        d2: T,
90    ) where
91        T: ErfFloat,
92        Erf: ErrorFunction<T>,
93    {
94        let dm = y2 - y1;
95        let dm_err = T::sqrt(d1 + d2);
96
97        let min_idx_dm = match self
98            .dm_grid
99            .idx(dm + Erf::min_dx_nonzero_normal_cdf(dm_err))
100        {
101            CellIndex::LowerMin => 0,
102            CellIndex::GreaterMax => return,
103            CellIndex::Value(min_idx_dm) => min_idx_dm,
104        };
105        let max_idx_dm = match self
106            .dm_grid
107            .idx(dm + Erf::max_dx_nonunity_normal_cdf(dm_err))
108        {
109            CellIndex::LowerMin => return,
110            CellIndex::GreaterMax => self.dm_grid.cell_count(),
111            CellIndex::Value(i) => usize::min(i + 1, self.dm_grid.cell_count()),
112        };
113
114        a.slice_mut(s![idx_dt, min_idx_dm..max_idx_dm])
115            .iter_mut()
116            .zip(
117                self.dm_grid
118                    .get_borders()
119                    .slice(s![min_idx_dm..max_idx_dm + 1])
120                    .iter()
121                    .map(|&dm_border| Erf::normal_cdf(dm_border, dm, dm_err))
122                    .tuple_windows()
123                    .map(|(a, b)| b - a),
124            )
125            .for_each(|(cell, value)| *cell += value);
126    }
127
128    /// Represents each pair of (t, m, err2) points as a Gaussian distribution in dm-dt map
129    ///
130    /// `t` must be an ascending slice.
131    ///
132    /// Each observation is assumed to happen at time moment `t_i` and have Gaussian distribution of
133    /// its magnitude `N(m_i, err2_i)`. Each pair of observations
134    /// `(t_1, m_1, err2_1), (t_2, m_2, err2_2)` is represented by 1-D Gaussian in the dm-dt space
135    /// having constant `dt` and `dm ~ N(m2-m1, err2_1 + err2_2)`. This distribution is integrated
136    /// over each cell using `Erf` struct implementing [ErrorFunction].
137    pub fn gausses<Erf>(&self, t: &[T], m: &[T], err2: &[T]) -> Array2<T>
138    where
139        T: ErfFloat,
140        Erf: ErrorFunction<T>,
141    {
142        let mut a = Array2::zeros(self.shape());
143        for (i1, ((&x1, &y1), &d1)) in t.iter().zip(m.iter()).zip(err2.iter()).enumerate() {
144            for ((&x2, &y2), &d2) in t[i1 + 1..]
145                .iter()
146                .zip(m[i1 + 1..].iter())
147                .zip(err2[i1 + 1..].iter())
148            {
149                let dt = x2 - x1;
150                let idx_dt = match self.dt_grid.idx(dt) {
151                    CellIndex::LowerMin => continue,
152                    CellIndex::GreaterMax => break,
153                    CellIndex::Value(idx_dt) => idx_dt,
154                };
155                self.update_gausses_helper::<Erf>(&mut a, idx_dt, y1, y2, d1, d2);
156            }
157        }
158        a
159    }
160
161    /// Count dt in the each dt grid cell
162    pub fn dt_points(&self, t: &[T]) -> Array1<u64> {
163        let mut a = Array1::zeros(self.dt_grid.cell_count());
164        for (i1, &x1) in t.iter().enumerate() {
165            for &x2 in t[i1 + 1..].iter() {
166                let dt = x2 - x1;
167                let idx_dt = match self.dt_grid.idx(dt) {
168                    CellIndex::LowerMin => continue,
169                    CellIndex::GreaterMax => break,
170                    CellIndex::Value(idx_dt) => idx_dt,
171                };
172                a[idx_dt] += 1;
173            }
174        }
175        a
176    }
177
178    /// Conditional probability `p(m2-m1|t2-t1)`
179    ///
180    /// Technically it is optimized version of [DmDt::gausses()] normalized by [DmDt::dt_points] but
181    /// with better performance. Mathematically it represents the distribution of conditional
182    /// probability `p(m2-m1|t2-t1)`, see
183    /// [Soraisam et al. 2020](https://doi.org/10.3847/1538-4357/ab7b61) for details.
184    pub fn cond_prob<Erf>(&self, t: &[T], m: &[T], err2: &[T]) -> Array2<T>
185    where
186        T: ErfFloat,
187        Erf: ErrorFunction<T>,
188    {
189        let mut a: Array2<T> = Array2::zeros(self.shape());
190        let mut dt_points: Array1<u64> = Array1::zeros(self.dt_grid.cell_count());
191        for (i1, ((&x1, &y1), &d1)) in t.iter().zip(m.iter()).zip(err2.iter()).enumerate() {
192            for ((&x2, &y2), &d2) in t[i1 + 1..]
193                .iter()
194                .zip(m[i1 + 1..].iter())
195                .zip(err2[i1 + 1..].iter())
196            {
197                let dt = x2 - x1;
198                let idx_dt = match self.dt_grid.idx(dt) {
199                    CellIndex::LowerMin => continue,
200                    CellIndex::GreaterMax => break,
201                    CellIndex::Value(idx_dt) => idx_dt,
202                };
203
204                dt_points[idx_dt] += 1;
205
206                self.update_gausses_helper::<Erf>(&mut a, idx_dt, y1, y2, d1, d2);
207            }
208        }
209        ndarray::Zip::from(a.rows_mut())
210            .and(&dt_points)
211            .for_each(|mut row, &count| {
212                if count == 0 {
213                    return;
214                }
215                row /= T::approx_from(count).unwrap();
216            });
217        a
218    }
219}
220
221#[cfg(test)]
222mod test {
223    use super::*;
224
225    use crate::dmdt::DmDt;
226    use crate::erf::{Eps1Over1e3Erf, ExactErf};
227
228    use approx::assert_abs_diff_eq;
229    use ndarray::Axis;
230    use static_assertions::assert_impl_all;
231
232    assert_impl_all!(DmDt<f32>: Clone, Debug, Send, Sync, Serialize, Deserialize<'static>);
233    assert_impl_all!(DmDt<f64>: Clone, Debug, Send, Sync, Serialize, Deserialize<'static>);
234
235    #[test]
236    fn dt_points_vs_points() {
237        let dmdt = DmDt::from_lgdt_dm_limits(0.0_f32, 2.0_f32, 32, 3.0_f32, 32);
238        let t = Array1::linspace(0.0, 100.0, 101);
239        // dm is within map borders
240        let m = t.mapv(f32::sin);
241
242        let points = dmdt.points(t.as_slice().unwrap(), m.as_slice().unwrap());
243        let dt_points = dmdt.dt_points(t.as_slice().unwrap());
244
245        assert_eq!(points.sum_axis(Axis(1)), dt_points,);
246    }
247
248    #[test]
249    fn dt_points_vs_gausses() {
250        let dmdt = DmDt::from_lgdt_dm_limits(0.0_f32, 2.0_f32, 32, 3.0_f32, 32);
251        let t = Array1::linspace(0.0, 100.0, 101);
252        // dm is within map borders
253        let m = t.mapv(f32::sin);
254        // err is ~0.03
255        let err2 = Array1::from_elem(101, 0.001_f32);
256
257        let gausses = dmdt.gausses::<ExactErf>(
258            t.as_slice().unwrap(),
259            m.as_slice().unwrap(),
260            err2.as_slice().unwrap(),
261        );
262        let sum_gausses = gausses.sum_axis(Axis(1));
263        let dt_points = dmdt.dt_points(t.as_slice().unwrap()).mapv(|x| x as f32);
264
265        assert_abs_diff_eq!(
266            sum_gausses.as_slice().unwrap(),
267            dt_points.as_slice().unwrap(),
268            epsilon = 1e-4,
269        );
270    }
271
272    #[test]
273    fn cond_prob() {
274        let dmdt = DmDt::from_lgdt_dm_limits(0.0_f32, 2.0_f32, 32, 1.25_f32, 32);
275
276        let t = Array1::linspace(0.0, 100.0, 101);
277        let m = t.mapv(f32::sin);
278        // err is ~0.03
279        let err2 = Array1::from_elem(101, 0.001);
280
281        let from_gausses_dt_points = {
282            let mut map = dmdt.gausses::<Eps1Over1e3Erf>(
283                t.as_slice().unwrap(),
284                m.as_slice().unwrap(),
285                err2.as_slice_memory_order().unwrap(),
286            );
287            let dt_points = dmdt.dt_points(t.as_slice().unwrap());
288            let dt_non_zero_points = dt_points.mapv(|x| if x == 0 { 1.0 } else { x as f32 });
289            map /= &dt_non_zero_points.to_shape((map.nrows(), 1)).unwrap();
290            map
291        };
292
293        let from_cond_prob = dmdt.cond_prob::<Eps1Over1e3Erf>(
294            t.as_slice().unwrap(),
295            m.as_slice().unwrap(),
296            err2.as_slice().unwrap(),
297        );
298
299        assert_abs_diff_eq!(
300            from_gausses_dt_points.as_slice().unwrap(),
301            from_cond_prob.as_slice().unwrap(),
302            epsilon = f32::EPSILON,
303        );
304    }
305}