polars_python/expr/
rolling.rs

1use polars::prelude::*;
2use pyo3::prelude::*;
3use pyo3::types::PyFloat;
4
5use crate::PyExpr;
6use crate::conversion::Wrap;
7use crate::error::PyPolarsErr;
8use crate::map::lazy::{ToSeries, call_lambda_with_series};
9use crate::py_modules::polars;
10
11#[pymethods]
12impl PyExpr {
13    #[pyo3(signature = (window_size, weights, min_periods, center))]
14    fn rolling_sum(
15        &self,
16        window_size: usize,
17        weights: Option<Vec<f64>>,
18        min_periods: Option<usize>,
19        center: bool,
20    ) -> Self {
21        let min_periods = min_periods.unwrap_or(window_size);
22        let options = RollingOptionsFixedWindow {
23            window_size,
24            weights,
25            min_periods,
26            center,
27            ..Default::default()
28        };
29        self.inner.clone().rolling_sum(options).into()
30    }
31
32    #[pyo3(signature = (by, window_size, min_periods, closed))]
33    fn rolling_sum_by(
34        &self,
35        by: PyExpr,
36        window_size: &str,
37        min_periods: usize,
38        closed: Wrap<ClosedWindow>,
39    ) -> PyResult<Self> {
40        let options = RollingOptionsDynamicWindow {
41            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
42            min_periods,
43            closed_window: closed.0,
44            fn_params: None,
45        };
46        Ok(self.inner.clone().rolling_sum_by(by.inner, options).into())
47    }
48
49    #[pyo3(signature = (window_size, weights, min_periods, center))]
50    fn rolling_min(
51        &self,
52        window_size: usize,
53        weights: Option<Vec<f64>>,
54        min_periods: Option<usize>,
55        center: bool,
56    ) -> Self {
57        let min_periods = min_periods.unwrap_or(window_size);
58        let options = RollingOptionsFixedWindow {
59            window_size,
60            weights,
61            min_periods,
62            center,
63            ..Default::default()
64        };
65        self.inner.clone().rolling_min(options).into()
66    }
67
68    #[pyo3(signature = (by, window_size, min_periods, closed))]
69    fn rolling_min_by(
70        &self,
71        by: PyExpr,
72        window_size: &str,
73        min_periods: usize,
74        closed: Wrap<ClosedWindow>,
75    ) -> PyResult<Self> {
76        let options = RollingOptionsDynamicWindow {
77            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
78            min_periods,
79            closed_window: closed.0,
80            fn_params: None,
81        };
82        Ok(self.inner.clone().rolling_min_by(by.inner, options).into())
83    }
84
85    #[pyo3(signature = (window_size, weights, min_periods, center))]
86    fn rolling_max(
87        &self,
88        window_size: usize,
89        weights: Option<Vec<f64>>,
90        min_periods: Option<usize>,
91        center: bool,
92    ) -> Self {
93        let min_periods = min_periods.unwrap_or(window_size);
94        let options = RollingOptionsFixedWindow {
95            window_size,
96            weights,
97            min_periods,
98            center,
99            ..Default::default()
100        };
101        self.inner.clone().rolling_max(options).into()
102    }
103    #[pyo3(signature = (by, window_size, min_periods, closed))]
104    fn rolling_max_by(
105        &self,
106        by: PyExpr,
107        window_size: &str,
108        min_periods: usize,
109        closed: Wrap<ClosedWindow>,
110    ) -> PyResult<Self> {
111        let options = RollingOptionsDynamicWindow {
112            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
113            min_periods,
114            closed_window: closed.0,
115            fn_params: None,
116        };
117        Ok(self.inner.clone().rolling_max_by(by.inner, options).into())
118    }
119
120    #[pyo3(signature = (window_size, weights, min_periods, center))]
121    fn rolling_mean(
122        &self,
123        window_size: usize,
124        weights: Option<Vec<f64>>,
125        min_periods: Option<usize>,
126        center: bool,
127    ) -> Self {
128        let min_periods = min_periods.unwrap_or(window_size);
129        let options = RollingOptionsFixedWindow {
130            window_size,
131            weights,
132            min_periods,
133            center,
134            ..Default::default()
135        };
136
137        self.inner.clone().rolling_mean(options).into()
138    }
139
140    #[pyo3(signature = (by, window_size, min_periods, closed))]
141    fn rolling_mean_by(
142        &self,
143        by: PyExpr,
144        window_size: &str,
145        min_periods: usize,
146        closed: Wrap<ClosedWindow>,
147    ) -> PyResult<Self> {
148        let options = RollingOptionsDynamicWindow {
149            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
150            min_periods,
151            closed_window: closed.0,
152            fn_params: None,
153        };
154
155        Ok(self.inner.clone().rolling_mean_by(by.inner, options).into())
156    }
157
158    #[pyo3(signature = (window_size, weights, min_periods, center, ddof))]
159    fn rolling_std(
160        &self,
161        window_size: usize,
162        weights: Option<Vec<f64>>,
163        min_periods: Option<usize>,
164        center: bool,
165        ddof: u8,
166    ) -> Self {
167        let min_periods = min_periods.unwrap_or(window_size);
168        let options = RollingOptionsFixedWindow {
169            window_size,
170            weights,
171            min_periods,
172            center,
173            fn_params: Some(RollingFnParams::Var(RollingVarParams { ddof })),
174        };
175
176        self.inner.clone().rolling_std(options).into()
177    }
178
179    #[pyo3(signature = (by, window_size, min_periods, closed, ddof))]
180    fn rolling_std_by(
181        &self,
182        by: PyExpr,
183        window_size: &str,
184        min_periods: usize,
185        closed: Wrap<ClosedWindow>,
186        ddof: u8,
187    ) -> PyResult<Self> {
188        let options = RollingOptionsDynamicWindow {
189            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
190            min_periods,
191            closed_window: closed.0,
192            fn_params: Some(RollingFnParams::Var(RollingVarParams { ddof })),
193        };
194
195        Ok(self.inner.clone().rolling_std_by(by.inner, options).into())
196    }
197
198    #[pyo3(signature = (window_size, weights, min_periods, center, ddof))]
199    fn rolling_var(
200        &self,
201        window_size: usize,
202        weights: Option<Vec<f64>>,
203        min_periods: Option<usize>,
204        center: bool,
205        ddof: u8,
206    ) -> Self {
207        let min_periods = min_periods.unwrap_or(window_size);
208        let options = RollingOptionsFixedWindow {
209            window_size,
210            weights,
211            min_periods,
212            center,
213            fn_params: Some(RollingFnParams::Var(RollingVarParams { ddof })),
214        };
215
216        self.inner.clone().rolling_var(options).into()
217    }
218
219    #[pyo3(signature = (by, window_size, min_periods, closed, ddof))]
220    fn rolling_var_by(
221        &self,
222        by: PyExpr,
223        window_size: &str,
224        min_periods: usize,
225        closed: Wrap<ClosedWindow>,
226        ddof: u8,
227    ) -> PyResult<Self> {
228        let options = RollingOptionsDynamicWindow {
229            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
230            min_periods,
231            closed_window: closed.0,
232            fn_params: Some(RollingFnParams::Var(RollingVarParams { ddof })),
233        };
234
235        Ok(self.inner.clone().rolling_var_by(by.inner, options).into())
236    }
237
238    #[pyo3(signature = (window_size, weights, min_periods, center))]
239    fn rolling_median(
240        &self,
241        window_size: usize,
242        weights: Option<Vec<f64>>,
243        min_periods: Option<usize>,
244        center: bool,
245    ) -> Self {
246        let min_periods = min_periods.unwrap_or(window_size);
247        let options = RollingOptionsFixedWindow {
248            window_size,
249            min_periods,
250            weights,
251            center,
252            fn_params: None,
253        };
254        self.inner.clone().rolling_median(options).into()
255    }
256
257    #[pyo3(signature = (by, window_size, min_periods, closed))]
258    fn rolling_median_by(
259        &self,
260        by: PyExpr,
261        window_size: &str,
262        min_periods: usize,
263        closed: Wrap<ClosedWindow>,
264    ) -> PyResult<Self> {
265        let options = RollingOptionsDynamicWindow {
266            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
267            min_periods,
268            closed_window: closed.0,
269            fn_params: None,
270        };
271        Ok(self
272            .inner
273            .clone()
274            .rolling_median_by(by.inner, options)
275            .into())
276    }
277
278    #[pyo3(signature = (quantile, interpolation, window_size, weights, min_periods, center))]
279    fn rolling_quantile(
280        &self,
281        quantile: f64,
282        interpolation: Wrap<QuantileMethod>,
283        window_size: usize,
284        weights: Option<Vec<f64>>,
285        min_periods: Option<usize>,
286        center: bool,
287    ) -> Self {
288        let min_periods = min_periods.unwrap_or(window_size);
289        let options = RollingOptionsFixedWindow {
290            window_size,
291            weights,
292            min_periods,
293            center,
294            fn_params: None,
295        };
296
297        self.inner
298            .clone()
299            .rolling_quantile(interpolation.0, quantile, options)
300            .into()
301    }
302
303    #[pyo3(signature = (by, quantile, interpolation, window_size, min_periods, closed))]
304    fn rolling_quantile_by(
305        &self,
306        by: PyExpr,
307        quantile: f64,
308        interpolation: Wrap<QuantileMethod>,
309        window_size: &str,
310        min_periods: usize,
311        closed: Wrap<ClosedWindow>,
312    ) -> PyResult<Self> {
313        let options = RollingOptionsDynamicWindow {
314            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
315            min_periods,
316            closed_window: closed.0,
317            fn_params: None,
318        };
319
320        Ok(self
321            .inner
322            .clone()
323            .rolling_quantile_by(by.inner, interpolation.0, quantile, options)
324            .into())
325    }
326
327    #[pyo3(signature = (window_size, bias, min_periods, center))]
328    fn rolling_skew(
329        &self,
330        window_size: usize,
331        bias: bool,
332        min_periods: Option<usize>,
333        center: bool,
334    ) -> Self {
335        let min_periods = min_periods.unwrap_or(window_size);
336        let options = RollingOptionsFixedWindow {
337            window_size,
338            weights: None,
339            min_periods,
340            center,
341            fn_params: Some(RollingFnParams::Skew { bias }),
342        };
343
344        self.inner.clone().rolling_skew(options).into()
345    }
346
347    #[pyo3(signature = (window_size, fisher, bias, min_periods, center))]
348    fn rolling_kurtosis(
349        &self,
350        window_size: usize,
351        fisher: bool,
352        bias: bool,
353        min_periods: Option<usize>,
354        center: bool,
355    ) -> Self {
356        let min_periods = min_periods.unwrap_or(window_size);
357        let options = RollingOptionsFixedWindow {
358            window_size,
359            weights: None,
360            min_periods,
361            center,
362            fn_params: Some(RollingFnParams::Kurtosis { fisher, bias }),
363        };
364
365        self.inner.clone().rolling_kurtosis(options).into()
366    }
367
368    #[pyo3(signature = (lambda, window_size, weights, min_periods, center))]
369    fn rolling_map(
370        &self,
371        lambda: PyObject,
372        window_size: usize,
373        weights: Option<Vec<f64>>,
374        min_periods: Option<usize>,
375        center: bool,
376    ) -> Self {
377        let min_periods = min_periods.unwrap_or(window_size);
378        let options = RollingOptionsFixedWindow {
379            window_size,
380            weights,
381            min_periods,
382            center,
383            ..Default::default()
384        };
385        let function = move |s: &Series| {
386            Python::with_gil(|py| {
387                let out =
388                    call_lambda_with_series(py, s, None, &lambda).expect("python function failed");
389                match out.getattr(py, "_s") {
390                    Ok(pyseries) => {
391                        let Ok(s) = pyseries
392                            .to_series(py, polars(py), s.name())
393                            .map_err(|e| panic!("{e:?}"));
394                        s
395                    },
396                    Err(_) => {
397                        let obj = out;
398                        let is_float = obj.bind(py).is_instance_of::<PyFloat>();
399
400                        let dtype = s.dtype();
401
402                        use DataType::*;
403                        let Ok(s) = match dtype {
404                            UInt8 => {
405                                if is_float {
406                                    let v = obj.extract::<f64>(py).unwrap();
407                                    Ok(UInt8Chunked::from_slice(PlSmallStr::EMPTY, &[v as u8])
408                                        .into_series())
409                                } else {
410                                    obj.extract::<u8>(py).map(|v| {
411                                        UInt8Chunked::from_slice(PlSmallStr::EMPTY, &[v])
412                                            .into_series()
413                                    })
414                                }
415                            },
416                            UInt16 => {
417                                if is_float {
418                                    let v = obj.extract::<f64>(py).unwrap();
419                                    Ok(UInt16Chunked::from_slice(PlSmallStr::EMPTY, &[v as u16])
420                                        .into_series())
421                                } else {
422                                    obj.extract::<u16>(py).map(|v| {
423                                        UInt16Chunked::from_slice(PlSmallStr::EMPTY, &[v])
424                                            .into_series()
425                                    })
426                                }
427                            },
428                            UInt32 => {
429                                if is_float {
430                                    let v = obj.extract::<f64>(py).unwrap();
431                                    Ok(UInt32Chunked::from_slice(PlSmallStr::EMPTY, &[v as u32])
432                                        .into_series())
433                                } else {
434                                    obj.extract::<u32>(py).map(|v| {
435                                        UInt32Chunked::from_slice(PlSmallStr::EMPTY, &[v])
436                                            .into_series()
437                                    })
438                                }
439                            },
440                            UInt64 => {
441                                if is_float {
442                                    let v = obj.extract::<f64>(py).unwrap();
443                                    Ok(UInt64Chunked::from_slice(PlSmallStr::EMPTY, &[v as u64])
444                                        .into_series())
445                                } else {
446                                    obj.extract::<u64>(py).map(|v| {
447                                        UInt64Chunked::from_slice(PlSmallStr::EMPTY, &[v])
448                                            .into_series()
449                                    })
450                                }
451                            },
452                            Int8 => {
453                                if is_float {
454                                    let v = obj.extract::<f64>(py).unwrap();
455                                    Ok(Int8Chunked::from_slice(PlSmallStr::EMPTY, &[v as i8])
456                                        .into_series())
457                                } else {
458                                    obj.extract::<i8>(py).map(|v| {
459                                        Int8Chunked::from_slice(PlSmallStr::EMPTY, &[v])
460                                            .into_series()
461                                    })
462                                }
463                            },
464                            Int16 => {
465                                if is_float {
466                                    let v = obj.extract::<f64>(py).unwrap();
467                                    Ok(Int16Chunked::from_slice(PlSmallStr::EMPTY, &[v as i16])
468                                        .into_series())
469                                } else {
470                                    obj.extract::<i16>(py).map(|v| {
471                                        Int16Chunked::from_slice(PlSmallStr::EMPTY, &[v])
472                                            .into_series()
473                                    })
474                                }
475                            },
476                            Int32 => {
477                                if is_float {
478                                    let v = obj.extract::<f64>(py).unwrap();
479                                    Ok(Int32Chunked::from_slice(PlSmallStr::EMPTY, &[v as i32])
480                                        .into_series())
481                                } else {
482                                    obj.extract::<i32>(py).map(|v| {
483                                        Int32Chunked::from_slice(PlSmallStr::EMPTY, &[v])
484                                            .into_series()
485                                    })
486                                }
487                            },
488                            Int64 => {
489                                if is_float {
490                                    let v = obj.extract::<f64>(py).unwrap();
491                                    Ok(Int64Chunked::from_slice(PlSmallStr::EMPTY, &[v as i64])
492                                        .into_series())
493                                } else {
494                                    obj.extract::<i64>(py).map(|v| {
495                                        Int64Chunked::from_slice(PlSmallStr::EMPTY, &[v])
496                                            .into_series()
497                                    })
498                                }
499                            },
500                            Int128 => {
501                                if is_float {
502                                    let v = obj.extract::<f64>(py).unwrap();
503                                    Ok(Int128Chunked::from_slice(PlSmallStr::EMPTY, &[v as i128])
504                                        .into_series())
505                                } else {
506                                    obj.extract::<i128>(py).map(|v| {
507                                        Int128Chunked::from_slice(PlSmallStr::EMPTY, &[v])
508                                            .into_series()
509                                    })
510                                }
511                            },
512                            Float32 => obj.extract::<f32>(py).map(|v| {
513                                Float32Chunked::from_slice(PlSmallStr::EMPTY, &[v]).into_series()
514                            }),
515                            Float64 => obj.extract::<f64>(py).map(|v| {
516                                Float64Chunked::from_slice(PlSmallStr::EMPTY, &[v]).into_series()
517                            }),
518                            dt => panic!("{dt:?} not implemented"),
519                        }
520                        .map_err(|e| panic!("{e:?}"));
521                        s
522                    },
523                }
524            })
525        };
526        self.inner
527            .clone()
528            .rolling_map(Arc::new(function), GetOutput::same_type(), options)
529            .into()
530    }
531}