rustframes/dataframe/
advanced_ops.rs1use wide::f64x4;
2
3use crate::{dataframe::window::Window, DataFrame, Series};
4
5impl DataFrame {
6 pub fn window<'a>(&'a self, window_size: usize) -> Window<'a> {
8 Window::new(self, window_size)
9 }
10
11 pub fn pivot_table(
13 &self,
14 index: &str,
15 columns: &str,
16 values: &str,
17 aggfunc: PivotAggFunc,
18 ) -> DataFrame {
19 use std::collections::{BTreeSet, HashMap};
20
21 let index_idx = self
22 .columns
23 .iter()
24 .position(|c| c == index)
25 .expect("Index column not found");
26 let columns_idx = self
27 .columns
28 .iter()
29 .position(|c| c == columns)
30 .expect("Columns column not found");
31 let values_idx = self
32 .columns
33 .iter()
34 .position(|c| c == values)
35 .expect("Values column not found");
36
37 let mut unique_columns: BTreeSet<String> = BTreeSet::new();
38 let mut unique_indices: BTreeSet<String> = BTreeSet::new();
39
40 match (&self.data[index_idx], &self.data[columns_idx]) {
41 (Series::Utf8(idx_vals), Series::Utf8(col_vals)) => {
42 unique_indices.extend(idx_vals.iter().cloned());
43 unique_columns.extend(col_vals.iter().cloned());
44 }
45 _ => panic!("Pivot currently only supports string indices and columns"),
46 }
47
48 let mut result_columns = vec![index.to_string()];
49 result_columns.extend(unique_columns.iter().cloned());
50
51 let mut result_data: Vec<Vec<String>> = vec![Vec::new(); result_columns.len()];
52
53 for idx_val in &unique_indices {
54 result_data[0].push(idx_val.clone());
55 }
56
57 let mut agg_map: HashMap<(String, String), Vec<f64>> = HashMap::new();
58
59 if let (Series::Utf8(idx_vals), Series::Utf8(col_vals), Series::Float64(val_vals)) = (
60 &self.data[index_idx],
61 &self.data[columns_idx],
62 &self.data[values_idx],
63 ) {
64 for i in 0..idx_vals.len() {
65 let key = (idx_vals[i].clone(), col_vals[i].clone());
66 agg_map.entry(key).or_default().push(val_vals[i]);
67 }
68 }
69
70 for (row_idx, idx_val) in unique_indices.iter().enumerate() {
71 for (col_idx, col_val) in unique_columns.iter().enumerate() {
72 let key = (idx_val.clone(), col_val.clone());
73 let value = if let Some(values) = agg_map.get(&key) {
74 match aggfunc {
75 PivotAggFunc::Sum => simd_sum(values).to_string(),
76 PivotAggFunc::Mean => (simd_sum(values) / values.len() as f64).to_string(),
77 PivotAggFunc::Count => values.len().to_string(),
78 PivotAggFunc::Min => values
79 .iter()
80 .fold(f64::INFINITY, |acc, &x| acc.min(x))
81 .to_string(),
82 PivotAggFunc::Max => values
83 .iter()
84 .fold(f64::NEG_INFINITY, |acc, &x| acc.max(x))
85 .to_string(),
86 }
87 } else {
88 "0".to_string()
89 };
90
91 while result_data[col_idx + 1].len() <= row_idx {
92 result_data[col_idx + 1].push("0".to_string());
93 }
94 result_data[col_idx + 1][row_idx] = value;
95 }
96 }
97
98 let series_data: Vec<Series> = result_data.into_iter().map(Series::Utf8).collect();
99
100 DataFrame {
101 columns: result_columns,
102 data: series_data,
103 }
104 }
105
106 pub fn melt(
108 &self,
109 id_vars: &[&str],
110 value_vars: &[&str],
111 var_name: Option<&str>,
112 value_name: Option<&str>,
113 ) -> DataFrame {
114 let var_name = var_name.unwrap_or("variable");
115 let value_name = value_name.unwrap_or("value");
116
117 let id_indices: Vec<usize> = id_vars
118 .iter()
119 .map(|col| {
120 self.columns
121 .iter()
122 .position(|c| c == col)
123 .expect("ID column not found")
124 })
125 .collect();
126
127 let value_indices: Vec<usize> = value_vars
128 .iter()
129 .map(|col| {
130 self.columns
131 .iter()
132 .position(|c| c == col)
133 .expect("Value column not found")
134 })
135 .collect();
136
137 let mut result_columns = Vec::new();
138 let mut result_data: Vec<Vec<String>> = Vec::new();
139
140 for &col in id_vars {
141 result_columns.push(col.to_string());
142 result_data.push(Vec::new());
143 }
144
145 result_columns.push(var_name.to_string());
146 result_columns.push(value_name.to_string());
147 result_data.push(Vec::new());
148 result_data.push(Vec::new());
149
150 for row_idx in 0..self.len() {
151 for (value_col_idx, &value_idx) in value_indices.iter().enumerate() {
152 for (id_col_idx, &id_idx) in id_indices.iter().enumerate() {
153 let value = match &self.data[id_idx] {
154 Series::Int64(v) => v[row_idx].to_string(),
155 Series::Float64(v) => v[row_idx].to_string(),
156 Series::Bool(v) => v[row_idx].to_string(),
157 Series::Utf8(v) => v[row_idx].clone(),
158 };
159 result_data[id_col_idx].push(value);
160 }
161
162 result_data[id_vars.len()].push(value_vars[value_col_idx].to_string());
163
164 let value = match &self.data[value_idx] {
165 Series::Int64(v) => v[row_idx].to_string(),
166 Series::Float64(v) => v[row_idx].to_string(),
167 Series::Bool(v) => v[row_idx].to_string(),
168 Series::Utf8(v) => v[row_idx].clone(),
169 };
170 result_data[id_vars.len() + 1].push(value);
171 }
172 }
173
174 let series_data: Vec<Series> = result_data.into_iter().map(Series::Utf8).collect();
175
176 DataFrame {
177 columns: result_columns,
178 data: series_data,
179 }
180 }
181
182 pub fn cv_split(&self, n_folds: usize, shuffle: bool) -> Vec<(DataFrame, DataFrame)> {
184 use rand::seq::SliceRandom;
185 let mut indices: Vec<usize> = (0..self.len()).collect();
186
187 if shuffle {
188 let mut rng = rand::rng();
189 indices.shuffle(&mut rng);
190 }
191
192 let fold_size = self.len() / n_folds;
193 let mut folds = Vec::new();
194
195 for fold in 0..n_folds {
196 let start = fold * fold_size;
197 let end = if fold == n_folds - 1 {
198 self.len()
199 } else {
200 start + fold_size
201 };
202
203 let test_indices: std::collections::HashSet<usize> =
204 indices[start..end].iter().cloned().collect();
205
206 let mask: Vec<bool> = (0..self.len()).map(|i| test_indices.contains(&i)).collect();
207
208 let test_df = self.filter(&mask);
209 let train_df = self.filter(&mask.iter().map(|b| !b).collect::<Vec<_>>());
210
211 folds.push((train_df, test_df));
212 }
213
214 folds
215 }
216
217 pub fn sample(&self, n: usize, replace: bool) -> DataFrame {
219 use rand::seq::SliceRandom;
220 use rand::Rng;
221
222 let mut rng = rand::rng();
223 let indices: Vec<usize> = if replace {
224 (0..n).map(|_| rng.random_range(0..self.len())).collect()
225 } else {
226 let mut all_indices: Vec<usize> = (0..self.len()).collect();
227 all_indices.shuffle(&mut rng);
228 all_indices.into_iter().take(n.min(self.len())).collect()
229 };
230
231 let mask: Vec<bool> = (0..self.len()).map(|i| indices.contains(&i)).collect();
232
233 self.filter(&mask)
234 }
235}
236
237#[derive(Debug, Clone)]
238pub enum PivotAggFunc {
239 Sum,
240 Mean,
241 Count,
242 Min,
243 Max,
244}
245
246fn simd_sum(values: &[f64]) -> f64 {
248 let mut acc = f64x4::from([0.0; 4]);
249 let chunks = values.chunks_exact(4);
250 let remainder = chunks.remainder();
251
252 for chunk in chunks {
253 acc += f64x4::from([chunk[0], chunk[1], chunk[2], chunk[3]]);
254 }
255
256 let arr: [f64; 4] = acc.into();
257 let mut total: f64 = arr.iter().sum();
258 for &r in remainder {
259 total += r;
260 }
261
262 total
263}