rustframes/dataframe/
advanced_ops.rs

1use wide::f64x4;
2
3use crate::{dataframe::window::Window, DataFrame, Series};
4
5impl DataFrame {
6    /// Create window for rolling operations
7    pub fn window<'a>(&'a self, window_size: usize) -> Window<'a> {
8        Window::new(self, window_size)
9    }
10
11    /// Pivot table functionality
12    pub fn pivot_table(
13        &self,
14        index: &str,
15        columns: &str,
16        values: &str,
17        aggfunc: PivotAggFunc,
18    ) -> DataFrame {
19        use std::collections::{BTreeSet, HashMap};
20
21        let index_idx = self
22            .columns
23            .iter()
24            .position(|c| c == index)
25            .expect("Index column not found");
26        let columns_idx = self
27            .columns
28            .iter()
29            .position(|c| c == columns)
30            .expect("Columns column not found");
31        let values_idx = self
32            .columns
33            .iter()
34            .position(|c| c == values)
35            .expect("Values column not found");
36
37        let mut unique_columns: BTreeSet<String> = BTreeSet::new();
38        let mut unique_indices: BTreeSet<String> = BTreeSet::new();
39
40        match (&self.data[index_idx], &self.data[columns_idx]) {
41            (Series::Utf8(idx_vals), Series::Utf8(col_vals)) => {
42                unique_indices.extend(idx_vals.iter().cloned());
43                unique_columns.extend(col_vals.iter().cloned());
44            }
45            _ => panic!("Pivot currently only supports string indices and columns"),
46        }
47
48        let mut result_columns = vec![index.to_string()];
49        result_columns.extend(unique_columns.iter().cloned());
50
51        let mut result_data: Vec<Vec<String>> = vec![Vec::new(); result_columns.len()];
52
53        for idx_val in &unique_indices {
54            result_data[0].push(idx_val.clone());
55        }
56
57        let mut agg_map: HashMap<(String, String), Vec<f64>> = HashMap::new();
58
59        if let (Series::Utf8(idx_vals), Series::Utf8(col_vals), Series::Float64(val_vals)) = (
60            &self.data[index_idx],
61            &self.data[columns_idx],
62            &self.data[values_idx],
63        ) {
64            for i in 0..idx_vals.len() {
65                let key = (idx_vals[i].clone(), col_vals[i].clone());
66                agg_map.entry(key).or_default().push(val_vals[i]);
67            }
68        }
69
70        for (row_idx, idx_val) in unique_indices.iter().enumerate() {
71            for (col_idx, col_val) in unique_columns.iter().enumerate() {
72                let key = (idx_val.clone(), col_val.clone());
73                let value = if let Some(values) = agg_map.get(&key) {
74                    match aggfunc {
75                        PivotAggFunc::Sum => simd_sum(values).to_string(),
76                        PivotAggFunc::Mean => (simd_sum(values) / values.len() as f64).to_string(),
77                        PivotAggFunc::Count => values.len().to_string(),
78                        PivotAggFunc::Min => values
79                            .iter()
80                            .fold(f64::INFINITY, |acc, &x| acc.min(x))
81                            .to_string(),
82                        PivotAggFunc::Max => values
83                            .iter()
84                            .fold(f64::NEG_INFINITY, |acc, &x| acc.max(x))
85                            .to_string(),
86                    }
87                } else {
88                    "0".to_string()
89                };
90
91                while result_data[col_idx + 1].len() <= row_idx {
92                    result_data[col_idx + 1].push("0".to_string());
93                }
94                result_data[col_idx + 1][row_idx] = value;
95            }
96        }
97
98        let series_data: Vec<Series> = result_data.into_iter().map(Series::Utf8).collect();
99
100        DataFrame {
101            columns: result_columns,
102            data: series_data,
103        }
104    }
105
106    /// Melt operation (unpivot)
107    pub fn melt(
108        &self,
109        id_vars: &[&str],
110        value_vars: &[&str],
111        var_name: Option<&str>,
112        value_name: Option<&str>,
113    ) -> DataFrame {
114        let var_name = var_name.unwrap_or("variable");
115        let value_name = value_name.unwrap_or("value");
116
117        let id_indices: Vec<usize> = id_vars
118            .iter()
119            .map(|col| {
120                self.columns
121                    .iter()
122                    .position(|c| c == col)
123                    .expect("ID column not found")
124            })
125            .collect();
126
127        let value_indices: Vec<usize> = value_vars
128            .iter()
129            .map(|col| {
130                self.columns
131                    .iter()
132                    .position(|c| c == col)
133                    .expect("Value column not found")
134            })
135            .collect();
136
137        let mut result_columns = Vec::new();
138        let mut result_data: Vec<Vec<String>> = Vec::new();
139
140        for &col in id_vars {
141            result_columns.push(col.to_string());
142            result_data.push(Vec::new());
143        }
144
145        result_columns.push(var_name.to_string());
146        result_columns.push(value_name.to_string());
147        result_data.push(Vec::new());
148        result_data.push(Vec::new());
149
150        for row_idx in 0..self.len() {
151            for (value_col_idx, &value_idx) in value_indices.iter().enumerate() {
152                for (id_col_idx, &id_idx) in id_indices.iter().enumerate() {
153                    let value = match &self.data[id_idx] {
154                        Series::Int64(v) => v[row_idx].to_string(),
155                        Series::Float64(v) => v[row_idx].to_string(),
156                        Series::Bool(v) => v[row_idx].to_string(),
157                        Series::Utf8(v) => v[row_idx].clone(),
158                    };
159                    result_data[id_col_idx].push(value);
160                }
161
162                result_data[id_vars.len()].push(value_vars[value_col_idx].to_string());
163
164                let value = match &self.data[value_idx] {
165                    Series::Int64(v) => v[row_idx].to_string(),
166                    Series::Float64(v) => v[row_idx].to_string(),
167                    Series::Bool(v) => v[row_idx].to_string(),
168                    Series::Utf8(v) => v[row_idx].clone(),
169                };
170                result_data[id_vars.len() + 1].push(value);
171            }
172        }
173
174        let series_data: Vec<Series> = result_data.into_iter().map(Series::Utf8).collect();
175
176        DataFrame {
177            columns: result_columns,
178            data: series_data,
179        }
180    }
181
182    /// Cross-validation split
183    pub fn cv_split(&self, n_folds: usize, shuffle: bool) -> Vec<(DataFrame, DataFrame)> {
184        use rand::seq::SliceRandom;
185        let mut indices: Vec<usize> = (0..self.len()).collect();
186
187        if shuffle {
188            let mut rng = rand::rng();
189            indices.shuffle(&mut rng);
190        }
191
192        let fold_size = self.len() / n_folds;
193        let mut folds = Vec::new();
194
195        for fold in 0..n_folds {
196            let start = fold * fold_size;
197            let end = if fold == n_folds - 1 {
198                self.len()
199            } else {
200                start + fold_size
201            };
202
203            let test_indices: std::collections::HashSet<usize> =
204                indices[start..end].iter().cloned().collect();
205
206            let mask: Vec<bool> = (0..self.len()).map(|i| test_indices.contains(&i)).collect();
207
208            let test_df = self.filter(&mask);
209            let train_df = self.filter(&mask.iter().map(|b| !b).collect::<Vec<_>>());
210
211            folds.push((train_df, test_df));
212        }
213
214        folds
215    }
216
217    /// Sample rows randomly
218    pub fn sample(&self, n: usize, replace: bool) -> DataFrame {
219        use rand::seq::SliceRandom;
220        use rand::Rng;
221
222        let mut rng = rand::rng();
223        let indices: Vec<usize> = if replace {
224            (0..n).map(|_| rng.random_range(0..self.len())).collect()
225        } else {
226            let mut all_indices: Vec<usize> = (0..self.len()).collect();
227            all_indices.shuffle(&mut rng);
228            all_indices.into_iter().take(n.min(self.len())).collect()
229        };
230
231        let mask: Vec<bool> = (0..self.len()).map(|i| indices.contains(&i)).collect();
232
233        self.filter(&mask)
234    }
235}
236
237#[derive(Debug, Clone)]
238pub enum PivotAggFunc {
239    Sum,
240    Mean,
241    Count,
242    Min,
243    Max,
244}
245
246/// SIMD accelerated sum using `wide::f64x4`
247fn simd_sum(values: &[f64]) -> f64 {
248    let mut acc = f64x4::from([0.0; 4]);
249    let chunks = values.chunks_exact(4);
250    let remainder = chunks.remainder();
251
252    for chunk in chunks {
253        acc += f64x4::from([chunk[0], chunk[1], chunk[2], chunk[3]]);
254    }
255
256    let arr: [f64; 4] = acc.into();
257    let mut total: f64 = arr.iter().sum();
258    for &r in remainder {
259        total += r;
260    }
261
262    total
263}