tea_map/
vec_map.rs

1use std::ops::Sub;
2
3use tea_core::prelude::*;
4/// Trait for vector-like types that support map operations on valid elements.
5///
6/// This trait provides methods for performing various operations on vectors,
7/// such as calculating differences, percentage changes, rankings, and partitions.
8pub trait MapValidVec<T: IsNone>: Vec1View<T> {
9    /// Calculates the difference between elements in the vector.
10    ///
11    /// # Arguments
12    ///
13    /// * `n` - The lag for the difference calculation. Positive values look forward, negative values look backward.
14    /// * `value` - The value to use for padding when there are not enough elements.
15    ///
16    /// # Returns
17    ///
18    /// A boxed iterator of differences.
19    fn vdiff<'a>(&'a self, n: i32, value: Option<T>) -> Box<dyn TrustedLen<Item = T> + 'a>
20    where
21        T: Clone + Sub<Output = T> + Zero + 'a,
22        Self: 'a,
23    {
24        let len = self.len();
25        let n_abs = n.unsigned_abs() as usize;
26        let value = value.unwrap_or_else(|| T::none());
27        if len <= n_abs {
28            return Box::new(std::iter::repeat_n(value, len));
29        }
30        match n {
31            n if n > 0 => Box::new(
32                std::iter::repeat_n(value, n_abs)
33                    .chain(self.titer().take(len - n_abs))
34                    .zip(self.titer())
35                    .map(|(a, b)| b - a)
36                    .to_trust(len),
37            ),
38            n if n < 0 => Box::new(
39                self.titer()
40                    .skip(n_abs)
41                    .zip(self.titer())
42                    .map(|(a, b)| b - a)
43                    .chain(std::iter::repeat_n(value, n_abs))
44                    .to_trust(len),
45            ),
46            _ => Box::new(std::iter::repeat_n(T::zero(), len).to_trust(len)),
47        }
48    }
49
50    /// Calculates the percentage change between elements in the vector.
51    ///
52    /// # Arguments
53    ///
54    /// * `n` - The lag for the percentage change calculation. Positive values look forward, negative values look backward.
55    ///
56    /// # Returns
57    ///
58    /// A boxed iterator of percentage changes.
59    fn vpct_change<'a>(&'a self, n: i32) -> Box<dyn TrustedLen<Item = f64> + 'a>
60    where
61        T: Clone + Cast<f64> + 'a,
62        Self: 'a,
63    {
64        let len = self.len();
65        let n_abs = n.unsigned_abs() as usize;
66        if len <= n_abs {
67            return Box::new(std::iter::repeat_n(f64::NAN, len));
68        }
69        match n {
70            n if n > 0 => Box::new(
71                std::iter::repeat_n(f64::NAN, n_abs)
72                    .chain(self.titer().take(len - n_abs).map(|v| v.cast()))
73                    .zip(self.titer())
74                    .map(|(a, b)| {
75                        if a.not_none() && b.not_none() && (a != 0.) {
76                            b.cast() / a - 1.
77                        } else {
78                            f64::NAN
79                        }
80                    })
81                    .to_trust(len),
82            ),
83            n if n < 0 => Box::new(
84                self.titer()
85                    .skip(n_abs)
86                    .zip(self.titer())
87                    .map(|(a, b)| {
88                        if a.not_none() && b.not_none() {
89                            let a: f64 = a.cast();
90                            if a != 0. { b.cast() / a - 1. } else { f64::NAN }
91                        } else {
92                            f64::NAN
93                        }
94                    })
95                    .chain(std::iter::repeat_n(f64::NAN, n_abs))
96                    .to_trust(len),
97            ),
98            _ => Box::new(std::iter::repeat_n(0., len).to_trust(len)),
99        }
100    }
101
102    /// Calculates the rank of elements in the vector.
103    ///
104    /// # Arguments
105    ///
106    /// * `pct` - If true, returns percentage ranks.
107    /// * `rev` - If true, ranks in descending order.
108    ///
109    /// # Returns
110    ///
111    /// A vector of ranks.
112    fn vrank<O: Vec1<OT>, OT: IsNone>(&self, pct: bool, rev: bool) -> O
113    where
114        T: IsNone + PartialEq,
115        T::Inner: PartialOrd,
116        f64: Cast<OT>,
117    {
118        let len = self.len();
119        if len == 0 {
120            return O::empty();
121        } else if len == 1 {
122            return O::full(len, (1.).cast());
123        }
124        // argsort at first
125        let mut idx_sorted: Vec<_> = (0..len).collect_trusted_to_vec();
126        if !rev {
127            idx_sorted
128                .sort_unstable_by(|a, b| {
129                    let (va, vb) = unsafe { (self.uget(*a), self.uget(*b)) }; // safety: out不超过self的长度
130                    va.sort_cmp(&vb)
131                })
132                .unwrap();
133        } else {
134            idx_sorted
135                .sort_unstable_by(|a, b| {
136                    let (va, vb) = unsafe { (self.uget(*a), self.uget(*b)) }; // safety: out不超过self的长度
137                    va.sort_cmp_rev(&vb)
138                })
139                .unwrap();
140        }
141        // if the first value is none then all the elements are none
142        if unsafe { self.uget(idx_sorted.uget(0)) }.is_none() {
143            return O::full(len, OT::none());
144        }
145        let mut out = O::uninit(len);
146        let mut repeat_num = 1usize;
147        let mut nan_flag = false;
148        let (mut cur_rank, mut sum_rank) = (1usize, 0usize);
149        let mut idx: usize = 0;
150        let mut idx1: usize;
151        if !pct {
152            unsafe {
153                for i in 0..len - 1 {
154                    // safe because max of i = len-2 and len >= 2
155                    (idx, idx1) = (idx_sorted.uget(i), idx_sorted.uget(i + 1));
156                    let (v, v1) = (self.uget(idx), self.uget(idx1)); // next_value
157                    if v1.is_none() {
158                        // next value is none, so remain values are none
159                        sum_rank += cur_rank;
160                        cur_rank += 1;
161                        for j in 0..repeat_num {
162                            // safe because i >= repeat_num
163                            out.uset(
164                                idx_sorted.uget(i - j),
165                                (sum_rank.f64() / repeat_num.f64()).cast(),
166                            );
167                        }
168                        idx = i + 1;
169                        nan_flag = true;
170                        break;
171                    } else if v == v1 {
172                        // current value is the same with next value, repeating
173                        repeat_num += 1;
174                        sum_rank += cur_rank;
175                        cur_rank += 1;
176                    } else if repeat_num == 1 {
177                        // no repeat, can get the rank directly
178                        out.uset(idx, (cur_rank as f64).cast());
179                        cur_rank += 1;
180                    } else {
181                        // current element is the last repeated value
182                        sum_rank += cur_rank;
183                        cur_rank += 1;
184                        for j in 0..repeat_num {
185                            // safe because i >= repeat_num
186                            out.uset(
187                                idx_sorted.uget(i - j),
188                                (sum_rank.f64() / repeat_num.f64()).cast(),
189                            );
190                        }
191                        sum_rank = 0;
192                        repeat_num = 1;
193                    }
194                }
195                if nan_flag {
196                    for i in idx..len {
197                        out.uset(idx_sorted.uget(i), f64::NAN.cast())
198                    }
199                } else {
200                    sum_rank += cur_rank;
201                    for i in len - repeat_num..len {
202                        // safe because repeat_num <= len
203                        out.uset(
204                            idx_sorted.uget(i),
205                            (sum_rank.f64() / repeat_num.f64()).cast(),
206                        )
207                    }
208                }
209            }
210        } else {
211            let not_none_count = self.titer().count_valid();
212            unsafe {
213                for i in 0..len - 1 {
214                    // safe because max of i = len-2 and len >= 2
215                    (idx, idx1) = (idx_sorted.uget(i), idx_sorted.uget(i + 1));
216                    let (v, v1) = (self.uget(idx), self.uget(idx1)); // next_value
217                    if v1.is_none() {
218                        // next value is none, so remain values are none
219                        sum_rank += cur_rank;
220                        cur_rank += 1;
221                        for j in 0..repeat_num {
222                            // safe because i >= repeat_num
223                            out.uset(
224                                idx_sorted.uget(i - j),
225                                (sum_rank.f64() / (repeat_num * not_none_count).f64()).cast(),
226                            );
227                        }
228                        idx = i + 1;
229                        nan_flag = true;
230                        break;
231                    } else if v == v1 {
232                        // current value is the same with next value, repeating
233                        repeat_num += 1;
234                        sum_rank += cur_rank;
235                        cur_rank += 1;
236                    } else if repeat_num == 1 {
237                        // no repeat, can get the rank directly
238                        out.uset(idx, (cur_rank as f64 / not_none_count as f64).cast());
239                        cur_rank += 1;
240                    } else {
241                        // current element is the last repeated value
242                        sum_rank += cur_rank;
243                        cur_rank += 1;
244                        for j in 0..repeat_num {
245                            // safe because i >= repeat_num
246                            out.uset(
247                                idx_sorted.uget(i - j),
248                                (sum_rank.f64() / (repeat_num * not_none_count).f64()).cast(),
249                            );
250                        }
251                        sum_rank = 0;
252                        repeat_num = 1;
253                    }
254                }
255                if nan_flag {
256                    for i in idx..len {
257                        out.uset(idx_sorted.uget(i), f64::NAN.cast())
258                    }
259                } else {
260                    sum_rank += cur_rank;
261                    for i in len - repeat_num..len {
262                        // safe because repeat_num <= len
263                        out.uset(
264                            idx_sorted.uget(i),
265                            (sum_rank.f64() / (repeat_num * not_none_count).f64()).cast(),
266                        )
267                    }
268                }
269            }
270        }
271        unsafe { out.assume_init() }
272    }
273
274    /// Returns the indices of the kth smallest elements.
275    ///
276    /// # Arguments
277    ///
278    /// * `kth` - The k-th smallest element to find.
279    /// * `sort` - If true, sort the result.
280    /// * `rev` - If true, find the kth largest elements instead.
281    ///
282    /// # Returns
283    ///
284    /// A boxed iterator of indices.
285    fn varg_partition<'a>(
286        &'a self,
287        kth: usize,
288        sort: bool,
289        rev: bool,
290    ) -> Box<dyn TrustedLen<Item = i32> + 'a>
291    where
292        T::Inner: Number,
293        T: 'a,
294    {
295        let n = self.titer().count_valid();
296        // fast path for n <= kth + 1
297        if n <= kth + 1 {
298            if !sort {
299                return Box::new(
300                    self.titer()
301                        .enumerate()
302                        .filter_map(|(i, v)| if v.not_none() { Some(i as i32) } else { None })
303                        .chain(std::iter::repeat(-1))
304                        .take(kth + 1)
305                        .to_trust(kth + 1),
306                );
307            } else {
308                let mut idx_sorted: Vec<_> = Vec1Create::range(None, self.len() as i32, None);
309                if !rev {
310                    idx_sorted
311                        .sort_unstable_by(|a: &i32, b: &i32| {
312                            let (va, vb) =
313                                unsafe { (self.uget((*a) as usize), self.uget((*b) as usize)) }; // safety: out不超过self的长度
314                            va.sort_cmp(&vb)
315                        })
316                        .unwrap()
317                } else {
318                    idx_sorted
319                        .sort_unstable_by(|a: &i32, b: &i32| {
320                            let (va, vb) =
321                                unsafe { (self.uget((*a) as usize), self.uget((*b) as usize)) }; // safety: out不超过self的长度
322                            va.sort_cmp_rev(&vb)
323                        })
324                        .unwrap()
325                }
326                return Box::new(
327                    idx_sorted
328                        .into_iter()
329                        .take(n)
330                        .chain(std::iter::repeat(-1))
331                        .take(kth + 1)
332                        .to_trust(kth + 1),
333                );
334            }
335        }
336        let mut out_c: Vec<_> = self.titer().collect_trusted_vec1(); // clone the array
337        let slc = out_c.try_as_slice_mut().unwrap();
338        let mut idx_sorted: Vec<_> = Vec1Create::range(None, slc.len() as i32, None);
339        if !rev {
340            let sort_func = |a: &i32, b: &i32| {
341                let (va, vb) = unsafe { (self.uget((*a) as usize), self.uget((*b) as usize)) }; // safety: out不超过self的长度
342                va.sort_cmp(&vb)
343            };
344            idx_sorted.select_nth_unstable_by(kth, sort_func);
345            idx_sorted.truncate(kth + 1);
346            if sort {
347                idx_sorted.sort_unstable_by(sort_func).unwrap();
348            }
349            Box::new(idx_sorted.into_iter().to_trust(kth + 1))
350        } else {
351            let sort_func = |a: &i32, b: &i32| {
352                let (va, vb) = unsafe { (self.uget((*a) as usize), self.uget((*b) as usize)) }; // safety: out不超过self的长度
353                va.sort_cmp_rev(&vb)
354            };
355            idx_sorted.select_nth_unstable_by(kth, sort_func);
356            idx_sorted.truncate(kth + 1);
357            if sort {
358                idx_sorted.sort_unstable_by(sort_func).unwrap();
359            }
360            Box::new(idx_sorted.into_iter().to_trust(kth + 1))
361        }
362    }
363    /// sort: whether to sort the result by the size of the element
364    fn vpartition<'a>(
365        &'a self,
366        kth: usize,
367        sort: bool,
368        rev: bool,
369    ) -> Box<dyn TrustedLen<Item = T> + 'a>
370    where
371        T::Inner: PartialOrd,
372        T: 'a,
373    {
374        let n = self.titer().count_valid();
375        if (n == kth + 1) && !sort {
376            return Box::new(self.titer().filter(IsNone::not_none).to_trust(kth + 1));
377        }
378        if n <= kth + 1 {
379            if !sort {
380                return Box::new(
381                    self.titer()
382                        .filter(IsNone::not_none)
383                        .chain(std::iter::repeat(T::none()))
384                        .take(kth + 1)
385                        .to_trust(kth + 1),
386                );
387            } else {
388                let mut vec: Vec<_> = self.titer().collect_trusted_vec1(); // clone the array
389                if !rev {
390                    vec.sort_unstable_by(|a, b| a.sort_cmp(b)).unwrap();
391                } else {
392                    vec.sort_unstable_by(|a, b| a.sort_cmp_rev(b)).unwrap();
393                }
394                return Box::new(vec.into_iter().take(kth + 1));
395            }
396        }
397        let mut out_c: Vec<_> = self.titer().collect_trusted_vec1(); // clone the array
398        let sort_func = if !rev { T::sort_cmp } else { T::sort_cmp_rev };
399        out_c.select_nth_unstable_by(kth, sort_func);
400        out_c.truncate(kth + 1);
401        if sort {
402            out_c.sort_unstable_by(sort_func).unwrap();
403        }
404        Box::new(out_c.into_iter().to_trust(kth + 1))
405    }
406}
407
408impl<T: IsNone, I: Vec1View<T>> MapValidVec<T> for I {}
409
410#[cfg(test)]
411mod test {
412    use tea_core::testing::assert_vec1d_equal_numeric;
413
414    use super::*;
415
416    #[test]
417    fn test_diff() {
418        let v: Vec<f64> = vec![];
419        let res: Vec<_> = v.vdiff(2, None).collect_trusted_vec1();
420        assert_eq!(res, vec![]);
421        let v = vec![4., 1., 12., 4.];
422        let res: Vec<_> = v.vdiff(1, None).collect_trusted_vec1();
423        assert_vec1d_equal_numeric(&res, &vec![f64::NAN, -3., 11., -8.], None);
424        let res: Vec<_> = v.vdiff(-1, Some(0.)).collect_trusted_vec1();
425        assert_eq!(res, vec![3., -11., 8., 0.]);
426    }
427
428    #[test]
429    fn test_pct_change() {
430        let v: Vec<f64> = vec![];
431        let res: Vec<_> = v.vpct_change(2).collect_trusted_vec1();
432        assert_eq!(res, vec![]);
433        let v = vec![1., 2., 3., 4.5];
434        let res: Vec<_> = v.vpct_change(1).collect_trusted_vec1();
435        assert_vec1d_equal_numeric(&res, &vec![f64::NAN, 1., 0.5, 0.5], None);
436        let res: Vec<_> = v.vpct_change(-1).collect_trusted_vec1();
437        assert_vec1d_equal_numeric(&res, &vec![-0.5, -1. / 3., -1. / 3., f64::NAN], None);
438    }
439
440    #[test]
441    fn test_rank() {
442        let v = vec![2., 1., f64::NAN, 3., 1.];
443        let res: Vec<f64> = v.vrank(false, false);
444        let expect = vec![3., 1.5, f64::NAN, 4., 1.5];
445        assert_vec1d_equal_numeric(&res, &expect, None);
446        let res: Vec<Option<f64>> = v.vrank(false, true);
447        let expect = vec![Some(2.), Some(3.5), None, Some(1.), Some(3.5)];
448        assert_vec1d_equal_numeric(&res, &expect, None);
449    }
450
451    #[test]
452    fn test_partition() {
453        let v = vec![1, 3, 5, 1, 5, 6, 7, 32, 1];
454        let res: Vec<_> = v.varg_partition(3, true, false).collect();
455        assert_eq!(res, vec![0, 3, 8, 1]);
456        let res: Vec<_> = v.vpartition(3, true, false).collect();
457        assert_eq!(res, vec![1, 1, 1, 3]);
458        let res: Vec<_> = v.varg_partition(3, true, true).collect();
459        assert_eq!(res, vec![7, 6, 5, 2]);
460        let res: Vec<_> = v.vpartition(3, true, true).collect();
461        assert_eq!(res, vec![32, 7, 6, 5]);
462        let v = vec![1., f64::NAN, 3., f64::NAN, f64::NAN];
463        assert_eq!(
464            v.varg_partition(2, true, true).collect_trusted_to_vec(),
465            vec![2, 0, -1]
466        );
467        assert_vec1d_equal_numeric(
468            &v.vpartition(2, true, true).collect_trusted_to_vec(),
469            &vec![3., 1., f64::NAN],
470            None,
471        )
472    }
473}