polars_plan/dsl/function_expr/
rolling.rs1#[cfg(feature = "cov")]
2use std::ops::BitAnd;
3
4use polars_core::utils::Container;
5use polars_time::chunkedarray::*;
6
7use super::*;
8#[cfg(feature = "cov")]
9use crate::dsl::pow::pow;
10
11#[derive(Clone, PartialEq, Debug, Hash)]
12#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
13pub enum RollingFunction {
14 Min(RollingOptionsFixedWindow),
15 Max(RollingOptionsFixedWindow),
16 Mean(RollingOptionsFixedWindow),
17 Sum(RollingOptionsFixedWindow),
18 Quantile(RollingOptionsFixedWindow),
19 Var(RollingOptionsFixedWindow),
20 Std(RollingOptionsFixedWindow),
21 #[cfg(feature = "moment")]
22 Skew(RollingOptionsFixedWindow),
23 #[cfg(feature = "moment")]
24 Kurtosis(RollingOptionsFixedWindow),
25 #[cfg(feature = "cov")]
26 CorrCov {
27 rolling_options: RollingOptionsFixedWindow,
28 corr_cov_options: RollingCovOptions,
29 is_corr: bool,
31 },
32}
33
34impl Display for RollingFunction {
35 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36 use RollingFunction::*;
37
38 let name = match self {
39 Min(_) => "min",
40 Max(_) => "max",
41 Mean(_) => "mean",
42 Sum(_) => "rsum",
43 Quantile(_) => "quantile",
44 Var(_) => "var",
45 Std(_) => "std",
46 #[cfg(feature = "moment")]
47 Skew(..) => "skew",
48 #[cfg(feature = "moment")]
49 Kurtosis(..) => "kurtosis",
50 #[cfg(feature = "cov")]
51 CorrCov { is_corr, .. } => {
52 if *is_corr {
53 "corr"
54 } else {
55 "cov"
56 }
57 },
58 };
59
60 write!(f, "rolling_{name}")
61 }
62}
63
64pub(super) fn rolling_min(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
65 s.as_materialized_series()
67 .rolling_min(options)
68 .map(Column::from)
69}
70
71pub(super) fn rolling_max(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
72 s.as_materialized_series()
74 .rolling_max(options)
75 .map(Column::from)
76}
77
78pub(super) fn rolling_mean(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
79 s.as_materialized_series()
81 .rolling_mean(options)
82 .map(Column::from)
83}
84
85pub(super) fn rolling_sum(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
86 s.as_materialized_series()
88 .rolling_sum(options)
89 .map(Column::from)
90}
91
92pub(super) fn rolling_quantile(
93 s: &Column,
94 options: RollingOptionsFixedWindow,
95) -> PolarsResult<Column> {
96 s.as_materialized_series()
98 .rolling_quantile(options)
99 .map(Column::from)
100}
101
102pub(super) fn rolling_var(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
103 s.as_materialized_series()
105 .rolling_var(options)
106 .map(Column::from)
107}
108
109pub(super) fn rolling_std(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
110 s.as_materialized_series()
112 .rolling_std(options)
113 .map(Column::from)
114}
115
116#[cfg(feature = "moment")]
117pub(super) fn rolling_skew(s: &Column, options: RollingOptionsFixedWindow) -> PolarsResult<Column> {
118 let s = s.as_materialized_series();
120 polars_ops::series::rolling_skew(s, options).map(Column::from)
121}
122
123#[cfg(feature = "moment")]
124pub(super) fn rolling_kurtosis(
125 s: &Column,
126 options: RollingOptionsFixedWindow,
127) -> PolarsResult<Column> {
128 let s = s.as_materialized_series();
130 polars_ops::series::rolling_kurtosis(s, options).map(Column::from)
131}
132
133#[cfg(feature = "cov")]
134fn det_count_x_y(window_size: usize, len: usize, dtype: &DataType) -> Series {
135 match dtype {
136 DataType::Float64 => {
137 let values = (0..len)
138 .map(|v| std::cmp::min(window_size, v + 1) as f64)
139 .collect::<Vec<_>>();
140 Series::new(PlSmallStr::EMPTY, values)
141 },
142 DataType::Float32 => {
143 let values = (0..len)
144 .map(|v| std::cmp::min(window_size, v + 1) as f32)
145 .collect::<Vec<_>>();
146 Series::new(PlSmallStr::EMPTY, values)
147 },
148 _ => unreachable!(),
149 }
150}
151
152#[cfg(feature = "cov")]
153pub(super) fn rolling_corr_cov(
154 s: &[Column],
155 rolling_options: RollingOptionsFixedWindow,
156 cov_options: RollingCovOptions,
157 is_corr: bool,
158) -> PolarsResult<Column> {
159 let mut x = s[0].as_materialized_series().rechunk();
160 let mut y = s[1].as_materialized_series().rechunk();
161
162 if !x.dtype().is_float() {
163 x = x.cast(&DataType::Float64)?;
164 }
165 if !y.dtype().is_float() {
166 y = y.cast(&DataType::Float64)?;
167 }
168 let dtype = x.dtype().clone();
169
170 let mean_x_y = (&x * &y)?.rolling_mean(rolling_options.clone())?;
171 let rolling_options_count = RollingOptionsFixedWindow {
172 window_size: rolling_options.window_size,
173 min_periods: 0,
174 ..Default::default()
175 };
176
177 let count_x_y = if (x.null_count() + y.null_count()) > 0 {
178 let valids = x.is_not_null().bitand(y.is_not_null());
180 let valids_arr = valids.downcast_as_array();
181 let valids_bitmap = valids_arr.values();
182
183 unsafe {
184 let xarr = &mut x.chunks_mut()[0];
185 *xarr = xarr.with_validity(Some(valids_bitmap.clone()));
186 let yarr = &mut y.chunks_mut()[0];
187 *yarr = yarr.with_validity(Some(valids_bitmap.clone()));
188 x.compute_len();
189 y.compute_len();
190 }
191 valids
192 .cast(&dtype)
193 .unwrap()
194 .rolling_sum(rolling_options_count)?
195 } else {
196 det_count_x_y(rolling_options.window_size, x.len(), &dtype)
197 };
198
199 let mean_x = x.rolling_mean(rolling_options.clone())?;
200 let mean_y = y.rolling_mean(rolling_options.clone())?;
201 let ddof = Series::new(
202 PlSmallStr::EMPTY,
203 &[AnyValue::from(cov_options.ddof).cast(&dtype)],
204 );
205
206 let numerator = ((mean_x_y - (mean_x * mean_y).unwrap()).unwrap()
207 * (count_x_y.clone() / (count_x_y - ddof).unwrap()).unwrap())
208 .unwrap();
209
210 if is_corr {
211 let var_x = x.rolling_var(rolling_options.clone())?;
212 let var_y = y.rolling_var(rolling_options.clone())?;
213
214 let base = (var_x * var_y).unwrap();
215 let sc = Scalar::new(
216 base.dtype().clone(),
217 AnyValue::Float64(0.5).cast(&dtype).into_static(),
218 );
219 let denominator = pow(&mut [base.into_column(), sc.into_column("".into())])
220 .unwrap()
221 .unwrap()
222 .take_materialized_series();
223
224 Ok((numerator / denominator)?.into_column())
225 } else {
226 Ok(numerator.into_column())
227 }
228}