tea_map/
vec_map.rs

1use std::ops::Sub;
2
3use tea_core::prelude::*;
4/// Trait for vector-like types that support map operations on valid elements.
5///
6/// This trait provides methods for performing various operations on vectors,
7/// such as calculating differences, percentage changes, rankings, and partitions.
8pub trait MapValidVec<T: IsNone>: Vec1View<T> {
9    /// Calculates the difference between elements in the vector.
10    ///
11    /// # Arguments
12    ///
13    /// * `n` - The lag for the difference calculation. Positive values look forward, negative values look backward.
14    /// * `value` - The value to use for padding when there are not enough elements.
15    ///
16    /// # Returns
17    ///
18    /// A boxed iterator of differences.
19    fn vdiff<'a>(&'a self, n: i32, value: Option<T>) -> Box<dyn TrustedLen<Item = T> + 'a>
20    where
21        T: Clone + Sub<Output = T> + Zero + 'a,
22        Self: 'a,
23    {
24        let len = self.len();
25        let n_abs = n.unsigned_abs() as usize;
26        let value = value.unwrap_or_else(|| T::none());
27        if len <= n_abs {
28            return Box::new(std::iter::repeat(value).take(len));
29        }
30        match n {
31            n if n > 0 => Box::new(
32                std::iter::repeat(value)
33                    .take(n_abs)
34                    .chain(self.titer().take(len - n_abs))
35                    .zip(self.titer())
36                    .map(|(a, b)| b - a)
37                    .to_trust(len),
38            ),
39            n if n < 0 => Box::new(
40                self.titer()
41                    .skip(n_abs)
42                    .zip(self.titer())
43                    .map(|(a, b)| b - a)
44                    .chain(std::iter::repeat(value).take(n_abs))
45                    .to_trust(len),
46            ),
47            _ => Box::new(std::iter::repeat(T::zero()).take(len).to_trust(len)),
48        }
49    }
50
51    /// Calculates the percentage change between elements in the vector.
52    ///
53    /// # Arguments
54    ///
55    /// * `n` - The lag for the percentage change calculation. Positive values look forward, negative values look backward.
56    ///
57    /// # Returns
58    ///
59    /// A boxed iterator of percentage changes.
60    fn vpct_change<'a>(&'a self, n: i32) -> Box<dyn TrustedLen<Item = f64> + 'a>
61    where
62        T: Clone + Cast<f64> + 'a,
63        Self: 'a,
64    {
65        let len = self.len();
66        let n_abs = n.unsigned_abs() as usize;
67        if len <= n_abs {
68            return Box::new(std::iter::repeat(f64::NAN).take(len));
69        }
70        match n {
71            n if n > 0 => Box::new(
72                std::iter::repeat(f64::NAN)
73                    .take(n_abs)
74                    .chain(self.titer().take(len - n_abs).map(|v| v.cast()))
75                    .zip(self.titer())
76                    .map(|(a, b)| {
77                        if a.not_none() && b.not_none() && (a != 0.) {
78                            b.cast() / a - 1.
79                        } else {
80                            f64::NAN
81                        }
82                    })
83                    .to_trust(len),
84            ),
85            n if n < 0 => Box::new(
86                self.titer()
87                    .skip(n_abs)
88                    .zip(self.titer())
89                    .map(|(a, b)| {
90                        if a.not_none() && b.not_none() {
91                            let a: f64 = a.cast();
92                            if a != 0. {
93                                b.cast() / a - 1.
94                            } else {
95                                f64::NAN
96                            }
97                        } else {
98                            f64::NAN
99                        }
100                    })
101                    .chain(std::iter::repeat(f64::NAN).take(n_abs))
102                    .to_trust(len),
103            ),
104            _ => Box::new(std::iter::repeat(0.).take(len).to_trust(len)),
105        }
106    }
107
108    /// Calculates the rank of elements in the vector.
109    ///
110    /// # Arguments
111    ///
112    /// * `pct` - If true, returns percentage ranks.
113    /// * `rev` - If true, ranks in descending order.
114    ///
115    /// # Returns
116    ///
117    /// A vector of ranks.
118    fn vrank<O: Vec1<OT>, OT: IsNone>(&self, pct: bool, rev: bool) -> O
119    where
120        T: IsNone + PartialEq,
121        T::Inner: PartialOrd,
122        f64: Cast<OT>,
123    {
124        let len = self.len();
125        if len == 0 {
126            return O::empty();
127        } else if len == 1 {
128            return O::full(len, (1.).cast());
129        }
130        // argsort at first
131        let mut idx_sorted: Vec<_> = (0..len).collect_trusted_to_vec();
132        if !rev {
133            idx_sorted
134                .sort_unstable_by(|a, b| {
135                    let (va, vb) = unsafe { (self.uget(*a), self.uget(*b)) }; // safety: out不超过self的长度
136                    va.sort_cmp(&vb)
137                })
138                .unwrap();
139        } else {
140            idx_sorted
141                .sort_unstable_by(|a, b| {
142                    let (va, vb) = unsafe { (self.uget(*a), self.uget(*b)) }; // safety: out不超过self的长度
143                    va.sort_cmp_rev(&vb)
144                })
145                .unwrap();
146        }
147        // if the first value is none then all the elements are none
148        if unsafe { self.uget(idx_sorted.uget(0)) }.is_none() {
149            return O::full(len, OT::none());
150        }
151        let mut out = O::uninit(len);
152        let mut repeat_num = 1usize;
153        let mut nan_flag = false;
154        let (mut cur_rank, mut sum_rank) = (1usize, 0usize);
155        let mut idx: usize = 0;
156        let mut idx1: usize;
157        if !pct {
158            unsafe {
159                for i in 0..len - 1 {
160                    // safe because max of i = len-2 and len >= 2
161                    (idx, idx1) = (idx_sorted.uget(i), idx_sorted.uget(i + 1));
162                    let (v, v1) = (self.uget(idx), self.uget(idx1)); // next_value
163                    if v1.is_none() {
164                        // next value is none, so remain values are none
165                        sum_rank += cur_rank;
166                        cur_rank += 1;
167                        for j in 0..repeat_num {
168                            // safe because i >= repeat_num
169                            out.uset(
170                                idx_sorted.uget(i - j),
171                                (sum_rank.f64() / repeat_num.f64()).cast(),
172                            );
173                        }
174                        idx = i + 1;
175                        nan_flag = true;
176                        break;
177                    } else if v == v1 {
178                        // current value is the same with next value, repeating
179                        repeat_num += 1;
180                        sum_rank += cur_rank;
181                        cur_rank += 1;
182                    } else if repeat_num == 1 {
183                        // no repeat, can get the rank directly
184                        out.uset(idx, (cur_rank as f64).cast());
185                        cur_rank += 1;
186                    } else {
187                        // current element is the last repeated value
188                        sum_rank += cur_rank;
189                        cur_rank += 1;
190                        for j in 0..repeat_num {
191                            // safe because i >= repeat_num
192                            out.uset(
193                                idx_sorted.uget(i - j),
194                                (sum_rank.f64() / repeat_num.f64()).cast(),
195                            );
196                        }
197                        sum_rank = 0;
198                        repeat_num = 1;
199                    }
200                }
201                if nan_flag {
202                    for i in idx..len {
203                        out.uset(idx_sorted.uget(i), f64::NAN.cast())
204                    }
205                } else {
206                    sum_rank += cur_rank;
207                    for i in len - repeat_num..len {
208                        // safe because repeat_num <= len
209                        out.uset(
210                            idx_sorted.uget(i),
211                            (sum_rank.f64() / repeat_num.f64()).cast(),
212                        )
213                    }
214                }
215            }
216        } else {
217            let not_none_count = self.titer().count_valid();
218            unsafe {
219                for i in 0..len - 1 {
220                    // safe because max of i = len-2 and len >= 2
221                    (idx, idx1) = (idx_sorted.uget(i), idx_sorted.uget(i + 1));
222                    let (v, v1) = (self.uget(idx), self.uget(idx1)); // next_value
223                    if v1.is_none() {
224                        // next value is none, so remain values are none
225                        sum_rank += cur_rank;
226                        cur_rank += 1;
227                        for j in 0..repeat_num {
228                            // safe because i >= repeat_num
229                            out.uset(
230                                idx_sorted.uget(i - j),
231                                (sum_rank.f64() / (repeat_num * not_none_count).f64()).cast(),
232                            );
233                        }
234                        idx = i + 1;
235                        nan_flag = true;
236                        break;
237                    } else if v == v1 {
238                        // current value is the same with next value, repeating
239                        repeat_num += 1;
240                        sum_rank += cur_rank;
241                        cur_rank += 1;
242                    } else if repeat_num == 1 {
243                        // no repeat, can get the rank directly
244                        out.uset(idx, (cur_rank as f64 / not_none_count as f64).cast());
245                        cur_rank += 1;
246                    } else {
247                        // current element is the last repeated value
248                        sum_rank += cur_rank;
249                        cur_rank += 1;
250                        for j in 0..repeat_num {
251                            // safe because i >= repeat_num
252                            out.uset(
253                                idx_sorted.uget(i - j),
254                                (sum_rank.f64() / (repeat_num * not_none_count).f64()).cast(),
255                            );
256                        }
257                        sum_rank = 0;
258                        repeat_num = 1;
259                    }
260                }
261                if nan_flag {
262                    for i in idx..len {
263                        out.uset(idx_sorted.uget(i), f64::NAN.cast())
264                    }
265                } else {
266                    sum_rank += cur_rank;
267                    for i in len - repeat_num..len {
268                        // safe because repeat_num <= len
269                        out.uset(
270                            idx_sorted.uget(i),
271                            (sum_rank.f64() / (repeat_num * not_none_count).f64()).cast(),
272                        )
273                    }
274                }
275            }
276        }
277        unsafe { out.assume_init() }
278    }
279
280    /// Returns the indices of the kth smallest elements.
281    ///
282    /// # Arguments
283    ///
284    /// * `kth` - The k-th smallest element to find.
285    /// * `sort` - If true, sort the result.
286    /// * `rev` - If true, find the kth largest elements instead.
287    ///
288    /// # Returns
289    ///
290    /// A boxed iterator of indices.
291    fn varg_partition<'a>(
292        &'a self,
293        kth: usize,
294        sort: bool,
295        rev: bool,
296    ) -> Box<dyn TrustedLen<Item = i32> + 'a>
297    where
298        T::Inner: Number,
299        T: 'a,
300    {
301        let n = self.titer().count_valid();
302        // fast path for n <= kth + 1
303        if n <= kth + 1 {
304            if !sort {
305                return Box::new(
306                    self.titer()
307                        .enumerate()
308                        .filter_map(|(i, v)| if v.not_none() { Some(i as i32) } else { None })
309                        .chain(std::iter::repeat(-1))
310                        .take(kth + 1)
311                        .to_trust(kth + 1),
312                );
313            } else {
314                let mut idx_sorted: Vec<_> = Vec1Create::range(None, self.len() as i32, None);
315                if !rev {
316                    idx_sorted
317                        .sort_unstable_by(|a: &i32, b: &i32| {
318                            let (va, vb) =
319                                unsafe { (self.uget((*a) as usize), self.uget((*b) as usize)) }; // safety: out不超过self的长度
320                            va.sort_cmp(&vb)
321                        })
322                        .unwrap()
323                } else {
324                    idx_sorted
325                        .sort_unstable_by(|a: &i32, b: &i32| {
326                            let (va, vb) =
327                                unsafe { (self.uget((*a) as usize), self.uget((*b) as usize)) }; // safety: out不超过self的长度
328                            va.sort_cmp_rev(&vb)
329                        })
330                        .unwrap()
331                }
332                return Box::new(
333                    idx_sorted
334                        .into_iter()
335                        .take(n)
336                        .chain(std::iter::repeat(-1))
337                        .take(kth + 1)
338                        .to_trust(kth + 1),
339                );
340            }
341        }
342        let mut out_c: Vec<_> = self.titer().collect_trusted_vec1(); // clone the array
343        let slc = out_c.try_as_slice_mut().unwrap();
344        let mut idx_sorted: Vec<_> = Vec1Create::range(None, slc.len() as i32, None);
345        if !rev {
346            let sort_func = |a: &i32, b: &i32| {
347                let (va, vb) = unsafe { (self.uget((*a) as usize), self.uget((*b) as usize)) }; // safety: out不超过self的长度
348                va.sort_cmp(&vb)
349            };
350            idx_sorted.select_nth_unstable_by(kth, sort_func);
351            idx_sorted.truncate(kth + 1);
352            if sort {
353                idx_sorted.sort_unstable_by(sort_func).unwrap();
354            }
355            Box::new(idx_sorted.into_iter().to_trust(kth + 1))
356        } else {
357            let sort_func = |a: &i32, b: &i32| {
358                let (va, vb) = unsafe { (self.uget((*a) as usize), self.uget((*b) as usize)) }; // safety: out不超过self的长度
359                va.sort_cmp_rev(&vb)
360            };
361            idx_sorted.select_nth_unstable_by(kth, sort_func);
362            idx_sorted.truncate(kth + 1);
363            if sort {
364                idx_sorted.sort_unstable_by(sort_func).unwrap();
365            }
366            Box::new(idx_sorted.into_iter().to_trust(kth + 1))
367        }
368    }
369    /// sort: whether to sort the result by the size of the element
370    fn vpartition<'a>(
371        &'a self,
372        kth: usize,
373        sort: bool,
374        rev: bool,
375    ) -> Box<dyn TrustedLen<Item = T> + 'a>
376    where
377        T::Inner: PartialOrd,
378        T: 'a,
379    {
380        let n = self.titer().count_valid();
381        if (n == kth + 1) && !sort {
382            return Box::new(self.titer().filter(IsNone::not_none).to_trust(kth + 1));
383        }
384        if n <= kth + 1 {
385            if !sort {
386                return Box::new(
387                    self.titer()
388                        .filter(IsNone::not_none)
389                        .chain(std::iter::repeat(T::none()))
390                        .take(kth + 1)
391                        .to_trust(kth + 1),
392                );
393            } else {
394                let mut vec: Vec<_> = self.titer().collect_trusted_vec1(); // clone the array
395                if !rev {
396                    vec.sort_unstable_by(|a, b| a.sort_cmp(b)).unwrap();
397                } else {
398                    vec.sort_unstable_by(|a, b| a.sort_cmp_rev(b)).unwrap();
399                }
400                return Box::new(vec.into_iter().take(kth + 1));
401            }
402        }
403        let mut out_c: Vec<_> = self.titer().collect_trusted_vec1(); // clone the array
404        let sort_func = if !rev { T::sort_cmp } else { T::sort_cmp_rev };
405        out_c.select_nth_unstable_by(kth, sort_func);
406        out_c.truncate(kth + 1);
407        if sort {
408            out_c.sort_unstable_by(sort_func).unwrap();
409        }
410        Box::new(out_c.into_iter().to_trust(kth + 1))
411    }
412}
413
414impl<T: IsNone, I: Vec1View<T>> MapValidVec<T> for I {}
415
416#[cfg(test)]
417mod test {
418    use tea_core::testing::assert_vec1d_equal_numeric;
419
420    use super::*;
421
422    #[test]
423    fn test_diff() {
424        let v: Vec<f64> = vec![];
425        let res: Vec<_> = v.vdiff(2, None).collect_trusted_vec1();
426        assert_eq!(res, vec![]);
427        let v = vec![4., 1., 12., 4.];
428        let res: Vec<_> = v.vdiff(1, None).collect_trusted_vec1();
429        assert_vec1d_equal_numeric(&res, &vec![f64::NAN, -3., 11., -8.], None);
430        let res: Vec<_> = v.vdiff(-1, Some(0.)).collect_trusted_vec1();
431        assert_eq!(res, vec![3., -11., 8., 0.]);
432    }
433
434    #[test]
435    fn test_pct_change() {
436        let v: Vec<f64> = vec![];
437        let res: Vec<_> = v.vpct_change(2).collect_trusted_vec1();
438        assert_eq!(res, vec![]);
439        let v = vec![1., 2., 3., 4.5];
440        let res: Vec<_> = v.vpct_change(1).collect_trusted_vec1();
441        assert_vec1d_equal_numeric(&res, &vec![f64::NAN, 1., 0.5, 0.5], None);
442        let res: Vec<_> = v.vpct_change(-1).collect_trusted_vec1();
443        assert_vec1d_equal_numeric(&res, &vec![-0.5, -1. / 3., -1. / 3., f64::NAN], None);
444    }
445
446    #[test]
447    fn test_rank() {
448        let v = vec![2., 1., f64::NAN, 3., 1.];
449        let res: Vec<f64> = v.vrank(false, false);
450        let expect = vec![3., 1.5, f64::NAN, 4., 1.5];
451        assert_vec1d_equal_numeric(&res, &expect, None);
452        let res: Vec<Option<f64>> = v.vrank(false, true);
453        let expect = vec![Some(2.), Some(3.5), None, Some(1.), Some(3.5)];
454        assert_vec1d_equal_numeric(&res, &expect, None);
455    }
456
457    #[test]
458    fn test_partition() {
459        let v = vec![1, 3, 5, 1, 5, 6, 7, 32, 1];
460        let res: Vec<_> = v.varg_partition(3, true, false).collect();
461        assert_eq!(res, vec![0, 3, 8, 1]);
462        let res: Vec<_> = v.vpartition(3, true, false).collect();
463        assert_eq!(res, vec![1, 1, 1, 3]);
464        let res: Vec<_> = v.varg_partition(3, true, true).collect();
465        assert_eq!(res, vec![7, 6, 5, 2]);
466        let res: Vec<_> = v.vpartition(3, true, true).collect();
467        assert_eq!(res, vec![32, 7, 6, 5]);
468        let v = vec![1., f64::NAN, 3., f64::NAN, f64::NAN];
469        assert_eq!(
470            v.varg_partition(2, true, true).collect_trusted_to_vec(),
471            vec![2, 0, -1]
472        );
473        assert_vec1d_equal_numeric(
474            &v.vpartition(2, true, true).collect_trusted_to_vec(),
475            &vec![3., 1., f64::NAN],
476            None,
477        )
478    }
479}