1use crate::dataframe::DataFrame;
2use crate::error::{AxionError, AxionResult};
3use crate::series::{SeriesTrait, Series};
4use crate::dtype::{DataType, DataTypeTrait};
5use std::collections::HashMap;
6use std::any::Any;
7use num_traits::Float;
8use std::fmt::Debug;
9
10use super::types::{GroupKeyValue, GroupKey};
11
12fn create_empty_series_from_dtype(name: String, dtype: DataType) -> AxionResult<Box<dyn SeriesTrait>> {
14 match dtype {
15 DataType::Int8 => Ok(Box::new(Series::<i8>::new_empty(name, dtype))),
16 DataType::Int16 => Ok(Box::new(Series::<i16>::new_empty(name, dtype))),
17 DataType::Int32 => Ok(Box::new(Series::<i32>::new_empty(name, dtype))),
18 DataType::Int64 => Ok(Box::new(Series::<i64>::new_empty(name, dtype))),
19 DataType::UInt8 => Ok(Box::new(Series::<u8>::new_empty(name, dtype))),
20 DataType::UInt16 => Ok(Box::new(Series::<u16>::new_empty(name, dtype))),
21 DataType::UInt32 => Ok(Box::new(Series::<u32>::new_empty(name, dtype))),
22 DataType::UInt64 => Ok(Box::new(Series::<u64>::new_empty(name, dtype))),
23 DataType::Float32 => Ok(Box::new(Series::<f32>::new_empty(name, dtype))),
24 DataType::Float64 => Ok(Box::new(Series::<f64>::new_empty(name, dtype))),
25 DataType::String => Ok(Box::new(Series::<String>::new_empty(name, dtype))),
26 DataType::Bool => Ok(Box::new(Series::<bool>::new_empty(name, dtype))),
27 _ => Err(AxionError::UnsupportedOperation(format!("无法为数据类型 {:?} 创建空 Series", dtype))),
28 }
29}
30
31#[derive(Debug, Clone, PartialEq)]
33enum AggValue {
34 Int8(Option<i8>),
35 Int16(Option<i16>),
36 Int32(Option<i32>),
37 Int64(Option<i64>),
38 UInt8(Option<u8>),
39 UInt16(Option<u16>),
40 UInt32(Option<u32>),
41 UInt64(Option<u64>),
42 Float32(Option<f32>),
43 Float64(Option<f64>),
44 String(Option<String>),
45 Bool(Option<bool>),
46 None, }
48
49impl AggValue {
50 fn from_option<T: 'static + Clone + Debug>(opt_val: Option<T>) -> Self {
52 match opt_val {
53 Some(val) => {
54 let any_val = &val as &dyn Any;
55 if let Some(v) = any_val.downcast_ref::<i8>() { AggValue::Int8(Some(*v)) }
56 else if let Some(v) = any_val.downcast_ref::<i16>() { AggValue::Int16(Some(*v)) }
57 else if let Some(v) = any_val.downcast_ref::<i32>() { AggValue::Int32(Some(*v)) }
58 else if let Some(v) = any_val.downcast_ref::<i64>() { AggValue::Int64(Some(*v)) }
59 else if let Some(v) = any_val.downcast_ref::<u8>() { AggValue::UInt8(Some(*v)) }
60 else if let Some(v) = any_val.downcast_ref::<u16>() { AggValue::UInt16(Some(*v)) }
61 else if let Some(v) = any_val.downcast_ref::<u32>() { AggValue::UInt32(Some(*v)) }
62 else if let Some(v) = any_val.downcast_ref::<u64>() { AggValue::UInt64(Some(*v)) }
63 else if let Some(v) = any_val.downcast_ref::<f32>() { AggValue::Float32(Some(*v)) }
64 else if let Some(v) = any_val.downcast_ref::<f64>() { AggValue::Float64(Some(*v)) }
65 else if let Some(v) = any_val.downcast_ref::<String>() { AggValue::String(Some(v.clone())) }
66 else if let Some(v) = any_val.downcast_ref::<bool>() { AggValue::Bool(Some(*v)) }
67 else {
68 eprintln!("警告: AggValue::from_option 遇到了未预期的类型: {:?}", std::any::type_name::<T>());
69 AggValue::None
70 }
71 }
72 None => AggValue::None,
73 }
74 }
75}
76
77fn calculate_min_max<T>(
79 series_trait: &dyn SeriesTrait,
80 indices: &[usize],
81 find_min: bool,
82) -> AxionResult<AggValue>
83where
84 T: DataTypeTrait + PartialOrd + Clone + Debug + 'static,
85{
86 let series = series_trait.as_any().downcast_ref::<Series<T>>()
87 .ok_or_else(|| AxionError::InternalError(format!("无法将 Series 向下转型为预期类型 {:?}", std::any::type_name::<T>())))?;
88
89 let mut current_agg: Option<T> = None;
90 for &idx in indices {
91 if let Some(val_ref) = series.get(idx) {
92 match current_agg.as_ref() {
93 Some(agg_val_ref) => {
94 if find_min {
95 if val_ref < agg_val_ref {
96 current_agg = Some(val_ref.clone());
97 }
98 } else if val_ref > agg_val_ref {
99 current_agg = Some(val_ref.clone());
100 }
101 }
102 None => {
103 current_agg = Some(val_ref.clone());
104 }
105 }
106 }
107 }
108 Ok(AggValue::from_option(current_agg))
109}
110
111fn calculate_min_max_float<T>(
113 series_trait: &dyn SeriesTrait,
114 indices: &[usize],
115 find_min: bool,
116) -> AxionResult<AggValue>
117where
118 T: DataTypeTrait + Float + Clone + Debug + 'static,
119{
120 let series = series_trait.as_any().downcast_ref::<Series<T>>()
121 .ok_or_else(|| AxionError::InternalError(format!("无法将 Series 向下转型为预期浮点类型 {:?}", std::any::type_name::<T>())))?;
122
123 let mut current_agg: Option<T> = None;
124 for &idx in indices {
125 if let Some(val_ref) = series.get(idx) {
126 if val_ref.is_nan() {
127 continue;
128 }
129 match current_agg.as_ref() {
130 Some(agg_val_ref) => {
131 if find_min {
132 if val_ref < agg_val_ref {
133 current_agg = Some(*val_ref);
134 }
135 } else if val_ref > agg_val_ref {
136 current_agg = Some(*val_ref);
137 }
138 }
139 None => {
140 current_agg = Some(*val_ref);
141 }
142 }
143 }
144 }
145 Ok(AggValue::from_option(current_agg))
146}
147
148macro_rules! dispatch_min_max {
150 ($series_trait:expr, $dtype:expr, $indices:expr, $find_min:expr) => {
151 match $dtype {
152 DataType::Int8 => calculate_min_max::<i8>($series_trait, $indices, $find_min),
153 DataType::Int16 => calculate_min_max::<i16>($series_trait, $indices, $find_min),
154 DataType::Int32 => calculate_min_max::<i32>($series_trait, $indices, $find_min),
155 DataType::Int64 => calculate_min_max::<i64>($series_trait, $indices, $find_min),
156 DataType::UInt8 => calculate_min_max::<u8>($series_trait, $indices, $find_min),
157 DataType::UInt16 => calculate_min_max::<u16>($series_trait, $indices, $find_min),
158 DataType::UInt32 => calculate_min_max::<u32>($series_trait, $indices, $find_min),
159 DataType::UInt64 => calculate_min_max::<u64>($series_trait, $indices, $find_min),
160 DataType::Float32 => calculate_min_max_float::<f32>($series_trait, $indices, $find_min),
161 DataType::Float64 => calculate_min_max_float::<f64>($series_trait, $indices, $find_min),
162 DataType::String => calculate_min_max::<String>($series_trait, $indices, $find_min),
163 DataType::Bool => calculate_min_max::<bool>($series_trait, $indices, $find_min),
164 _ => Err(AxionError::UnsupportedOperation(format!("数据类型 {:?} 不支持 Min/Max 操作", $dtype))),
165 }
166 };
167}
168
169#[derive(Debug)]
183pub struct GroupBy<'a> {
184 df: &'a DataFrame,
186 keys: Vec<String>,
188 groups: HashMap<GroupKey, Vec<usize>>,
190}
191
192impl<'a> GroupBy<'a> {
193 pub(crate) fn new(df: &'a DataFrame, keys: Vec<String>) -> AxionResult<Self> {
211 let mut key_cols: Vec<&dyn SeriesTrait> = Vec::with_capacity(keys.len());
212 for key_name in &keys {
213 let col = df.column(key_name)?;
214 key_cols.push(col);
215 match col.dtype() {
216 DataType::Int32 | DataType::String | DataType::Bool => {},
217 unsupported_dtype => {
218 return Err(AxionError::UnsupportedOperation(format!(
219 "列 '{}' 的数据类型 {:?} 不支持分组操作",
220 key_name, unsupported_dtype
221 )));
222 }
223 }
224 }
225
226 let mut groups: HashMap<GroupKey, Vec<usize>> = HashMap::new();
227 for row_idx in 0..df.height() {
228 let mut current_key: GroupKey = Vec::with_capacity(keys.len());
229 let mut has_null = false;
230
231 for key_col in &key_cols {
232 let key_value = match key_col.dtype() {
233 DataType::Int32 => {
234 let series = key_col.as_any().downcast_ref::<Series<i32>>().unwrap();
235 match series.get(row_idx) {
236 Some(v) => GroupKeyValue::Int(*v),
237 None => { has_null = true; break; }
238 }
239 }
240 DataType::String => {
241 let series = key_col.as_any().downcast_ref::<Series<String>>().unwrap();
242 match series.get(row_idx) {
243 Some(v) => GroupKeyValue::Str(v.clone()),
244 None => { has_null = true; break; }
245 }
246 }
247 DataType::Bool => {
248 let series = key_col.as_any().downcast_ref::<Series<bool>>().unwrap();
249 match series.get(row_idx) {
250 Some(v) => GroupKeyValue::Bool(*v),
251 None => { has_null = true; break; }
252 }
253 }
254 _ => unreachable!("类型检查后遇到不支持的分组类型"),
255 };
256 current_key.push(key_value);
257 }
258
259 if !has_null {
260 groups.entry(current_key).or_default().push(row_idx);
261 }
262 }
263
264 Ok(Self { df, keys, groups })
265 }
266
267 pub fn count(&self) -> AxionResult<DataFrame> {
280 let groups = &self.groups;
281
282 if groups.is_empty() {
283 let mut output_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + 1);
284 for key_name in &self.keys {
285 let original_key_col = self.df.column(key_name)?;
286 let dtype = original_key_col.dtype();
287 let empty_key_series = create_empty_series_from_dtype(key_name.clone(), dtype)?;
288 output_columns.push(empty_key_series);
289 }
290 let empty_count_series = Series::<u32>::new_empty("count".into(), DataType::UInt32);
291 output_columns.push(Box::new(empty_count_series));
292 return DataFrame::new(output_columns);
293 }
294
295 let mut key_data_vecs: Vec<Box<dyn std::any::Any>> = Vec::with_capacity(self.keys.len());
296 let mut key_dtypes: Vec<DataType> = Vec::with_capacity(self.keys.len());
297 for key_name in &self.keys {
298 let original_key_col = self.df.column(key_name)?;
299 let dtype = original_key_col.dtype();
300 key_dtypes.push(dtype.clone());
301 match dtype {
302 DataType::Int32 => key_data_vecs.push(Box::new(Vec::<Option<i32>>::new())),
303 DataType::String => key_data_vecs.push(Box::new(Vec::<Option<String>>::new())),
304 DataType::Bool => key_data_vecs.push(Box::new(Vec::<Option<bool>>::new())),
305 DataType::UInt32 => key_data_vecs.push(Box::new(Vec::<Option<u32>>::new())),
306 _ => return Err(AxionError::UnsupportedOperation(format!(
307 "列 '{}' 的数据类型 {:?} 不支持分组操作", key_name, dtype
308 ))),
309 }
310 }
311 let mut count_data_vec = Vec::<u32>::with_capacity(groups.len());
312
313 for (key, indices) in groups.iter() {
314 let key_values = key.iter();
315
316 for (i, group_key_value) in key_values.enumerate() {
317 match key_dtypes[i] {
318 DataType::Int32 => {
319 if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<i32>>>() {
320 if let GroupKeyValue::Int(val) = group_key_value {
321 vec.push(Some(*val));
322 } else { vec.push(None); }
323 }
324 }
325 DataType::String => {
326 if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<String>>>() {
327 if let GroupKeyValue::Str(val) = group_key_value {
328 vec.push(Some(val.clone()));
329 } else { vec.push(None); }
330 }
331 }
332 DataType::Bool => {
333 if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<bool>>>() {
334 if let GroupKeyValue::Bool(val) = group_key_value {
335 vec.push(Some(*val));
336 } else { vec.push(None); }
337 }
338 }
339 DataType::UInt32 => {
340 if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<u32>>>() {
341 vec.push(None);
342 }
343 }
344 _ => {}
345 }
346 }
347 count_data_vec.push(indices.len() as u32);
348 }
349
350 let mut final_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + 1);
351 for (i, key_name) in self.keys.iter().enumerate() {
352 let boxed_any = &key_data_vecs[i];
353
354 let final_key_series: Box<dyn SeriesTrait> = match key_dtypes[i] {
355 DataType::Int32 => {
356 let data_vec_ref = boxed_any.downcast_ref::<Vec<Option<i32>>>().unwrap();
357 Box::new(Series::new_from_options(key_name.clone(), data_vec_ref.clone()))
358 }
359 DataType::String => {
360 let data_vec_ref = boxed_any.downcast_ref::<Vec<Option<String>>>().unwrap();
361 Box::new(Series::new_from_options(key_name.clone(), data_vec_ref.clone()))
362 }
363 DataType::Bool => {
364 let data_vec_ref = boxed_any.downcast_ref::<Vec<Option<bool>>>().unwrap();
365 Box::new(Series::new_from_options(key_name.clone(), data_vec_ref.clone()))
366 }
367 DataType::UInt32 => {
368 let data_vec_ref = boxed_any.downcast_ref::<Vec<Option<u32>>>().unwrap();
369 Box::new(Series::new_from_options(key_name.clone(), data_vec_ref.clone()))
370 }
371 _ => unreachable!(),
372 };
373 final_columns.push(final_key_series);
374 }
375 final_columns.push(Box::new(Series::new("count".into(), count_data_vec)));
376
377 DataFrame::new(final_columns)
378 }
379
380 pub fn sum(&self) -> AxionResult<DataFrame> {
389 let groups = &self.groups;
390
391 let value_col_names: Vec<String> = self.df.columns_names()
392 .into_iter()
393 .filter(|name| !self.keys.iter().any(|k| k == *name))
394 .filter(|name| {
395 if let Ok(col) = self.df.column(name) {
396 matches!(col.dtype(), DataType::Int32 | DataType::UInt32 | DataType::Float32 | DataType::Float64)
397 } else {
398 false
399 }
400 })
401 .map(|name| name.to_string())
402 .collect();
403
404 if groups.is_empty() {
405 let mut output_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + value_col_names.len());
406 for key_name in &self.keys {
407 let original_key_col = self.df.column(key_name)?;
408 let dtype = original_key_col.dtype();
409 let empty_key_series = create_empty_series_from_dtype(key_name.clone(), dtype)?;
410 output_columns.push(empty_key_series);
411 }
412 for value_col_name in &value_col_names {
413 let original_value_col = self.df.column(value_col_name)?;
414 let dtype = original_value_col.dtype();
415 let empty_sum_series = create_empty_series_from_dtype(value_col_name.clone(), dtype)?;
416 output_columns.push(empty_sum_series);
417 }
418 return DataFrame::new(output_columns);
419 }
420
421 let mut key_data_vecs: Vec<Box<dyn std::any::Any>> = Vec::with_capacity(self.keys.len());
422 let mut key_dtypes: Vec<DataType> = Vec::with_capacity(self.keys.len());
423 for key_name in &self.keys {
424 let original_key_col = self.df.column(key_name)?;
425 let dtype = original_key_col.dtype();
426 key_dtypes.push(dtype.clone());
427 match dtype {
428 DataType::Int32 => key_data_vecs.push(Box::new(Vec::<Option<i32>>::new())),
429 DataType::String => key_data_vecs.push(Box::new(Vec::<Option<String>>::new())),
430 DataType::Bool => key_data_vecs.push(Box::new(Vec::<Option<bool>>::new())),
431 _ => return Err(AxionError::UnsupportedOperation(format!(
432 "列 '{}' 的数据类型 {:?} 不支持分组操作", key_name, dtype
433 ))),
434 }
435 }
436
437 let mut sum_data_vecs: Vec<Box<dyn std::any::Any>> = Vec::with_capacity(value_col_names.len());
438 let mut sum_dtypes: Vec<DataType> = Vec::with_capacity(value_col_names.len());
439 for value_col_name in &value_col_names {
440 let original_value_col = self.df.column(value_col_name)?;
441 let dtype = original_value_col.dtype();
442 sum_dtypes.push(dtype.clone());
443 match dtype {
444 DataType::Int32 => sum_data_vecs.push(Box::new(Vec::<Option<i32>>::new())),
445 DataType::UInt32 => sum_data_vecs.push(Box::new(Vec::<Option<u32>>::new())),
446 DataType::Float32 => sum_data_vecs.push(Box::new(Vec::<Option<f32>>::new())),
447 DataType::Float64 => sum_data_vecs.push(Box::new(Vec::<Option<f64>>::new())),
448 _ => unreachable!(),
449 }
450 }
451
452 for (key, indices) in groups.iter() {
453 let key_values = key.iter();
454 for (i, group_key_value) in key_values.enumerate() {
455 match key_dtypes[i] {
456 DataType::Int32 => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<i32>>>() { if let GroupKeyValue::Int(val) = group_key_value { vec.push(Some(*val)); } else { vec.push(None); } },
457 DataType::String => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<String>>>() { if let GroupKeyValue::Str(val) = group_key_value { vec.push(Some(val.clone())); } else { vec.push(None); } },
458 DataType::Bool => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<bool>>>() { if let GroupKeyValue::Bool(val) = group_key_value { vec.push(Some(*val)); } else { vec.push(None); } },
459 _ => {}
460 }
461 }
462
463 for (j, value_col_name) in value_col_names.iter().enumerate() {
464 let value_col = self.df.column(value_col_name)?;
465
466 match sum_dtypes[j] {
467 DataType::Int32 => {
468 let series = value_col.as_any().downcast_ref::<Series<i32>>().unwrap();
469 let mut current_sum: Option<i32> = None;
470 for &idx in indices {
471 if let Some(val) = series.get(idx) {
472 current_sum = Some(current_sum.unwrap_or(0).saturating_add(*val));
473 }
474 }
475 if let Some(vec) = sum_data_vecs[j].downcast_mut::<Vec<Option<i32>>>() {
476 vec.push(current_sum);
477 }
478 }
479 DataType::UInt32 => {
480 let series = value_col.as_any().downcast_ref::<Series<u32>>().unwrap();
481 let mut current_sum: Option<u32> = None;
482 for &idx in indices {
483 if let Some(val) = series.get(idx) {
484 current_sum = Some(current_sum.unwrap_or(0).saturating_add(*val));
485 }
486 }
487 if let Some(vec) = sum_data_vecs[j].downcast_mut::<Vec<Option<u32>>>() {
488 vec.push(current_sum);
489 }
490 }
491 DataType::Float32 => {
492 let series = value_col.as_any().downcast_ref::<Series<f32>>().unwrap();
493 let mut current_sum: Option<f32> = None;
494 for &idx in indices {
495 if let Some(val) = series.get(idx) {
496 if val.is_nan() { continue; }
497 current_sum = Some(current_sum.unwrap_or(0.0) + *val);
498 }
499 }
500 if let Some(vec) = sum_data_vecs[j].downcast_mut::<Vec<Option<f32>>>() {
501 vec.push(current_sum);
502 }
503 }
504 DataType::Float64 => {
505 let series = value_col.as_any().downcast_ref::<Series<f64>>().unwrap();
506 let mut current_sum: Option<f64> = None;
507 for &idx in indices {
508 if let Some(val) = series.get(idx) {
509 if val.is_nan() { continue; }
510 current_sum = Some(current_sum.unwrap_or(0.0) + *val);
511 }
512 }
513 if let Some(vec) = sum_data_vecs[j].downcast_mut::<Vec<Option<f64>>>() {
514 vec.push(current_sum);
515 }
516 }
517 _ => unreachable!(),
518 }
519 }
520 }
521
522 let mut final_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + value_col_names.len());
523 for (i, key_name) in self.keys.iter().enumerate() {
524 let boxed_any = &key_data_vecs[i];
525 let final_key_series: Box<dyn SeriesTrait> = match key_dtypes[i] {
526 DataType::Int32 => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<i32>>>().unwrap().clone())),
527 DataType::String => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<String>>>().unwrap().clone())),
528 DataType::Bool => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<bool>>>().unwrap().clone())),
529 _ => unreachable!(),
530 };
531 final_columns.push(final_key_series);
532 }
533 for (j, value_col_name) in value_col_names.iter().enumerate() {
534 let boxed_any = &sum_data_vecs[j];
535 let final_sum_series: Box<dyn SeriesTrait> = match sum_dtypes[j] {
536 DataType::Int32 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<i32>>>().unwrap().clone())),
537 DataType::UInt32 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<u32>>>().unwrap().clone())),
538 DataType::Float32 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<f32>>>().unwrap().clone())),
539 DataType::Float64 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<f64>>>().unwrap().clone())),
540 _ => unreachable!(),
541 };
542 final_columns.push(final_sum_series);
543 }
544
545 DataFrame::new(final_columns)
546 }
547
548 pub fn mean(&self) -> AxionResult<DataFrame> {
557 let groups = &self.groups;
558
559 let value_col_names: Vec<String> = self.df.columns_names()
560 .into_iter()
561 .filter(|name| !self.keys.iter().any(|k| k == *name))
562 .filter(|name| {
563 if let Ok(col) = self.df.column(name) {
564 col.dtype().is_numeric()
565 } else {
566 false
567 }
568 })
569 .map(|name| name.to_string())
570 .collect();
571
572 if groups.is_empty() {
573 let mut output_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + value_col_names.len());
574 for key_name in &self.keys {
575 let original_key_col = self.df.column(key_name)?;
576 let dtype = original_key_col.dtype();
577 let empty_key_series = create_empty_series_from_dtype(key_name.clone(), dtype)?;
578 output_columns.push(empty_key_series);
579 }
580 for value_col_name in &value_col_names {
581 let empty_mean_series = Series::<f64>::new_empty(value_col_name.clone(), DataType::Float64);
582 output_columns.push(Box::new(empty_mean_series));
583 }
584 return DataFrame::new(output_columns);
585 }
586
587 let mut key_data_vecs: Vec<Box<dyn std::any::Any>> = Vec::with_capacity(self.keys.len());
588 let mut key_dtypes: Vec<DataType> = Vec::with_capacity(self.keys.len());
589 for key_name in &self.keys {
590 let original_key_col = self.df.column(key_name)?;
591 let dtype = original_key_col.dtype();
592 key_dtypes.push(dtype.clone());
593 match dtype {
594 DataType::Int32 => key_data_vecs.push(Box::new(Vec::<Option<i32>>::new())),
595 DataType::String => key_data_vecs.push(Box::new(Vec::<Option<String>>::new())),
596 DataType::Bool => key_data_vecs.push(Box::new(Vec::<Option<bool>>::new())),
597 _ => return Err(AxionError::UnsupportedOperation(format!(
598 "列 '{}' 的数据类型 {:?} 不支持分组操作", key_name, dtype
599 ))),
600 }
601 }
602
603 let mut mean_data_vecs: Vec<Box<dyn std::any::Any>> = Vec::with_capacity(value_col_names.len());
604 for _ in &value_col_names {
605 mean_data_vecs.push(Box::new(Vec::<Option<f64>>::new()));
606 }
607
608 for (key, indices) in groups.iter() {
609 let key_values = key.iter();
610 for (i, group_key_value) in key_values.enumerate() {
611 match key_dtypes[i] {
612 DataType::Int32 => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<i32>>>() { if let GroupKeyValue::Int(val) = group_key_value { vec.push(Some(*val)); } else { vec.push(None); } },
613 DataType::String => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<String>>>() { if let GroupKeyValue::Str(val) = group_key_value { vec.push(Some(val.clone())); } else { vec.push(None); } },
614 DataType::Bool => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<bool>>>() { if let GroupKeyValue::Bool(val) = group_key_value { vec.push(Some(*val)); } else { vec.push(None); } },
615 _ => {}
616 }
617 }
618
619 for (j, value_col_name) in value_col_names.iter().enumerate() {
620 let value_col = self.df.column(value_col_name)?;
621 let mut current_sum: f64 = 0.0;
622 let mut current_count: u32 = 0;
623
624 for &idx in indices {
625 if let Some(value_f64) = value_col.get_as_f64(idx)? {
626 if !value_f64.is_nan() {
627 current_sum += value_f64;
628 current_count += 1;
629 }
630 }
631 }
632
633 let mean_value = if current_count > 0 {
634 Some(current_sum / current_count as f64)
635 } else {
636 None
637 };
638
639 if let Some(vec) = mean_data_vecs[j].downcast_mut::<Vec<Option<f64>>>() {
640 vec.push(mean_value);
641 }
642 }
643 }
644
645 let mut final_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + value_col_names.len());
646 for (i, key_name) in self.keys.iter().enumerate() {
647 let boxed_any = &key_data_vecs[i];
648 let final_key_series: Box<dyn SeriesTrait> = match key_dtypes[i] {
649 DataType::Int32 => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<i32>>>().unwrap().clone())),
650 DataType::String => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<String>>>().unwrap().clone())),
651 DataType::Bool => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<bool>>>().unwrap().clone())),
652 _ => unreachable!(),
653 };
654 final_columns.push(final_key_series);
655 }
656 for (j, value_col_name) in value_col_names.iter().enumerate() {
657 let boxed_any = &mean_data_vecs[j];
658 let final_mean_series = Box::new(Series::new_from_options(
659 value_col_name.clone(),
660 boxed_any.downcast_ref::<Vec<Option<f64>>>().unwrap().clone()
661 ));
662 final_columns.push(final_mean_series);
663 }
664
665 DataFrame::new(final_columns)
666 }
667
668 pub fn min(&self) -> AxionResult<DataFrame> {
678 self.aggregate_min_max(true)
679 }
680
681 pub fn max(&self) -> AxionResult<DataFrame> {
691 self.aggregate_min_max(false)
692 }
693
694 fn aggregate_min_max(&self, find_min: bool) -> AxionResult<DataFrame> {
696 let groups = &self.groups;
697
698 let value_col_names: Vec<String> = self.df.columns_names()
699 .into_iter()
700 .filter(|name| !self.keys.iter().any(|k| k == *name))
701 .filter(|name| {
702 if let Ok(col) = self.df.column(name) {
703 matches!(col.dtype(),
704 DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 |
705 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 |
706 DataType::Float32 | DataType::Float64 |
707 DataType::String |
708 DataType::Bool
709 )
710 } else {
711 false
712 }
713 })
714 .map(|name| name.to_string())
715 .collect();
716
717 if groups.is_empty() {
718 let mut output_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + value_col_names.len());
719 for key_name in &self.keys {
720 let original_key_col = self.df.column(key_name)?;
721 let dtype = original_key_col.dtype();
722 let empty_key_series = create_empty_series_from_dtype(key_name.clone(), dtype)?;
723 output_columns.push(empty_key_series);
724 }
725 for value_col_name in &value_col_names {
726 let original_value_col = self.df.column(value_col_name)?;
727 let dtype = original_value_col.dtype();
728 let empty_agg_series = create_empty_series_from_dtype(value_col_name.clone(), dtype)?;
729 output_columns.push(empty_agg_series);
730 }
731 return DataFrame::new(output_columns);
732 }
733
734 let mut key_data_vecs: Vec<Box<dyn std::any::Any>> = Vec::with_capacity(self.keys.len());
735 let mut key_dtypes: Vec<DataType> = Vec::with_capacity(self.keys.len());
736 for key_name in &self.keys {
737 let original_key_col = self.df.column(key_name)?;
738 let dtype = original_key_col.dtype();
739 key_dtypes.push(dtype.clone());
740 match dtype {
741 DataType::Int32 => key_data_vecs.push(Box::new(Vec::<Option<i32>>::new())),
742 DataType::String => key_data_vecs.push(Box::new(Vec::<Option<String>>::new())),
743 DataType::Bool => key_data_vecs.push(Box::new(Vec::<Option<bool>>::new())),
744 _ => return Err(AxionError::UnsupportedOperation(format!(
745 "列 '{}' 的数据类型 {:?} 不支持分组操作", key_name, dtype
746 ))),
747 }
748 }
749
750 let mut agg_data_vecs: Vec<Box<dyn std::any::Any>> = Vec::with_capacity(value_col_names.len());
751 let mut agg_dtypes: Vec<DataType> = Vec::with_capacity(value_col_names.len());
752 for value_col_name in &value_col_names {
753 let original_value_col = self.df.column(value_col_name)?;
754 let dtype = original_value_col.dtype();
755 agg_dtypes.push(dtype.clone());
756 match dtype {
757 DataType::Int8 => agg_data_vecs.push(Box::new(Vec::<Option<i8>>::new())),
758 DataType::Int16 => agg_data_vecs.push(Box::new(Vec::<Option<i16>>::new())),
759 DataType::Int32 => agg_data_vecs.push(Box::new(Vec::<Option<i32>>::new())),
760 DataType::Int64 => agg_data_vecs.push(Box::new(Vec::<Option<i64>>::new())),
761 DataType::UInt8 => agg_data_vecs.push(Box::new(Vec::<Option<u8>>::new())),
762 DataType::UInt16 => agg_data_vecs.push(Box::new(Vec::<Option<u16>>::new())),
763 DataType::UInt32 => agg_data_vecs.push(Box::new(Vec::<Option<u32>>::new())),
764 DataType::UInt64 => agg_data_vecs.push(Box::new(Vec::<Option<u64>>::new())),
765 DataType::Float32 => agg_data_vecs.push(Box::new(Vec::<Option<f32>>::new())),
766 DataType::Float64 => agg_data_vecs.push(Box::new(Vec::<Option<f64>>::new())),
767 DataType::String => agg_data_vecs.push(Box::new(Vec::<Option<String>>::new())),
768 DataType::Bool => agg_data_vecs.push(Box::new(Vec::<Option<bool>>::new())),
769 _ => unreachable!("应该只包含之前过滤的可比较类型"),
770 }
771 }
772
773 for (key, indices) in groups.iter() {
774 let key_values = key.iter();
775 for (i, group_key_value) in key_values.enumerate() {
776 match key_dtypes[i] {
777 DataType::Int32 => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<i32>>>() { if let GroupKeyValue::Int(val) = group_key_value { vec.push(Some(*val)); } else { vec.push(None); } },
778 DataType::String => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<String>>>() { if let GroupKeyValue::Str(val) = group_key_value { vec.push(Some(val.clone())); } else { vec.push(None); } },
779 DataType::Bool => if let Some(vec) = key_data_vecs[i].downcast_mut::<Vec<Option<bool>>>() { if let GroupKeyValue::Bool(val) = group_key_value { vec.push(Some(*val)); } else { vec.push(None); } },
780 _ => {}
781 }
782 }
783
784 for (j, value_col_name) in value_col_names.iter().enumerate() {
785 let value_col = self.df.column(value_col_name)?;
786
787 let agg_value = dispatch_min_max!(
788 value_col,
789 &agg_dtypes[j],
790 indices,
791 find_min
792 )?;
793
794 let boxed_any = &mut agg_data_vecs[j];
795 match agg_dtypes[j] {
796 DataType::Int8 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<i8>>>() { if let AggValue::Int8(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
797 DataType::Int16 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<i16>>>() { if let AggValue::Int16(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
798 DataType::Int32 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<i32>>>() { if let AggValue::Int32(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
799 DataType::Int64 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<i64>>>() { if let AggValue::Int64(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
800 DataType::UInt8 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<u8>>>() { if let AggValue::UInt8(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
801 DataType::UInt16 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<u16>>>() { if let AggValue::UInt16(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
802 DataType::UInt32 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<u32>>>() { if let AggValue::UInt32(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
803 DataType::UInt64 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<u64>>>() { if let AggValue::UInt64(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
804 DataType::Float32 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<f32>>>() { if let AggValue::Float32(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
805 DataType::Float64 => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<f64>>>() { if let AggValue::Float64(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
806 DataType::String => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<String>>>() { if let AggValue::String(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
807 DataType::Bool => if let Some(vec) = boxed_any.downcast_mut::<Vec<Option<bool>>>() { if let AggValue::Bool(opt_val) = agg_value { vec.push(opt_val); } else { vec.push(None); } },
808 _ => unreachable!(),
809 }
810 }
811 }
812
813 let mut final_columns: Vec<Box<dyn SeriesTrait>> = Vec::with_capacity(self.keys.len() + value_col_names.len());
814 for (i, key_name) in self.keys.iter().enumerate() {
815 let boxed_any = &key_data_vecs[i];
816 let final_key_series: Box<dyn SeriesTrait> = match key_dtypes[i] {
817 DataType::Int32 => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<i32>>>().unwrap().clone())),
818 DataType::String => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<String>>>().unwrap().clone())),
819 DataType::Bool => Box::new(Series::new_from_options(key_name.clone(), boxed_any.downcast_ref::<Vec<Option<bool>>>().unwrap().clone())),
820 _ => unreachable!(),
821 };
822 final_columns.push(final_key_series);
823 }
824 for (j, value_col_name) in value_col_names.iter().enumerate() {
825 let boxed_any = &agg_data_vecs[j];
826 let final_agg_series: Box<dyn SeriesTrait> = match agg_dtypes[j] {
827 DataType::Int8 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<i8>>>().unwrap().clone())),
828 DataType::Int16 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<i16>>>().unwrap().clone())),
829 DataType::Int32 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<i32>>>().unwrap().clone())),
830 DataType::Int64 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<i64>>>().unwrap().clone())),
831 DataType::UInt8 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<u8>>>().unwrap().clone())),
832 DataType::UInt16 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<u16>>>().unwrap().clone())),
833 DataType::UInt32 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<u32>>>().unwrap().clone())),
834 DataType::UInt64 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<u64>>>().unwrap().clone())),
835 DataType::Float32 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<f32>>>().unwrap().clone())),
836 DataType::Float64 => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<f64>>>().unwrap().clone())),
837 DataType::String => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<String>>>().unwrap().clone())),
838 DataType::Bool => Box::new(Series::new_from_options(value_col_name.clone(), boxed_any.downcast_ref::<Vec<Option<bool>>>().unwrap().clone())),
839 _ => unreachable!(),
840 };
841 final_columns.push(final_agg_series);
842 }
843
844 DataFrame::new(final_columns)
845 }
846}