polars_python/expr/
rolling.rs

1use polars::prelude::*;
2use pyo3::prelude::*;
3use pyo3::types::PyFloat;
4
5use crate::conversion::Wrap;
6use crate::error::PyPolarsErr;
7use crate::map::lazy::call_lambda_with_series;
8use crate::{PyExpr, PySeries};
9
10#[pymethods]
11impl PyExpr {
12    #[pyo3(signature = (window_size, weights, min_periods, center))]
13    fn rolling_sum(
14        &self,
15        window_size: usize,
16        weights: Option<Vec<f64>>,
17        min_periods: Option<usize>,
18        center: bool,
19    ) -> Self {
20        let min_periods = min_periods.unwrap_or(window_size);
21        let options = RollingOptionsFixedWindow {
22            window_size,
23            weights,
24            min_periods,
25            center,
26            ..Default::default()
27        };
28        self.inner.clone().rolling_sum(options).into()
29    }
30
31    #[pyo3(signature = (by, window_size, min_periods, closed))]
32    fn rolling_sum_by(
33        &self,
34        by: PyExpr,
35        window_size: &str,
36        min_periods: usize,
37        closed: Wrap<ClosedWindow>,
38    ) -> PyResult<Self> {
39        let options = RollingOptionsDynamicWindow {
40            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
41            min_periods,
42            closed_window: closed.0,
43            fn_params: None,
44        };
45        Ok(self.inner.clone().rolling_sum_by(by.inner, options).into())
46    }
47
48    #[pyo3(signature = (window_size, weights, min_periods, center))]
49    fn rolling_min(
50        &self,
51        window_size: usize,
52        weights: Option<Vec<f64>>,
53        min_periods: Option<usize>,
54        center: bool,
55    ) -> Self {
56        let min_periods = min_periods.unwrap_or(window_size);
57        let options = RollingOptionsFixedWindow {
58            window_size,
59            weights,
60            min_periods,
61            center,
62            ..Default::default()
63        };
64        self.inner.clone().rolling_min(options).into()
65    }
66
67    #[pyo3(signature = (by, window_size, min_periods, closed))]
68    fn rolling_min_by(
69        &self,
70        by: PyExpr,
71        window_size: &str,
72        min_periods: usize,
73        closed: Wrap<ClosedWindow>,
74    ) -> PyResult<Self> {
75        let options = RollingOptionsDynamicWindow {
76            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
77            min_periods,
78            closed_window: closed.0,
79            fn_params: None,
80        };
81        Ok(self.inner.clone().rolling_min_by(by.inner, options).into())
82    }
83
84    #[pyo3(signature = (window_size, weights, min_periods, center))]
85    fn rolling_max(
86        &self,
87        window_size: usize,
88        weights: Option<Vec<f64>>,
89        min_periods: Option<usize>,
90        center: bool,
91    ) -> Self {
92        let min_periods = min_periods.unwrap_or(window_size);
93        let options = RollingOptionsFixedWindow {
94            window_size,
95            weights,
96            min_periods,
97            center,
98            ..Default::default()
99        };
100        self.inner.clone().rolling_max(options).into()
101    }
102    #[pyo3(signature = (by, window_size, min_periods, closed))]
103    fn rolling_max_by(
104        &self,
105        by: PyExpr,
106        window_size: &str,
107        min_periods: usize,
108        closed: Wrap<ClosedWindow>,
109    ) -> PyResult<Self> {
110        let options = RollingOptionsDynamicWindow {
111            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
112            min_periods,
113            closed_window: closed.0,
114            fn_params: None,
115        };
116        Ok(self.inner.clone().rolling_max_by(by.inner, options).into())
117    }
118
119    #[pyo3(signature = (window_size, weights, min_periods, center))]
120    fn rolling_mean(
121        &self,
122        window_size: usize,
123        weights: Option<Vec<f64>>,
124        min_periods: Option<usize>,
125        center: bool,
126    ) -> Self {
127        let min_periods = min_periods.unwrap_or(window_size);
128        let options = RollingOptionsFixedWindow {
129            window_size,
130            weights,
131            min_periods,
132            center,
133            ..Default::default()
134        };
135
136        self.inner.clone().rolling_mean(options).into()
137    }
138
139    #[pyo3(signature = (by, window_size, min_periods, closed))]
140    fn rolling_mean_by(
141        &self,
142        by: PyExpr,
143        window_size: &str,
144        min_periods: usize,
145        closed: Wrap<ClosedWindow>,
146    ) -> PyResult<Self> {
147        let options = RollingOptionsDynamicWindow {
148            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
149            min_periods,
150            closed_window: closed.0,
151            fn_params: None,
152        };
153
154        Ok(self.inner.clone().rolling_mean_by(by.inner, options).into())
155    }
156
157    #[pyo3(signature = (window_size, weights, min_periods, center, ddof))]
158    fn rolling_std(
159        &self,
160        window_size: usize,
161        weights: Option<Vec<f64>>,
162        min_periods: Option<usize>,
163        center: bool,
164        ddof: u8,
165    ) -> Self {
166        let min_periods = min_periods.unwrap_or(window_size);
167        let options = RollingOptionsFixedWindow {
168            window_size,
169            weights,
170            min_periods,
171            center,
172            fn_params: Some(RollingFnParams::Var(RollingVarParams { ddof })),
173        };
174
175        self.inner.clone().rolling_std(options).into()
176    }
177
178    #[pyo3(signature = (by, window_size, min_periods, closed, ddof))]
179    fn rolling_std_by(
180        &self,
181        by: PyExpr,
182        window_size: &str,
183        min_periods: usize,
184        closed: Wrap<ClosedWindow>,
185        ddof: u8,
186    ) -> PyResult<Self> {
187        let options = RollingOptionsDynamicWindow {
188            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
189            min_periods,
190            closed_window: closed.0,
191            fn_params: Some(RollingFnParams::Var(RollingVarParams { ddof })),
192        };
193
194        Ok(self.inner.clone().rolling_std_by(by.inner, options).into())
195    }
196
197    #[pyo3(signature = (window_size, weights, min_periods, center, ddof))]
198    fn rolling_var(
199        &self,
200        window_size: usize,
201        weights: Option<Vec<f64>>,
202        min_periods: Option<usize>,
203        center: bool,
204        ddof: u8,
205    ) -> Self {
206        let min_periods = min_periods.unwrap_or(window_size);
207        let options = RollingOptionsFixedWindow {
208            window_size,
209            weights,
210            min_periods,
211            center,
212            fn_params: Some(RollingFnParams::Var(RollingVarParams { ddof })),
213        };
214
215        self.inner.clone().rolling_var(options).into()
216    }
217
218    #[pyo3(signature = (by, window_size, min_periods, closed, ddof))]
219    fn rolling_var_by(
220        &self,
221        by: PyExpr,
222        window_size: &str,
223        min_periods: usize,
224        closed: Wrap<ClosedWindow>,
225        ddof: u8,
226    ) -> PyResult<Self> {
227        let options = RollingOptionsDynamicWindow {
228            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
229            min_periods,
230            closed_window: closed.0,
231            fn_params: Some(RollingFnParams::Var(RollingVarParams { ddof })),
232        };
233
234        Ok(self.inner.clone().rolling_var_by(by.inner, options).into())
235    }
236
237    #[pyo3(signature = (window_size, weights, min_periods, center))]
238    fn rolling_median(
239        &self,
240        window_size: usize,
241        weights: Option<Vec<f64>>,
242        min_periods: Option<usize>,
243        center: bool,
244    ) -> Self {
245        let min_periods = min_periods.unwrap_or(window_size);
246        let options = RollingOptionsFixedWindow {
247            window_size,
248            min_periods,
249            weights,
250            center,
251            fn_params: None,
252        };
253        self.inner.clone().rolling_median(options).into()
254    }
255
256    #[pyo3(signature = (by, window_size, min_periods, closed))]
257    fn rolling_median_by(
258        &self,
259        by: PyExpr,
260        window_size: &str,
261        min_periods: usize,
262        closed: Wrap<ClosedWindow>,
263    ) -> PyResult<Self> {
264        let options = RollingOptionsDynamicWindow {
265            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
266            min_periods,
267            closed_window: closed.0,
268            fn_params: None,
269        };
270        Ok(self
271            .inner
272            .clone()
273            .rolling_median_by(by.inner, options)
274            .into())
275    }
276
277    #[pyo3(signature = (quantile, interpolation, window_size, weights, min_periods, center))]
278    fn rolling_quantile(
279        &self,
280        quantile: f64,
281        interpolation: Wrap<QuantileMethod>,
282        window_size: usize,
283        weights: Option<Vec<f64>>,
284        min_periods: Option<usize>,
285        center: bool,
286    ) -> Self {
287        let min_periods = min_periods.unwrap_or(window_size);
288        let options = RollingOptionsFixedWindow {
289            window_size,
290            weights,
291            min_periods,
292            center,
293            fn_params: None,
294        };
295
296        self.inner
297            .clone()
298            .rolling_quantile(interpolation.0, quantile, options)
299            .into()
300    }
301
302    #[pyo3(signature = (by, quantile, interpolation, window_size, min_periods, closed))]
303    fn rolling_quantile_by(
304        &self,
305        by: PyExpr,
306        quantile: f64,
307        interpolation: Wrap<QuantileMethod>,
308        window_size: &str,
309        min_periods: usize,
310        closed: Wrap<ClosedWindow>,
311    ) -> PyResult<Self> {
312        let options = RollingOptionsDynamicWindow {
313            window_size: Duration::try_parse(window_size).map_err(PyPolarsErr::from)?,
314            min_periods,
315            closed_window: closed.0,
316            fn_params: None,
317        };
318
319        Ok(self
320            .inner
321            .clone()
322            .rolling_quantile_by(by.inner, interpolation.0, quantile, options)
323            .into())
324    }
325
326    fn rolling_skew(&self, window_size: usize, bias: bool) -> Self {
327        self.inner.clone().rolling_skew(window_size, bias).into()
328    }
329
330    #[pyo3(signature = (lambda, window_size, weights, min_periods, center))]
331    fn rolling_map(
332        &self,
333        lambda: PyObject,
334        window_size: usize,
335        weights: Option<Vec<f64>>,
336        min_periods: Option<usize>,
337        center: bool,
338    ) -> Self {
339        let min_periods = min_periods.unwrap_or(window_size);
340        let options = RollingOptionsFixedWindow {
341            window_size,
342            weights,
343            min_periods,
344            center,
345            ..Default::default()
346        };
347        let function = move |s: &Series| {
348            Python::with_gil(|py| {
349                let out = call_lambda_with_series(py, s.clone(), &lambda)
350                    .expect("python function failed");
351                match out.getattr(py, "_s") {
352                    Ok(pyseries) => {
353                        let pyseries = pyseries.extract::<PySeries>(py).unwrap();
354                        pyseries.series
355                    },
356                    Err(_) => {
357                        let obj = out;
358                        let is_float = obj.bind(py).is_instance_of::<PyFloat>();
359
360                        let dtype = s.dtype();
361
362                        use DataType::*;
363                        let result = match dtype {
364                            UInt8 => {
365                                if is_float {
366                                    let v = obj.extract::<f64>(py).unwrap();
367                                    Ok(UInt8Chunked::from_slice(PlSmallStr::EMPTY, &[v as u8])
368                                        .into_series())
369                                } else {
370                                    obj.extract::<u8>(py).map(|v| {
371                                        UInt8Chunked::from_slice(PlSmallStr::EMPTY, &[v])
372                                            .into_series()
373                                    })
374                                }
375                            },
376                            UInt16 => {
377                                if is_float {
378                                    let v = obj.extract::<f64>(py).unwrap();
379                                    Ok(UInt16Chunked::from_slice(PlSmallStr::EMPTY, &[v as u16])
380                                        .into_series())
381                                } else {
382                                    obj.extract::<u16>(py).map(|v| {
383                                        UInt16Chunked::from_slice(PlSmallStr::EMPTY, &[v])
384                                            .into_series()
385                                    })
386                                }
387                            },
388                            UInt32 => {
389                                if is_float {
390                                    let v = obj.extract::<f64>(py).unwrap();
391                                    Ok(UInt32Chunked::from_slice(PlSmallStr::EMPTY, &[v as u32])
392                                        .into_series())
393                                } else {
394                                    obj.extract::<u32>(py).map(|v| {
395                                        UInt32Chunked::from_slice(PlSmallStr::EMPTY, &[v])
396                                            .into_series()
397                                    })
398                                }
399                            },
400                            UInt64 => {
401                                if is_float {
402                                    let v = obj.extract::<f64>(py).unwrap();
403                                    Ok(UInt64Chunked::from_slice(PlSmallStr::EMPTY, &[v as u64])
404                                        .into_series())
405                                } else {
406                                    obj.extract::<u64>(py).map(|v| {
407                                        UInt64Chunked::from_slice(PlSmallStr::EMPTY, &[v])
408                                            .into_series()
409                                    })
410                                }
411                            },
412                            Int8 => {
413                                if is_float {
414                                    let v = obj.extract::<f64>(py).unwrap();
415                                    Ok(Int8Chunked::from_slice(PlSmallStr::EMPTY, &[v as i8])
416                                        .into_series())
417                                } else {
418                                    obj.extract::<i8>(py).map(|v| {
419                                        Int8Chunked::from_slice(PlSmallStr::EMPTY, &[v])
420                                            .into_series()
421                                    })
422                                }
423                            },
424                            Int16 => {
425                                if is_float {
426                                    let v = obj.extract::<f64>(py).unwrap();
427                                    Ok(Int16Chunked::from_slice(PlSmallStr::EMPTY, &[v as i16])
428                                        .into_series())
429                                } else {
430                                    obj.extract::<i16>(py).map(|v| {
431                                        Int16Chunked::from_slice(PlSmallStr::EMPTY, &[v])
432                                            .into_series()
433                                    })
434                                }
435                            },
436                            Int32 => {
437                                if is_float {
438                                    let v = obj.extract::<f64>(py).unwrap();
439                                    Ok(Int32Chunked::from_slice(PlSmallStr::EMPTY, &[v as i32])
440                                        .into_series())
441                                } else {
442                                    obj.extract::<i32>(py).map(|v| {
443                                        Int32Chunked::from_slice(PlSmallStr::EMPTY, &[v])
444                                            .into_series()
445                                    })
446                                }
447                            },
448                            Int64 => {
449                                if is_float {
450                                    let v = obj.extract::<f64>(py).unwrap();
451                                    Ok(Int64Chunked::from_slice(PlSmallStr::EMPTY, &[v as i64])
452                                        .into_series())
453                                } else {
454                                    obj.extract::<i64>(py).map(|v| {
455                                        Int64Chunked::from_slice(PlSmallStr::EMPTY, &[v])
456                                            .into_series()
457                                    })
458                                }
459                            },
460                            Int128 => {
461                                if is_float {
462                                    let v = obj.extract::<f64>(py).unwrap();
463                                    Ok(Int128Chunked::from_slice(PlSmallStr::EMPTY, &[v as i128])
464                                        .into_series())
465                                } else {
466                                    obj.extract::<i128>(py).map(|v| {
467                                        Int128Chunked::from_slice(PlSmallStr::EMPTY, &[v])
468                                            .into_series()
469                                    })
470                                }
471                            },
472                            Float32 => obj.extract::<f32>(py).map(|v| {
473                                Float32Chunked::from_slice(PlSmallStr::EMPTY, &[v]).into_series()
474                            }),
475                            Float64 => obj.extract::<f64>(py).map(|v| {
476                                Float64Chunked::from_slice(PlSmallStr::EMPTY, &[v]).into_series()
477                            }),
478                            dt => panic!("{dt:?} not implemented"),
479                        };
480
481                        match result {
482                            Ok(s) => s,
483                            Err(e) => {
484                                panic!("{e:?}")
485                            },
486                        }
487                    },
488                }
489            })
490        };
491        self.inner
492            .clone()
493            .rolling_map(Arc::new(function), GetOutput::same_type(), options)
494            .with_fmt("rolling_map")
495            .into()
496    }
497}