axion_data/dataframe/
groupby.rs

1use crate::dataframe::DataFrame;
2use crate::error::{AxionError, AxionResult};
3use crate::series::{SeriesTrait, Series};
4use crate::dtype::{DataType, DataTypeTrait};
5use std::collections::HashMap;
6use std::any::Any;
7use num_traits::Float;
8use std::fmt::Debug;
9
10use super::types::{GroupKeyValue, GroupKey};
11
12/// 根据数据类型创建空的 Series
13fn create_empty_series_from_dtype(name: String, dtype: DataType) -> AxionResult<Box<dyn SeriesTrait>> {
14    match dtype {
15        DataType::Int8 => Ok(Box::new(Series::<i8>::new_empty(name, dtype))),
16        DataType::Int16 => Ok(Box::new(Series::<i16>::new_empty(name, dtype))),
17        DataType::Int32 => Ok(Box::new(Series::<i32>::new_empty(name, dtype))),
18        DataType::Int64 => Ok(Box::new(Series::<i64>::new_empty(name, dtype))),
19        DataType::UInt8 => Ok(Box::new(Series::<u8>::new_empty(name, dtype))),
20        DataType::UInt16 => Ok(Box::new(Series::<u16>::new_empty(name, dtype))),
21        DataType::UInt32 => Ok(Box::new(Series::<u32>::new_empty(name, dtype))),
22        DataType::UInt64 => Ok(Box::new(Series::<u64>::new_empty(name, dtype))),
23        DataType::Float32 => Ok(Box::new(Series::<f32>::new_empty(name, dtype))),
24        DataType::Float64 => Ok(Box::new(Series::<f64>::new_empty(name, dtype))),
25        DataType::String => Ok(Box::new(Series::<String>::new_empty(name, dtype))),
26        DataType::Bool => Ok(Box::new(Series::<bool>::new_empty(name, dtype))),
27        _ => Err(AxionError::UnsupportedOperation(format!("无法为数据类型 {:?} 创建空 Series", dtype))),
28    }
29}
30
31/// 聚合值枚举,用于表示分组聚合操作的结果
32#[derive(Debug, Clone, PartialEq)]
33enum AggValue {
34    Int8(Option<i8>), 
35    Int16(Option<i16>), 
36    Int32(Option<i32>), 
37    Int64(Option<i64>),
38    UInt8(Option<u8>), 
39    UInt16(Option<u16>), 
40    UInt32(Option<u32>), 
41    UInt64(Option<u64>),
42    Float32(Option<f32>), 
43    Float64(Option<f64>),
44    String(Option<String>), 
45    Bool(Option<bool>),
46    None, // 表示组内全为 null 或类型不匹配
47}
48
49impl AggValue {
50    /// 从 Option<T> 创建 AggValue
51    fn from_option<T: 'static + Clone + Debug>(opt_val: Option<T>) -> Self {
52        match opt_val {
53            Some(val) => {
54                let any_val = &val as &dyn Any;
55                if let Some(v) = any_val.downcast_ref::<i8>() { AggValue::Int8(Some(*v)) }
56                else if let Some(v) = any_val.downcast_ref::<i16>() { AggValue::Int16(Some(*v)) }
57                else if let Some(v) = any_val.downcast_ref::<i32>() { AggValue::Int32(Some(*v)) }
58                else if let Some(v) = any_val.downcast_ref::<i64>() { AggValue::Int64(Some(*v)) }
59                else if let Some(v) = any_val.downcast_ref::<u8>() { AggValue::UInt8(Some(*v)) }
60                else if let Some(v) = any_val.downcast_ref::<u16>() { AggValue::UInt16(Some(*v)) }
61                else if let Some(v) = any_val.downcast_ref::<u32>() { AggValue::UInt32(Some(*v)) }
62                else if let Some(v) = any_val.downcast_ref::<u64>() { AggValue::UInt64(Some(*v)) }
63                else if let Some(v) = any_val.downcast_ref::<f32>() { AggValue::Float32(Some(*v)) }
64                else if let Some(v) = any_val.downcast_ref::<f64>() { AggValue::Float64(Some(*v)) }
65                else if let Some(v) = any_val.downcast_ref::<String>() { AggValue::String(Some(v.clone())) }
66                else if let Some(v) = any_val.downcast_ref::<bool>() { AggValue::Bool(Some(*v)) }
67                else {
68                    eprintln!("警告: AggValue::from_option 遇到了未预期的类型: {:?}", std::any::type_name::<T>());
69                    AggValue::None
70                }
71            }
72            None => AggValue::None,
73        }
74    }
75}
76
77/// 计算最小值/最大值的泛型函数
78fn calculate_min_max<T>(
79    series_trait: &dyn SeriesTrait,
80    indices: &[usize],
81    find_min: bool,
82) -> AxionResult<AggValue>
83where
84    T: DataTypeTrait + PartialOrd + Clone + Debug + 'static,
85{
86    let series = series_trait.as_any().downcast_ref::<Series<T>>()
87        .ok_or_else(|| AxionError::InternalError(format!("无法将 Series 向下转型为预期类型 {:?}", std::any::type_name::<T>())))?;
88    
89    let mut current_agg: Option<T> = None;
90    for &idx in indices {
91        if let Some(val_ref) = series.get(idx) {
92             match current_agg.as_ref() {
93                 Some(agg_val_ref) => {
94                     if find_min { 
95                         if val_ref < agg_val_ref { 
96                             current_agg = Some(val_ref.clone()); 
97                         } 
98                     } else if val_ref > agg_val_ref { 
99                         current_agg = Some(val_ref.clone()); 
100                     }
101                 }
102                 None => { 
103                     current_agg = Some(val_ref.clone()); 
104                 }
105             }
106        }
107    }
108    Ok(AggValue::from_option(current_agg))
109}
110
111/// 计算浮点数最小值/最大值的泛型函数(处理 NaN)
112fn calculate_min_max_float<T>(
113    series_trait: &dyn SeriesTrait,
114    indices: &[usize],
115    find_min: bool,
116) -> AxionResult<AggValue>
117where
118    T: DataTypeTrait + Float + Clone + Debug + 'static,
119{
120     let series = series_trait.as_any().downcast_ref::<Series<T>>()
121        .ok_or_else(|| AxionError::InternalError(format!("无法将 Series 向下转型为预期浮点类型 {:?}", std::any::type_name::<T>())))?;
122    
123    let mut current_agg: Option<T> = None;
124    for &idx in indices {
125        if let Some(val_ref) = series.get(idx) {
126            if val_ref.is_nan() { 
127                continue; 
128            }
129            match current_agg.as_ref() {
130                 Some(agg_val_ref) => {
131                     if find_min { 
132                         if val_ref < agg_val_ref { 
133                             current_agg = Some(*val_ref); 
134                         } 
135                     } else if val_ref > agg_val_ref { 
136                         current_agg = Some(*val_ref); 
137                     }
138                 }
139                 None => { 
140                     current_agg = Some(*val_ref); 
141                 }
142             }
143        }
144    }
145    Ok(AggValue::from_option(current_agg))
146}
147
148/// 根据数据类型分发最小值/最大值计算的宏
149macro_rules! dispatch_min_max {
150    ($series_trait:expr, $dtype:expr, $indices:expr, $find_min:expr) => {
151        match $dtype {
152            DataType::Int8 => calculate_min_max::<i8>($series_trait, $indices, $find_min),
153            DataType::Int16 => calculate_min_max::<i16>($series_trait, $indices, $find_min),
154            DataType::Int32 => calculate_min_max::<i32>($series_trait, $indices, $find_min),
155            DataType::Int64 => calculate_min_max::<i64>($series_trait, $indices, $find_min),
156            DataType::UInt8 => calculate_min_max::<u8>($series_trait, $indices, $find_min),
157            DataType::UInt16 => calculate_min_max::<u16>($series_trait, $indices, $find_min),
158            DataType::UInt32 => calculate_min_max::<u32>($series_trait, $indices, $find_min),
159            DataType::UInt64 => calculate_min_max::<u64>($series_trait, $indices, $find_min),
160            DataType::Float32 => calculate_min_max_float::<f32>($series_trait, $indices, $find_min),
161            DataType::Float64 => calculate_min_max_float::<f64>($series_trait, $indices, $find_min),
162            DataType::String => calculate_min_max::<String>($series_trait, $indices, $find_min),
163            DataType::Bool => calculate_min_max::<bool>($series_trait, $indices, $find_min),
164            _ => Err(AxionError::UnsupportedOperation(format!("数据类型 {:?} 不支持 Min/Max 操作", $dtype))),
165        }
166    };
167}
168
169/// 表示分组操作的中间状态。
170/// 
171/// 持有对原始 DataFrame 的引用和计算出的分组索引。
172/// 可以在此基础上执行各种聚合操作,如计数、求和、平均值等。
173/// 
174/// # 示例
175/// 
176/// ```rust
177/// let grouped = df.groupby(&["类别"])?;
178/// let count_result = grouped.count()?;
179/// let sum_result = grouped.sum()?;
180/// let mean_result = grouped.mean()?;
181/// ```
182#[derive(Debug)]
183pub struct GroupBy<'a> {
184    /// 原始 DataFrame 的引用
185    df: &'a DataFrame,
186    /// 用于分组的列名
187    keys: Vec<String>,
188    /// 从分组键值到行索引的映射
189    groups: HashMap<GroupKey, Vec<usize>>,
190}
191
192impl<'a> GroupBy<'a> {
193    /// 创建新的 GroupBy 对象(内部使用,由 DataFrame::groupby 调用)
194    /// 
195    /// 根据提供的键计算分组成员关系。
196    ///
197    /// # 参数
198    /// 
199    /// * `df` - 要分组的 DataFrame 引用
200    /// * `keys` - 用于分组的列名向量
201    ///
202    /// # 返回值
203    /// 
204    /// 返回新创建的 GroupBy 对象
205    ///
206    /// # 错误
207    /// 
208    /// * `AxionError::ColumnNotFound` - 指定的分组列不存在
209    /// * `AxionError::UnsupportedOperation` - 列的数据类型不支持分组
210    pub(crate) fn new(df: &'a DataFrame, keys: Vec<String>) -> AxionResult<Self> {
211        let mut key_cols: Vec<&dyn SeriesTrait> = Vec::with_capacity(keys.len());
212        for key_name in &keys {
213            let col = df.column(key_name)?;
214            key_cols.push(col);
215            match col.dtype() {
216                DataType::Int32 | DataType::String | DataType::Bool => {},
217                unsupported_dtype => {
218                    return Err(AxionError::UnsupportedOperation(format!(
219                        "列 '{}' 的数据类型 {:?} 不支持分组操作",
220                        key_name, unsupported_dtype
221                    )));
222                }
223            }
224        }
225
226        let mut groups: HashMap<GroupKey, Vec<usize>> = HashMap::new();
227        for row_idx in 0..df.height() {
228            let mut current_key: GroupKey = Vec::with_capacity(keys.len());
229            let mut has_null = false;
230
231            for key_col in &key_cols {
232                let key_value = match key_col.dtype() {
233                    DataType::Int32 => {
234                        let series = key_col.as_any().downcast_ref::<Series<i32>>().unwrap();
235                        match series.get(row_idx) {
236                            Some(v) => GroupKeyValue::Int(*v),
237                            None => { has_null = true; break; }
238                        }
239                    }
240                    DataType::String => {
241                        let series = key_col.as_any().downcast_ref::<Series<String>>().unwrap();
242                        match series.get(row_idx) {
243                            Some(v) => GroupKeyValue::Str(v.clone()),
244                            None => { has_null = true; break; }
245                        }
246                    }
247                    DataType::Bool => {
248                        let series = key_col.as_any().downcast_ref::<Series<bool>>().unwrap();
249                        match series.get(row_idx) {
250                            Some(v) => GroupKeyValue::Bool(*v),
251                            None => { has_null = true; break; }
252                        }
253                    }
254                    _ => unreachable!("类型检查后遇到不支持的分组类型"),
255                };
256                current_key.push(key_value);
257            }
258
259            if !has_null {
260                groups.entry(current_key).or_default().push(row_idx);
261            }
262        }
263
264        Ok(Self { df, keys, groups })
265    }
266
267    /// 计算每个组的行数。
268    ///
269    /// # 返回值
270    /// 
271    /// 返回包含分组键和对应计数的新 DataFrame
272    ///
273    /// # 示例
274    /// 
275    /// ```rust
276    /// let grouped = df.groupby(&["类别"])?;
277    /// let count_df = grouped.count()?;
278    /// ```
279    pub fn count(&self) -> AxionResult<DataFrame> {
280        let groups = &self.groups;
281
282        if groups.is_empty() {
283            let mut output_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + 1);
284            for key_name in &self.keys {
285                let original_key_col = self.df.column(key_name)?;
286                let dtype = original_key_col.dtype();
287                let empty_key_series = create_empty_series_from_dtype(key_name.clone(), dtype)?;
288                output_columns.push(empty_key_series);
289            }
290            let empty_count_series = Series::<u32>::new_empty("count".into(), DataType::UInt32);
291            output_columns.push(Box::new(empty_count_series));
292            return DataFrame::new(output_columns);
293        }
294
295        let mut key_data_vecs: Vec<Box<dyn std::any::Any>> = Vec::with_capacity(self.keys.len());
296        let mut key_dtypes: Vec<DataType> = Vec::with_capacity(self.keys.len());
297        for key_name in &self.keys {
298             let original_key_col = self.df.column(key_name)?;
299             let dtype = original_key_col.dtype();
300             key_dtypes.push(dtype.clone());
301             match dtype {
302                 DataType::Int32 => key_data_vecs.push(Box::new(Vec::<Option<i32>>::new())),
303                 DataType::String => key_data_vecs.push(Box::new(Vec::<Option<String>>::new())),
304                 DataType::Bool => key_data_vecs.push(Box::new(Vec::<Option<bool>>::new())),
305                 DataType::UInt32 => key_data_vecs.push(Box::new(Vec::<Option<u32>>::new())),
306                 _ => return Err(AxionError::UnsupportedOperation(format!(
307                     "列 '{}' 的数据类型 {:?} 不支持分组操作", key_name, dtype
308                 ))),
309             }
310        }
311        let mut count_data_vec = Vec::<u32>::with_capacity(groups.len());
312
313        for (key, indices) in groups.iter() {
314            let key_values = key.iter();
315
316            for (i, group_key_value) in key_values.enumerate() {
317                match key_dtypes[i] {
318                    DataType::Int32 => {
319                        if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<i32>>>() {
320                            if let GroupKeyValue::Int(val) = group_key_value {
321                                vec.push(Some(*val));
322                            } else { vec.push(None); }
323                        }
324                    }
325                    DataType::String => {
326                         if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<String>>>() {
327                            if let GroupKeyValue::Str(val) = group_key_value {
328                                vec.push(Some(val.clone()));
329                            } else { vec.push(None); }
330                        }
331                    }
332                    DataType::Bool => {
333                         if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<bool>>>() {
334                            if let GroupKeyValue::Bool(val) = group_key_value {
335                                vec.push(Some(*val));
336                            } else { vec.push(None); }
337                        }
338                    }
339                     DataType::UInt32 => {
340                         if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<u32>>>() {
341                             vec.push(None);
342                         }
343                    }
344                    _ => {}
345                }
346            }
347            count_data_vec.push(indices.len() as u32);
348        }
349
350        let mut final_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + 1);
351        for (i, key_name) in self.keys.iter().enumerate() {
352            let boxed_any = &key_data_vecs[i];
353
354            let final_key_series: Box<dyn SeriesTrait> = match key_dtypes[i] {
355                 DataType::Int32 => {
356                     let data_vec_ref = boxed_any.downcast_ref::<Vec<Option<i32>>>().unwrap();
357                     Box::new(Series::new_from_options(key_name.clone(), data_vec_ref.clone()))
358                 }
359                 DataType::String => {
360                     let data_vec_ref = boxed_any.downcast_ref::<Vec<Option<String>>>().unwrap();
361                     Box::new(Series::new_from_options(key_name.clone(), data_vec_ref.clone()))
362                 }
363                 DataType::Bool => {
364                     let data_vec_ref = boxed_any.downcast_ref::<Vec<Option<bool>>>().unwrap();
365                     Box::new(Series::new_from_options(key_name.clone(), data_vec_ref.clone()))
366                 }
367                 DataType::UInt32 => {
368                     let data_vec_ref = boxed_any.downcast_ref::<Vec<Option<u32>>>().unwrap();
369                     Box::new(Series::new_from_options(key_name.clone(), data_vec_ref.clone()))
370                 }
371                 _ => unreachable!(),
372            };
373            final_columns.push(final_key_series);
374        }
375        final_columns.push(Box::new(Series::new("count".into(), count_data_vec)));
376
377        DataFrame::new(final_columns)
378    }
379
380    /// 计算每个组中数值列的和。
381    ///
382    /// 非数值列(不是分组键的列)将被忽略。
383    /// 组内的 null 值在求和时被忽略(空组或全 null 组的和为 0)。
384    ///
385    /// # 返回值
386    /// 
387    /// 返回包含分组键和对应求和结果的新 DataFrame
388    pub fn sum(&self) -> AxionResult<DataFrame> {
389        let groups = &self.groups;
390
391        let value_col_names: Vec<String> = self.df.columns_names()
392            .into_iter()
393            .filter(|name| !self.keys.iter().any(|k| k == *name))
394            .filter(|name| {
395                if let Ok(col) = self.df.column(name) {
396                    matches!(col.dtype(), DataType::Int32 | DataType::UInt32 | DataType::Float32 | DataType::Float64)
397                } else {
398                    false
399                }
400            })
401            .map(|name| name.to_string())
402            .collect();
403
404        if groups.is_empty() {
405            let mut output_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + value_col_names.len());
406            for key_name in &self.keys {
407                let original_key_col = self.df.column(key_name)?;
408                let dtype = original_key_col.dtype();
409                let empty_key_series = create_empty_series_from_dtype(key_name.clone(), dtype)?;
410                output_columns.push(empty_key_series);
411            }
412            for value_col_name in &value_col_names {
413                let original_value_col = self.df.column(value_col_name)?;
414                let dtype = original_value_col.dtype();
415                let empty_sum_series = create_empty_series_from_dtype(value_col_name.clone(), dtype)?;
416                output_columns.push(empty_sum_series);
417            }
418            return DataFrame::new(output_columns);
419        }
420
421        let mut key_data_vecs: Vec<Box<dyn std::any::Any>> = Vec::with_capacity(self.keys.len());
422        let mut key_dtypes: Vec<DataType> = Vec::with_capacity(self.keys.len());
423        for key_name in &self.keys {
424             let original_key_col = self.df.column(key_name)?;
425             let dtype = original_key_col.dtype();
426             key_dtypes.push(dtype.clone());
427             match dtype {
428                 DataType::Int32 => key_data_vecs.push(Box::new(Vec::<Option<i32>>::new())),
429                 DataType::String => key_data_vecs.push(Box::new(Vec::<Option<String>>::new())),
430                 DataType::Bool => key_data_vecs.push(Box::new(Vec::<Option<bool>>::new())),
431                 _ => return Err(AxionError::UnsupportedOperation(format!(
432                     "列 '{}' 的数据类型 {:?} 不支持分组操作", key_name, dtype
433                 ))),
434             }
435        }
436
437        let mut sum_data_vecs: Vec<Box<dyn std::any::Any>> = Vec::with_capacity(value_col_names.len());
438        let mut sum_dtypes: Vec<DataType> = Vec::with_capacity(value_col_names.len());
439        for value_col_name in &value_col_names {
440            let original_value_col = self.df.column(value_col_name)?;
441            let dtype = original_value_col.dtype();
442            sum_dtypes.push(dtype.clone());
443            match dtype {
444                DataType::Int32 => sum_data_vecs.push(Box::new(Vec::<Option<i32>>::new())),
445                DataType::UInt32 => sum_data_vecs.push(Box::new(Vec::<Option<u32>>::new())),
446                DataType::Float32 => sum_data_vecs.push(Box::new(Vec::<Option<f32>>::new())),
447                DataType::Float64 => sum_data_vecs.push(Box::new(Vec::<Option<f64>>::new())),
448                _ => unreachable!(),
449            }
450        }
451
452        for (key, indices) in groups.iter() {
453            let key_values = key.iter();
454            for (i, group_key_value) in key_values.enumerate() {
455                 match key_dtypes[i] {
456                    DataType::Int32 => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<i32>>>() { if let GroupKeyValue::Int(val) = group_key_value { vec.push(Some(*val)); } else { vec.push(None); } },
457                    DataType::String => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<String>>>() { if let GroupKeyValue::Str(val) = group_key_value { vec.push(Some(val.clone())); } else { vec.push(None); } },
458                    DataType::Bool => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<bool>>>() { if let GroupKeyValue::Bool(val) = group_key_value { vec.push(Some(*val)); } else { vec.push(None); } },
459                    _ => {}
460                }
461            }
462
463            for (j, value_col_name) in value_col_names.iter().enumerate() {
464                let value_col = self.df.column(value_col_name)?;
465
466                match sum_dtypes[j] {
467                    DataType::Int32 => {
468                        let series = value_col.as_any().downcast_ref::<Series<i32>>().unwrap();
469                        let mut current_sum: Option<i32> = None;
470                        for &idx in indices {
471                            if let Some(val) = series.get(idx) {
472                                current_sum = Some(current_sum.unwrap_or(0).saturating_add(*val));
473                            }
474                        }
475                        if let Some(vec) = sum_data_vecs[j].downcast_mut::<Vec<Option<i32>>>() {
476                            vec.push(current_sum);
477                        }
478                    }
479                    DataType::UInt32 => {
480                        let series = value_col.as_any().downcast_ref::<Series<u32>>().unwrap();
481                        let mut current_sum: Option<u32> = None;
482                        for &idx in indices {
483                            if let Some(val) = series.get(idx) {
484                                current_sum = Some(current_sum.unwrap_or(0).saturating_add(*val));
485                            }
486                        }
487                        if let Some(vec) = sum_data_vecs[j].downcast_mut::<Vec<Option<u32>>>() {
488                            vec.push(current_sum);
489                        }
490                    }
491                    DataType::Float32 => {
492                        let series = value_col.as_any().downcast_ref::<Series<f32>>().unwrap();
493                        let mut current_sum: Option<f32> = None;
494                        for &idx in indices {
495                            if let Some(val) = series.get(idx) {
496                                if val.is_nan() { continue; }
497                                current_sum = Some(current_sum.unwrap_or(0.0) + *val);
498                            }
499                        }
500                        if let Some(vec) = sum_data_vecs[j].downcast_mut::<Vec<Option<f32>>>() {
501                            vec.push(current_sum);
502                        }
503                    }
504                    DataType::Float64 => {
505                        let series = value_col.as_any().downcast_ref::<Series<f64>>().unwrap();
506                        let mut current_sum: Option<f64> = None;
507                        for &idx in indices {
508                            if let Some(val) = series.get(idx) {
509                                if val.is_nan() { continue; }
510                                current_sum = Some(current_sum.unwrap_or(0.0) + *val);
511                            }
512                        }
513                        if let Some(vec) = sum_data_vecs[j].downcast_mut::<Vec<Option<f64>>>() {
514                            vec.push(current_sum);
515                        }
516                    }
517                    _ => unreachable!(),
518                }
519            }
520        }
521
522        let mut final_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + value_col_names.len());
523        for (i, key_name) in self.keys.iter().enumerate() {
524            let boxed_any = &key_data_vecs[i];
525            let final_key_series: Box<dyn SeriesTrait> = match key_dtypes[i] {
526                 DataType::Int32 => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<i32>>>().unwrap().clone())),
527                 DataType::String => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<String>>>().unwrap().clone())),
528                 DataType::Bool => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<bool>>>().unwrap().clone())),
529                 _ => unreachable!(),
530            };
531            final_columns.push(final_key_series);
532        }
533        for (j, value_col_name) in value_col_names.iter().enumerate() {
534            let boxed_any = &sum_data_vecs[j];
535            let final_sum_series: Box<dyn SeriesTrait> = match sum_dtypes[j] {
536                 DataType::Int32 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<i32>>>().unwrap().clone())),
537                 DataType::UInt32 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<u32>>>().unwrap().clone())),
538                 DataType::Float32 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<f32>>>().unwrap().clone())),
539                 DataType::Float64 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<f64>>>().unwrap().clone())),
540                 _ => unreachable!(),
541            };
542            final_columns.push(final_sum_series);
543        }
544
545        DataFrame::new(final_columns)
546    }
547
548    /// 计算每个组中数值列的平均值。
549    ///
550    /// 非数值列(不是分组键的列)将被忽略。
551    /// 组内的 null 值在计算时被忽略(空组或全 null 组的平均值为 null)。
552    ///
553    /// # 返回值
554    /// 
555    /// 返回包含分组键和对应平均值的新 DataFrame,平均值列的类型为 f64
556    pub fn mean(&self) -> AxionResult<DataFrame> {
557        let groups = &self.groups;
558
559        let value_col_names: Vec<String> = self.df.columns_names()
560            .into_iter()
561            .filter(|name| !self.keys.iter().any(|k| k == *name))
562            .filter(|name| {
563                if let Ok(col) = self.df.column(name) {
564                    col.dtype().is_numeric()
565                } else {
566                    false
567                }
568            })
569            .map(|name| name.to_string())
570            .collect();
571
572        if groups.is_empty() {
573            let mut output_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + value_col_names.len());
574            for key_name in &self.keys {
575                let original_key_col = self.df.column(key_name)?;
576                let dtype = original_key_col.dtype();
577                let empty_key_series = create_empty_series_from_dtype(key_name.clone(), dtype)?;
578                output_columns.push(empty_key_series);
579            }
580            for value_col_name in &value_col_names {
581                 let empty_mean_series = Series::<f64>::new_empty(value_col_name.clone(), DataType::Float64);
582                 output_columns.push(Box::new(empty_mean_series));
583            }
584            return DataFrame::new(output_columns);
585        }
586
587        let mut key_data_vecs: Vec<Box<dyn std::any::Any>> = Vec::with_capacity(self.keys.len());
588        let mut key_dtypes: Vec<DataType> = Vec::with_capacity(self.keys.len());
589        for key_name in &self.keys {
590             let original_key_col = self.df.column(key_name)?;
591             let dtype = original_key_col.dtype();
592             key_dtypes.push(dtype.clone());
593             match dtype {
594                 DataType::Int32 => key_data_vecs.push(Box::new(Vec::<Option<i32>>::new())),
595                 DataType::String => key_data_vecs.push(Box::new(Vec::<Option<String>>::new())),
596                 DataType::Bool => key_data_vecs.push(Box::new(Vec::<Option<bool>>::new())),
597                 _ => return Err(AxionError::UnsupportedOperation(format!(
598                     "列 '{}' 的数据类型 {:?} 不支持分组操作", key_name, dtype
599                 ))),
600             }
601        }
602
603        let mut mean_data_vecs: Vec<Box<dyn std::any::Any>> = Vec::with_capacity(value_col_names.len());
604        for _ in &value_col_names {
605            mean_data_vecs.push(Box::new(Vec::<Option<f64>>::new()));
606        }
607
608        for (key, indices) in groups.iter() {
609            let key_values = key.iter();
610            for (i, group_key_value) in key_values.enumerate() {
611                 match key_dtypes[i] {
612                    DataType::Int32 => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<i32>>>() { if let GroupKeyValue::Int(val) = group_key_value { vec.push(Some(*val)); } else { vec.push(None); } },
613                    DataType::String => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<String>>>() { if let GroupKeyValue::Str(val) = group_key_value { vec.push(Some(val.clone())); } else { vec.push(None); } },
614                    DataType::Bool => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<bool>>>() { if let GroupKeyValue::Bool(val) = group_key_value { vec.push(Some(*val)); } else { vec.push(None); } },
615                    _ => {}
616                }
617            }
618
619            for (j, value_col_name) in value_col_names.iter().enumerate() {
620                let value_col = self.df.column(value_col_name)?;
621                let mut current_sum: f64 = 0.0;
622                let mut current_count: u32 = 0;
623
624                for &idx in indices {
625                    if let Some(value_f64) = value_col.get_as_f64(idx)? {
626                        if !value_f64.is_nan() {
627                            current_sum += value_f64;
628                            current_count += 1;
629                        }
630                    }
631                }
632
633                let mean_value = if current_count > 0 {
634                    Some(current_sum / current_count as f64)
635                } else {
636                    None
637                };
638
639                if let Some(vec) = mean_data_vecs[j].downcast_mut::<Vec<Option<f64>>>() {
640                    vec.push(mean_value);
641                }
642            }
643        }
644
645        let mut final_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + value_col_names.len());
646        for (i, key_name) in self.keys.iter().enumerate() {
647            let boxed_any = &key_data_vecs[i];
648            let final_key_series: Box<dyn SeriesTrait> = match key_dtypes[i] {
649                 DataType::Int32 => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<i32>>>().unwrap().clone())),
650                 DataType::String => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<String>>>().unwrap().clone())),
651                 DataType::Bool => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<bool>>>().unwrap().clone())),
652                 _ => unreachable!(),
653            };
654            final_columns.push(final_key_series);
655        }
656        for (j, value_col_name) in value_col_names.iter().enumerate() {
657            let boxed_any = &mean_data_vecs[j];
658            let final_mean_series = Box::new(Series::new_from_options(
659                value_col_name.clone(),
660                boxed_any.downcast_ref::<Vec<Option<f64>>>().unwrap().clone()
661            ));
662            final_columns.push(final_mean_series);
663        }
664
665        DataFrame::new(final_columns)
666    }
667
668    /// 计算每个组中可比较列的最小值。
669    ///
670    /// 非可比较列(如 List)和分组键列将被忽略。
671    /// 计算中会忽略 null 值。
672    /// 结果列的类型将与原始列相同。
673    ///
674    /// # 返回值
675    /// 
676    /// 返回包含分组键和对应最小值的新 DataFrame
677    pub fn min(&self) -> AxionResult<DataFrame> {
678        self.aggregate_min_max(true)
679    }
680
681    /// 计算每个组中可比较列的最大值。
682    ///
683    /// 非可比较列(如 List)和分组键列将被忽略。
684    /// 计算中会忽略 null 值。
685    /// 结果列的类型将与原始列相同。
686    ///
687    /// # 返回值
688    /// 
689    /// 返回包含分组键和对应最大值的新 DataFrame
690    pub fn max(&self) -> AxionResult<DataFrame> {
691        self.aggregate_min_max(false)
692    }
693
694    /// 内部辅助函数,处理 min 和 max 的通用逻辑
695    fn aggregate_min_max(&self, find_min: bool) -> AxionResult<DataFrame> {
696        let groups = &self.groups;
697
698        let value_col_names: Vec<String> = self.df.columns_names()
699            .into_iter()
700            .filter(|name| !self.keys.iter().any(|k| k == *name))
701            .filter(|name| {
702                if let Ok(col) = self.df.column(name) {
703                    matches!(col.dtype(),
704                        DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 |
705                        DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 |
706                        DataType::Float32 | DataType::Float64 |
707                        DataType::String |
708                        DataType::Bool
709                    )
710                } else {
711                    false
712                }
713            })
714            .map(|name| name.to_string())
715            .collect();
716
717        if groups.is_empty() {
718            let mut output_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + value_col_names.len());
719            for key_name in &self.keys {
720                let original_key_col = self.df.column(key_name)?;
721                let dtype = original_key_col.dtype();
722                let empty_key_series = create_empty_series_from_dtype(key_name.clone(), dtype)?;
723                output_columns.push(empty_key_series);
724            }
725            for value_col_name in &value_col_names {
726                let original_value_col = self.df.column(value_col_name)?;
727                let dtype = original_value_col.dtype();
728                let empty_agg_series = create_empty_series_from_dtype(value_col_name.clone(), dtype)?;
729                output_columns.push(empty_agg_series);
730            }
731            return DataFrame::new(output_columns);
732        }
733
734        let mut key_data_vecs: Vec<Box<dyn std::any::Any>> = Vec::with_capacity(self.keys.len());
735        let mut key_dtypes: Vec<DataType> = Vec::with_capacity(self.keys.len());
736        for key_name in &self.keys {
737             let original_key_col = self.df.column(key_name)?;
738             let dtype = original_key_col.dtype();
739             key_dtypes.push(dtype.clone());
740             match dtype {
741                 DataType::Int32 => key_data_vecs.push(Box::new(Vec::<Option<i32>>::new())),
742                 DataType::String => key_data_vecs.push(Box::new(Vec::<Option<String>>::new())),
743                 DataType::Bool => key_data_vecs.push(Box::new(Vec::<Option<bool>>::new())),
744                 _ => return Err(AxionError::UnsupportedOperation(format!(
745                     "列 '{}' 的数据类型 {:?} 不支持分组操作", key_name, dtype
746                 ))),
747             }
748        }
749
750        let mut agg_data_vecs: Vec<Box<dyn std::any::Any>> = Vec::with_capacity(value_col_names.len());
751        let mut agg_dtypes: Vec<DataType> = Vec::with_capacity(value_col_names.len());
752        for value_col_name in &value_col_names {
753            let original_value_col = self.df.column(value_col_name)?;
754            let dtype = original_value_col.dtype();
755            agg_dtypes.push(dtype.clone());
756            match dtype {
757                DataType::Int8 => agg_data_vecs.push(Box::new(Vec::<Option<i8>>::new())),
758                DataType::Int16 => agg_data_vecs.push(Box::new(Vec::<Option<i16>>::new())),
759                DataType::Int32 => agg_data_vecs.push(Box::new(Vec::<Option<i32>>::new())),
760                DataType::Int64 => agg_data_vecs.push(Box::new(Vec::<Option<i64>>::new())),
761                DataType::UInt8 => agg_data_vecs.push(Box::new(Vec::<Option<u8>>::new())),
762                DataType::UInt16 => agg_data_vecs.push(Box::new(Vec::<Option<u16>>::new())),
763                DataType::UInt32 => agg_data_vecs.push(Box::new(Vec::<Option<u32>>::new())),
764                DataType::UInt64 => agg_data_vecs.push(Box::new(Vec::<Option<u64>>::new())),
765                DataType::Float32 => agg_data_vecs.push(Box::new(Vec::<Option<f32>>::new())),
766                DataType::Float64 => agg_data_vecs.push(Box::new(Vec::<Option<f64>>::new())),
767                DataType::String => agg_data_vecs.push(Box::new(Vec::<Option<String>>::new())),
768                DataType::Bool => agg_data_vecs.push(Box::new(Vec::<Option<bool>>::new())),
769                _ => unreachable!("应该只包含之前过滤的可比较类型"),
770            }
771        }
772
773        for (key, indices) in groups.iter() {
774            let key_values = key.iter();
775            for (i, group_key_value) in key_values.enumerate() {
776                 match key_dtypes[i] {
777                    DataType::Int32 => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<i32>>>() { if let GroupKeyValue::Int(val) = group_key_value { vec.push(Some(*val)); } else { vec.push(None); } },
778                    DataType::String => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<String>>>() { if let GroupKeyValue::Str(val) = group_key_value { vec.push(Some(val.clone())); } else { vec.push(None); } },
779                    DataType::Bool => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<bool>>>() { if let GroupKeyValue::Bool(val) = group_key_value { vec.push(Some(*val)); } else { vec.push(None); } },
780                    _ => {}
781                }
782            }
783
784            for (j, value_col_name) in value_col_names.iter().enumerate() {
785                let value_col = self.df.column(value_col_name)?;
786
787                let agg_value = dispatch_min_max!(
788                    value_col,
789                    &agg_dtypes[j],
790                    indices,
791                    find_min
792                )?;
793
794                let boxed_any = &mut agg_data_vecs[j];
795                match agg_dtypes[j] {
796                    DataType::Int8 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<i8>>>() { if let AggValue::Int8(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
797                    DataType::Int16 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<i16>>>() { if let AggValue::Int16(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
798                    DataType::Int32 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<i32>>>() { if let AggValue::Int32(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
799                    DataType::Int64 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<i64>>>() { if let AggValue::Int64(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
800                    DataType::UInt8 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<u8>>>() { if let AggValue::UInt8(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
801                    DataType::UInt16 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<u16>>>() { if let AggValue::UInt16(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
802                    DataType::UInt32 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<u32>>>() { if let AggValue::UInt32(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
803                    DataType::UInt64 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<u64>>>() { if let AggValue::UInt64(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
804                    DataType::Float32 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<f32>>>() { if let AggValue::Float32(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
805                    DataType::Float64 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<f64>>>() { if let AggValue::Float64(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
806                    DataType::String => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<String>>>() { if let AggValue::String(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
807                    DataType::Bool => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<bool>>>() { if let AggValue::Bool(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
808                    _ => unreachable!(),
809                }
810            }
811        }
812
813        let mut final_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + value_col_names.len());
814        for (i, key_name) in self.keys.iter().enumerate() {
815            let boxed_any = &key_data_vecs[i];
816            let final_key_series: Box<dyn SeriesTrait> = match key_dtypes[i] {
817                 DataType::Int32 => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<i32>>>().unwrap().clone())),
818                 DataType::String => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<String>>>().unwrap().clone())),
819                 DataType::Bool => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<bool>>>().unwrap().clone())),
820                 _ => unreachable!(),
821            };
822            final_columns.push(final_key_series);
823        }
824        for (j, value_col_name) in value_col_names.iter().enumerate() {
825            let boxed_any = &agg_data_vecs[j];
826            let final_agg_series: Box<dyn SeriesTrait> = match agg_dtypes[j] {
827                 DataType::Int8 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<i8>>>().unwrap().clone())),
828                 DataType::Int16 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<i16>>>().unwrap().clone())),
829                 DataType::Int32 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<i32>>>().unwrap().clone())),
830                 DataType::Int64 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<i64>>>().unwrap().clone())),
831                 DataType::UInt8 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<u8>>>().unwrap().clone())),
832                 DataType::UInt16 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<u16>>>().unwrap().clone())),
833                 DataType::UInt32 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<u32>>>().unwrap().clone())),
834                 DataType::UInt64 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<u64>>>().unwrap().clone())),
835                 DataType::Float32 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<f32>>>().unwrap().clone())),
836                 DataType::Float64 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<f64>>>().unwrap().clone())),
837                 DataType::String => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<String>>>().unwrap().clone())),
838                 DataType::Bool => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<bool>>>().unwrap().clone())),
839                 _ => unreachable!(),
840            };
841            final_columns.push(final_agg_series);
842        }
843
844        DataFrame::new(final_columns)
845    }
846}