rustframes/dataframe/
groupby.rs

1use super::{DataFrame, Series};
2use std::collections::HashMap;
3
4pub struct GroupBy<'a> {
5    df: &'a DataFrame,
6    by_column: String,
7    groups: HashMap<String, Vec<usize>>,
8}
9
10impl<'a> GroupBy<'a> {
11    pub fn new(df: &'a DataFrame, by: &str) -> Self {
12        let by_column = by.to_string();
13        let col_idx = df
14            .columns
15            .iter()
16            .position(|c| c == by)
17            .expect("GroupBy column not found");
18
19        let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
20
21        match &df.data[col_idx] {
22            Series::Utf8(values) => {
23                for (idx, value) in values.iter().enumerate() {
24                    groups.entry(value.clone()).or_default().push(idx);
25                }
26            }
27            Series::Int64(values) => {
28                for (idx, &value) in values.iter().enumerate() {
29                    groups.entry(value.to_string()).or_default().push(idx);
30                }
31            }
32            Series::Float64(values) => {
33                for (idx, &value) in values.iter().enumerate() {
34                    groups.entry(value.to_string()).or_default().push(idx);
35                }
36            }
37            Series::Bool(values) => {
38                for (idx, &value) in values.iter().enumerate() {
39                    groups.entry(value.to_string()).or_default().push(idx);
40                }
41            }
42        }
43
44        GroupBy {
45            df,
46            by_column,
47            groups,
48        }
49    }
50
51    /// Count occurrences in each group
52    pub fn count(&self) -> DataFrame {
53        let mut keys = Vec::new();
54        let mut counts = Vec::new();
55
56        for (key, indices) in &self.groups {
57            keys.push(key.clone());
58            counts.push(indices.len() as i64);
59        }
60
61        DataFrame::new(vec![
62            (self.by_column.clone(), Series::Utf8(keys)),
63            ("count".to_string(), Series::Int64(counts)),
64        ])
65    }
66
67    /// Sum numeric columns by group
68    pub fn sum(&self) -> DataFrame {
69        let mut result_columns = vec![(
70            self.by_column.clone(),
71            Series::Utf8(self.groups.keys().cloned().collect()),
72        )];
73
74        for (col_idx, col_name) in self.df.columns.iter().enumerate() {
75            if col_name == &self.by_column {
76                continue; // Skip the groupby column
77            }
78
79            let mut group_sums = Vec::new();
80
81            match &self.df.data[col_idx] {
82                Series::Int64(values) => {
83                    for key in self.groups.keys() {
84                        let indices = &self.groups[key];
85                        let sum: i64 = indices.iter().map(|&i| values[i]).sum();
86                        group_sums.push(sum);
87                    }
88                    result_columns.push((col_name.clone(), Series::Int64(group_sums)));
89                }
90                Series::Float64(values) => {
91                    let mut group_sums = Vec::new();
92                    for key in self.groups.keys() {
93                        let indices = &self.groups[key];
94                        let sum: f64 = indices.iter().map(|&i| values[i]).sum();
95                        group_sums.push(sum);
96                    }
97                    result_columns.push((col_name.clone(), Series::Float64(group_sums)));
98                }
99                _ => {
100                    // Skip non-numeric columns for sum operation
101                    continue;
102                }
103            }
104        }
105
106        DataFrame::new(result_columns)
107    }
108
109    /// Mean of numeric columns by group
110    pub fn mean(&self) -> DataFrame {
111        let mut result_columns = vec![(
112            self.by_column.clone(),
113            Series::Utf8(self.groups.keys().cloned().collect()),
114        )];
115
116        for (col_idx, col_name) in self.df.columns.iter().enumerate() {
117            if col_name == &self.by_column {
118                continue;
119            }
120
121            let mut group_means = Vec::new();
122
123            match &self.df.data[col_idx] {
124                Series::Int64(values) => {
125                    for key in self.groups.keys() {
126                        let indices = &self.groups[key];
127                        let sum: i64 = indices.iter().map(|&i| values[i]).sum();
128                        let mean = sum as f64 / indices.len() as f64;
129                        group_means.push(mean);
130                    }
131                    result_columns.push((col_name.clone(), Series::Float64(group_means)));
132                }
133                Series::Float64(values) => {
134                    for key in self.groups.keys() {
135                        let indices = &self.groups[key];
136                        let sum: f64 = indices.iter().map(|&i| values[i]).sum();
137                        let mean = sum / indices.len() as f64;
138                        group_means.push(mean);
139                    }
140                    result_columns.push((col_name.clone(), Series::Float64(group_means)));
141                }
142                _ => continue,
143            }
144        }
145
146        DataFrame::new(result_columns)
147    }
148
149    /// Standard deviation of numeric columns by group
150    pub fn std(&self) -> DataFrame {
151        let mut result_columns = vec![(
152            self.by_column.clone(),
153            Series::Utf8(self.groups.keys().cloned().collect()),
154        )];
155
156        for (col_idx, col_name) in self.df.columns.iter().enumerate() {
157            if col_name == &self.by_column {
158                continue;
159            }
160
161            let mut group_stds = Vec::new();
162
163            match &self.df.data[col_idx] {
164                Series::Int64(values) => {
165                    for key in self.groups.keys() {
166                        let indices = &self.groups[key];
167                        let values_in_group: Vec<f64> =
168                            indices.iter().map(|&i| values[i] as f64).collect();
169                        let mean: f64 =
170                            values_in_group.iter().sum::<f64>() / values_in_group.len() as f64;
171                        let variance = values_in_group
172                            .iter()
173                            .map(|&x| (x - mean).powi(2))
174                            .sum::<f64>()
175                            / values_in_group.len() as f64;
176                        group_stds.push(variance.sqrt());
177                    }
178                    result_columns.push((col_name.clone(), Series::Float64(group_stds)));
179                }
180                Series::Float64(values) => {
181                    for key in self.groups.keys() {
182                        let indices = &self.groups[key];
183                        let values_in_group: Vec<f64> =
184                            indices.iter().map(|&i| values[i]).collect();
185                        let mean: f64 =
186                            values_in_group.iter().sum::<f64>() / values_in_group.len() as f64;
187                        let variance = values_in_group
188                            .iter()
189                            .map(|&x| (x - mean).powi(2))
190                            .sum::<f64>()
191                            / values_in_group.len() as f64;
192                        group_stds.push(variance.sqrt());
193                    }
194                    result_columns.push((col_name.clone(), Series::Float64(group_stds)));
195                }
196                _ => continue,
197            }
198        }
199
200        DataFrame::new(result_columns)
201    }
202
203    /// Min of numeric columns by group
204    pub fn min(&self) -> DataFrame {
205        let mut result_columns = vec![(
206            self.by_column.clone(),
207            Series::Utf8(self.groups.keys().cloned().collect()),
208        )];
209
210        for (col_idx, col_name) in self.df.columns.iter().enumerate() {
211            if col_name == &self.by_column {
212                continue;
213            }
214
215            match &self.df.data[col_idx] {
216                Series::Int64(values) => {
217                    let mut group_mins = Vec::new();
218                    for key in self.groups.keys() {
219                        let indices = &self.groups[key];
220                        let min_val = indices.iter().map(|&i| values[i]).min().unwrap_or(0);
221                        group_mins.push(min_val);
222                    }
223                    result_columns.push((col_name.clone(), Series::Int64(group_mins)));
224                }
225                Series::Float64(values) => {
226                    let mut group_mins = Vec::new();
227                    for key in self.groups.keys() {
228                        let indices = &self.groups[key];
229                        let min_val = indices
230                            .iter()
231                            .map(|&i| values[i])
232                            .fold(f64::INFINITY, |acc, x| acc.min(x));
233                        group_mins.push(min_val);
234                    }
235                    result_columns.push((col_name.clone(), Series::Float64(group_mins)));
236                }
237                _ => continue,
238            }
239        }
240
241        DataFrame::new(result_columns)
242    }
243
244    /// Max of numeric columns by group
245    pub fn max(&self) -> DataFrame {
246        let mut result_columns = vec![(
247            self.by_column.clone(),
248            Series::Utf8(self.groups.keys().cloned().collect()),
249        )];
250
251        for (col_idx, col_name) in self.df.columns.iter().enumerate() {
252            if col_name == &self.by_column {
253                continue;
254            }
255
256            match &self.df.data[col_idx] {
257                Series::Int64(values) => {
258                    let mut group_maxs = Vec::new();
259                    for key in self.groups.keys() {
260                        let indices = &self.groups[key];
261                        let max_val = indices.iter().map(|&i| values[i]).max().unwrap_or(0);
262                        group_maxs.push(max_val);
263                    }
264                    result_columns.push((col_name.clone(), Series::Int64(group_maxs)));
265                }
266                Series::Float64(values) => {
267                    let mut group_maxs = Vec::new();
268                    for key in self.groups.keys() {
269                        let indices = &self.groups[key];
270                        let max_val = indices
271                            .iter()
272                            .map(|&i| values[i])
273                            .fold(f64::NEG_INFINITY, |acc, x| acc.max(x));
274                        group_maxs.push(max_val);
275                    }
276                    result_columns.push((col_name.clone(), Series::Float64(group_maxs)));
277                }
278                _ => continue,
279            }
280        }
281
282        DataFrame::new(result_columns)
283    }
284
285    /// Apply custom aggregation function
286    pub fn agg<F>(&self, func: F) -> DataFrame
287    where
288        F: Fn(&[usize], &Series) -> f64,
289    {
290        let mut result_columns = vec![(
291            self.by_column.clone(),
292            Series::Utf8(self.groups.keys().cloned().collect()),
293        )];
294
295        for (col_idx, col_name) in self.df.columns.iter().enumerate() {
296            if col_name == &self.by_column {
297                continue;
298            }
299
300            let mut group_results = Vec::new();
301            for key in self.groups.keys() {
302                let indices = &self.groups[key];
303                let result = func(indices, &self.df.data[col_idx]);
304                group_results.push(result);
305            }
306
307            result_columns.push((col_name.clone(), Series::Float64(group_results)));
308        }
309
310        DataFrame::new(result_columns)
311    }
312
313    /// Get the first row of each group
314    pub fn first(&self) -> DataFrame {
315        let mut result_data = vec![Vec::new(); self.df.columns.len()];
316
317        for key in self.groups.keys() {
318            let first_idx = self.groups[key][0]; // Get first index in group
319
320            for (col_idx, series) in self.df.data.iter().enumerate() {
321                let value = match series {
322                    Series::Int64(v) => v[first_idx].to_string(),
323                    Series::Float64(v) => v[first_idx].to_string(),
324                    Series::Bool(v) => v[first_idx].to_string(),
325                    Series::Utf8(v) => v[first_idx].clone(),
326                };
327                result_data[col_idx].push(value);
328            }
329        }
330
331        let result_series: Vec<Series> = result_data.into_iter().map(Series::Utf8).collect();
332
333        DataFrame {
334            columns: self.df.columns.clone(),
335            data: result_series,
336        }
337    }
338
339    /// Get the last row of each group  
340    pub fn last(&self) -> DataFrame {
341        let mut result_data = vec![Vec::new(); self.df.columns.len()];
342
343        for key in self.groups.keys() {
344            let last_idx = *self.groups[key].last().unwrap(); // Get last index in group
345
346            for (col_idx, series) in self.df.data.iter().enumerate() {
347                let value = match series {
348                    Series::Int64(v) => v[last_idx].to_string(),
349                    Series::Float64(v) => v[last_idx].to_string(),
350                    Series::Bool(v) => v[last_idx].to_string(),
351                    Series::Utf8(v) => v[last_idx].clone(),
352                };
353                result_data[col_idx].push(value);
354            }
355        }
356
357        let result_series: Vec<Series> = result_data.into_iter().map(Series::Utf8).collect();
358
359        DataFrame {
360            columns: self.df.columns.clone(),
361            data: result_series,
362        }
363    }
364
365    /// Get size of each group
366    pub fn size(&self) -> HashMap<String, usize> {
367        self.groups
368            .iter()
369            .map(|(k, v)| (k.clone(), v.len()))
370            .collect()
371    }
372
373    /// Get groups as separate DataFrames
374    pub fn get_group(&self, key: &str) -> Option<DataFrame> {
375        if let Some(indices) = self.groups.get(key) {
376            let mask: Vec<bool> = (0..self.df.len()).map(|i| indices.contains(&i)).collect();
377            Some(self.df.filter(&mask))
378        } else {
379            None
380        }
381    }
382}
383
384// Add groupby method to DataFrame
385impl DataFrame {
386    /// Group DataFrame by column
387    pub fn groupby<'a>(&'a self, by: &str) -> GroupBy<'a> {
388        GroupBy::new(self, by)
389    }
390
391    /// Convenience method for groupby count (maintains backward compatibility)
392    pub fn groupby_count(&self, by: &str) -> DataFrame {
393        self.groupby(by).count()
394    }
395}