polars_plan/dsl/function_expr/
rolling.rs

1#[cfg(feature = "cov")]
2use std::ops::BitAnd;
3
4use polars_core::utils::Container;
5use polars_time::chunkedarray::*;
6
7use super::*;
8#[cfg(feature = "cov")]
9use crate::dsl::pow::pow;
10
11#[derive(Clone, PartialEq, Debug, Hash)]
12#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
13pub enum RollingFunction {
14    Min(RollingOptionsFixedWindow),
15    Max(RollingOptionsFixedWindow),
16    Mean(RollingOptionsFixedWindow),
17    Sum(RollingOptionsFixedWindow),
18    Quantile(RollingOptionsFixedWindow),
19    Var(RollingOptionsFixedWindow),
20    Std(RollingOptionsFixedWindow),
21    #[cfg(feature = "moment")]
22    Skew(RollingOptionsFixedWindow),
23    #[cfg(feature = "moment")]
24    Kurtosis(RollingOptionsFixedWindow),
25    #[cfg(feature = "cov")]
26    CorrCov {
27        rolling_options: RollingOptionsFixedWindow,
28        corr_cov_options: RollingCovOptions,
29        // Whether is Corr or Cov
30        is_corr: bool,
31    },
32}
33
34impl Display for RollingFunction {
35    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36        use RollingFunction::*;
37
38        let name = match self {
39            Min(_) => "min",
40            Max(_) => "max",
41            Mean(_) => "mean",
42            Sum(_) => "rsum",
43            Quantile(_) => "quantile",
44            Var(_) => "var",
45            Std(_) => "std",
46            #[cfg(feature = "moment")]
47            Skew(..) => "skew",
48            #[cfg(feature = "moment")]
49            Kurtosis(..) => "kurtosis",
50            #[cfg(feature = "cov")]
51            CorrCov { is_corr, .. } => {
52                if *is_corr {
53                    "corr"
54                } else {
55                    "cov"
56                }
57            },
58        };
59
60        write!(f, "rolling_{name}")
61    }
62}
63
64pub(super) fn rolling_min(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
65    // @scalar-opt
66    s.as_materialized_series()
67        .rolling_min(options)
68        .map(Column::from)
69}
70
71pub(super) fn rolling_max(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
72    // @scalar-opt
73    s.as_materialized_series()
74        .rolling_max(options)
75        .map(Column::from)
76}
77
78pub(super) fn rolling_mean(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
79    // @scalar-opt
80    s.as_materialized_series()
81        .rolling_mean(options)
82        .map(Column::from)
83}
84
85pub(super) fn rolling_sum(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
86    // @scalar-opt
87    s.as_materialized_series()
88        .rolling_sum(options)
89        .map(Column::from)
90}
91
92pub(super) fn rolling_quantile(
93    s: &Column,
94    options: RollingOptionsFixedWindow,
95) -> PolarsResult<Column> {
96    // @scalar-opt
97    s.as_materialized_series()
98        .rolling_quantile(options)
99        .map(Column::from)
100}
101
102pub(super) fn rolling_var(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
103    // @scalar-opt
104    s.as_materialized_series()
105        .rolling_var(options)
106        .map(Column::from)
107}
108
109pub(super) fn rolling_std(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
110    // @scalar-opt
111    s.as_materialized_series()
112        .rolling_std(options)
113        .map(Column::from)
114}
115
116#[cfg(feature = "moment")]
117pub(super) fn rolling_skew(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
118    // @scalar-opt
119    let s = s.as_materialized_series();
120    polars_ops::series::rolling_skew(s, options).map(Column::from)
121}
122
123#[cfg(feature = "moment")]
124pub(super) fn rolling_kurtosis(
125    s: &Column,
126    options: RollingOptionsFixedWindow,
127) -> PolarsResult<Column> {
128    // @scalar-opt
129    let s = s.as_materialized_series();
130    polars_ops::series::rolling_kurtosis(s, options).map(Column::from)
131}
132
133#[cfg(feature = "cov")]
134fn det_count_x_y(window_size: usize, len: usize, dtype: &DataType) -> Series {
135    match dtype {
136        DataType::Float64 => {
137            let values = (0..len)
138                .map(|v| std::cmp::min(window_size, v + 1) as f64)
139                .collect::<Vec<_>>();
140            Series::new(PlSmallStr::EMPTY, values)
141        },
142        DataType::Float32 => {
143            let values = (0..len)
144                .map(|v| std::cmp::min(window_size, v + 1) as f32)
145                .collect::<Vec<_>>();
146            Series::new(PlSmallStr::EMPTY, values)
147        },
148        _ => unreachable!(),
149    }
150}
151
152#[cfg(feature = "cov")]
153pub(super) fn rolling_corr_cov(
154    s: &[Column],
155    rolling_options: RollingOptionsFixedWindow,
156    cov_options: RollingCovOptions,
157    is_corr: bool,
158) -> PolarsResult<Column> {
159    let mut x = s[0].as_materialized_series().rechunk();
160    let mut y = s[1].as_materialized_series().rechunk();
161
162    if !x.dtype().is_float() {
163        x = x.cast(&DataType::Float64)?;
164    }
165    if !y.dtype().is_float() {
166        y = y.cast(&DataType::Float64)?;
167    }
168    let dtype = x.dtype().clone();
169
170    let mean_x_y = (&x * &y)?.rolling_mean(rolling_options.clone())?;
171    let rolling_options_count = RollingOptionsFixedWindow {
172        window_size: rolling_options.window_size,
173        min_periods: 0,
174        ..Default::default()
175    };
176
177    let count_x_y = if (x.null_count() + y.null_count()) > 0 {
178        // mask out nulls on both sides before compute mean/var
179        let valids = x.is_not_null().bitand(y.is_not_null());
180        let valids_arr = valids.downcast_as_array();
181        let valids_bitmap = valids_arr.values();
182
183        unsafe {
184            let xarr = &mut x.chunks_mut()[0];
185            *xarr = xarr.with_validity(Some(valids_bitmap.clone()));
186            let yarr = &mut y.chunks_mut()[0];
187            *yarr = yarr.with_validity(Some(valids_bitmap.clone()));
188            x.compute_len();
189            y.compute_len();
190        }
191        valids
192            .cast(&dtype)
193            .unwrap()
194            .rolling_sum(rolling_options_count)?
195    } else {
196        det_count_x_y(rolling_options.window_size, x.len(), &dtype)
197    };
198
199    let mean_x = x.rolling_mean(rolling_options.clone())?;
200    let mean_y = y.rolling_mean(rolling_options.clone())?;
201    let ddof = Series::new(
202        PlSmallStr::EMPTY,
203        &[AnyValue::from(cov_options.ddof).cast(&dtype)],
204    );
205
206    let numerator = ((mean_x_y - (mean_x * mean_y).unwrap()).unwrap()
207        * (count_x_y.clone() / (count_x_y - ddof).unwrap()).unwrap())
208    .unwrap();
209
210    if is_corr {
211        let var_x = x.rolling_var(rolling_options.clone())?;
212        let var_y = y.rolling_var(rolling_options.clone())?;
213
214        let base = (var_x * var_y).unwrap();
215        let sc = Scalar::new(
216            base.dtype().clone(),
217            AnyValue::Float64(0.5).cast(&dtype).into_static(),
218        );
219        let denominator = pow(&mut [base.into_column(), sc.into_column("".into())])
220            .unwrap()
221            .unwrap()
222            .take_materialized_series();
223
224        Ok((numerator / denominator)?.into_column())
225    } else {
226        Ok(numerator.into_column())
227    }
228}