polars_compute/rolling/
mod.rs

1mod mean;
2mod min_max;
3mod moment;
4pub mod no_nulls;
5pub mod nulls;
6pub mod quantile_filter;
7mod rank;
8mod sum;
9
10pub(super) mod window;
11use std::hash::Hash;
12use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};
13
14use arrow::array::{ArrayRef, PrimitiveArray};
15use arrow::bitmap::{Bitmap, MutableBitmap};
16use arrow::types::NativeType;
17pub use mean::MeanWindow;
18use num_traits::{Bounded, Float, NumCast, One, Zero};
19use polars_utils::float::IsFloat;
20#[cfg(feature = "serde")]
21use serde::{Deserialize, Serialize};
22use strum_macros::IntoStaticStr;
23pub use sum::SumWindow;
24use window::*;
25
26type Start = usize;
27type End = usize;
28type Idx = usize;
29type WindowSize = usize;
30type Len = usize;
31
32#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
33#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
34#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
35#[strum(serialize_all = "snake_case")]
36pub enum QuantileMethod {
37    #[default]
38    Nearest,
39    Lower,
40    Higher,
41    Midpoint,
42    Linear,
43    Equiprobable,
44}
45
46#[deprecated(note = "use QuantileMethod instead")]
47pub type QuantileInterpolOptions = QuantileMethod;
48
49#[derive(Clone, Copy, Debug, PartialEq, Hash)]
50#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
51#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
52pub enum RollingFnParams {
53    Quantile(RollingQuantileParams),
54    Var(RollingVarParams),
55    Rank {
56        method: RollingRankMethod,
57        seed: Option<u64>,
58    },
59    Skew {
60        bias: bool,
61    },
62    Kurtosis {
63        fisher: bool,
64        bias: bool,
65    },
66}
67
68fn det_offsets(i: Idx, window_size: WindowSize, _len: Len) -> (usize, usize) {
69    (i.saturating_sub(window_size - 1), i + 1)
70}
71fn det_offsets_center(i: Idx, window_size: WindowSize, len: Len) -> (usize, usize) {
72    let right_window = window_size.div_ceil(2);
73    (
74        i.saturating_sub(window_size - right_window),
75        std::cmp::min(len, i + right_window),
76    )
77}
78
79fn create_validity<Fo>(
80    min_periods: usize,
81    len: usize,
82    window_size: usize,
83    det_offsets_fn: Fo,
84) -> Option<MutableBitmap>
85where
86    Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
87{
88    if min_periods > 1 {
89        let mut validity = MutableBitmap::with_capacity(len);
90        validity.extend_constant(len, true);
91
92        // Set the null values at the boundaries
93
94        // Head.
95        for i in 0..len {
96            let (start, end) = det_offsets_fn(i, window_size, len);
97            if (end - start) < min_periods {
98                validity.set(i, false)
99            } else {
100                break;
101            }
102        }
103        // Tail.
104        for i in (0..len).rev() {
105            let (start, end) = det_offsets_fn(i, window_size, len);
106            if (end - start) < min_periods {
107                validity.set(i, false)
108            } else {
109                break;
110            }
111        }
112
113        Some(validity)
114    } else {
115        None
116    }
117}
118
119// Parameters allowed for rolling operations.
120#[derive(Clone, Copy, Debug, PartialEq, Hash)]
121#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
122#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
123pub struct RollingVarParams {
124    pub ddof: u8,
125}
126
127#[derive(Clone, Copy, Debug, PartialEq)]
128#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
129#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
130pub struct RollingQuantileParams {
131    pub prob: f64,
132    pub method: QuantileMethod,
133}
134
135impl Hash for RollingQuantileParams {
136    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
137        // Will not be NaN, so hash + eq symmetry will hold.
138        self.prob.to_bits().hash(state);
139        self.method.hash(state);
140    }
141}
142
143#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
144#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
145#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
146#[strum(serialize_all = "snake_case")]
147pub enum RollingRankMethod {
148    #[default]
149    Average,
150    Min,
151    Max,
152    Dense,
153    Random,
154}