1#[allow(
10 clippy::disallowed_types,
11 reason = "HashMap needed for O(1) group-key accumulation; output row order is explicitly unspecified"
12)]
13use std::collections::HashMap;
14
15use crate::column::Column;
16use crate::dataframe::DataFrame;
17use crate::error::DataFrameError;
18use crate::scalar::Scalar;
19
20#[derive(Debug, Clone)]
22#[non_exhaustive]
23pub enum Agg {
24 Sum(String),
26 Mean(String),
28 Min(String),
30 Max(String),
32 Count,
34 First(String),
36 Last(String),
38 NUnique(String),
40}
41
42#[derive(Debug)]
44pub struct GroupBy<'a> {
45 df: &'a DataFrame,
46 group_cols: Vec<String>,
47 #[allow(
49 clippy::disallowed_types,
50 reason = "HashMap for O(1) group-key lookup; see module-level allow"
51 )]
52 groups: HashMap<Vec<String>, Vec<usize>>,
53}
54
55impl GroupBy<'_> {
56 #[must_use]
58 pub fn n_groups(&self) -> usize {
59 self.groups.len()
60 }
61
62 pub fn agg(&self, aggs: &[Agg]) -> Result<DataFrame, DataFrameError> {
67 let n_groups = self.groups.len();
68
69 let mut key_vecs: Vec<Vec<Option<String>>> = self
71 .group_cols
72 .iter()
73 .map(|_| Vec::with_capacity(n_groups))
74 .collect();
75
76 let mut agg_results: Vec<Vec<Scalar>> =
78 aggs.iter().map(|_| Vec::with_capacity(n_groups)).collect();
79
80 #[allow(
83 clippy::iter_over_hash_type,
84 reason = "HashMap iteration builds parallel group rows; output row order is explicitly unspecified — callers sort if order matters"
85 )]
86 for (key, indices) in &self.groups {
88 for (i, val) in key.iter().enumerate() {
90 #[allow(
92 clippy::indexing_slicing,
93 reason = "i iterates over key positions; key.len() == group_cols.len() == key_vecs.len() by GroupBy construction"
94 )]
95 key_vecs[i].push(Some(val.clone()));
96 }
97
98 for (agg_idx, agg) in aggs.iter().enumerate() {
100 let result = self.compute_agg(agg, indices)?;
101 #[allow(
103 clippy::indexing_slicing,
104 reason = "agg_idx < aggs.len() == agg_results.len(); index is valid by parallel iteration"
105 )]
106 agg_results[agg_idx].push(result);
107 }
108 }
109
110 #[allow(
112 clippy::indexing_slicing,
113 reason = "i iterates over 0..group_cols.len(); key_vecs has the same length by construction"
114 )]
115 let mut columns: Vec<Column> = key_vecs
116 .into_iter()
117 .enumerate()
118 .map(|(i, data)| Column::new_string(self.group_cols[i].clone(), data))
119 .collect();
120
121 #[allow(
124 clippy::indexing_slicing,
125 reason = "agg_idx < aggs.len() == agg_results.len(); parallel zip ensures valid index"
126 )]
127 for (agg_idx, agg) in aggs.iter().enumerate() {
128 let col_name = agg_column_name(agg);
129 let col = scalars_to_column(&col_name, &agg_results[agg_idx]);
130 columns.push(col);
131 }
132
133 DataFrame::new(columns)
134 }
135
136 fn compute_agg(&self, agg: &Agg, indices: &[usize]) -> Result<Scalar, DataFrameError> {
138 match agg {
139 Agg::Count => {
140 #[allow(
142 clippy::as_conversions,
143 reason = "usize→u64: on all supported platforms usize <= 64 bits, so this cast is lossless"
144 )]
145 Ok(Scalar::UInt64(indices.len() as u64))
146 }
147 Agg::Sum(col_name) => {
148 let col = self.df.column(col_name)?;
149 let sub = col.take(indices);
150 Ok(sub.sum())
151 }
152 Agg::Mean(col_name) => {
153 let col = self.df.column(col_name)?;
154 let sub = col.take(indices);
155 Ok(sub.mean())
156 }
157 Agg::Min(col_name) => {
158 let col = self.df.column(col_name)?;
159 let sub = col.take(indices);
160 Ok(sub.min())
161 }
162 Agg::Max(col_name) => {
163 let col = self.df.column(col_name)?;
164 let sub = col.take(indices);
165 Ok(sub.max())
166 }
167 Agg::First(col_name) => {
168 let col = self.df.column(col_name)?;
169 let sub = col.take(indices);
170 Ok(sub.first())
171 }
172 Agg::Last(col_name) => {
173 let col = self.df.column(col_name)?;
174 let sub = col.take(indices);
175 Ok(sub.last())
176 }
177 Agg::NUnique(col_name) => {
178 let col = self.df.column(col_name)?;
179 let sub = col.take(indices);
180 #[allow(
182 clippy::as_conversions,
183 reason = "usize→u64: n_unique() returns a Vec-length bounded by usize; lossless on all supported 32/64-bit platforms"
184 )]
185 Ok(Scalar::UInt64(sub.n_unique() as u64))
186 }
187 }
188 }
189}
190
191impl DataFrame {
192 pub fn group_by(&self, cols: &[&str]) -> Result<GroupBy<'_>, DataFrameError> {
195 for name in cols {
197 self.column(name)?;
198 }
199
200 let group_cols: Vec<String> = cols.iter().map(|s| (*s).to_string()).collect();
201 #[allow(
202 clippy::disallowed_types,
203 reason = "HashMap::new() for group accumulation; see module-level allow"
204 )]
205 let mut groups: HashMap<Vec<String>, Vec<usize>> = HashMap::new();
206
207 for row_idx in 0..self.height() {
208 let key: Vec<String> = cols
209 .iter()
210 .map(|name| {
211 self.column(name)
212 .ok()
213 .and_then(|col| col.get(row_idx))
214 .map_or_else(|| "null".to_string(), |s| s.to_string())
215 })
216 .collect();
217 groups.entry(key).or_default().push(row_idx);
218 }
219
220 Ok(GroupBy {
221 df: self,
222 group_cols,
223 groups,
224 })
225 }
226}
227
228fn agg_column_name(agg: &Agg) -> String {
230 match agg {
231 Agg::Sum(c) => format!("{c}_sum"),
232 Agg::Mean(c) => format!("{c}_mean"),
233 Agg::Min(c) => format!("{c}_min"),
234 Agg::Max(c) => format!("{c}_max"),
235 Agg::Count => "count".to_string(),
236 Agg::First(c) => format!("{c}_first"),
237 Agg::Last(c) => format!("{c}_last"),
238 Agg::NUnique(c) => format!("{c}_nunique"),
239 }
240}
241
242fn scalars_to_column(name: &str, scalars: &[Scalar]) -> Column {
244 let first_non_null = scalars.iter().find(|s| !s.is_null());
246
247 match first_non_null {
248 Some(Scalar::Int64(_)) => {
249 let data: Vec<Option<i64>> = scalars
250 .iter()
251 .map(|s| match s {
252 Scalar::Int64(v) => Some(*v),
253 Scalar::Null
254 | Scalar::Bool(_)
255 | Scalar::UInt64(_)
256 | Scalar::Float64(_)
257 | Scalar::String(_) => None,
258 })
259 .collect();
260 Column::new_i64(name, data)
261 }
262 Some(Scalar::UInt64(_)) => {
263 let data: Vec<Option<u64>> = scalars
264 .iter()
265 .map(|s| match s {
266 Scalar::UInt64(v) => Some(*v),
267 Scalar::Null
268 | Scalar::Bool(_)
269 | Scalar::Int64(_)
270 | Scalar::Float64(_)
271 | Scalar::String(_) => None,
272 })
273 .collect();
274 Column::new_u64(name, data)
275 }
276 Some(Scalar::Float64(_)) => {
277 let data: Vec<Option<f64>> = scalars.iter().map(|s| s.as_f64()).collect();
278 Column::new_f64(name, data)
279 }
280 Some(Scalar::Bool(_)) => {
281 let data: Vec<Option<bool>> = scalars
282 .iter()
283 .map(|s| match s {
284 Scalar::Bool(v) => Some(*v),
285 Scalar::Null
286 | Scalar::Int64(_)
287 | Scalar::UInt64(_)
288 | Scalar::Float64(_)
289 | Scalar::String(_) => None,
290 })
291 .collect();
292 Column::new_bool(name, data)
293 }
294 Some(Scalar::String(_)) | None => {
295 let data: Vec<Option<String>> = scalars
296 .iter()
297 .map(|s| match s {
298 Scalar::String(v) => Some(v.clone()),
299 Scalar::Null => None,
300 other @ (Scalar::Bool(_)
301 | Scalar::Int64(_)
302 | Scalar::UInt64(_)
303 | Scalar::Float64(_)) => Some(other.to_string()),
304 })
305 .collect();
306 Column::new_string(name, data)
307 }
308 Some(Scalar::Null) => {
309 let data: Vec<Option<String>> = scalars.iter().map(|_| None).collect();
311 Column::new_string(name, data)
312 }
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn group_by_count() {
322 let df = DataFrame::new(vec![
323 Column::from_strs("drug", &["asp", "met", "asp", "met", "asp"]),
324 Column::from_i64s("val", vec![1, 2, 3, 4, 5]),
325 ])
326 .unwrap_or_else(|_| unreachable!());
327
328 let gb = df.group_by(&["drug"]).unwrap_or_else(|_| unreachable!());
329 assert_eq!(gb.n_groups(), 2);
330
331 let result = gb.agg(&[Agg::Count]).unwrap_or_else(|_| unreachable!());
332 assert_eq!(result.height(), 2);
333 assert_eq!(result.width(), 2); }
335
336 #[test]
337 fn group_by_sum() {
338 let df = DataFrame::new(vec![
339 Column::from_strs("cat", &["a", "b", "a"]),
340 Column::from_i64s("val", vec![10, 20, 30]),
341 ])
342 .unwrap_or_else(|_| unreachable!());
343
344 let result = df
345 .group_by(&["cat"])
346 .unwrap_or_else(|_| unreachable!())
347 .agg(&[Agg::Sum("val".into())])
348 .unwrap_or_else(|_| unreachable!());
349
350 assert_eq!(result.height(), 2);
351 for i in 0..result.height() {
353 let cat = result
354 .column("cat")
355 .unwrap_or_else(|_| unreachable!())
356 .get(i);
357 let val = result
358 .column("val_sum")
359 .unwrap_or_else(|_| unreachable!())
360 .get(i);
361 if cat == Some(Scalar::String("a".into())) {
362 assert_eq!(val, Some(Scalar::Int64(40)));
363 }
364 }
365 }
366
367 #[test]
368 fn group_by_multiple_aggs() {
369 let df = DataFrame::new(vec![
370 Column::from_strs("g", &["x", "y", "x"]),
371 Column::from_i64s("n", vec![1, 2, 3]),
372 ])
373 .unwrap_or_else(|_| unreachable!());
374
375 let result = df
376 .group_by(&["g"])
377 .unwrap_or_else(|_| unreachable!())
378 .agg(&[
379 Agg::Count,
380 Agg::Sum("n".into()),
381 Agg::Min("n".into()),
382 Agg::Max("n".into()),
383 ])
384 .unwrap_or_else(|_| unreachable!());
385
386 assert_eq!(result.height(), 2);
387 assert_eq!(result.width(), 5); }
389
390 #[test]
391 fn group_by_missing_column() {
392 let df = DataFrame::new(vec![Column::from_i64s("x", vec![1])])
393 .unwrap_or_else(|_| unreachable!());
394 assert!(df.group_by(&["missing"]).is_err());
395 }
396
397 #[test]
398 fn group_by_multi_key() {
399 let df = DataFrame::new(vec![
400 Column::from_strs("drug", &["asp", "asp", "met", "asp"]),
401 Column::from_strs("event", &["ha", "na", "ha", "ha"]),
402 Column::from_i64s("n", vec![1, 1, 1, 1]),
403 ])
404 .unwrap_or_else(|_| unreachable!());
405
406 let gb = df
407 .group_by(&["drug", "event"])
408 .unwrap_or_else(|_| unreachable!());
409 assert_eq!(gb.n_groups(), 3); }
411
412 #[test]
413 fn group_by_first_last() {
414 let df = DataFrame::new(vec![
415 Column::from_strs("g", &["a", "a", "a"]),
416 Column::from_i64s("v", vec![10, 20, 30]),
417 ])
418 .unwrap_or_else(|_| unreachable!());
419
420 let result = df
421 .group_by(&["g"])
422 .unwrap_or_else(|_| unreachable!())
423 .agg(&[Agg::First("v".into()), Agg::Last("v".into())])
424 .unwrap_or_else(|_| unreachable!());
425
426 assert_eq!(result.height(), 1);
427 assert_eq!(
428 result
429 .column("v_first")
430 .unwrap_or_else(|_| unreachable!())
431 .get(0),
432 Some(Scalar::Int64(10))
433 );
434 assert_eq!(
435 result
436 .column("v_last")
437 .unwrap_or_else(|_| unreachable!())
438 .get(0),
439 Some(Scalar::Int64(30))
440 );
441 }
442}