1use crate::column::{Column, ColumnTrait};
7use crate::error::Result;
8use crate::optimized::split_dataframe::OptimizedDataFrame;
9use crate::stats::{
10 self, AnovaResult, ChiSquareResult, DescriptiveStats, LinearRegressionResult,
11 MannWhitneyResult, TTestResult,
12};
13use std::collections::HashMap;
14
15#[derive(Debug, Clone)]
17pub enum StatResult {
18 Descriptive(DescriptiveStats),
20 TTest(TTestResult),
22 Anova(AnovaResult),
24 MannWhitneyU(MannWhitneyResult),
26 ChiSquare(ChiSquareResult),
28 LinearRegression(LinearRegressionResult),
30}
31
32#[derive(Debug, Clone)]
34pub struct StatDescribe {
35 pub stats: HashMap<String, f64>,
37 pub stats_list: Vec<(String, f64)>,
39}
40
41impl OptimizedDataFrame {
43 pub fn describe(&self, column_name: &str) -> Result<StatDescribe> {
51 let col = self.column(column_name)?;
52
53 if let Some(float_col) = col.as_float64() {
54 let values: Vec<f64> = (0..self.row_count())
56 .filter_map(|i| float_col.get(i).ok().flatten())
57 .collect();
58
59 let stats = stats::describe(&values)?;
61
62 let mut result = HashMap::new();
64 result.insert("count".to_string(), stats.count as f64);
65 result.insert("mean".to_string(), stats.mean);
66 result.insert("std".to_string(), stats.std);
67 result.insert("min".to_string(), stats.min);
68 result.insert("25%".to_string(), stats.q1);
69 result.insert("50%".to_string(), stats.median);
70 result.insert("75%".to_string(), stats.q3);
71 result.insert("max".to_string(), stats.max);
72
73 let stats_list = vec![
75 ("count".to_string(), stats.count as f64),
76 ("mean".to_string(), stats.mean),
77 ("std".to_string(), stats.std),
78 ("min".to_string(), stats.min),
79 ("25%".to_string(), stats.q1),
80 ("50%".to_string(), stats.median),
81 ("75%".to_string(), stats.q3),
82 ("max".to_string(), stats.max),
83 ];
84
85 let mut result = HashMap::new();
86 result.insert("count".to_string(), stats.count as f64);
87 result.insert("mean".to_string(), stats.mean);
88 result.insert("std".to_string(), stats.std);
89 result.insert("min".to_string(), stats.min);
90 result.insert("25%".to_string(), stats.q1);
91 result.insert("50%".to_string(), stats.median);
92 result.insert("75%".to_string(), stats.q3);
93 result.insert("max".to_string(), stats.max);
94
95 Ok(StatDescribe {
96 stats: result,
97 stats_list,
98 })
99 } else if let Some(int_col) = col.as_int64() {
100 let values: Vec<f64> = (0..self.row_count())
102 .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
103 .collect();
104
105 let stats = stats::describe(&values)?;
107
108 let mut result = HashMap::new();
110 result.insert("count".to_string(), stats.count as f64);
111 result.insert("mean".to_string(), stats.mean);
112 result.insert("std".to_string(), stats.std);
113 result.insert("min".to_string(), stats.min);
114 result.insert("25%".to_string(), stats.q1);
115 result.insert("50%".to_string(), stats.median);
116 result.insert("75%".to_string(), stats.q3);
117 result.insert("max".to_string(), stats.max);
118
119 let stats_list = vec![
121 ("count".to_string(), stats.count as f64),
122 ("mean".to_string(), stats.mean),
123 ("std".to_string(), stats.std),
124 ("min".to_string(), stats.min),
125 ("25%".to_string(), stats.q1),
126 ("50%".to_string(), stats.median),
127 ("75%".to_string(), stats.q3),
128 ("max".to_string(), stats.max),
129 ];
130
131 let mut result = HashMap::new();
132 result.insert("count".to_string(), stats.count as f64);
133 result.insert("mean".to_string(), stats.mean);
134 result.insert("std".to_string(), stats.std);
135 result.insert("min".to_string(), stats.min);
136 result.insert("25%".to_string(), stats.q1);
137 result.insert("50%".to_string(), stats.median);
138 result.insert("75%".to_string(), stats.q3);
139 result.insert("max".to_string(), stats.max);
140
141 Ok(StatDescribe {
142 stats: result,
143 stats_list,
144 })
145 } else {
146 Err(crate::error::Error::Type(format!(
147 "Column '{}' is not a numeric type",
148 column_name
149 )))
150 }
151 }
152
153 pub fn describe_all(&self) -> Result<HashMap<String, StatDescribe>> {
158 let mut results = HashMap::new();
159
160 for col_name in self.column_names() {
161 let col = self.column(col_name)?;
163 if col.as_float64().is_some() || col.as_int64().is_some() {
164 if let Ok(desc) = self.describe(col_name) {
165 results.insert(col_name.to_string(), desc);
166 }
167 }
168 }
169
170 Ok(results)
171 }
172
173 pub fn ttest(
184 &self,
185 col1: &str,
186 col2: &str,
187 alpha: Option<f64>,
188 equal_var: Option<bool>,
189 ) -> Result<TTestResult> {
190 let alpha = alpha.unwrap_or(0.05);
191 let equal_var = equal_var.unwrap_or(true);
192
193 let column1 = self.column(col1)?;
195 let column2 = self.column(col2)?;
196
197 let values1: Vec<f64> = match column1 {
199 col if col.as_float64().is_some() => {
200 let float_col = col.as_float64().ok_or_else(|| {
201 crate::error::Error::TypeMismatch("column type check failed for Float64".into())
202 })?;
203 (0..self.row_count())
204 .filter_map(|i| float_col.get(i).ok().flatten())
205 .collect()
206 }
207 col if col.as_int64().is_some() => {
208 let int_col = col.as_int64().ok_or_else(|| {
209 crate::error::Error::TypeMismatch("column type check failed for Int64".into())
210 })?;
211 (0..self.row_count())
212 .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
213 .collect()
214 }
215 _ => {
216 return Err(crate::error::Error::Type(format!(
217 "Column '{}' is not a numeric type",
218 col1
219 )))
220 }
221 };
222
223 let values2: Vec<f64> = match column2 {
224 col if col.as_float64().is_some() => {
225 let float_col = col.as_float64().ok_or_else(|| {
226 crate::error::Error::TypeMismatch("column type check failed for Float64".into())
227 })?;
228 (0..self.row_count())
229 .filter_map(|i| float_col.get(i).ok().flatten())
230 .collect()
231 }
232 col if col.as_int64().is_some() => {
233 let int_col = col.as_int64().ok_or_else(|| {
234 crate::error::Error::TypeMismatch("column type check failed for Int64".into())
235 })?;
236 (0..self.row_count())
237 .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
238 .collect()
239 }
240 _ => {
241 return Err(crate::error::Error::Type(format!(
242 "Column '{}' is not a numeric type",
243 col2
244 )))
245 }
246 };
247
248 stats::ttest(&values1, &values2, alpha, equal_var)
250 }
251
252 pub fn anova(
262 &self,
263 value_col: &str,
264 group_col: &str,
265 alpha: Option<f64>,
266 ) -> Result<AnovaResult> {
267 let alpha = alpha.unwrap_or(0.05);
268
269 let value_column = self.column(value_col)?;
271
272 let group_column = self.column(group_col)?;
274 let group_col_string = group_column.as_string().ok_or_else(|| {
275 crate::error::Error::Type(format!("Column '{}' must be a string type", group_col))
276 })?;
277
278 let values: Vec<(f64, String)> = match value_column {
280 col if col.as_float64().is_some() => {
281 let float_col = col.as_float64().ok_or_else(|| {
282 crate::error::Error::TypeMismatch("column type check failed for Float64".into())
283 })?;
284 (0..self.row_count())
285 .filter_map(|i| {
286 let val = float_col.get(i).ok().flatten()?;
287 let group = group_col_string.get(i).ok().flatten()?;
288 Some((val, group.to_string()))
289 })
290 .collect()
291 }
292 col if col.as_int64().is_some() => {
293 let int_col = col.as_int64().ok_or_else(|| {
294 crate::error::Error::TypeMismatch("column type check failed for Int64".into())
295 })?;
296 (0..self.row_count())
297 .filter_map(|i| {
298 let val = int_col.get(i).ok().flatten()? as f64;
299 let group = group_col_string.get(i).ok().flatten()?;
300 Some((val, group.to_string()))
301 })
302 .collect()
303 }
304 _ => {
305 return Err(crate::error::Error::Type(format!(
306 "Column '{}' is not a numeric type",
307 value_col
308 )))
309 }
310 };
311
312 let mut groups: HashMap<String, Vec<f64>> = HashMap::new();
314 for (val, group) in values {
315 groups.entry(group).or_insert_with(Vec::new).push(val);
316 }
317
318 if groups.len() < 2 {
320 return Err(crate::error::Error::InsufficientData(
321 "ANOVA requires at least 2 groups".to_string(),
322 ));
323 }
324
325 let str_groups: HashMap<&str, Vec<f64>> = groups
327 .iter()
328 .map(|(k, v)| (k.as_str(), v.clone()))
329 .collect();
330
331 stats::anova(&str_groups, alpha)
333 }
334
335 pub fn mann_whitney_u(
345 &self,
346 col1: &str,
347 col2: &str,
348 alpha: Option<f64>,
349 ) -> Result<MannWhitneyResult> {
350 let alpha = alpha.unwrap_or(0.05);
351
352 let column1 = self.column(col1)?;
354 let column2 = self.column(col2)?;
355
356 let values1: Vec<f64> = match column1 {
358 col if col.as_float64().is_some() => {
359 let float_col = col.as_float64().ok_or_else(|| {
360 crate::error::Error::TypeMismatch("column type check failed for Float64".into())
361 })?;
362 (0..self.row_count())
363 .filter_map(|i| float_col.get(i).ok().flatten())
364 .collect()
365 }
366 col if col.as_int64().is_some() => {
367 let int_col = col.as_int64().ok_or_else(|| {
368 crate::error::Error::TypeMismatch("column type check failed for Int64".into())
369 })?;
370 (0..self.row_count())
371 .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
372 .collect()
373 }
374 _ => {
375 return Err(crate::error::Error::Type(format!(
376 "Column '{}' is not a numeric type",
377 col1
378 )))
379 }
380 };
381
382 let values2: Vec<f64> = match column2 {
383 col if col.as_float64().is_some() => {
384 let float_col = col.as_float64().ok_or_else(|| {
385 crate::error::Error::TypeMismatch("column type check failed for Float64".into())
386 })?;
387 (0..self.row_count())
388 .filter_map(|i| float_col.get(i).ok().flatten())
389 .collect()
390 }
391 col if col.as_int64().is_some() => {
392 let int_col = col.as_int64().ok_or_else(|| {
393 crate::error::Error::TypeMismatch("column type check failed for Int64".into())
394 })?;
395 (0..self.row_count())
396 .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
397 .collect()
398 }
399 _ => {
400 return Err(crate::error::Error::Type(format!(
401 "Column '{}' is not a numeric type",
402 col2
403 )))
404 }
405 };
406
407 stats::mann_whitney_u(&values1, &values2, alpha)
409 }
410
411 pub fn chi_square_test(
422 &self,
423 row_col: &str,
424 col_col: &str,
425 count_col: &str,
426 alpha: Option<f64>,
427 ) -> Result<ChiSquareResult> {
428 let alpha = alpha.unwrap_or(0.05);
429
430 let row_column = self.column(row_col)?;
432 let col_column = self.column(col_col)?;
433 let count_column = self.column(count_col)?;
434
435 let row_strings = row_column.as_string().ok_or_else(|| {
437 crate::error::Error::Type(format!("Column '{}' must be a string type", row_col))
438 })?;
439
440 let col_strings = col_column.as_string().ok_or_else(|| {
441 crate::error::Error::Type(format!("Column '{}' must be a string type", col_col))
442 })?;
443
444 let count_values: Vec<f64> = match count_column {
446 col if col.as_float64().is_some() => {
447 let float_col = col.as_float64().ok_or_else(|| {
448 crate::error::Error::TypeMismatch("column type check failed for Float64".into())
449 })?;
450 (0..self.row_count())
451 .filter_map(|i| float_col.get(i).ok().flatten())
452 .collect()
453 }
454 col if col.as_int64().is_some() => {
455 let int_col = col.as_int64().ok_or_else(|| {
456 crate::error::Error::TypeMismatch("column type check failed for Int64".into())
457 })?;
458 (0..self.row_count())
459 .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
460 .collect()
461 }
462 _ => {
463 return Err(crate::error::Error::Type(format!(
464 "Column '{}' is not a numeric type",
465 count_col
466 )))
467 }
468 };
469
470 let mut unique_rows = vec![];
473 let mut unique_cols = vec![];
474
475 for i in 0..self.row_count() {
476 if let Ok(Some(row_val)) = row_strings.get(i) {
477 if !unique_rows.contains(&row_val) {
478 unique_rows.push(row_val);
479 }
480 }
481
482 if let Ok(Some(col_val)) = col_strings.get(i) {
483 if !unique_cols.contains(&col_val) {
484 unique_cols.push(col_val);
485 }
486 }
487 }
488
489 let mut observed = vec![vec![0.0; unique_cols.len()]; unique_rows.len()];
491
492 for i in 0..self.row_count() {
493 if let (Ok(Some(row_val)), Ok(Some(col_val)), count) =
494 (row_strings.get(i), col_strings.get(i), count_values.get(i))
495 {
496 if let (Some(row_idx), Some(col_idx)) = (
497 unique_rows.iter().position(|r| r == &row_val),
498 unique_cols.iter().position(|c| c == &col_val),
499 ) {
500 if let Some(cnt) = count {
502 observed[row_idx][col_idx] += *cnt;
503 } else {
504 observed[row_idx][col_idx] += 1.0;
505 }
506 }
507 }
508 }
509
510 stats::chi_square_test(&observed, alpha)
512 }
513
514 pub fn linear_regression(
523 &self,
524 y_col: &str,
525 x_cols: &[&str],
526 ) -> Result<LinearRegressionResult> {
527 let mut df = crate::dataframe::DataFrame::new();
529
530 let y_column = self.column(y_col)?;
532 if let Some(float_col) = y_column.as_float64() {
533 let values: Vec<f64> = (0..self.row_count())
534 .filter_map(|i| float_col.get(i).ok().flatten())
535 .collect();
536
537 let series = crate::series::Series::new(values, Some(y_col.to_string()))?;
538 df.add_column(y_col.to_string(), series)?;
539 } else if let Some(int_col) = y_column.as_int64() {
540 let values: Vec<f64> = (0..self.row_count())
542 .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
543 .collect();
544
545 let series = crate::series::Series::new(values, Some(y_col.to_string()))?;
546 df.add_column(y_col.to_string(), series)?;
547 } else {
548 return Err(crate::error::Error::Type(format!(
549 "Column '{}' must be a numeric type",
550 y_col
551 )));
552 }
553
554 for &x_col in x_cols {
556 let x_column = self.column(x_col)?;
557 if let Some(float_col) = x_column.as_float64() {
558 let values: Vec<f64> = (0..self.row_count())
559 .filter_map(|i| float_col.get(i).ok().flatten())
560 .collect();
561
562 let series = crate::series::Series::new(values, Some(x_col.to_string()))?;
563 df.add_column(x_col.to_string(), series)?;
564 } else if let Some(int_col) = x_column.as_int64() {
565 let values: Vec<f64> = (0..self.row_count())
567 .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
568 .collect();
569
570 let series = crate::series::Series::new(values, Some(x_col.to_string()))?;
571 df.add_column(x_col.to_string(), series)?;
572 } else {
573 return Err(crate::error::Error::Type(format!(
574 "Column '{}' must be a numeric type",
575 x_col
576 )));
577 }
578 }
579
580 stats::linear_regression(&df, y_col, x_cols)
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588 use crate::column::{Column, Float64Column, StringColumn};
589 use crate::optimized::split_dataframe::OptimizedDataFrame;
590
591 #[test]
592 fn test_describe() {
593 let mut df = OptimizedDataFrame::new();
594
595 let values = Float64Column::with_name(vec![1.0, 2.0, 3.0, 4.0, 5.0], "values");
597 df.add_column("values", Column::Float64(values))
598 .expect("operation should succeed");
599
600 let desc = df.describe("values").expect("operation should succeed");
602
603 assert_eq!(
605 desc.stats
606 .get("count")
607 .expect("operation should succeed")
608 .clone() as usize,
609 5
610 );
611 assert!((desc.stats.get("mean").expect("operation should succeed") - 3.0).abs() < 1e-10);
612 assert!((desc.stats.get("min").expect("operation should succeed") - 1.0).abs() < 1e-10);
613 assert!((desc.stats.get("max").expect("operation should succeed") - 5.0).abs() < 1e-10);
614 }
615
616 #[test]
617 fn test_ttest() {
618 let mut df = OptimizedDataFrame::new();
619
620 let values1 = Float64Column::with_name(vec![1.0, 2.0, 3.0, 4.0, 5.0], "sample1");
622 let values2 = Float64Column::with_name(vec![2.0, 3.0, 4.0, 5.0, 6.0], "sample2");
623
624 df.add_column("sample1", Column::Float64(values1))
625 .expect("operation should succeed");
626 df.add_column("sample2", Column::Float64(values2))
627 .expect("operation should succeed");
628
629 let result = df
631 .ttest("sample1", "sample2", Some(0.05), Some(true))
632 .expect("operation should succeed");
633
634 assert!(result.statistic < 0.0); assert_eq!(result.df, 8); }
638
639 #[test]
640 fn test_anova() {
641 let mut df = OptimizedDataFrame::new();
642
643 let values = Float64Column::with_name(
645 vec![
646 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 4.0, 5.0, 6.0, 3.0, 4.0, 5.0, 6.0, 7.0,
647 ],
648 "values",
649 );
650
651 let groups = StringColumn::with_name(
652 vec![
653 "A".to_string(),
654 "A".to_string(),
655 "A".to_string(),
656 "A".to_string(),
657 "A".to_string(),
658 "B".to_string(),
659 "B".to_string(),
660 "B".to_string(),
661 "B".to_string(),
662 "B".to_string(),
663 "C".to_string(),
664 "C".to_string(),
665 "C".to_string(),
666 "C".to_string(),
667 "C".to_string(),
668 ],
669 "group",
670 );
671
672 df.add_column("values", Column::Float64(values))
673 .expect("operation should succeed");
674 df.add_column("group", Column::String(groups))
675 .expect("operation should succeed");
676
677 let result = df
679 .anova("values", "group", Some(0.05))
680 .expect("operation should succeed");
681
682 assert!(result.f_statistic > 0.0);
684 assert_eq!(result.df_between, 2); assert_eq!(result.df_within, 12); assert_eq!(result.df_total, 14); }
688}