Skip to main content

robin_sparkless/dataframe/
stats.rs

1//! DataFrame statistical methods: stat (cov, corr), summary.
2//! PySpark: df.stat().cov("a", "b"), df.stat().corr("a", "b"), df.corr() (matrix), df.summary(...).
3
4use super::DataFrame;
5use polars::datatypes::DataType;
6use polars::prelude::{DataFrame as PlDataFrame, NamedFrom, PolarsError, Series};
7
8/// Helper for DataFrame statistical methods (PySpark-style df.stat().cov/corr).
9pub struct DataFrameStat<'a> {
10    pub(crate) df: &'a DataFrame,
11}
12
13impl<'a> DataFrameStat<'a> {
14    /// Sample covariance between two columns. PySpark stat.cov. ddof=1 for sample covariance.
15    pub fn cov(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
16        let c1 = self.df.resolve_column_name(col1)?;
17        let c2 = self.df.resolve_column_name(col2)?;
18        let pl = self.df.collect_inner()?;
19        let s1 = pl
20            .column(c1.as_str())?
21            .cast(&polars::datatypes::DataType::Float64)?;
22        let s2 = pl
23            .column(c2.as_str())?
24            .cast(&polars::datatypes::DataType::Float64)?;
25        let a = s1
26            .f64()
27            .map_err(|_| PolarsError::ComputeError("cov: need float column".into()))?;
28        let b = s2
29            .f64()
30            .map_err(|_| PolarsError::ComputeError("cov: need float column".into()))?;
31        let mut sum_ab = 0.0_f64;
32        let mut sum_a = 0.0_f64;
33        let mut sum_b = 0.0_f64;
34        let mut n = 0_usize;
35        for (x, y) in a.into_iter().zip(b.into_iter()) {
36            if let (Some(xv), Some(yv)) = (x, y) {
37                n += 1;
38                sum_a += xv;
39                sum_b += yv;
40                sum_ab += xv * yv;
41            }
42        }
43        if n < 2 {
44            return Ok(f64::NAN);
45        }
46        let mean_a = sum_a / n as f64;
47        let mean_b = sum_b / n as f64;
48        let cov = (sum_ab - n as f64 * mean_a * mean_b) / (n as f64 - 1.0);
49        Ok(cov)
50    }
51
52    /// Pearson correlation between two columns. PySpark stat.corr.
53    pub fn corr(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
54        let c1 = self.df.resolve_column_name(col1)?;
55        let c2 = self.df.resolve_column_name(col2)?;
56        let pl = self.df.collect_inner()?;
57        let s1 = pl
58            .column(c1.as_str())?
59            .cast(&polars::datatypes::DataType::Float64)?;
60        let s2 = pl
61            .column(c2.as_str())?
62            .cast(&polars::datatypes::DataType::Float64)?;
63        let a = s1
64            .f64()
65            .map_err(|_| PolarsError::ComputeError("corr: need float column".into()))?;
66        let b = s2
67            .f64()
68            .map_err(|_| PolarsError::ComputeError("corr: need float column".into()))?;
69        let mut sum_ab = 0.0_f64;
70        let mut sum_a = 0.0_f64;
71        let mut sum_b = 0.0_f64;
72        let mut sum_a2 = 0.0_f64;
73        let mut sum_b2 = 0.0_f64;
74        let mut n = 0_usize;
75        for (x, y) in a.into_iter().zip(b.into_iter()) {
76            if let (Some(xv), Some(yv)) = (x, y) {
77                n += 1;
78                sum_a += xv;
79                sum_b += yv;
80                sum_ab += xv * yv;
81                sum_a2 += xv * xv;
82                sum_b2 += yv * yv;
83            }
84        }
85        if n < 2 {
86            return Ok(f64::NAN);
87        }
88        let mean_a = sum_a / n as f64;
89        let mean_b = sum_b / n as f64;
90        let std_a = ((sum_a2 / n as f64 - mean_a * mean_a) * (n as f64 / (n as f64 - 1.0))).sqrt();
91        let std_b = ((sum_b2 / n as f64 - mean_b * mean_b) * (n as f64 / (n as f64 - 1.0))).sqrt();
92        if std_a == 0.0 || std_b == 0.0 {
93            return Ok(f64::NAN);
94        }
95        let cov = (sum_ab - n as f64 * mean_a * mean_b) / (n as f64 - 1.0);
96        Ok(cov / (std_a * std_b))
97    }
98
99    /// Correlation matrix of all numeric columns. PySpark df.corr() returns a DataFrame of pairwise correlations.
100    /// Returns a DataFrame with column names as first column and one column per numeric column with correlation values.
101    pub fn corr_matrix(&self) -> Result<DataFrame, PolarsError> {
102        let collected = self.df.collect_inner()?;
103        let pl_df = collected.as_ref();
104        let numeric_cols: Vec<String> = pl_df
105            .columns()
106            .iter()
107            .filter(|s| {
108                matches!(
109                    s.dtype(),
110                    DataType::Int8
111                        | DataType::Int16
112                        | DataType::Int32
113                        | DataType::Int64
114                        | DataType::UInt8
115                        | DataType::UInt16
116                        | DataType::UInt32
117                        | DataType::UInt64
118                        | DataType::Float32
119                        | DataType::Float64
120                )
121            })
122            .map(|s| s.name().to_string())
123            .collect();
124        if numeric_cols.is_empty() {
125            return Ok(DataFrame::from_polars_with_options(
126                PlDataFrame::default(),
127                self.df.case_sensitive,
128            ));
129        }
130        let mut columns: Vec<Series> = Vec::with_capacity(numeric_cols.len());
131        for (i, name_i) in numeric_cols.iter().enumerate() {
132            let mut row_vals = Vec::with_capacity(numeric_cols.len());
133            for (j, name_j) in numeric_cols.iter().enumerate() {
134                let r = if i == j {
135                    1.0_f64
136                } else {
137                    self.corr(name_i, name_j)?
138                };
139                row_vals.push(Some(r));
140            }
141            let series = Series::new(name_i.as_str().into(), row_vals);
142            columns.push(series);
143        }
144        let out_pl =
145            PlDataFrame::new_infer_height(columns.into_iter().map(|s| s.into()).collect())?;
146        Ok(DataFrame::from_polars_with_options(
147            out_pl,
148            self.df.case_sensitive,
149        ))
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use crate::{DataFrame, SparkSession};
156
157    fn test_df() -> DataFrame {
158        let spark = SparkSession::builder()
159            .app_name("stat_tests")
160            .get_or_create();
161        let tuples = vec![
162            (1i64, 25i64, "a".to_string()),
163            (2i64, 30i64, "b".to_string()),
164            (3i64, 35i64, "c".to_string()),
165        ];
166        spark
167            .create_dataframe(tuples, vec!["id", "age", "name"])
168            .unwrap()
169    }
170
171    #[test]
172    fn stat_corr_two_columns() {
173        let df = test_df();
174        let stat = df.stat();
175        let r = stat.corr("id", "age").unwrap();
176        assert!(
177            r.is_nan() || (-1.0 - 1e-10..=1.0 + 1e-10).contains(&r),
178            "corr should be in [-1,1] or NaN, got {r}"
179        );
180    }
181
182    #[test]
183    fn stat_cov_two_columns() {
184        let df = test_df();
185        let stat = df.stat();
186        let c = stat.cov("id", "age").unwrap();
187        assert!(c.is_finite() || c.is_nan());
188    }
189
190    #[test]
191    fn stat_corr_less_than_two_rows_returns_nan() {
192        let spark = SparkSession::builder()
193            .app_name("stat_tests")
194            .get_or_create();
195        let tuples = vec![(1i64, 10i64, "x".to_string())];
196        let df = spark.create_dataframe(tuples, vec!["a", "b", "c"]).unwrap();
197        let stat = df.stat();
198        let r = stat.corr("a", "b").unwrap();
199        assert!(r.is_nan());
200    }
201
202    #[test]
203    fn stat_cov_constant_column() {
204        let spark = SparkSession::builder()
205            .app_name("stat_tests")
206            .get_or_create();
207        let tuples = vec![(1i64, 5i64, "a".to_string()), (1i64, 5i64, "b".to_string())];
208        let df = spark
209            .create_dataframe(tuples, vec!["k", "v", "label"])
210            .unwrap();
211        let stat = df.stat();
212        let c = stat.cov("k", "v").unwrap();
213        assert!(c.is_nan() || c == 0.0);
214    }
215}