rustframes/dataframe/
window.rs

1use super::{DataFrame, Series};
2use std::collections::VecDeque;
3
4pub struct Window<'a> {
5    df: &'a DataFrame,
6    window_size: usize,
7    partition_by: Option<String>,
8    order_by: Option<(String, bool)>, // column, ascending
9}
10
11impl<'a> Window<'a> {
12    pub fn new(df: &'a DataFrame, window_size: usize) -> Self {
13        Window {
14            df,
15            window_size,
16            partition_by: None,
17            order_by: None,
18        }
19    }
20
21    pub fn partition_by(mut self, column: &str) -> Self {
22        self.partition_by = Some(column.to_string());
23        self
24    }
25
26    pub fn order_by(mut self, column: &str, ascending: bool) -> Self {
27        self.order_by = Some((column.to_string(), ascending));
28        self
29    }
30
31    /// Rolling sum over window
32    pub fn rolling_sum(&self, column: &str) -> Series {
33        self.apply_rolling_function(column, |window| window.iter().sum())
34    }
35
36    /// Rolling mean over window
37    pub fn rolling_mean(&self, column: &str) -> Series {
38        self.apply_rolling_function(column, |window| {
39            let sum: f64 = window.iter().sum();
40            sum / window.len() as f64
41        })
42    }
43
44    /// Rolling standard deviation
45    pub fn rolling_std(&self, column: &str) -> Series {
46        self.apply_rolling_function(column, |window| {
47            let mean: f64 = window.iter().sum::<f64>() / window.len() as f64;
48            let variance =
49                window.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / window.len() as f64;
50            variance.sqrt()
51        })
52    }
53
54    /// Rolling minimum
55    pub fn rolling_min(&self, column: &str) -> Series {
56        self.apply_rolling_function(column, |window| {
57            window.iter().fold(f64::INFINITY, |acc, &x| acc.min(x))
58        })
59    }
60
61    /// Rolling maximum
62    pub fn rolling_max(&self, column: &str) -> Series {
63        self.apply_rolling_function(column, |window| {
64            window.iter().fold(f64::INFINITY, |acc, &x| acc.max(x))
65        })
66    }
67
68    /// Exponentially weighted moving average
69    pub fn ewm(&self, column: &str, alpha: f64) -> Series {
70        let col_idx = self
71            .df
72            .columns
73            .iter()
74            .position(|c| c == column)
75            .expect("Column not found");
76
77        let mut result = Vec::new();
78
79        if let Series::Float64(values) = &self.df.data[col_idx] {
80            let mut ewm = values[0];
81            result.push(ewm);
82
83            for &value in &values[1..] {
84                ewm = alpha * value + (1.0 - alpha) * ewm;
85                result.push(ewm);
86            }
87        }
88
89        Series::Float64(result)
90    }
91
92    /// Lag operation (shift values)
93    pub fn lag(&self, column: &str, periods: usize) -> Series {
94        let col_idx = self
95            .df
96            .columns
97            .iter()
98            .position(|c| c == column)
99            .expect("Column not found");
100
101        match &self.df.data[col_idx] {
102            Series::Float64(values) => {
103                let mut result = vec![f64::NAN; periods];
104                result.extend_from_slice(&values[..values.len().saturating_sub(periods)]);
105                Series::Float64(result)
106            }
107            Series::Int64(values) => {
108                let mut result = vec![0; periods]; // Use 0 as null for integers
109                result.extend_from_slice(&values[..values.len().saturating_sub(periods)]);
110                Series::Int64(result)
111            }
112            Series::Utf8(values) => {
113                let mut result = vec!["".to_string(); periods];
114                result.extend(
115                    values[..values.len().saturating_sub(periods)]
116                        .iter()
117                        .cloned(),
118                );
119                Series::Utf8(result)
120            }
121            Series::Bool(values) => {
122                let mut result = vec![false; periods];
123                result.extend_from_slice(&values[..values.len().saturating_sub(periods)]);
124                Series::Bool(result)
125            }
126        }
127    }
128
129    /// Lead operation (negative shift)
130    pub fn lead(&self, column: &str, periods: usize) -> Series {
131        let col_idx = self
132            .df
133            .columns
134            .iter()
135            .position(|c| c == column)
136            .expect("Column not found");
137
138        match &self.df.data[col_idx] {
139            Series::Float64(values) => {
140                let mut result = values[periods..].to_vec();
141                result.extend(vec![f64::NAN; periods]);
142                Series::Float64(result)
143            }
144            Series::Int64(values) => {
145                let mut result = values[periods..].to_vec();
146                result.extend(vec![0; periods]);
147                Series::Int64(result)
148            }
149            Series::Utf8(values) => {
150                let mut result = values[periods..].to_vec();
151                result.extend(vec!["".to_string(); periods]);
152                Series::Utf8(result)
153            }
154            Series::Bool(values) => {
155                let mut result = values[periods..].to_vec();
156                result.extend(vec![false; periods]);
157                Series::Bool(result)
158            }
159        }
160    }
161
162    /// Percent change
163    pub fn pct_change(&self, column: &str) -> Series {
164        let col_idx = self
165            .df
166            .columns
167            .iter()
168            .position(|c| c == column)
169            .expect("Column not found");
170
171        match &self.df.data[col_idx] {
172            Series::Float64(values) => {
173                let mut result = vec![f64::NAN]; // First value is NaN
174                for i in 1..values.len() {
175                    let pct = (values[i] - values[i - 1]) / values[i - 1];
176                    result.push(pct);
177                }
178                Series::Float64(result)
179            }
180            Series::Int64(values) => {
181                let mut result = vec![f64::NAN]; // First value is NaN
182                for i in 1..values.len() {
183                    let pct = (values[i] - values[i - 1]) as f64 / values[i - 1] as f64;
184                    result.push(pct);
185                }
186                Series::Float64(result)
187            }
188            _ => panic!("Percent change only supported for numeric columns"),
189        }
190    }
191
192    fn apply_rolling_function<F>(&self, column: &str, func: F) -> Series
193    where
194        F: Fn(&[f64]) -> f64,
195    {
196        let col_idx = self
197            .df
198            .columns
199            .iter()
200            .position(|c| c == column)
201            .expect("Column not found");
202
203        if let Series::Float64(values) = &self.df.data[col_idx] {
204            let mut result = Vec::new();
205            let mut window: VecDeque<f64> = VecDeque::with_capacity(self.window_size);
206
207            for &value in values {
208                window.push_back(value);
209
210                if window.len() > self.window_size {
211                    window.pop_front();
212                }
213
214                if window.len() == self.window_size {
215                    let window_slice: Vec<f64> = window.iter().cloned().collect();
216                    result.push(func(&window_slice));
217                } else {
218                    result.push(f64::NAN);
219                }
220            }
221
222            Series::Float64(result)
223        } else {
224            panic!("Rolling function only supported for Float64 columns");
225        }
226    }
227}