polars_core/chunked_array/ops/
rolling_window.rs1use std::hash::{Hash, Hasher};
2
3use polars_compute::rolling::RollingFnParams;
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7#[derive(Clone, Debug)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
10#[cfg_attr(feature = "rolling_window", derive(PartialEq))]
11pub struct RollingOptionsFixedWindow {
12 pub window_size: usize,
14 pub min_periods: usize,
16 pub weights: Option<Vec<f64>>,
19 pub center: bool,
21 #[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(default))]
23 pub fn_params: Option<RollingFnParams>,
24}
25
26impl Hash for RollingOptionsFixedWindow {
27 fn hash<H: Hasher>(&self, state: &mut H) {
28 self.window_size.hash(state);
29 self.min_periods.hash(state);
30 self.center.hash(state);
31 self.weights.is_some().hash(state);
32 }
33}
34
35impl Default for RollingOptionsFixedWindow {
36 fn default() -> Self {
37 RollingOptionsFixedWindow {
38 window_size: 3,
39 min_periods: 1,
40 weights: None,
41 center: false,
42 fn_params: None,
43 }
44 }
45}
46
47#[cfg(feature = "rolling_window")]
48mod inner_mod {
49 use std::ops::SubAssign;
50
51 use arrow::bitmap::MutableBitmap;
52 use arrow::bitmap::utils::set_bit_unchecked;
53 use arrow::legacy::trusted_len::TrustedLenPush;
54 use num_traits::pow::Pow;
55 use num_traits::{Float, Zero};
56 use polars_utils::float::IsFloat;
57
58 use crate::chunked_array::cast::CastOptions;
59 use crate::prelude::*;
60
61 fn check_input(window_size: usize, min_periods: usize) -> PolarsResult<()> {
63 polars_ensure!(
64 min_periods <= window_size,
65 ComputeError: "`window_size`: {} should be >= `min_periods`: {}",
66 window_size, min_periods
67 );
68 Ok(())
69 }
70
71 fn window_edges(idx: usize, len: usize, window_size: usize, center: bool) -> (usize, usize) {
73 let (start, end) = if center {
74 let right_window = window_size.div_ceil(2);
75 (
76 idx.saturating_sub(window_size - right_window),
77 len.min(idx + right_window),
78 )
79 } else {
80 (idx.saturating_sub(window_size - 1), idx + 1)
81 };
82
83 (start, end - start)
84 }
85
86 impl<T: PolarsNumericType> ChunkRollApply for ChunkedArray<T> {
87 fn rolling_map(
89 &self,
90 f: &dyn Fn(&Series) -> Series,
91 mut options: RollingOptionsFixedWindow,
92 ) -> PolarsResult<Series> {
93 check_input(options.window_size, options.min_periods)?;
94
95 let ca = self.rechunk();
96 if options.weights.is_some()
97 && !matches!(self.dtype(), DataType::Float64 | DataType::Float32)
98 {
99 let s = self.cast_with_options(&DataType::Float64, CastOptions::NonStrict)?;
100 return s.rolling_map(f, options);
101 }
102
103 options.window_size = std::cmp::min(self.len(), options.window_size);
104
105 let len = self.len();
106 let arr = ca.downcast_as_array();
107 let mut ca = ChunkedArray::<T>::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]);
108 let ptr = ca.chunks[0].as_mut() as *mut dyn Array as *mut PrimitiveArray<T::Native>;
109 let mut series_container = ca.into_series();
110
111 let mut builder = PrimitiveChunkedBuilder::<T>::new(self.name().clone(), self.len());
112
113 if let Some(weights) = options.weights {
114 let weights_series =
115 Float64Chunked::new(PlSmallStr::from_static("weights"), &weights).into_series();
116
117 let weights_series = weights_series.cast(self.dtype()).unwrap();
118
119 for idx in 0..len {
120 let (start, size) = window_edges(idx, len, options.window_size, options.center);
121
122 if size < options.min_periods {
123 builder.append_null();
124 } else {
125 let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
128
129 if size - arr_window.null_count() < options.min_periods {
131 builder.append_null();
132 continue;
133 }
134
135 unsafe {
140 *ptr = arr_window;
141 }
142 series_container.clear_flags();
144 series_container._get_inner_mut().compute_len();
146 let s = if size == options.window_size {
147 f(&series_container.multiply(&weights_series).unwrap())
148 } else {
149 let weights_cutoff: Series = match self.dtype() {
150 DataType::Float64 => weights_series
151 .f64()
152 .unwrap()
153 .into_iter()
154 .take(series_container.len())
155 .collect(),
156 _ => weights_series .f32()
158 .unwrap()
159 .into_iter()
160 .take(series_container.len())
161 .collect(),
162 };
163 f(&series_container.multiply(&weights_cutoff).unwrap())
164 };
165
166 let out = self.unpack_series_matching_type(&s)?;
167 builder.append_option(out.get(0));
168 }
169 }
170
171 Ok(builder.finish().into_series())
172 } else {
173 for idx in 0..len {
174 let (start, size) = window_edges(idx, len, options.window_size, options.center);
175
176 if size < options.min_periods {
177 builder.append_null();
178 } else {
179 let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
182
183 if size - arr_window.null_count() < options.min_periods {
185 builder.append_null();
186 continue;
187 }
188
189 unsafe {
194 *ptr = arr_window;
195 }
196 series_container.clear_flags();
198 series_container._get_inner_mut().compute_len();
200 let s = f(&series_container);
201 let out = self.unpack_series_matching_type(&s)?;
202 builder.append_option(out.get(0));
203 }
204 }
205
206 Ok(builder.finish().into_series())
207 }
208 }
209 }
210
211 impl<T> ChunkedArray<T>
212 where
213 T: PolarsFloatType,
214 T::Native: Float + IsFloat + SubAssign + Pow<T::Native, Output = T::Native>,
215 {
216 pub fn rolling_map_float<F>(&self, window_size: usize, mut f: F) -> PolarsResult<Self>
218 where
219 F: FnMut(&mut ChunkedArray<T>) -> Option<T::Native>,
220 {
221 if window_size > self.len() {
222 return Ok(Self::full_null(self.name().clone(), self.len()));
223 }
224 let ca = self.rechunk();
225 let arr = ca.downcast_as_array();
226
227 let mut heap_container =
231 ChunkedArray::<T>::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]);
232 let ptr = heap_container.chunks[0].as_mut() as *mut dyn Array
233 as *mut PrimitiveArray<T::Native>;
234
235 let mut validity = MutableBitmap::with_capacity(ca.len());
236 validity.extend_constant(window_size - 1, false);
237 validity.extend_constant(ca.len() - (window_size - 1), true);
238 let validity_slice = validity.as_mut_slice();
239
240 let mut values = Vec::with_capacity(ca.len());
241 values.extend(std::iter::repeat_n(T::Native::default(), window_size - 1));
242
243 for offset in 0..self.len() + 1 - window_size {
244 debug_assert!(offset + window_size <= arr.len());
245 let arr_window = unsafe { arr.slice_typed_unchecked(offset, window_size) };
246 heap_container.length = arr_window.len();
248
249 unsafe {
252 *ptr = arr_window;
253 }
254
255 let out = f(&mut heap_container);
256 match out {
257 Some(v) => {
258 unsafe { values.push_unchecked(v) }
260 },
261 None => {
262 unsafe {
265 values.push_unchecked(T::Native::default());
266 set_bit_unchecked(validity_slice, offset + window_size - 1, false);
267 }
268 },
269 }
270 }
271 let arr = PrimitiveArray::new(
272 T::get_static_dtype().to_arrow(CompatLevel::newest()),
273 values.into(),
274 Some(validity.into()),
275 );
276 Ok(Self::with_chunk(self.name().clone(), arr))
277 }
278 }
279}