Skip to main content

polars_compute/rolling/nulls/
rank.rs

1use core::panic;
2use std::marker::PhantomData;
3
4use polars_utils::IdxSize;
5use polars_utils::order_statistic_tree::OrderStatisticTree;
6
7use super::super::rank::*;
8use super::*;
9
10pub struct RankWindow<'a, T, Out, P> {
11    slice: &'a [T],
12    validity: &'a Bitmap,
13    start: usize,
14    end: usize,
15    ost: OrderStatisticTree<&'a T>,
16    policy: P,
17    _out: PhantomData<Out>,
18}
19
20impl<T, Out, P> RollingAggWindowNulls<T, Out> for RankWindow<'_, T, Out, P>
21where
22    T: NativeType,
23    Out: NativeType,
24    P: RankPolicy<T, Out>,
25{
26    type This<'a> = RankWindow<'a, T, Out, P>;
27
28    fn new<'a>(
29        slice: &'a [T],
30        validity: &'a Bitmap,
31        start: usize,
32        end: usize,
33        params: Option<RollingFnParams>,
34        window_size: Option<usize>,
35    ) -> Self::This<'a> {
36        assert!(start <= slice.len() && end <= slice.len() && start <= end);
37
38        let cmp = |a: &&T, b: &&T| T::tot_cmp(*a, *b);
39        let ost: OrderStatisticTree<&T> = match window_size {
40            Some(ws) => OrderStatisticTree::with_capacity(ws, cmp),
41            None => OrderStatisticTree::new(cmp),
42        };
43        let mut this = RankWindow {
44            slice,
45            validity,
46            start: 0,
47            end: 0,
48            ost,
49            policy: P::new(&params.unwrap()),
50            _out: PhantomData,
51        };
52        // SAFETY: We bounds checked `start` and `end`.
53        unsafe {
54            this.update(start, end);
55        }
56        this
57    }
58
59    unsafe fn update(&mut self, new_start: usize, new_end: usize) {
60        debug_assert!(self.start <= self.end);
61        debug_assert!(self.end <= self.slice.len());
62        debug_assert!(new_start <= new_end);
63        debug_assert!(new_end <= self.slice.len());
64        debug_assert!(self.start <= new_start);
65        debug_assert!(self.end <= new_end);
66
67        for i in self.end..new_end {
68            if !self.validity.get(i).unwrap() {
69                continue;
70            }
71            self.ost.insert(unsafe { self.slice.get_unchecked(i) });
72        }
73        for i in self.start..new_start {
74            if !self.validity.get(i).unwrap() {
75                continue;
76            }
77            self.ost
78                .remove(&unsafe { self.slice.get_unchecked(i) })
79                .expect("previously added value is missing");
80        }
81        self.start = new_start;
82        self.end = new_end;
83    }
84
85    fn get_agg(&self, idx: usize) -> Option<Out> {
86        if !(self.start..self.end).contains(&idx) {
87            panic!("index out of bounds");
88        }
89        self.policy.rank(&self.ost, &self.slice[idx])
90    }
91
92    fn is_valid(&self, _min_periods: usize) -> bool {
93        self.validity.get(self.end - 1).unwrap()
94    }
95
96    fn slice_len(&self) -> usize {
97        self.slice.len()
98    }
99}
100
101pub type RankWindowAvg<'a, T> = RankWindow<'a, T, f64, RankPolicyAverage>;
102pub type RankWindowMin<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMin>;
103pub type RankWindowMax<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMax>;
104pub type RankWindowDense<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyDense>;
105pub type RankWindowRandom<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyRandom>;
106
107pub fn rolling_rank<T>(
108    arr: &PrimitiveArray<T>,
109    window_size: usize,
110    min_periods: usize,
111    center: bool,
112    weights: Option<&[f64]>,
113    params: Option<RollingFnParams>,
114) -> ArrayRef
115where
116    T: NativeType,
117{
118    assert!(weights.is_none(), "weights are not supported for rank");
119
120    let offset_fn = match center {
121        true => det_offsets_center,
122        false => det_offsets,
123    };
124    let method = if let Some(RollingFnParams::Rank { method, .. }) = params {
125        method
126    } else {
127        unreachable!("expected RollingFnParams::Rank");
128    };
129
130    match method {
131        RollingRankMethod::Average => rolling_apply_agg_window::<RankWindowAvg<T>, _, _, _>(
132            arr.values().as_slice(),
133            arr.validity().as_ref().unwrap(),
134            window_size,
135            min_periods,
136            offset_fn,
137            params,
138        ),
139        RollingRankMethod::Min => rolling_apply_agg_window::<RankWindowMin<T>, _, _, _>(
140            arr.values().as_slice(),
141            arr.validity().as_ref().unwrap(),
142            window_size,
143            min_periods,
144            offset_fn,
145            params,
146        ),
147        RollingRankMethod::Max => rolling_apply_agg_window::<RankWindowMax<T>, _, _, _>(
148            arr.values().as_slice(),
149            arr.validity().as_ref().unwrap(),
150            window_size,
151            min_periods,
152            offset_fn,
153            params,
154        ),
155        RollingRankMethod::Dense => rolling_apply_agg_window::<RankWindowDense<T>, _, _, _>(
156            arr.values().as_slice(),
157            arr.validity().as_ref().unwrap(),
158            window_size,
159            min_periods,
160            offset_fn,
161            params,
162        ),
163        RollingRankMethod::Random => rolling_apply_agg_window::<RankWindowRandom<T>, _, _, _>(
164            arr.values().as_slice(),
165            arr.validity().as_ref().unwrap(),
166            window_size,
167            min_periods,
168            offset_fn,
169            params,
170        ),
171    }
172}