polars_core/chunked_array/ops/
rolling_window.rs

1use std::hash::{Hash, Hasher};
2
3use polars_compute::rolling::RollingFnParams;
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7#[derive(Clone, Debug)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
10#[cfg_attr(feature = "rolling_window", derive(PartialEq))]
11pub struct RollingOptionsFixedWindow {
12    /// The length of the window.
13    pub window_size: usize,
14    /// Amount of elements in the window that should be filled before computing a result.
15    pub min_periods: usize,
16    /// An optional slice with the same length as the window that will be multiplied
17    ///              elementwise with the values in the window.
18    pub weights: Option<Vec<f64>>,
19    /// Set the labels at the center of the window.
20    pub center: bool,
21    /// Optional parameters for the rolling
22    #[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(default))]
23    pub fn_params: Option<RollingFnParams>,
24}
25
26impl Hash for RollingOptionsFixedWindow {
27    fn hash<H: Hasher>(&self, state: &mut H) {
28        self.window_size.hash(state);
29        self.min_periods.hash(state);
30        self.center.hash(state);
31        self.weights.is_some().hash(state);
32    }
33}
34
35impl Default for RollingOptionsFixedWindow {
36    fn default() -> Self {
37        RollingOptionsFixedWindow {
38            window_size: 3,
39            min_periods: 1,
40            weights: None,
41            center: false,
42            fn_params: None,
43        }
44    }
45}
46
47#[cfg(feature = "rolling_window")]
48mod inner_mod {
49    use std::ops::SubAssign;
50
51    use arrow::bitmap::MutableBitmap;
52    use arrow::bitmap::utils::set_bit_unchecked;
53    use arrow::legacy::trusted_len::TrustedLenPush;
54    use num_traits::pow::Pow;
55    use num_traits::{Float, Zero};
56    use polars_utils::float::IsFloat;
57
58    use crate::chunked_array::cast::CastOptions;
59    use crate::prelude::*;
60
61    /// utility
62    fn check_input(window_size: usize, min_periods: usize) -> PolarsResult<()> {
63        polars_ensure!(
64            min_periods <= window_size,
65            ComputeError: "`window_size`: {} should be >= `min_periods`: {}",
66            window_size, min_periods
67        );
68        Ok(())
69    }
70
71    /// utility
72    fn window_edges(idx: usize, len: usize, window_size: usize, center: bool) -> (usize, usize) {
73        let (start, end) = if center {
74            let right_window = window_size.div_ceil(2);
75            (
76                idx.saturating_sub(window_size - right_window),
77                len.min(idx + right_window),
78            )
79        } else {
80            (idx.saturating_sub(window_size - 1), idx + 1)
81        };
82
83        (start, end - start)
84    }
85
86    impl<T: PolarsNumericType> ChunkRollApply for ChunkedArray<T> {
87        /// Apply a rolling custom function. This is pretty slow because of dynamic dispatch.
88        fn rolling_map(
89            &self,
90            f: &dyn Fn(&Series) -> Series,
91            mut options: RollingOptionsFixedWindow,
92        ) -> PolarsResult<Series> {
93            check_input(options.window_size, options.min_periods)?;
94
95            let ca = self.rechunk();
96            if options.weights.is_some()
97                && !matches!(self.dtype(), DataType::Float64 | DataType::Float32)
98            {
99                let s = self.cast_with_options(&DataType::Float64, CastOptions::NonStrict)?;
100                return s.rolling_map(f, options);
101            }
102
103            options.window_size = std::cmp::min(self.len(), options.window_size);
104
105            let len = self.len();
106            let arr = ca.downcast_as_array();
107            let mut ca = ChunkedArray::<T>::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]);
108            let ptr = ca.chunks[0].as_mut() as *mut dyn Array as *mut PrimitiveArray<T::Native>;
109            let mut series_container = ca.into_series();
110
111            let mut builder = PrimitiveChunkedBuilder::<T>::new(self.name().clone(), self.len());
112
113            if let Some(weights) = options.weights {
114                let weights_series =
115                    Float64Chunked::new(PlSmallStr::from_static("weights"), &weights).into_series();
116
117                let weights_series = weights_series.cast(self.dtype()).unwrap();
118
119                for idx in 0..len {
120                    let (start, size) = window_edges(idx, len, options.window_size, options.center);
121
122                    if size < options.min_periods {
123                        builder.append_null();
124                    } else {
125                        // SAFETY:
126                        // we are in bounds
127                        let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
128
129                        // ensure we still meet window size criteria after removing null values
130                        if size - arr_window.null_count() < options.min_periods {
131                            builder.append_null();
132                            continue;
133                        }
134
135                        // SAFETY.
136                        // ptr is not dropped as we are in scope
137                        // We are also the only owner of the contents of the Arc
138                        // we do this to reduce heap allocs.
139                        unsafe {
140                            *ptr = arr_window;
141                        }
142                        // reset flags as we reuse this container
143                        series_container.clear_flags();
144                        // ensure the length is correct
145                        series_container._get_inner_mut().compute_len();
146                        let s = if size == options.window_size {
147                            f(&series_container.multiply(&weights_series).unwrap())
148                        } else {
149                            let weights_cutoff: Series = match self.dtype() {
150                                DataType::Float64 => weights_series
151                                    .f64()
152                                    .unwrap()
153                                    .into_iter()
154                                    .take(series_container.len())
155                                    .collect(),
156                                _ => weights_series // Float32 case
157                                    .f32()
158                                    .unwrap()
159                                    .into_iter()
160                                    .take(series_container.len())
161                                    .collect(),
162                            };
163                            f(&series_container.multiply(&weights_cutoff).unwrap())
164                        };
165
166                        let out = self.unpack_series_matching_type(&s)?;
167                        builder.append_option(out.get(0));
168                    }
169                }
170
171                Ok(builder.finish().into_series())
172            } else {
173                for idx in 0..len {
174                    let (start, size) = window_edges(idx, len, options.window_size, options.center);
175
176                    if size < options.min_periods {
177                        builder.append_null();
178                    } else {
179                        // SAFETY:
180                        // we are in bounds
181                        let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
182
183                        // ensure we still meet window size criteria after removing null values
184                        if size - arr_window.null_count() < options.min_periods {
185                            builder.append_null();
186                            continue;
187                        }
188
189                        // SAFETY.
190                        // ptr is not dropped as we are in scope
191                        // We are also the only owner of the contents of the Arc
192                        // we do this to reduce heap allocs.
193                        unsafe {
194                            *ptr = arr_window;
195                        }
196                        // reset flags as we reuse this container
197                        series_container.clear_flags();
198                        // ensure the length is correct
199                        series_container._get_inner_mut().compute_len();
200                        let s = f(&series_container);
201                        let out = self.unpack_series_matching_type(&s)?;
202                        builder.append_option(out.get(0));
203                    }
204                }
205
206                Ok(builder.finish().into_series())
207            }
208        }
209    }
210
211    impl<T> ChunkedArray<T>
212    where
213        T: PolarsFloatType,
214        T::Native: Float + IsFloat + SubAssign + Pow<T::Native, Output = T::Native>,
215    {
216        /// Apply a rolling custom function. This is pretty slow because of dynamic dispatch.
217        pub fn rolling_map_float<F>(&self, window_size: usize, mut f: F) -> PolarsResult<Self>
218        where
219            F: FnMut(&mut ChunkedArray<T>) -> Option<T::Native>,
220        {
221            if window_size > self.len() {
222                return Ok(Self::full_null(self.name().clone(), self.len()));
223            }
224            let ca = self.rechunk();
225            let arr = ca.downcast_as_array();
226
227            // We create a temporary dummy ChunkedArray. This will be a
228            // container where we swap the window contents every iteration doing
229            // so will save a lot of heap allocations.
230            let mut heap_container =
231                ChunkedArray::<T>::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]);
232            let ptr = heap_container.chunks[0].as_mut() as *mut dyn Array
233                as *mut PrimitiveArray<T::Native>;
234
235            let mut validity = MutableBitmap::with_capacity(ca.len());
236            validity.extend_constant(window_size - 1, false);
237            validity.extend_constant(ca.len() - (window_size - 1), true);
238            let validity_slice = validity.as_mut_slice();
239
240            let mut values = Vec::with_capacity(ca.len());
241            values.extend(std::iter::repeat_n(T::Native::default(), window_size - 1));
242
243            for offset in 0..self.len() + 1 - window_size {
244                debug_assert!(offset + window_size <= arr.len());
245                let arr_window = unsafe { arr.slice_typed_unchecked(offset, window_size) };
246                // The lengths are cached, so we must update them.
247                heap_container.length = arr_window.len();
248
249                // SAFETY: ptr is not dropped as we are in scope. We are also the only
250                // owner of the contents of the Arc (we do this to reduce heap allocs).
251                unsafe {
252                    *ptr = arr_window;
253                }
254
255                let out = f(&mut heap_container);
256                match out {
257                    Some(v) => {
258                        // SAFETY: we have pre-allocated.
259                        unsafe { values.push_unchecked(v) }
260                    },
261                    None => {
262                        // SAFETY: we allocated enough for both the `values` vec
263                        // and the `validity_ptr`.
264                        unsafe {
265                            values.push_unchecked(T::Native::default());
266                            set_bit_unchecked(validity_slice, offset + window_size - 1, false);
267                        }
268                    },
269                }
270            }
271            let arr = PrimitiveArray::new(
272                T::get_static_dtype().to_arrow(CompatLevel::newest()),
273                values.into(),
274                Some(validity.into()),
275            );
276            Ok(Self::with_chunk(self.name().clone(), arr))
277        }
278    }
279}