1use crate::{CellIndex, ErfFloat, ErrorFunction, Float, Grid, GridTrait, LgGrid, LinearGrid};
2
3use itertools::Itertools;
4use ndarray::{Array1, Array2, s};
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7use std::fmt::Debug;
8
9#[derive(Clone, Debug)]
11#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
12pub struct DmDt<T>
13where
14 T: Float,
15{
16 pub dt_grid: Grid<T>,
17 pub dm_grid: Grid<T>,
18}
19
20impl<T> DmDt<T>
21where
22 T: Float,
23{
24 pub fn from_grids<Gdt, Gdm>(dt_grid: Gdt, dm_grid: Gdm) -> Self
26 where
27 Gdt: Into<Grid<T>>,
28 Gdm: Into<Grid<T>>,
29 {
30 Self {
31 dt_grid: dt_grid.into(),
32 dm_grid: dm_grid.into(),
33 }
34 }
35
36 pub fn from_lgdt_dm_limits(
41 min_lgdt: T,
42 max_lgdt: T,
43 lgdt_size: usize,
44 max_abs_dm: T,
45 dm_size: usize,
46 ) -> Self {
47 Self::from_grids(
48 LgGrid::from_lg_start_end(min_lgdt, max_lgdt, lgdt_size),
49 LinearGrid::new(-max_abs_dm, max_abs_dm, dm_size),
50 )
51 }
52
53 pub fn shape(&self) -> (usize, usize) {
55 (self.dt_grid.cell_count(), self.dm_grid.cell_count())
56 }
57
58 pub fn points(&self, t: &[T], m: &[T]) -> Array2<u64> {
62 let mut a = Array2::zeros(self.shape());
63 for (i1, (&x1, &y1)) in t.iter().zip(m.iter()).enumerate() {
64 for (&x2, &y2) in t[i1 + 1..].iter().zip(m[i1 + 1..].iter()) {
65 let dt = x2 - x1;
66 let idx_dt = match self.dt_grid.idx(dt) {
67 CellIndex::LowerMin => continue,
68 CellIndex::GreaterMax => break,
69 CellIndex::Value(idx_dt) => idx_dt,
70 };
71 let dm = y2 - y1;
72 let idx_dm = match self.dm_grid.idx(dm) {
73 CellIndex::Value(idx_dm) => idx_dm,
74 CellIndex::LowerMin | CellIndex::GreaterMax => continue,
75 };
76 a[(idx_dt, idx_dm)] += 1;
77 }
78 }
79 a
80 }
81
82 fn update_gausses_helper<Erf>(
83 &self,
84 a: &mut Array2<T>,
85 idx_dt: usize,
86 y1: T,
87 y2: T,
88 d1: T,
89 d2: T,
90 ) where
91 T: ErfFloat,
92 Erf: ErrorFunction<T>,
93 {
94 let dm = y2 - y1;
95 let dm_err = T::sqrt(d1 + d2);
96
97 let min_idx_dm = match self
98 .dm_grid
99 .idx(dm + Erf::min_dx_nonzero_normal_cdf(dm_err))
100 {
101 CellIndex::LowerMin => 0,
102 CellIndex::GreaterMax => return,
103 CellIndex::Value(min_idx_dm) => min_idx_dm,
104 };
105 let max_idx_dm = match self
106 .dm_grid
107 .idx(dm + Erf::max_dx_nonunity_normal_cdf(dm_err))
108 {
109 CellIndex::LowerMin => return,
110 CellIndex::GreaterMax => self.dm_grid.cell_count(),
111 CellIndex::Value(i) => usize::min(i + 1, self.dm_grid.cell_count()),
112 };
113
114 a.slice_mut(s![idx_dt, min_idx_dm..max_idx_dm])
115 .iter_mut()
116 .zip(
117 self.dm_grid
118 .get_borders()
119 .slice(s![min_idx_dm..max_idx_dm + 1])
120 .iter()
121 .map(|&dm_border| Erf::normal_cdf(dm_border, dm, dm_err))
122 .tuple_windows()
123 .map(|(a, b)| b - a),
124 )
125 .for_each(|(cell, value)| *cell += value);
126 }
127
128 pub fn gausses<Erf>(&self, t: &[T], m: &[T], err2: &[T]) -> Array2<T>
138 where
139 T: ErfFloat,
140 Erf: ErrorFunction<T>,
141 {
142 let mut a = Array2::zeros(self.shape());
143 for (i1, ((&x1, &y1), &d1)) in t.iter().zip(m.iter()).zip(err2.iter()).enumerate() {
144 for ((&x2, &y2), &d2) in t[i1 + 1..]
145 .iter()
146 .zip(m[i1 + 1..].iter())
147 .zip(err2[i1 + 1..].iter())
148 {
149 let dt = x2 - x1;
150 let idx_dt = match self.dt_grid.idx(dt) {
151 CellIndex::LowerMin => continue,
152 CellIndex::GreaterMax => break,
153 CellIndex::Value(idx_dt) => idx_dt,
154 };
155 self.update_gausses_helper::<Erf>(&mut a, idx_dt, y1, y2, d1, d2);
156 }
157 }
158 a
159 }
160
161 pub fn dt_points(&self, t: &[T]) -> Array1<u64> {
163 let mut a = Array1::zeros(self.dt_grid.cell_count());
164 for (i1, &x1) in t.iter().enumerate() {
165 for &x2 in t[i1 + 1..].iter() {
166 let dt = x2 - x1;
167 let idx_dt = match self.dt_grid.idx(dt) {
168 CellIndex::LowerMin => continue,
169 CellIndex::GreaterMax => break,
170 CellIndex::Value(idx_dt) => idx_dt,
171 };
172 a[idx_dt] += 1;
173 }
174 }
175 a
176 }
177
178 pub fn cond_prob<Erf>(&self, t: &[T], m: &[T], err2: &[T]) -> Array2<T>
185 where
186 T: ErfFloat,
187 Erf: ErrorFunction<T>,
188 {
189 let mut a: Array2<T> = Array2::zeros(self.shape());
190 let mut dt_points: Array1<u64> = Array1::zeros(self.dt_grid.cell_count());
191 for (i1, ((&x1, &y1), &d1)) in t.iter().zip(m.iter()).zip(err2.iter()).enumerate() {
192 for ((&x2, &y2), &d2) in t[i1 + 1..]
193 .iter()
194 .zip(m[i1 + 1..].iter())
195 .zip(err2[i1 + 1..].iter())
196 {
197 let dt = x2 - x1;
198 let idx_dt = match self.dt_grid.idx(dt) {
199 CellIndex::LowerMin => continue,
200 CellIndex::GreaterMax => break,
201 CellIndex::Value(idx_dt) => idx_dt,
202 };
203
204 dt_points[idx_dt] += 1;
205
206 self.update_gausses_helper::<Erf>(&mut a, idx_dt, y1, y2, d1, d2);
207 }
208 }
209 ndarray::Zip::from(a.rows_mut())
210 .and(&dt_points)
211 .for_each(|mut row, &count| {
212 if count == 0 {
213 return;
214 }
215 row /= T::approx_from(count).unwrap();
216 });
217 a
218 }
219}
220
221#[cfg(test)]
222mod test {
223 use super::*;
224
225 use crate::dmdt::DmDt;
226 use crate::erf::{Eps1Over1e3Erf, ExactErf};
227
228 use approx::assert_abs_diff_eq;
229 use ndarray::Axis;
230 use static_assertions::assert_impl_all;
231
232 assert_impl_all!(DmDt<f32>: Clone, Debug, Send, Sync, Serialize, Deserialize<'static>);
233 assert_impl_all!(DmDt<f64>: Clone, Debug, Send, Sync, Serialize, Deserialize<'static>);
234
235 #[test]
236 fn dt_points_vs_points() {
237 let dmdt = DmDt::from_lgdt_dm_limits(0.0_f32, 2.0_f32, 32, 3.0_f32, 32);
238 let t = Array1::linspace(0.0, 100.0, 101);
239 let m = t.mapv(f32::sin);
241
242 let points = dmdt.points(t.as_slice().unwrap(), m.as_slice().unwrap());
243 let dt_points = dmdt.dt_points(t.as_slice().unwrap());
244
245 assert_eq!(points.sum_axis(Axis(1)), dt_points,);
246 }
247
248 #[test]
249 fn dt_points_vs_gausses() {
250 let dmdt = DmDt::from_lgdt_dm_limits(0.0_f32, 2.0_f32, 32, 3.0_f32, 32);
251 let t = Array1::linspace(0.0, 100.0, 101);
252 let m = t.mapv(f32::sin);
254 let err2 = Array1::from_elem(101, 0.001_f32);
256
257 let gausses = dmdt.gausses::<ExactErf>(
258 t.as_slice().unwrap(),
259 m.as_slice().unwrap(),
260 err2.as_slice().unwrap(),
261 );
262 let sum_gausses = gausses.sum_axis(Axis(1));
263 let dt_points = dmdt.dt_points(t.as_slice().unwrap()).mapv(|x| x as f32);
264
265 assert_abs_diff_eq!(
266 sum_gausses.as_slice().unwrap(),
267 dt_points.as_slice().unwrap(),
268 epsilon = 1e-4,
269 );
270 }
271
272 #[test]
273 fn cond_prob() {
274 let dmdt = DmDt::from_lgdt_dm_limits(0.0_f32, 2.0_f32, 32, 1.25_f32, 32);
275
276 let t = Array1::linspace(0.0, 100.0, 101);
277 let m = t.mapv(f32::sin);
278 let err2 = Array1::from_elem(101, 0.001);
280
281 let from_gausses_dt_points = {
282 let mut map = dmdt.gausses::<Eps1Over1e3Erf>(
283 t.as_slice().unwrap(),
284 m.as_slice().unwrap(),
285 err2.as_slice_memory_order().unwrap(),
286 );
287 let dt_points = dmdt.dt_points(t.as_slice().unwrap());
288 let dt_non_zero_points = dt_points.mapv(|x| if x == 0 { 1.0 } else { x as f32 });
289 map /= &dt_non_zero_points.to_shape((map.nrows(), 1)).unwrap();
290 map
291 };
292
293 let from_cond_prob = dmdt.cond_prob::<Eps1Over1e3Erf>(
294 t.as_slice().unwrap(),
295 m.as_slice().unwrap(),
296 err2.as_slice().unwrap(),
297 );
298
299 assert_abs_diff_eq!(
300 from_gausses_dt_points.as_slice().unwrap(),
301 from_cond_prob.as_slice().unwrap(),
302 epsilon = f32::EPSILON,
303 );
304 }
305}