polars_compute/rolling/
mod.rs

1mod min_max;
2pub mod moment;
3pub mod no_nulls;
4pub mod nulls;
5pub mod quantile_filter;
6pub(super) mod window;
7use std::hash::Hash;
8use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};
9
10use arrow::array::{ArrayRef, PrimitiveArray};
11use arrow::bitmap::{Bitmap, MutableBitmap};
12use arrow::types::NativeType;
13use num_traits::{Bounded, Float, NumCast, One, Zero};
14use polars_utils::float::IsFloat;
15#[cfg(feature = "serde")]
16use serde::{Deserialize, Serialize};
17use strum_macros::IntoStaticStr;
18use window::*;
19
20type Start = usize;
21type End = usize;
22type Idx = usize;
23type WindowSize = usize;
24type Len = usize;
25
26#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
27#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
28#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
29#[strum(serialize_all = "snake_case")]
30pub enum QuantileMethod {
31    #[default]
32    Nearest,
33    Lower,
34    Higher,
35    Midpoint,
36    Linear,
37    Equiprobable,
38}
39
40#[deprecated(note = "use QuantileMethod instead")]
41pub type QuantileInterpolOptions = QuantileMethod;
42
43#[derive(Clone, Copy, Debug, PartialEq, Hash)]
44#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
45#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
46pub enum RollingFnParams {
47    Quantile(RollingQuantileParams),
48    Var(RollingVarParams),
49    Skew { bias: bool },
50    Kurtosis { fisher: bool, bias: bool },
51}
52
53fn det_offsets(i: Idx, window_size: WindowSize, _len: Len) -> (usize, usize) {
54    (i.saturating_sub(window_size - 1), i + 1)
55}
56fn det_offsets_center(i: Idx, window_size: WindowSize, len: Len) -> (usize, usize) {
57    let right_window = window_size.div_ceil(2);
58    (
59        i.saturating_sub(window_size - right_window),
60        std::cmp::min(len, i + right_window),
61    )
62}
63
64fn create_validity<Fo>(
65    min_periods: usize,
66    len: usize,
67    window_size: usize,
68    det_offsets_fn: Fo,
69) -> Option<MutableBitmap>
70where
71    Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
72{
73    if min_periods > 1 {
74        let mut validity = MutableBitmap::with_capacity(len);
75        validity.extend_constant(len, true);
76
77        // Set the null values at the boundaries
78
79        // Head.
80        for i in 0..len {
81            let (start, end) = det_offsets_fn(i, window_size, len);
82            if (end - start) < min_periods {
83                validity.set(i, false)
84            } else {
85                break;
86            }
87        }
88        // Tail.
89        for i in (0..len).rev() {
90            let (start, end) = det_offsets_fn(i, window_size, len);
91            if (end - start) < min_periods {
92                validity.set(i, false)
93            } else {
94                break;
95            }
96        }
97
98        Some(validity)
99    } else {
100        None
101    }
102}
103
104// Parameters allowed for rolling operations.
105#[derive(Clone, Copy, Debug, PartialEq, Hash)]
106#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
107#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
108pub struct RollingVarParams {
109    pub ddof: u8,
110}
111
112#[derive(Clone, Copy, Debug, PartialEq)]
113#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
114#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
115pub struct RollingQuantileParams {
116    pub prob: f64,
117    pub method: QuantileMethod,
118}
119
120impl Hash for RollingQuantileParams {
121    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
122        // Will not be NaN, so hash + eq symmetry will hold.
123        self.prob.to_bits().hash(state);
124        self.method.hash(state);
125    }
126}