rustframes/dataframe/
groupby.rs1use super::{DataFrame, Series};
2use std::collections::HashMap;
3
4pub struct GroupBy<'a> {
5 df: &'a DataFrame,
6 by_column: String,
7 groups: HashMap<String, Vec<usize>>,
8}
9
10impl<'a> GroupBy<'a> {
11 pub fn new(df: &'a DataFrame, by: &str) -> Self {
12 let by_column = by.to_string();
13 let col_idx = df
14 .columns
15 .iter()
16 .position(|c| c == by)
17 .expect("GroupBy column not found");
18
19 let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
20
21 match &df.data[col_idx] {
22 Series::Utf8(values) => {
23 for (idx, value) in values.iter().enumerate() {
24 groups.entry(value.clone()).or_default().push(idx);
25 }
26 }
27 Series::Int64(values) => {
28 for (idx, &value) in values.iter().enumerate() {
29 groups.entry(value.to_string()).or_default().push(idx);
30 }
31 }
32 Series::Float64(values) => {
33 for (idx, &value) in values.iter().enumerate() {
34 groups.entry(value.to_string()).or_default().push(idx);
35 }
36 }
37 Series::Bool(values) => {
38 for (idx, &value) in values.iter().enumerate() {
39 groups.entry(value.to_string()).or_default().push(idx);
40 }
41 }
42 }
43
44 GroupBy {
45 df,
46 by_column,
47 groups,
48 }
49 }
50
51 pub fn count(&self) -> DataFrame {
53 let mut keys = Vec::new();
54 let mut counts = Vec::new();
55
56 for (key, indices) in &self.groups {
57 keys.push(key.clone());
58 counts.push(indices.len() as i64);
59 }
60
61 DataFrame::new(vec![
62 (self.by_column.clone(), Series::Utf8(keys)),
63 ("count".to_string(), Series::Int64(counts)),
64 ])
65 }
66
67 pub fn sum(&self) -> DataFrame {
69 let mut result_columns = vec![(
70 self.by_column.clone(),
71 Series::Utf8(self.groups.keys().cloned().collect()),
72 )];
73
74 for (col_idx, col_name) in self.df.columns.iter().enumerate() {
75 if col_name == &self.by_column {
76 continue; }
78
79 let mut group_sums = Vec::new();
80
81 match &self.df.data[col_idx] {
82 Series::Int64(values) => {
83 for key in self.groups.keys() {
84 let indices = &self.groups[key];
85 let sum: i64 = indices.iter().map(|&i| values[i]).sum();
86 group_sums.push(sum);
87 }
88 result_columns.push((col_name.clone(), Series::Int64(group_sums)));
89 }
90 Series::Float64(values) => {
91 let mut group_sums = Vec::new();
92 for key in self.groups.keys() {
93 let indices = &self.groups[key];
94 let sum: f64 = indices.iter().map(|&i| values[i]).sum();
95 group_sums.push(sum);
96 }
97 result_columns.push((col_name.clone(), Series::Float64(group_sums)));
98 }
99 _ => {
100 continue;
102 }
103 }
104 }
105
106 DataFrame::new(result_columns)
107 }
108
109 pub fn mean(&self) -> DataFrame {
111 let mut result_columns = vec![(
112 self.by_column.clone(),
113 Series::Utf8(self.groups.keys().cloned().collect()),
114 )];
115
116 for (col_idx, col_name) in self.df.columns.iter().enumerate() {
117 if col_name == &self.by_column {
118 continue;
119 }
120
121 let mut group_means = Vec::new();
122
123 match &self.df.data[col_idx] {
124 Series::Int64(values) => {
125 for key in self.groups.keys() {
126 let indices = &self.groups[key];
127 let sum: i64 = indices.iter().map(|&i| values[i]).sum();
128 let mean = sum as f64 / indices.len() as f64;
129 group_means.push(mean);
130 }
131 result_columns.push((col_name.clone(), Series::Float64(group_means)));
132 }
133 Series::Float64(values) => {
134 for key in self.groups.keys() {
135 let indices = &self.groups[key];
136 let sum: f64 = indices.iter().map(|&i| values[i]).sum();
137 let mean = sum / indices.len() as f64;
138 group_means.push(mean);
139 }
140 result_columns.push((col_name.clone(), Series::Float64(group_means)));
141 }
142 _ => continue,
143 }
144 }
145
146 DataFrame::new(result_columns)
147 }
148
149 pub fn std(&self) -> DataFrame {
151 let mut result_columns = vec![(
152 self.by_column.clone(),
153 Series::Utf8(self.groups.keys().cloned().collect()),
154 )];
155
156 for (col_idx, col_name) in self.df.columns.iter().enumerate() {
157 if col_name == &self.by_column {
158 continue;
159 }
160
161 let mut group_stds = Vec::new();
162
163 match &self.df.data[col_idx] {
164 Series::Int64(values) => {
165 for key in self.groups.keys() {
166 let indices = &self.groups[key];
167 let values_in_group: Vec<f64> =
168 indices.iter().map(|&i| values[i] as f64).collect();
169 let mean: f64 =
170 values_in_group.iter().sum::<f64>() / values_in_group.len() as f64;
171 let variance = values_in_group
172 .iter()
173 .map(|&x| (x - mean).powi(2))
174 .sum::<f64>()
175 / values_in_group.len() as f64;
176 group_stds.push(variance.sqrt());
177 }
178 result_columns.push((col_name.clone(), Series::Float64(group_stds)));
179 }
180 Series::Float64(values) => {
181 for key in self.groups.keys() {
182 let indices = &self.groups[key];
183 let values_in_group: Vec<f64> =
184 indices.iter().map(|&i| values[i]).collect();
185 let mean: f64 =
186 values_in_group.iter().sum::<f64>() / values_in_group.len() as f64;
187 let variance = values_in_group
188 .iter()
189 .map(|&x| (x - mean).powi(2))
190 .sum::<f64>()
191 / values_in_group.len() as f64;
192 group_stds.push(variance.sqrt());
193 }
194 result_columns.push((col_name.clone(), Series::Float64(group_stds)));
195 }
196 _ => continue,
197 }
198 }
199
200 DataFrame::new(result_columns)
201 }
202
203 pub fn min(&self) -> DataFrame {
205 let mut result_columns = vec![(
206 self.by_column.clone(),
207 Series::Utf8(self.groups.keys().cloned().collect()),
208 )];
209
210 for (col_idx, col_name) in self.df.columns.iter().enumerate() {
211 if col_name == &self.by_column {
212 continue;
213 }
214
215 match &self.df.data[col_idx] {
216 Series::Int64(values) => {
217 let mut group_mins = Vec::new();
218 for key in self.groups.keys() {
219 let indices = &self.groups[key];
220 let min_val = indices.iter().map(|&i| values[i]).min().unwrap_or(0);
221 group_mins.push(min_val);
222 }
223 result_columns.push((col_name.clone(), Series::Int64(group_mins)));
224 }
225 Series::Float64(values) => {
226 let mut group_mins = Vec::new();
227 for key in self.groups.keys() {
228 let indices = &self.groups[key];
229 let min_val = indices
230 .iter()
231 .map(|&i| values[i])
232 .fold(f64::INFINITY, |acc, x| acc.min(x));
233 group_mins.push(min_val);
234 }
235 result_columns.push((col_name.clone(), Series::Float64(group_mins)));
236 }
237 _ => continue,
238 }
239 }
240
241 DataFrame::new(result_columns)
242 }
243
244 pub fn max(&self) -> DataFrame {
246 let mut result_columns = vec![(
247 self.by_column.clone(),
248 Series::Utf8(self.groups.keys().cloned().collect()),
249 )];
250
251 for (col_idx, col_name) in self.df.columns.iter().enumerate() {
252 if col_name == &self.by_column {
253 continue;
254 }
255
256 match &self.df.data[col_idx] {
257 Series::Int64(values) => {
258 let mut group_maxs = Vec::new();
259 for key in self.groups.keys() {
260 let indices = &self.groups[key];
261 let max_val = indices.iter().map(|&i| values[i]).max().unwrap_or(0);
262 group_maxs.push(max_val);
263 }
264 result_columns.push((col_name.clone(), Series::Int64(group_maxs)));
265 }
266 Series::Float64(values) => {
267 let mut group_maxs = Vec::new();
268 for key in self.groups.keys() {
269 let indices = &self.groups[key];
270 let max_val = indices
271 .iter()
272 .map(|&i| values[i])
273 .fold(f64::NEG_INFINITY, |acc, x| acc.max(x));
274 group_maxs.push(max_val);
275 }
276 result_columns.push((col_name.clone(), Series::Float64(group_maxs)));
277 }
278 _ => continue,
279 }
280 }
281
282 DataFrame::new(result_columns)
283 }
284
285 pub fn agg<F>(&self, func: F) -> DataFrame
287 where
288 F: Fn(&[usize], &Series) -> f64,
289 {
290 let mut result_columns = vec![(
291 self.by_column.clone(),
292 Series::Utf8(self.groups.keys().cloned().collect()),
293 )];
294
295 for (col_idx, col_name) in self.df.columns.iter().enumerate() {
296 if col_name == &self.by_column {
297 continue;
298 }
299
300 let mut group_results = Vec::new();
301 for key in self.groups.keys() {
302 let indices = &self.groups[key];
303 let result = func(indices, &self.df.data[col_idx]);
304 group_results.push(result);
305 }
306
307 result_columns.push((col_name.clone(), Series::Float64(group_results)));
308 }
309
310 DataFrame::new(result_columns)
311 }
312
313 pub fn first(&self) -> DataFrame {
315 let mut result_data = vec![Vec::new(); self.df.columns.len()];
316
317 for key in self.groups.keys() {
318 let first_idx = self.groups[key][0]; for (col_idx, series) in self.df.data.iter().enumerate() {
321 let value = match series {
322 Series::Int64(v) => v[first_idx].to_string(),
323 Series::Float64(v) => v[first_idx].to_string(),
324 Series::Bool(v) => v[first_idx].to_string(),
325 Series::Utf8(v) => v[first_idx].clone(),
326 };
327 result_data[col_idx].push(value);
328 }
329 }
330
331 let result_series: Vec<Series> = result_data.into_iter().map(Series::Utf8).collect();
332
333 DataFrame {
334 columns: self.df.columns.clone(),
335 data: result_series,
336 }
337 }
338
339 pub fn last(&self) -> DataFrame {
341 let mut result_data = vec![Vec::new(); self.df.columns.len()];
342
343 for key in self.groups.keys() {
344 let last_idx = *self.groups[key].last().unwrap(); for (col_idx, series) in self.df.data.iter().enumerate() {
347 let value = match series {
348 Series::Int64(v) => v[last_idx].to_string(),
349 Series::Float64(v) => v[last_idx].to_string(),
350 Series::Bool(v) => v[last_idx].to_string(),
351 Series::Utf8(v) => v[last_idx].clone(),
352 };
353 result_data[col_idx].push(value);
354 }
355 }
356
357 let result_series: Vec<Series> = result_data.into_iter().map(Series::Utf8).collect();
358
359 DataFrame {
360 columns: self.df.columns.clone(),
361 data: result_series,
362 }
363 }
364
365 pub fn size(&self) -> HashMap<String, usize> {
367 self.groups
368 .iter()
369 .map(|(k, v)| (k.clone(), v.len()))
370 .collect()
371 }
372
373 pub fn get_group(&self, key: &str) -> Option<DataFrame> {
375 if let Some(indices) = self.groups.get(key) {
376 let mask: Vec<bool> = (0..self.df.len()).map(|i| indices.contains(&i)).collect();
377 Some(self.df.filter(&mask))
378 } else {
379 None
380 }
381 }
382}
383
384impl DataFrame {
386 pub fn groupby<'a>(&'a self, by: &str) -> GroupBy<'a> {
388 GroupBy::new(self, by)
389 }
390
391 pub fn groupby_count(&self, by: &str) -> DataFrame {
393 self.groupby(by).count()
394 }
395}