Skip to main content

nexcore_dataframe/
group.rs

1//! GroupBy: hash-based grouping with aggregation.
2//!
3//! Supports the polars pattern: `df.lazy().group_by([cols]).agg([exprs])`
4//! via `df.group_by(&["col1", "col2"])?.agg(&[Agg::Sum("val"), Agg::Count])`
5
6// HashMap used for group-key accumulation during group_by. O(1) lookup is essential
7// for large DataFrames; BTreeMap O(log n) is not appropriate here.
8// Output row order is explicitly unspecified — callers sort if order matters.
9#[allow(
10    clippy::disallowed_types,
11    reason = "HashMap needed for O(1) group-key accumulation; output row order is explicitly unspecified"
12)]
13use std::collections::HashMap;
14
15use crate::column::Column;
16use crate::dataframe::DataFrame;
17use crate::error::DataFrameError;
18use crate::scalar::Scalar;
19
20/// Aggregation operation to apply within each group.
21#[derive(Debug, Clone)]
22#[non_exhaustive]
23pub enum Agg {
24    /// Sum of a numeric column.
25    Sum(String),
26    /// Mean of a numeric column.
27    Mean(String),
28    /// Minimum value.
29    Min(String),
30    /// Maximum value.
31    Max(String),
32    /// Count of rows in each group (no column needed).
33    Count,
34    /// First value of a column.
35    First(String),
36    /// Last value of a column.
37    Last(String),
38    /// Count of unique values.
39    NUnique(String),
40}
41
42/// Intermediate grouping result. Call `.agg()` to produce a DataFrame.
43#[derive(Debug)]
44pub struct GroupBy<'a> {
45    df: &'a DataFrame,
46    group_cols: Vec<String>,
47    /// Maps group key → row indices belonging to that group.
48    #[allow(
49        clippy::disallowed_types,
50        reason = "HashMap for O(1) group-key lookup; see module-level allow"
51    )]
52    groups: HashMap<Vec<String>, Vec<usize>>,
53}
54
55impl GroupBy<'_> {
56    /// Number of unique groups.
57    #[must_use]
58    pub fn n_groups(&self) -> usize {
59        self.groups.len()
60    }
61
62    /// Apply aggregations and produce a result DataFrame.
63    ///
64    /// The result has one row per group. Group key columns come first,
65    /// followed by one column per aggregation.
66    pub fn agg(&self, aggs: &[Agg]) -> Result<DataFrame, DataFrameError> {
67        let n_groups = self.groups.len();
68
69        // Pre-allocate group key columns
70        let mut key_vecs: Vec<Vec<Option<String>>> = self
71            .group_cols
72            .iter()
73            .map(|_| Vec::with_capacity(n_groups))
74            .collect();
75
76        // Pre-allocate agg result columns (as Scalar vecs)
77        let mut agg_results: Vec<Vec<Scalar>> =
78            aggs.iter().map(|_| Vec::with_capacity(n_groups)).collect();
79
80        // HashMap iteration: row order is arbitrary but consistent within this call.
81        // The output DataFrame rows are not required to be ordered.
82        #[allow(
83            clippy::iter_over_hash_type,
84            reason = "HashMap iteration builds parallel group rows; output row order is explicitly unspecified — callers sort if order matters"
85        )]
86        // Process each group
87        for (key, indices) in &self.groups {
88            // Fill key columns
89            for (i, val) in key.iter().enumerate() {
90                // i < key.len() == group_cols.len() == key_vecs.len() by construction
91                #[allow(
92                    clippy::indexing_slicing,
93                    reason = "i iterates over key positions; key.len() == group_cols.len() == key_vecs.len() by GroupBy construction"
94                )]
95                key_vecs[i].push(Some(val.clone()));
96            }
97
98            // Compute each aggregation on the group's rows
99            for (agg_idx, agg) in aggs.iter().enumerate() {
100                let result = self.compute_agg(agg, indices)?;
101                // agg_idx < aggs.len() == agg_results.len() by construction
102                #[allow(
103                    clippy::indexing_slicing,
104                    reason = "agg_idx < aggs.len() == agg_results.len(); index is valid by parallel iteration"
105                )]
106                agg_results[agg_idx].push(result);
107            }
108        }
109
110        // Build columns — i < group_cols.len() == key_vecs.len()
111        #[allow(
112            clippy::indexing_slicing,
113            reason = "i iterates over 0..group_cols.len(); key_vecs has the same length by construction"
114        )]
115        let mut columns: Vec<Column> = key_vecs
116            .into_iter()
117            .enumerate()
118            .map(|(i, data)| Column::new_string(self.group_cols[i].clone(), data))
119            .collect();
120
121        // Convert agg result scalars to typed columns
122        // agg_idx < aggs.len() == agg_results.len()
123        #[allow(
124            clippy::indexing_slicing,
125            reason = "agg_idx < aggs.len() == agg_results.len(); parallel zip ensures valid index"
126        )]
127        for (agg_idx, agg) in aggs.iter().enumerate() {
128            let col_name = agg_column_name(agg);
129            let col = scalars_to_column(&col_name, &agg_results[agg_idx]);
130            columns.push(col);
131        }
132
133        DataFrame::new(columns)
134    }
135
136    /// Compute a single aggregation over the rows at `indices`.
137    fn compute_agg(&self, agg: &Agg, indices: &[usize]) -> Result<Scalar, DataFrameError> {
138        match agg {
139            Agg::Count => {
140                // indices.len() fits u64: slice length is bounded by usize which is <= u64 on all targets
141                #[allow(
142                    clippy::as_conversions,
143                    reason = "usize→u64: on all supported platforms usize <= 64 bits, so this cast is lossless"
144                )]
145                Ok(Scalar::UInt64(indices.len() as u64))
146            }
147            Agg::Sum(col_name) => {
148                let col = self.df.column(col_name)?;
149                let sub = col.take(indices);
150                Ok(sub.sum())
151            }
152            Agg::Mean(col_name) => {
153                let col = self.df.column(col_name)?;
154                let sub = col.take(indices);
155                Ok(sub.mean())
156            }
157            Agg::Min(col_name) => {
158                let col = self.df.column(col_name)?;
159                let sub = col.take(indices);
160                Ok(sub.min())
161            }
162            Agg::Max(col_name) => {
163                let col = self.df.column(col_name)?;
164                let sub = col.take(indices);
165                Ok(sub.max())
166            }
167            Agg::First(col_name) => {
168                let col = self.df.column(col_name)?;
169                let sub = col.take(indices);
170                Ok(sub.first())
171            }
172            Agg::Last(col_name) => {
173                let col = self.df.column(col_name)?;
174                let sub = col.take(indices);
175                Ok(sub.last())
176            }
177            Agg::NUnique(col_name) => {
178                let col = self.df.column(col_name)?;
179                let sub = col.take(indices);
180                // usize→u64: lossless on all platforms where usize <= 64 bits
181                #[allow(
182                    clippy::as_conversions,
183                    reason = "usize→u64: n_unique() returns a Vec-length bounded by usize; lossless on all supported 32/64-bit platforms"
184                )]
185                Ok(Scalar::UInt64(sub.n_unique() as u64))
186            }
187        }
188    }
189}
190
191impl DataFrame {
192    /// Group the DataFrame by specified columns. Returns a `GroupBy` that
193    /// can be aggregated.
194    pub fn group_by(&self, cols: &[&str]) -> Result<GroupBy<'_>, DataFrameError> {
195        // Validate all group columns exist
196        for name in cols {
197            self.column(name)?;
198        }
199
200        let group_cols: Vec<String> = cols.iter().map(|s| (*s).to_string()).collect();
201        #[allow(
202            clippy::disallowed_types,
203            reason = "HashMap::new() for group accumulation; see module-level allow"
204        )]
205        let mut groups: HashMap<Vec<String>, Vec<usize>> = HashMap::new();
206
207        for row_idx in 0..self.height() {
208            let key: Vec<String> = cols
209                .iter()
210                .map(|name| {
211                    self.column(name)
212                        .ok()
213                        .and_then(|col| col.get(row_idx))
214                        .map_or_else(|| "null".to_string(), |s| s.to_string())
215                })
216                .collect();
217            groups.entry(key).or_default().push(row_idx);
218        }
219
220        Ok(GroupBy {
221            df: self,
222            group_cols,
223            groups,
224        })
225    }
226}
227
228/// Generate a descriptive column name for an aggregation.
229fn agg_column_name(agg: &Agg) -> String {
230    match agg {
231        Agg::Sum(c) => format!("{c}_sum"),
232        Agg::Mean(c) => format!("{c}_mean"),
233        Agg::Min(c) => format!("{c}_min"),
234        Agg::Max(c) => format!("{c}_max"),
235        Agg::Count => "count".to_string(),
236        Agg::First(c) => format!("{c}_first"),
237        Agg::Last(c) => format!("{c}_last"),
238        Agg::NUnique(c) => format!("{c}_nunique"),
239    }
240}
241
242/// Convert a vec of mixed Scalars into a Column, inferring the best type.
243fn scalars_to_column(name: &str, scalars: &[Scalar]) -> Column {
244    // Determine dominant type (first non-null)
245    let first_non_null = scalars.iter().find(|s| !s.is_null());
246
247    match first_non_null {
248        Some(Scalar::Int64(_)) => {
249            let data: Vec<Option<i64>> = scalars
250                .iter()
251                .map(|s| match s {
252                    Scalar::Int64(v) => Some(*v),
253                    Scalar::Null
254                    | Scalar::Bool(_)
255                    | Scalar::UInt64(_)
256                    | Scalar::Float64(_)
257                    | Scalar::String(_) => None,
258                })
259                .collect();
260            Column::new_i64(name, data)
261        }
262        Some(Scalar::UInt64(_)) => {
263            let data: Vec<Option<u64>> = scalars
264                .iter()
265                .map(|s| match s {
266                    Scalar::UInt64(v) => Some(*v),
267                    Scalar::Null
268                    | Scalar::Bool(_)
269                    | Scalar::Int64(_)
270                    | Scalar::Float64(_)
271                    | Scalar::String(_) => None,
272                })
273                .collect();
274            Column::new_u64(name, data)
275        }
276        Some(Scalar::Float64(_)) => {
277            let data: Vec<Option<f64>> = scalars.iter().map(|s| s.as_f64()).collect();
278            Column::new_f64(name, data)
279        }
280        Some(Scalar::Bool(_)) => {
281            let data: Vec<Option<bool>> = scalars
282                .iter()
283                .map(|s| match s {
284                    Scalar::Bool(v) => Some(*v),
285                    Scalar::Null
286                    | Scalar::Int64(_)
287                    | Scalar::UInt64(_)
288                    | Scalar::Float64(_)
289                    | Scalar::String(_) => None,
290                })
291                .collect();
292            Column::new_bool(name, data)
293        }
294        Some(Scalar::String(_)) | None => {
295            let data: Vec<Option<String>> = scalars
296                .iter()
297                .map(|s| match s {
298                    Scalar::String(v) => Some(v.clone()),
299                    Scalar::Null => None,
300                    other @ (Scalar::Bool(_)
301                    | Scalar::Int64(_)
302                    | Scalar::UInt64(_)
303                    | Scalar::Float64(_)) => Some(other.to_string()),
304                })
305                .collect();
306            Column::new_string(name, data)
307        }
308        Some(Scalar::Null) => {
309            // All nulls — default to string column
310            let data: Vec<Option<String>> = scalars.iter().map(|_| None).collect();
311            Column::new_string(name, data)
312        }
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    #[test]
321    fn group_by_count() {
322        let df = DataFrame::new(vec![
323            Column::from_strs("drug", &["asp", "met", "asp", "met", "asp"]),
324            Column::from_i64s("val", vec![1, 2, 3, 4, 5]),
325        ])
326        .unwrap_or_else(|_| unreachable!());
327
328        let gb = df.group_by(&["drug"]).unwrap_or_else(|_| unreachable!());
329        assert_eq!(gb.n_groups(), 2);
330
331        let result = gb.agg(&[Agg::Count]).unwrap_or_else(|_| unreachable!());
332        assert_eq!(result.height(), 2);
333        assert_eq!(result.width(), 2); // drug + count
334    }
335
336    #[test]
337    fn group_by_sum() {
338        let df = DataFrame::new(vec![
339            Column::from_strs("cat", &["a", "b", "a"]),
340            Column::from_i64s("val", vec![10, 20, 30]),
341        ])
342        .unwrap_or_else(|_| unreachable!());
343
344        let result = df
345            .group_by(&["cat"])
346            .unwrap_or_else(|_| unreachable!())
347            .agg(&[Agg::Sum("val".into())])
348            .unwrap_or_else(|_| unreachable!());
349
350        assert_eq!(result.height(), 2);
351        // Find the "a" group and verify sum = 40
352        for i in 0..result.height() {
353            let cat = result
354                .column("cat")
355                .unwrap_or_else(|_| unreachable!())
356                .get(i);
357            let val = result
358                .column("val_sum")
359                .unwrap_or_else(|_| unreachable!())
360                .get(i);
361            if cat == Some(Scalar::String("a".into())) {
362                assert_eq!(val, Some(Scalar::Int64(40)));
363            }
364        }
365    }
366
367    #[test]
368    fn group_by_multiple_aggs() {
369        let df = DataFrame::new(vec![
370            Column::from_strs("g", &["x", "y", "x"]),
371            Column::from_i64s("n", vec![1, 2, 3]),
372        ])
373        .unwrap_or_else(|_| unreachable!());
374
375        let result = df
376            .group_by(&["g"])
377            .unwrap_or_else(|_| unreachable!())
378            .agg(&[
379                Agg::Count,
380                Agg::Sum("n".into()),
381                Agg::Min("n".into()),
382                Agg::Max("n".into()),
383            ])
384            .unwrap_or_else(|_| unreachable!());
385
386        assert_eq!(result.height(), 2);
387        assert_eq!(result.width(), 5); // g + count + n_sum + n_min + n_max
388    }
389
390    #[test]
391    fn group_by_missing_column() {
392        let df = DataFrame::new(vec![Column::from_i64s("x", vec![1])])
393            .unwrap_or_else(|_| unreachable!());
394        assert!(df.group_by(&["missing"]).is_err());
395    }
396
397    #[test]
398    fn group_by_multi_key() {
399        let df = DataFrame::new(vec![
400            Column::from_strs("drug", &["asp", "asp", "met", "asp"]),
401            Column::from_strs("event", &["ha", "na", "ha", "ha"]),
402            Column::from_i64s("n", vec![1, 1, 1, 1]),
403        ])
404        .unwrap_or_else(|_| unreachable!());
405
406        let gb = df
407            .group_by(&["drug", "event"])
408            .unwrap_or_else(|_| unreachable!());
409        assert_eq!(gb.n_groups(), 3); // asp+ha, asp+na, met+ha
410    }
411
412    #[test]
413    fn group_by_first_last() {
414        let df = DataFrame::new(vec![
415            Column::from_strs("g", &["a", "a", "a"]),
416            Column::from_i64s("v", vec![10, 20, 30]),
417        ])
418        .unwrap_or_else(|_| unreachable!());
419
420        let result = df
421            .group_by(&["g"])
422            .unwrap_or_else(|_| unreachable!())
423            .agg(&[Agg::First("v".into()), Agg::Last("v".into())])
424            .unwrap_or_else(|_| unreachable!());
425
426        assert_eq!(result.height(), 1);
427        assert_eq!(
428            result
429                .column("v_first")
430                .unwrap_or_else(|_| unreachable!())
431                .get(0),
432            Some(Scalar::Int64(10))
433        );
434        assert_eq!(
435            result
436                .column("v_last")
437                .unwrap_or_else(|_| unreachable!())
438                .get(0),
439            Some(Scalar::Int64(30))
440        );
441    }
442}