Skip to main content

polars_compute/rolling/
mod.rs

1mod mean;
2mod min_max;
3mod moment;
4pub mod no_nulls;
5pub mod nulls;
6pub mod quantile_filter;
7mod rank;
8mod sum;
9
10mod arg_min_max;
11pub(super) mod window;
12use std::hash::Hash;
13use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};
14
15pub use arg_min_max::{ArgMaxWindow, ArgMinMaxWindow, ArgMinWindow};
16use arrow::array::{ArrayRef, PrimitiveArray};
17use arrow::bitmap::{Bitmap, MutableBitmap};
18use arrow::types::NativeType;
19pub use mean::MeanWindow;
20use num_traits::{Bounded, Float, NumCast, One, Zero};
21use polars_utils::float::IsFloat;
22#[cfg(feature = "serde")]
23use serde::{Deserialize, Serialize};
24use strum_macros::IntoStaticStr;
25pub use sum::SumWindow;
26use window::*;
27
28type Start = usize;
29type End = usize;
30type Idx = usize;
31type WindowSize = usize;
32type Len = usize;
33
34#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
35#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
36#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
37#[strum(serialize_all = "snake_case")]
38pub enum QuantileMethod {
39    #[default]
40    Nearest,
41    Lower,
42    Higher,
43    Midpoint,
44    Linear,
45    Equiprobable,
46}
47
48#[deprecated(note = "use QuantileMethod instead")]
49pub type QuantileInterpolOptions = QuantileMethod;
50
51#[derive(Clone, Copy, Debug, PartialEq, Hash)]
52#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
53#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
54pub enum RollingFnParams {
55    Quantile(RollingQuantileParams),
56    Var(RollingVarParams),
57    Rank {
58        method: RollingRankMethod,
59        seed: Option<u64>,
60    },
61    Skew {
62        bias: bool,
63    },
64    Kurtosis {
65        fisher: bool,
66        bias: bool,
67    },
68}
69
70fn det_offsets(i: Idx, window_size: WindowSize, _len: Len) -> (usize, usize) {
71    (i.saturating_sub(window_size - 1), i + 1)
72}
73fn det_offsets_center(i: Idx, window_size: WindowSize, len: Len) -> (usize, usize) {
74    let right_window = window_size.div_ceil(2);
75    (
76        i.saturating_sub(window_size - right_window),
77        std::cmp::min(len, i + right_window),
78    )
79}
80
81fn create_validity<Fo>(
82    min_periods: usize,
83    len: usize,
84    window_size: usize,
85    det_offsets_fn: Fo,
86) -> Option<MutableBitmap>
87where
88    Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
89{
90    if min_periods > 1 {
91        let mut validity = MutableBitmap::with_capacity(len);
92        validity.extend_constant(len, true);
93
94        // Set the null values at the boundaries
95
96        // Head.
97        for i in 0..len {
98            let (start, end) = det_offsets_fn(i, window_size, len);
99            if (end - start) < min_periods {
100                validity.set(i, false)
101            } else {
102                break;
103            }
104        }
105        // Tail.
106        for i in (0..len).rev() {
107            let (start, end) = det_offsets_fn(i, window_size, len);
108            if (end - start) < min_periods {
109                validity.set(i, false)
110            } else {
111                break;
112            }
113        }
114
115        Some(validity)
116    } else {
117        None
118    }
119}
120
121// Parameters allowed for rolling operations.
122#[derive(Clone, Copy, Debug, PartialEq, Hash)]
123#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
124#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
125pub struct RollingVarParams {
126    pub ddof: u8,
127}
128
129#[derive(Clone, Copy, Debug, PartialEq)]
130#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
131#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
132pub struct RollingQuantileParams {
133    pub prob: f64,
134    pub method: QuantileMethod,
135}
136
137impl Hash for RollingQuantileParams {
138    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
139        // Will not be NaN, so hash + eq symmetry will hold.
140        self.prob.to_bits().hash(state);
141        self.method.hash(state);
142    }
143}
144
145#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
146#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
147#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
148#[strum(serialize_all = "snake_case")]
149pub enum RollingRankMethod {
150    #[default]
151    Average,
152    Min,
153    Max,
154    Dense,
155    Random,
156}