Skip to main content

pandrs/optimized/split_dataframe/
stats.rs

1//! Statistical functions module for OptimizedDataFrame
2//!
3//! This module provides statistical functionality for data analysis.
4//! It supports ANOVA, t-tests, chi-square tests, Mann-Whitney U tests, and more.
5
6use crate::column::{Column, ColumnTrait};
7use crate::error::Result;
8use crate::optimized::split_dataframe::OptimizedDataFrame;
9use crate::stats::{
10    self, AnovaResult, ChiSquareResult, DescriptiveStats, LinearRegressionResult,
11    MannWhitneyResult, TTestResult,
12};
13use std::collections::HashMap;
14
15/// Statistical result type for OptimizedDataFrame
16#[derive(Debug, Clone)]
17pub enum StatResult {
18    /// Descriptive statistics results
19    Descriptive(DescriptiveStats),
20    /// t-test results
21    TTest(TTestResult),
22    /// Analysis of variance (ANOVA) results
23    Anova(AnovaResult),
24    /// Mann-Whitney U test results
25    MannWhitneyU(MannWhitneyResult),
26    /// Chi-square test results
27    ChiSquare(ChiSquareResult),
28    /// Linear regression results
29    LinearRegression(LinearRegressionResult),
30}
31
32/// Output format for descriptive statistics results
33#[derive(Debug, Clone)]
34pub struct StatDescribe {
35    /// Map of statistics
36    pub stats: HashMap<String, f64>,
37    /// List of statistics (ordered)
38    pub stats_list: Vec<(String, f64)>,
39}
40
41/// Statistical functionality extension for OptimizedDataFrame
42impl OptimizedDataFrame {
43    /// Calculate basic statistics for a specific column
44    ///
45    /// # Arguments
46    /// * `column_name` - Name of the column to calculate statistics for
47    ///
48    /// # Returns
49    /// A structure containing descriptive statistics
50    pub fn describe(&self, column_name: &str) -> Result<StatDescribe> {
51        let col = self.column(column_name)?;
52
53        if let Some(float_col) = col.as_float64() {
54            // For floating-point columns
55            let values: Vec<f64> = (0..self.row_count())
56                .filter_map(|i| float_col.get(i).ok().flatten())
57                .collect();
58
59            // Use stats module
60            let stats = stats::describe(&values)?;
61
62            // Store results in HashMap
63            let mut result = HashMap::new();
64            result.insert("count".to_string(), stats.count as f64);
65            result.insert("mean".to_string(), stats.mean);
66            result.insert("std".to_string(), stats.std);
67            result.insert("min".to_string(), stats.min);
68            result.insert("25%".to_string(), stats.q1);
69            result.insert("50%".to_string(), stats.median);
70            result.insert("75%".to_string(), stats.q3);
71            result.insert("max".to_string(), stats.max);
72
73            // Also provide ordered list
74            let stats_list = vec![
75                ("count".to_string(), stats.count as f64),
76                ("mean".to_string(), stats.mean),
77                ("std".to_string(), stats.std),
78                ("min".to_string(), stats.min),
79                ("25%".to_string(), stats.q1),
80                ("50%".to_string(), stats.median),
81                ("75%".to_string(), stats.q3),
82                ("max".to_string(), stats.max),
83            ];
84
85            let mut result = HashMap::new();
86            result.insert("count".to_string(), stats.count as f64);
87            result.insert("mean".to_string(), stats.mean);
88            result.insert("std".to_string(), stats.std);
89            result.insert("min".to_string(), stats.min);
90            result.insert("25%".to_string(), stats.q1);
91            result.insert("50%".to_string(), stats.median);
92            result.insert("75%".to_string(), stats.q3);
93            result.insert("max".to_string(), stats.max);
94
95            Ok(StatDescribe {
96                stats: result,
97                stats_list,
98            })
99        } else if let Some(int_col) = col.as_int64() {
100            // For integer columns, convert to floating-point for calculation
101            let values: Vec<f64> = (0..self.row_count())
102                .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
103                .collect();
104
105            // Use stats module
106            let stats = stats::describe(&values)?;
107
108            // Store results in HashMap
109            let mut result = HashMap::new();
110            result.insert("count".to_string(), stats.count as f64);
111            result.insert("mean".to_string(), stats.mean);
112            result.insert("std".to_string(), stats.std);
113            result.insert("min".to_string(), stats.min);
114            result.insert("25%".to_string(), stats.q1);
115            result.insert("50%".to_string(), stats.median);
116            result.insert("75%".to_string(), stats.q3);
117            result.insert("max".to_string(), stats.max);
118
119            // Also provide ordered list
120            let stats_list = vec![
121                ("count".to_string(), stats.count as f64),
122                ("mean".to_string(), stats.mean),
123                ("std".to_string(), stats.std),
124                ("min".to_string(), stats.min),
125                ("25%".to_string(), stats.q1),
126                ("50%".to_string(), stats.median),
127                ("75%".to_string(), stats.q3),
128                ("max".to_string(), stats.max),
129            ];
130
131            let mut result = HashMap::new();
132            result.insert("count".to_string(), stats.count as f64);
133            result.insert("mean".to_string(), stats.mean);
134            result.insert("std".to_string(), stats.std);
135            result.insert("min".to_string(), stats.min);
136            result.insert("25%".to_string(), stats.q1);
137            result.insert("50%".to_string(), stats.median);
138            result.insert("75%".to_string(), stats.q3);
139            result.insert("max".to_string(), stats.max);
140
141            Ok(StatDescribe {
142                stats: result,
143                stats_list,
144            })
145        } else {
146            Err(crate::error::Error::Type(format!(
147                "Column '{}' is not a numeric type",
148                column_name
149            )))
150        }
151    }
152
153    /// Calculate descriptive statistics for multiple columns at once
154    ///
155    /// # Returns
156    /// A mapping from column names to statistical results
157    pub fn describe_all(&self) -> Result<HashMap<String, StatDescribe>> {
158        let mut results = HashMap::new();
159
160        for col_name in self.column_names() {
161            // Target only numeric columns
162            let col = self.column(col_name)?;
163            if col.as_float64().is_some() || col.as_int64().is_some() {
164                if let Ok(desc) = self.describe(col_name) {
165                    results.insert(col_name.to_string(), desc);
166                }
167            }
168        }
169
170        Ok(results)
171    }
172
173    /// Perform t-test on two columns
174    ///
175    /// # Arguments
176    /// * `col1` - First column name
177    /// * `col2` - Second column name
178    /// * `alpha` - Significance level (default: 0.05)
179    /// * `equal_var` - Whether to assume equal variance (default: true)
180    ///
181    /// # Returns
182    /// Results of the t-test
183    pub fn ttest(
184        &self,
185        col1: &str,
186        col2: &str,
187        alpha: Option<f64>,
188        equal_var: Option<bool>,
189    ) -> Result<TTestResult> {
190        let alpha = alpha.unwrap_or(0.05);
191        let equal_var = equal_var.unwrap_or(true);
192
193        // Get column data
194        let column1 = self.column(col1)?;
195        let column2 = self.column(col2)?;
196
197        // Convert to floating-point vectors
198        let values1: Vec<f64> = match column1 {
199            col if col.as_float64().is_some() => {
200                let float_col = col.as_float64().ok_or_else(|| {
201                    crate::error::Error::TypeMismatch("column type check failed for Float64".into())
202                })?;
203                (0..self.row_count())
204                    .filter_map(|i| float_col.get(i).ok().flatten())
205                    .collect()
206            }
207            col if col.as_int64().is_some() => {
208                let int_col = col.as_int64().ok_or_else(|| {
209                    crate::error::Error::TypeMismatch("column type check failed for Int64".into())
210                })?;
211                (0..self.row_count())
212                    .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
213                    .collect()
214            }
215            _ => {
216                return Err(crate::error::Error::Type(format!(
217                    "Column '{}' is not a numeric type",
218                    col1
219                )))
220            }
221        };
222
223        let values2: Vec<f64> = match column2 {
224            col if col.as_float64().is_some() => {
225                let float_col = col.as_float64().ok_or_else(|| {
226                    crate::error::Error::TypeMismatch("column type check failed for Float64".into())
227                })?;
228                (0..self.row_count())
229                    .filter_map(|i| float_col.get(i).ok().flatten())
230                    .collect()
231            }
232            col if col.as_int64().is_some() => {
233                let int_col = col.as_int64().ok_or_else(|| {
234                    crate::error::Error::TypeMismatch("column type check failed for Int64".into())
235                })?;
236                (0..self.row_count())
237                    .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
238                    .collect()
239            }
240            _ => {
241                return Err(crate::error::Error::Type(format!(
242                    "Column '{}' is not a numeric type",
243                    col2
244                )))
245            }
246        };
247
248        // Perform t-test
249        stats::ttest(&values1, &values2, alpha, equal_var)
250    }
251
252    /// Perform one-way analysis of variance (ANOVA)
253    ///
254    /// # Arguments
255    /// * `value_col` - Column name containing the measured values
256    /// * `group_col` - Column name for grouping
257    /// * `alpha` - Significance level (default: 0.05)
258    ///
259    /// # Returns
260    /// Results of the ANOVA
261    pub fn anova(
262        &self,
263        value_col: &str,
264        group_col: &str,
265        alpha: Option<f64>,
266    ) -> Result<AnovaResult> {
267        let alpha = alpha.unwrap_or(0.05);
268
269        // Get the value column
270        let value_column = self.column(value_col)?;
271
272        // Get the group column
273        let group_column = self.column(group_col)?;
274        let group_col_string = group_column.as_string().ok_or_else(|| {
275            crate::error::Error::Type(format!("Column '{}' must be a string type", group_col))
276        })?;
277
278        // Convert values to floating-point
279        let values: Vec<(f64, String)> = match value_column {
280            col if col.as_float64().is_some() => {
281                let float_col = col.as_float64().ok_or_else(|| {
282                    crate::error::Error::TypeMismatch("column type check failed for Float64".into())
283                })?;
284                (0..self.row_count())
285                    .filter_map(|i| {
286                        let val = float_col.get(i).ok().flatten()?;
287                        let group = group_col_string.get(i).ok().flatten()?;
288                        Some((val, group.to_string()))
289                    })
290                    .collect()
291            }
292            col if col.as_int64().is_some() => {
293                let int_col = col.as_int64().ok_or_else(|| {
294                    crate::error::Error::TypeMismatch("column type check failed for Int64".into())
295                })?;
296                (0..self.row_count())
297                    .filter_map(|i| {
298                        let val = int_col.get(i).ok().flatten()? as f64;
299                        let group = group_col_string.get(i).ok().flatten()?;
300                        Some((val, group.to_string()))
301                    })
302                    .collect()
303            }
304            _ => {
305                return Err(crate::error::Error::Type(format!(
306                    "Column '{}' is not a numeric type",
307                    value_col
308                )))
309            }
310        };
311
312        // Organize data by group
313        let mut groups: HashMap<String, Vec<f64>> = HashMap::new();
314        for (val, group) in values {
315            groups.entry(group).or_insert_with(Vec::new).push(val);
316        }
317
318        // Ensure there are at least 2 groups
319        if groups.len() < 2 {
320            return Err(crate::error::Error::InsufficientData(
321                "ANOVA requires at least 2 groups".to_string(),
322            ));
323        }
324
325        // Convert to &str group map
326        let str_groups: HashMap<&str, Vec<f64>> = groups
327            .iter()
328            .map(|(k, v)| (k.as_str(), v.clone()))
329            .collect();
330
331        // Perform ANOVA
332        stats::anova(&str_groups, alpha)
333    }
334
335    /// Perform Mann-Whitney U test (non-parametric test)
336    ///
337    /// # Arguments
338    /// * `col1` - First column name
339    /// * `col2` - Second column name
340    /// * `alpha` - Significance level (default: 0.05)
341    ///
342    /// # Returns
343    /// Results of the Mann-Whitney U test
344    pub fn mann_whitney_u(
345        &self,
346        col1: &str,
347        col2: &str,
348        alpha: Option<f64>,
349    ) -> Result<MannWhitneyResult> {
350        let alpha = alpha.unwrap_or(0.05);
351
352        // Get column data
353        let column1 = self.column(col1)?;
354        let column2 = self.column(col2)?;
355
356        // Convert to floating-point vectors
357        let values1: Vec<f64> = match column1 {
358            col if col.as_float64().is_some() => {
359                let float_col = col.as_float64().ok_or_else(|| {
360                    crate::error::Error::TypeMismatch("column type check failed for Float64".into())
361                })?;
362                (0..self.row_count())
363                    .filter_map(|i| float_col.get(i).ok().flatten())
364                    .collect()
365            }
366            col if col.as_int64().is_some() => {
367                let int_col = col.as_int64().ok_or_else(|| {
368                    crate::error::Error::TypeMismatch("column type check failed for Int64".into())
369                })?;
370                (0..self.row_count())
371                    .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
372                    .collect()
373            }
374            _ => {
375                return Err(crate::error::Error::Type(format!(
376                    "Column '{}' is not a numeric type",
377                    col1
378                )))
379            }
380        };
381
382        let values2: Vec<f64> = match column2 {
383            col if col.as_float64().is_some() => {
384                let float_col = col.as_float64().ok_or_else(|| {
385                    crate::error::Error::TypeMismatch("column type check failed for Float64".into())
386                })?;
387                (0..self.row_count())
388                    .filter_map(|i| float_col.get(i).ok().flatten())
389                    .collect()
390            }
391            col if col.as_int64().is_some() => {
392                let int_col = col.as_int64().ok_or_else(|| {
393                    crate::error::Error::TypeMismatch("column type check failed for Int64".into())
394                })?;
395                (0..self.row_count())
396                    .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
397                    .collect()
398            }
399            _ => {
400                return Err(crate::error::Error::Type(format!(
401                    "Column '{}' is not a numeric type",
402                    col2
403                )))
404            }
405        };
406
407        // Perform Mann-Whitney U test
408        stats::mann_whitney_u(&values1, &values2, alpha)
409    }
410
411    /// Perform chi-square test
412    ///
413    /// # Arguments
414    /// * `row_col` - Column name determining rows
415    /// * `col_col` - Column name determining columns
416    /// * `count_col` - Column name containing counts/frequencies
417    /// * `alpha` - Significance level (default: 0.05)
418    ///
419    /// # Returns
420    /// Results of the chi-square test
421    pub fn chi_square_test(
422        &self,
423        row_col: &str,
424        col_col: &str,
425        count_col: &str,
426        alpha: Option<f64>,
427    ) -> Result<ChiSquareResult> {
428        let alpha = alpha.unwrap_or(0.05);
429
430        // Get column data
431        let row_column = self.column(row_col)?;
432        let col_column = self.column(col_col)?;
433        let count_column = self.column(count_col)?;
434
435        // Get string columns
436        let row_strings = row_column.as_string().ok_or_else(|| {
437            crate::error::Error::Type(format!("Column '{}' must be a string type", row_col))
438        })?;
439
440        let col_strings = col_column.as_string().ok_or_else(|| {
441            crate::error::Error::Type(format!("Column '{}' must be a string type", col_col))
442        })?;
443
444        // Get count values
445        let count_values: Vec<f64> = match count_column {
446            col if col.as_float64().is_some() => {
447                let float_col = col.as_float64().ok_or_else(|| {
448                    crate::error::Error::TypeMismatch("column type check failed for Float64".into())
449                })?;
450                (0..self.row_count())
451                    .filter_map(|i| float_col.get(i).ok().flatten())
452                    .collect()
453            }
454            col if col.as_int64().is_some() => {
455                let int_col = col.as_int64().ok_or_else(|| {
456                    crate::error::Error::TypeMismatch("column type check failed for Int64".into())
457                })?;
458                (0..self.row_count())
459                    .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
460                    .collect()
461            }
462            _ => {
463                return Err(crate::error::Error::Type(format!(
464                    "Column '{}' is not a numeric type",
465                    count_col
466                )))
467            }
468        };
469
470        // Generate contingency table
471        // Extract unique row and column values
472        let mut unique_rows = vec![];
473        let mut unique_cols = vec![];
474
475        for i in 0..self.row_count() {
476            if let Ok(Some(row_val)) = row_strings.get(i) {
477                if !unique_rows.contains(&row_val) {
478                    unique_rows.push(row_val);
479                }
480            }
481
482            if let Ok(Some(col_val)) = col_strings.get(i) {
483                if !unique_cols.contains(&col_val) {
484                    unique_cols.push(col_val);
485                }
486            }
487        }
488
489        // Build observed data matrix
490        let mut observed = vec![vec![0.0; unique_cols.len()]; unique_rows.len()];
491
492        for i in 0..self.row_count() {
493            if let (Ok(Some(row_val)), Ok(Some(col_val)), count) =
494                (row_strings.get(i), col_strings.get(i), count_values.get(i))
495            {
496                if let (Some(row_idx), Some(col_idx)) = (
497                    unique_rows.iter().position(|r| r == &row_val),
498                    unique_cols.iter().position(|c| c == &col_val),
499                ) {
500                    // Add count value if available, otherwise add 1.0
501                    if let Some(cnt) = count {
502                        observed[row_idx][col_idx] += *cnt;
503                    } else {
504                        observed[row_idx][col_idx] += 1.0;
505                    }
506                }
507            }
508        }
509
510        // Perform chi-square test
511        stats::chi_square_test(&observed, alpha)
512    }
513
514    /// Perform linear regression analysis
515    ///
516    /// # Arguments
517    /// * `y_col` - Name of the target (dependent) variable column
518    /// * `x_cols` - List of explanatory (independent) variable column names
519    ///
520    /// # Returns
521    /// Results of the linear regression analysis
522    pub fn linear_regression(
523        &self,
524        y_col: &str,
525        x_cols: &[&str],
526    ) -> Result<LinearRegressionResult> {
527        // Convert to DataFrame format
528        let mut df = crate::dataframe::DataFrame::new();
529
530        // Add the target variable
531        let y_column = self.column(y_col)?;
532        if let Some(float_col) = y_column.as_float64() {
533            let values: Vec<f64> = (0..self.row_count())
534                .filter_map(|i| float_col.get(i).ok().flatten())
535                .collect();
536
537            let series = crate::series::Series::new(values, Some(y_col.to_string()))?;
538            df.add_column(y_col.to_string(), series)?;
539        } else if let Some(int_col) = y_column.as_int64() {
540            // Convert integer column to floating-point
541            let values: Vec<f64> = (0..self.row_count())
542                .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
543                .collect();
544
545            let series = crate::series::Series::new(values, Some(y_col.to_string()))?;
546            df.add_column(y_col.to_string(), series)?;
547        } else {
548            return Err(crate::error::Error::Type(format!(
549                "Column '{}' must be a numeric type",
550                y_col
551            )));
552        }
553
554        // Add explanatory variable columns
555        for &x_col in x_cols {
556            let x_column = self.column(x_col)?;
557            if let Some(float_col) = x_column.as_float64() {
558                let values: Vec<f64> = (0..self.row_count())
559                    .filter_map(|i| float_col.get(i).ok().flatten())
560                    .collect();
561
562                let series = crate::series::Series::new(values, Some(x_col.to_string()))?;
563                df.add_column(x_col.to_string(), series)?;
564            } else if let Some(int_col) = x_column.as_int64() {
565                // Convert integer column to floating-point
566                let values: Vec<f64> = (0..self.row_count())
567                    .filter_map(|i| int_col.get(i).ok().flatten().map(|v| v as f64))
568                    .collect();
569
570                let series = crate::series::Series::new(values, Some(x_col.to_string()))?;
571                df.add_column(x_col.to_string(), series)?;
572            } else {
573                return Err(crate::error::Error::Type(format!(
574                    "Column '{}' must be a numeric type",
575                    x_col
576                )));
577            }
578        }
579
580        // Build linear regression model
581        stats::linear_regression(&df, y_col, x_cols)
582    }
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588    use crate::column::{Column, Float64Column, StringColumn};
589    use crate::optimized::split_dataframe::OptimizedDataFrame;
590
591    #[test]
592    fn test_describe() {
593        let mut df = OptimizedDataFrame::new();
594
595        // Create test data
596        let values = Float64Column::with_name(vec![1.0, 2.0, 3.0, 4.0, 5.0], "values");
597        df.add_column("values", Column::Float64(values))
598            .expect("operation should succeed");
599
600        // Test describe function
601        let desc = df.describe("values").expect("operation should succeed");
602
603        // Verify results
604        assert_eq!(
605            desc.stats
606                .get("count")
607                .expect("operation should succeed")
608                .clone() as usize,
609            5
610        );
611        assert!((desc.stats.get("mean").expect("operation should succeed") - 3.0).abs() < 1e-10);
612        assert!((desc.stats.get("min").expect("operation should succeed") - 1.0).abs() < 1e-10);
613        assert!((desc.stats.get("max").expect("operation should succeed") - 5.0).abs() < 1e-10);
614    }
615
616    #[test]
617    fn test_ttest() {
618        let mut df = OptimizedDataFrame::new();
619
620        // Create test data
621        let values1 = Float64Column::with_name(vec![1.0, 2.0, 3.0, 4.0, 5.0], "sample1");
622        let values2 = Float64Column::with_name(vec![2.0, 3.0, 4.0, 5.0, 6.0], "sample2");
623
624        df.add_column("sample1", Column::Float64(values1))
625            .expect("operation should succeed");
626        df.add_column("sample2", Column::Float64(values2))
627            .expect("operation should succeed");
628
629        // Run t-test
630        let result = df
631            .ttest("sample1", "sample2", Some(0.05), Some(true))
632            .expect("operation should succeed");
633
634        // Verify results
635        assert!(result.statistic < 0.0); // Because sample2 has larger values
636        assert_eq!(result.df, 8); // Degrees of freedom is total sample size - 2
637    }
638
639    #[test]
640    fn test_anova() {
641        let mut df = OptimizedDataFrame::new();
642
643        // Create test data
644        let values = Float64Column::with_name(
645            vec![
646                1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 4.0, 5.0, 6.0, 3.0, 4.0, 5.0, 6.0, 7.0,
647            ],
648            "values",
649        );
650
651        let groups = StringColumn::with_name(
652            vec![
653                "A".to_string(),
654                "A".to_string(),
655                "A".to_string(),
656                "A".to_string(),
657                "A".to_string(),
658                "B".to_string(),
659                "B".to_string(),
660                "B".to_string(),
661                "B".to_string(),
662                "B".to_string(),
663                "C".to_string(),
664                "C".to_string(),
665                "C".to_string(),
666                "C".to_string(),
667                "C".to_string(),
668            ],
669            "group",
670        );
671
672        df.add_column("values", Column::Float64(values))
673            .expect("operation should succeed");
674        df.add_column("group", Column::String(groups))
675            .expect("operation should succeed");
676
677        // Perform ANOVA
678        let result = df
679            .anova("values", "group", Some(0.05))
680            .expect("operation should succeed");
681
682        // Verify results
683        assert!(result.f_statistic > 0.0);
684        assert_eq!(result.df_between, 2); // Number of groups - 1
685        assert_eq!(result.df_within, 12); // Total sample size - number of groups
686        assert_eq!(result.df_total, 14); // Total sample size - 1
687    }
688}