Skip to main content

polars_compute/rolling/
mod.rs

1mod mean;
2mod min_max;
3mod moment;
4pub mod no_nulls;
5pub mod nulls;
6pub mod quantile_filter;
7mod rank;
8mod sum;
9
10mod arg_min_max;
11pub(super) mod window;
12use std::hash::Hash;
13use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};
14
15pub use arg_min_max::{ArgMaxWindow, ArgMinMaxWindow, ArgMinWindow};
16use arrow::array::{ArrayRef, PrimitiveArray};
17use arrow::bitmap::{Bitmap, MutableBitmap};
18use arrow::types::NativeType;
19pub use mean::MeanWindow;
20use num_traits::{Bounded, Float, NumCast, One, Zero};
21use polars_utils::float::IsFloat;
22#[cfg(feature = "serde")]
23use serde::{Deserialize, Serialize};
24use strum_macros::IntoStaticStr;
25pub use sum::SumWindow;
26use window::*;
27
28type Start = usize;
29type End = usize;
30type Idx = usize;
31type WindowSize = usize;
32type Len = usize;
33
34#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
35#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
36#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
37#[strum(serialize_all = "snake_case")]
38pub enum QuantileMethod {
39    #[default]
40    Nearest,
41    Lower,
42    Higher,
43    Midpoint,
44    Linear,
45    Equiprobable,
46}
47
48#[deprecated(note = "use QuantileMethod instead")]
49pub type QuantileInterpolOptions = QuantileMethod;
50
51#[derive(Clone, Copy, Debug, PartialEq, Hash)]
52#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
53#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
54pub enum RollingFnParams {
55    Quantile(RollingQuantileParams),
56    Var(RollingVarParams),
57    Rank {
58        method: RollingRankMethod,
59        seed: Option<u64>,
60    },
61    Skew {
62        bias: bool,
63    },
64    Kurtosis {
65        fisher: bool,
66        bias: bool,
67    },
68}
69
70fn det_offsets(i: Idx, window_size: WindowSize, _len: Len) -> (usize, usize) {
71    if window_size == 0 {
72        return (i, i);
73    }
74    (i.saturating_sub(window_size - 1), i + 1)
75}
76fn det_offsets_center(i: Idx, window_size: WindowSize, len: Len) -> (usize, usize) {
77    if window_size == 0 {
78        return (i, i);
79    }
80    let right_window = window_size.div_ceil(2);
81    (
82        i.saturating_sub(window_size - right_window),
83        std::cmp::min(len, i + right_window),
84    )
85}
86
87fn create_validity<Fo>(
88    min_periods: usize,
89    len: usize,
90    window_size: usize,
91    det_offsets_fn: Fo,
92) -> Option<MutableBitmap>
93where
94    Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
95{
96    if min_periods > 1 {
97        let mut validity = MutableBitmap::with_capacity(len);
98        validity.extend_constant(len, true);
99
100        // Set the null values at the boundaries
101
102        // Head.
103        for i in 0..len {
104            let (start, end) = det_offsets_fn(i, window_size, len);
105            if (end - start) < min_periods {
106                validity.set(i, false)
107            } else {
108                break;
109            }
110        }
111        // Tail.
112        for i in (0..len).rev() {
113            let (start, end) = det_offsets_fn(i, window_size, len);
114            if (end - start) < min_periods {
115                validity.set(i, false)
116            } else {
117                break;
118            }
119        }
120
121        Some(validity)
122    } else {
123        None
124    }
125}
126
127// Parameters allowed for rolling operations.
128#[derive(Clone, Copy, Debug, PartialEq, Hash)]
129#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
130#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
131pub struct RollingVarParams {
132    pub ddof: u8,
133}
134
135#[derive(Clone, Copy, Debug, PartialEq)]
136#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
137#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
138pub struct RollingQuantileParams {
139    pub prob: f64,
140    pub method: QuantileMethod,
141}
142
143impl Hash for RollingQuantileParams {
144    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
145        // Will not be NaN, so hash + eq symmetry will hold.
146        self.prob.to_bits().hash(state);
147        self.method.hash(state);
148    }
149}
150
151#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
152#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
153#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
154#[strum(serialize_all = "snake_case")]
155pub enum RollingRankMethod {
156    #[default]
157    Average,
158    Min,
159    Max,
160    Dense,
161    Random,
162}