1mod mean;
2mod min_max;
3mod moment;
4pub mod no_nulls;
5pub mod nulls;
6pub mod quantile_filter;
7mod rank;
8mod sum;
9
10mod arg_min_max;
11pub(super) mod window;
12use std::hash::Hash;
13use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};
14
15pub use arg_min_max::{ArgMaxWindow, ArgMinMaxWindow, ArgMinWindow};
16use arrow::array::{ArrayRef, PrimitiveArray};
17use arrow::bitmap::{Bitmap, MutableBitmap};
18use arrow::types::NativeType;
19pub use mean::MeanWindow;
20use num_traits::{Bounded, Float, NumCast, One, Zero};
21use polars_utils::float::IsFloat;
22#[cfg(feature = "serde")]
23use serde::{Deserialize, Serialize};
24use strum_macros::IntoStaticStr;
25pub use sum::SumWindow;
26use window::*;
27
28type Start = usize;
29type End = usize;
30type Idx = usize;
31type WindowSize = usize;
32type Len = usize;
33
34#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
35#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
36#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
37#[strum(serialize_all = "snake_case")]
38pub enum QuantileMethod {
39 #[default]
40 Nearest,
41 Lower,
42 Higher,
43 Midpoint,
44 Linear,
45 Equiprobable,
46}
47
48#[deprecated(note = "use QuantileMethod instead")]
49pub type QuantileInterpolOptions = QuantileMethod;
50
51#[derive(Clone, Copy, Debug, PartialEq, Hash)]
52#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
53#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
54pub enum RollingFnParams {
55 Quantile(RollingQuantileParams),
56 Var(RollingVarParams),
57 Rank {
58 method: RollingRankMethod,
59 seed: Option<u64>,
60 },
61 Skew {
62 bias: bool,
63 },
64 Kurtosis {
65 fisher: bool,
66 bias: bool,
67 },
68}
69
70fn det_offsets(i: Idx, window_size: WindowSize, _len: Len) -> (usize, usize) {
71 (i.saturating_sub(window_size - 1), i + 1)
72}
73fn det_offsets_center(i: Idx, window_size: WindowSize, len: Len) -> (usize, usize) {
74 let right_window = window_size.div_ceil(2);
75 (
76 i.saturating_sub(window_size - right_window),
77 std::cmp::min(len, i + right_window),
78 )
79}
80
81fn create_validity<Fo>(
82 min_periods: usize,
83 len: usize,
84 window_size: usize,
85 det_offsets_fn: Fo,
86) -> Option<MutableBitmap>
87where
88 Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
89{
90 if min_periods > 1 {
91 let mut validity = MutableBitmap::with_capacity(len);
92 validity.extend_constant(len, true);
93
94 for i in 0..len {
98 let (start, end) = det_offsets_fn(i, window_size, len);
99 if (end - start) < min_periods {
100 validity.set(i, false)
101 } else {
102 break;
103 }
104 }
105 for i in (0..len).rev() {
107 let (start, end) = det_offsets_fn(i, window_size, len);
108 if (end - start) < min_periods {
109 validity.set(i, false)
110 } else {
111 break;
112 }
113 }
114
115 Some(validity)
116 } else {
117 None
118 }
119}
120
121#[derive(Clone, Copy, Debug, PartialEq, Hash)]
123#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
124#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
125pub struct RollingVarParams {
126 pub ddof: u8,
127}
128
129#[derive(Clone, Copy, Debug, PartialEq)]
130#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
131#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
132pub struct RollingQuantileParams {
133 pub prob: f64,
134 pub method: QuantileMethod,
135}
136
137impl Hash for RollingQuantileParams {
138 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
139 self.prob.to_bits().hash(state);
141 self.method.hash(state);
142 }
143}
144
145#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash, IntoStaticStr)]
146#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
147#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
148#[strum(serialize_all = "snake_case")]
149pub enum RollingRankMethod {
150 #[default]
151 Average,
152 Min,
153 Max,
154 Dense,
155 Random,
156}