1mod vec_valid;
2
3use tea_core::prelude::*;
4pub use vec_valid::*;
5
6#[derive(Default, Clone, Copy)]
7pub enum PercentileOfMethod {
8    #[default]
9    Rank,
10    Weak,
11    Strict,
12}
13
14pub trait AggValidExt<T: IsNone>: IntoIterator<Item = T> + Sized {
16    #[inline]
26    fn n_vsum_filter<U, I>(self, mask: I) -> (usize, T::Inner)
27    where
28        I: IntoIterator<Item = U>,
29        U: IsNone,
30        U::Inner: Cast<bool>,
31        T::Inner: Number,
32    {
33        self.into_iter()
34            .zip(mask)
35            .filter_map(|(v, flag)| {
36                if flag.not_none() {
37                    if flag.unwrap().cast() { Some(v) } else { None }
38                } else {
39                    None
40                }
41            })
42            .vfold_n(T::Inner::zero(), |acc, x| acc + x)
43    }
44
45    #[inline]
55    fn n_sum_filter<U, I>(self, mask: I) -> Option<T::Inner>
56    where
57        I: IntoIterator<Item = U>,
58        U: IsNone,
59        U::Inner: Cast<bool>,
60        T::Inner: Number,
61    {
62        let (n, sum) = self.n_vsum_filter(mask);
63        if n > 0 { Some(sum) } else { None }
64    }
65
66    #[inline]
77    fn vmean_filter<U, I>(self, mask: I, min_periods: usize) -> f64
78    where
79        I: IntoIterator<Item = U>,
80        U: IsNone,
81        U::Inner: Cast<bool>,
82        T::Inner: Number,
83    {
84        let (n, sum) = self.n_vsum_filter(mask);
85        if n >= min_periods {
86            sum.f64() / n.f64()
87        } else {
88            f64::NAN
89        }
90    }
91
92    fn vkurt(self, min_periods: usize) -> f64
102    where
103        T::Inner: Number,
104    {
105        let (mut m1, mut m2, mut m3, mut m4) = (0., 0., 0., 0.);
106        let n = self.vapply_n(|v| {
107            let v = v.f64();
108            m1 += v;
109            let v2 = v * v;
110            m2 += v2;
111            m3 += v2 * v;
112            m4 += v2 * v2;
113        });
114        if n < min_periods {
115            return f64::NAN;
116        }
117        let mut res = if n >= 4 {
118            let n_f64 = n.f64();
119            m1 /= n_f64; m2 /= n_f64; let var = m2 - m1.powi(2);
122            if var <= EPS {
123                0.
124            } else {
125                let var2 = var.powi(2); m4 /= n_f64; m3 /= n_f64; let mean2_var = m1.powi(2) / var; (m4 - 4. * m1 * m3) / var2 + 6. * mean2_var + 3. * mean2_var.powi(2)
130            }
131        } else {
132            f64::NAN
133        };
134        if res.not_none() && res != 0. {
135            res = 1. / ((n - 2) * (n - 3)).f64()
136                * ((n.pow(2) - 1).f64() * res - (3 * (n - 1).pow(2)).f64())
137        }
138        res
139    }
140
141    fn vpercentile_of(self, score: T, method: PercentileOfMethod) -> f64
155    where
156        T::Inner: Number + PartialOrd,
157        T: IsNone,
158    {
159        let (mut less_than_count, mut exact_match_count, mut total_count) = (0, 0, 0);
160        let score = if score.is_none() {
161            return f64::NAN;
162        } else {
163            score.unwrap()
164        };
165        self.into_iter().for_each(|v| {
166            if let Some(value) = v.to_opt() {
167                total_count += 1;
168                if value < score {
169                    less_than_count += 1;
170                } else if value == score {
171                    exact_match_count += 1;
172                }
173            }
174        });
175
176        if total_count == 0 {
177            return f64::NAN;
178        }
179
180        let less_equal_count = less_than_count + exact_match_count;
181
182        match method {
183            PercentileOfMethod::Rank => {
184                if exact_match_count > 1 {
185                    let rank_start = less_than_count + 1;
186                    let rank_end = rank_start + (exact_match_count - 1);
187                    ((rank_start + rank_end).f64() * 0.5) / total_count.f64()
188                } else {
189                    (less_than_count + exact_match_count).f64() / total_count.f64()
190                }
191            },
192            PercentileOfMethod::Weak => less_equal_count.f64() / total_count.f64(),
193            PercentileOfMethod::Strict => less_than_count.f64() / total_count.f64(),
194        }
195    }
196}
197
198impl<I: IntoIterator<Item = T>, T: IsNone> AggValidExt<T> for I {}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    #[test]
204    fn test_vpercentile_of() {
205        assert!([].vpercentile_of(2, Default::default()).is_nan());
206        assert_eq!(vec![1, 2, 3, 4].vpercentile_of(3, Default::default()), 0.75);
207        assert_eq!([1, 2, 3, 3, 4].vpercentile_of(3, Default::default()), 0.7);
208        assert_eq!(
209            [1, 2, 3, 3, 4].vpercentile_of(3, PercentileOfMethod::Strict),
210            0.4
211        );
212        assert_eq!(
213            [1, 2, 3, 3, 4].vpercentile_of(3, PercentileOfMethod::Weak),
214            0.8
215        );
216        assert_eq!(
217            [1., f64::NAN, 2., f64::NAN, 3., 3., 3., 4., 5.].vpercentile_of(3., Default::default()),
218            4. / 7.
219        )
220    }
221}