polars_compute/rolling/nulls/
rank.rs1use core::panic;
2use std::marker::PhantomData;
3
4use polars_utils::IdxSize;
5use polars_utils::order_statistic_tree::OrderStatisticTree;
6
7use super::super::rank::*;
8use super::*;
9
10pub struct RankWindow<'a, T, Out, P> {
11 slice: &'a [T],
12 validity: &'a Bitmap,
13 start: usize,
14 end: usize,
15 ost: OrderStatisticTree<&'a T>,
16 policy: P,
17 _out: PhantomData<Out>,
18}
19
20impl<T, Out, P> RollingAggWindowNulls<T, Out> for RankWindow<'_, T, Out, P>
21where
22 T: NativeType,
23 Out: NativeType,
24 P: RankPolicy<T, Out>,
25{
26 type This<'a> = RankWindow<'a, T, Out, P>;
27
28 fn new<'a>(
29 slice: &'a [T],
30 validity: &'a Bitmap,
31 start: usize,
32 end: usize,
33 params: Option<RollingFnParams>,
34 window_size: Option<usize>,
35 ) -> Self::This<'a> {
36 assert!(start <= slice.len() && end <= slice.len() && start <= end);
37
38 let cmp = |a: &&T, b: &&T| T::tot_cmp(*a, *b);
39 let ost: OrderStatisticTree<&T> = match window_size {
40 Some(ws) => OrderStatisticTree::with_capacity(ws, cmp),
41 None => OrderStatisticTree::new(cmp),
42 };
43 let mut this = RankWindow {
44 slice,
45 validity,
46 start: 0,
47 end: 0,
48 ost,
49 policy: P::new(¶ms.unwrap()),
50 _out: PhantomData,
51 };
52 unsafe {
54 this.update(start, end);
55 }
56 this
57 }
58
59 unsafe fn update(&mut self, new_start: usize, new_end: usize) {
60 debug_assert!(self.start <= self.end);
61 debug_assert!(self.end <= self.slice.len());
62 debug_assert!(new_start <= new_end);
63 debug_assert!(new_end <= self.slice.len());
64 debug_assert!(self.start <= new_start);
65 debug_assert!(self.end <= new_end);
66
67 for i in self.end..new_end {
68 if !self.validity.get(i).unwrap() {
69 continue;
70 }
71 self.ost.insert(unsafe { self.slice.get_unchecked(i) });
72 }
73 for i in self.start..new_start {
74 if !self.validity.get(i).unwrap() {
75 continue;
76 }
77 self.ost
78 .remove(&unsafe { self.slice.get_unchecked(i) })
79 .expect("previously added value is missing");
80 }
81 self.start = new_start;
82 self.end = new_end;
83 }
84
85 fn get_agg(&self, idx: usize) -> Option<Out> {
86 if !(self.start..self.end).contains(&idx) {
87 panic!("index out of bounds");
88 }
89 self.policy.rank(&self.ost, &self.slice[idx])
90 }
91
92 fn is_valid(&self, _min_periods: usize) -> bool {
93 self.validity.get(self.end - 1).unwrap()
94 }
95
96 fn slice_len(&self) -> usize {
97 self.slice.len()
98 }
99}
100
101pub type RankWindowAvg<'a, T> = RankWindow<'a, T, f64, RankPolicyAverage>;
102pub type RankWindowMin<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMin>;
103pub type RankWindowMax<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyMax>;
104pub type RankWindowDense<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyDense>;
105pub type RankWindowRandom<'a, T> = RankWindow<'a, T, IdxSize, RankPolicyRandom>;
106
107pub fn rolling_rank<T>(
108 arr: &PrimitiveArray<T>,
109 window_size: usize,
110 min_periods: usize,
111 center: bool,
112 weights: Option<&[f64]>,
113 params: Option<RollingFnParams>,
114) -> ArrayRef
115where
116 T: NativeType,
117{
118 assert!(weights.is_none(), "weights are not supported for rank");
119
120 let offset_fn = match center {
121 true => det_offsets_center,
122 false => det_offsets,
123 };
124 let method = if let Some(RollingFnParams::Rank { method, .. }) = params {
125 method
126 } else {
127 unreachable!("expected RollingFnParams::Rank");
128 };
129
130 match method {
131 RollingRankMethod::Average => rolling_apply_agg_window::<RankWindowAvg<T>, _, _, _>(
132 arr.values().as_slice(),
133 arr.validity().as_ref().unwrap(),
134 window_size,
135 min_periods,
136 offset_fn,
137 params,
138 ),
139 RollingRankMethod::Min => rolling_apply_agg_window::<RankWindowMin<T>, _, _, _>(
140 arr.values().as_slice(),
141 arr.validity().as_ref().unwrap(),
142 window_size,
143 min_periods,
144 offset_fn,
145 params,
146 ),
147 RollingRankMethod::Max => rolling_apply_agg_window::<RankWindowMax<T>, _, _, _>(
148 arr.values().as_slice(),
149 arr.validity().as_ref().unwrap(),
150 window_size,
151 min_periods,
152 offset_fn,
153 params,
154 ),
155 RollingRankMethod::Dense => rolling_apply_agg_window::<RankWindowDense<T>, _, _, _>(
156 arr.values().as_slice(),
157 arr.validity().as_ref().unwrap(),
158 window_size,
159 min_periods,
160 offset_fn,
161 params,
162 ),
163 RollingRankMethod::Random => rolling_apply_agg_window::<RankWindowRandom<T>, _, _, _>(
164 arr.values().as_slice(),
165 arr.validity().as_ref().unwrap(),
166 window_size,
167 min_periods,
168 offset_fn,
169 params,
170 ),
171 }
172}