Skip to main content

polars_compute/rolling/
arg_min_max.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3
4use arrow::bitmap::Bitmap;
5use arrow::types::NativeType;
6use polars_utils::IdxSize;
7use polars_utils::min_max::{MaxPropagateNan, MinMaxPolicy, MinPropagateNan};
8
9use super::RollingFnParams;
10use super::no_nulls::RollingAggWindowNoNulls;
11use super::nulls::RollingAggWindowNulls;
12
13// Algorithm: https://cs.stackexchange.com/questions/120915/interview-question-with-arrays-and-consecutive-subintervals/120936#120936
14// Modified to return the argmin/argmax instead of the value:
15pub struct ArgMinMaxWindow<'a, T, P> {
16    pub(crate) values: &'a [T],
17    validity: Option<&'a Bitmap>,
18    // values[monotonic_idxs[i]] is better than values[monotonic_idxs[i+1]] for
19    // all i, as per the policy.
20    monotonic_idxs: VecDeque<usize>,
21    nonnulls_in_window: usize,
22    pub(super) start: usize,
23    pub(super) end: usize,
24    policy: PhantomData<P>,
25}
26
27impl<T: NativeType, P: MinMaxPolicy> ArgMinMaxWindow<'_, T, P> {
28    /// # Safety
29    /// The index must be in-bounds.
30    unsafe fn insert_nonnull_value(&mut self, idx: usize) {
31        unsafe {
32            let value = self.values.get_unchecked(idx);
33
34            // Remove values which are older and worse.
35            while let Some(&tail_idx) = self.monotonic_idxs.back() {
36                let tail_value = self.values.get_unchecked(tail_idx);
37                if !P::is_better(value, tail_value) {
38                    break;
39                }
40                self.monotonic_idxs.pop_back();
41            }
42
43            self.monotonic_idxs.push_back(idx);
44            self.nonnulls_in_window += 1;
45        }
46    }
47
48    fn remove_old_values(&mut self, window_start: usize) {
49        // Remove values which have fallen outside the window start.
50        while let Some(&head_idx) = self.monotonic_idxs.front() {
51            if head_idx >= window_start {
52                break;
53            }
54            self.monotonic_idxs.pop_front();
55        }
56    }
57}
58
59impl<T: NativeType, P: MinMaxPolicy> RollingAggWindowNulls<T, IdxSize>
60    for ArgMinMaxWindow<'_, T, P>
61{
62    type This<'a> = ArgMinMaxWindow<'a, T, P>;
63
64    fn new<'a>(
65        slice: &'a [T],
66        validity: &'a Bitmap,
67        start: usize,
68        end: usize,
69        params: Option<RollingFnParams>,
70        _window_size: Option<usize>,
71    ) -> Self::This<'a> {
72        assert!(params.is_none());
73        assert!(start <= slice.len() && end <= slice.len() && start <= end);
74
75        let mut this = ArgMinMaxWindow {
76            values: slice,
77            validity: Some(validity),
78            monotonic_idxs: VecDeque::new(),
79            nonnulls_in_window: 0,
80            start: 0,
81            end: 0,
82            policy: PhantomData,
83        };
84        // SAFETY: We bounds checked `start` and `end`.
85        unsafe { RollingAggWindowNulls::update(&mut this, start, end) };
86        this
87    }
88
89    unsafe fn update(&mut self, new_start: usize, new_end: usize) {
90        unsafe {
91            let v = self.validity.unwrap_unchecked();
92            self.remove_old_values(new_start);
93            for i in self.start..new_start.min(self.end) {
94                self.nonnulls_in_window -= v.get_bit_unchecked(i) as usize;
95            }
96            for i in new_start.max(self.end)..new_end {
97                if v.get_bit_unchecked(i) {
98                    self.insert_nonnull_value(i);
99                }
100            }
101        };
102        self.start = new_start;
103        self.end = new_end;
104    }
105
106    fn get_agg(&self, _idx: usize) -> Option<IdxSize> {
107        self.monotonic_idxs
108            .front()
109            .map(|&best_abs| (best_abs - self.start) as IdxSize)
110    }
111
112    fn is_valid(&self, min_periods: usize) -> bool {
113        self.nonnulls_in_window >= min_periods
114    }
115
116    fn slice_len(&self) -> usize {
117        self.values.len()
118    }
119}
120
121impl<T: NativeType, P: MinMaxPolicy> RollingAggWindowNoNulls<T, IdxSize>
122    for ArgMinMaxWindow<'_, T, P>
123{
124    type This<'a> = ArgMinMaxWindow<'a, T, P>;
125
126    fn new<'a>(
127        slice: &'a [T],
128        start: usize,
129        end: usize,
130        params: Option<RollingFnParams>,
131        _window_size: Option<usize>,
132    ) -> Self::This<'a> {
133        assert!(params.is_none());
134        assert!(start <= slice.len() && end <= slice.len() && start <= end);
135
136        let mut this = ArgMinMaxWindow {
137            values: slice,
138            validity: None,
139            monotonic_idxs: VecDeque::new(),
140            nonnulls_in_window: 0,
141            start: 0,
142            end: 0,
143            policy: PhantomData,
144        };
145
146        // SAFETY: We bounds checked `start` and `end`.
147        unsafe { RollingAggWindowNoNulls::update(&mut this, start, end) };
148        this
149    }
150
151    unsafe fn update(&mut self, new_start: usize, new_end: usize) {
152        unsafe {
153            self.remove_old_values(new_start);
154
155            for i in new_start.max(self.end)..new_end {
156                self.insert_nonnull_value(i);
157            }
158        };
159        self.start = new_start;
160        self.end = new_end;
161    }
162
163    fn get_agg(&self, _idx: usize) -> Option<IdxSize> {
164        self.monotonic_idxs
165            .front()
166            .map(|&best_abs| (best_abs - self.start) as IdxSize)
167    }
168
169    fn slice_len(&self) -> usize {
170        self.values.len()
171    }
172}
173
174pub type ArgMinWindow<'a, T> = ArgMinMaxWindow<'a, T, MinPropagateNan>;
175pub type ArgMaxWindow<'a, T> = ArgMinMaxWindow<'a, T, MaxPropagateNan>;