Skip to main content

robin_sparkless/dataframe/
stats.rs

1//! DataFrame statistical methods: stat (cov, corr), summary.
2//! PySpark: df.stat().cov("a", "b"), df.stat().corr("a", "b"), df.corr() (matrix), df.summary(...).
3
4use super::DataFrame;
5use polars::datatypes::DataType;
6use polars::prelude::{DataFrame as PlDataFrame, NamedFrom, PolarsError, Series};
7
8/// Helper for DataFrame statistical methods (PySpark-style df.stat().cov/corr).
9pub struct DataFrameStat<'a> {
10    pub(crate) df: &'a DataFrame,
11}
12
13impl<'a> DataFrameStat<'a> {
14    /// Sample covariance between two columns. PySpark stat.cov. ddof=1 for sample covariance.
15    pub fn cov(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
16        let c1 = self.df.resolve_column_name(col1)?;
17        let c2 = self.df.resolve_column_name(col2)?;
18        let s1 = self
19            .df
20            .df
21            .as_ref()
22            .column(c1.as_str())?
23            .cast(&polars::datatypes::DataType::Float64)?;
24        let s2 = self
25            .df
26            .df
27            .as_ref()
28            .column(c2.as_str())?
29            .cast(&polars::datatypes::DataType::Float64)?;
30        let a = s1
31            .f64()
32            .map_err(|_| PolarsError::ComputeError("cov: need float column".into()))?;
33        let b = s2
34            .f64()
35            .map_err(|_| PolarsError::ComputeError("cov: need float column".into()))?;
36        let mut sum_ab = 0.0_f64;
37        let mut sum_a = 0.0_f64;
38        let mut sum_b = 0.0_f64;
39        let mut n = 0_usize;
40        for (x, y) in a.into_iter().zip(b.into_iter()) {
41            if let (Some(xv), Some(yv)) = (x, y) {
42                n += 1;
43                sum_a += xv;
44                sum_b += yv;
45                sum_ab += xv * yv;
46            }
47        }
48        if n < 2 {
49            return Ok(f64::NAN);
50        }
51        let mean_a = sum_a / n as f64;
52        let mean_b = sum_b / n as f64;
53        let cov = (sum_ab - n as f64 * mean_a * mean_b) / (n as f64 - 1.0);
54        Ok(cov)
55    }
56
57    /// Pearson correlation between two columns. PySpark stat.corr.
58    pub fn corr(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
59        let c1 = self.df.resolve_column_name(col1)?;
60        let c2 = self.df.resolve_column_name(col2)?;
61        let s1 = self
62            .df
63            .df
64            .as_ref()
65            .column(c1.as_str())?
66            .cast(&polars::datatypes::DataType::Float64)?;
67        let s2 = self
68            .df
69            .df
70            .as_ref()
71            .column(c2.as_str())?
72            .cast(&polars::datatypes::DataType::Float64)?;
73        let a = s1
74            .f64()
75            .map_err(|_| PolarsError::ComputeError("corr: need float column".into()))?;
76        let b = s2
77            .f64()
78            .map_err(|_| PolarsError::ComputeError("corr: need float column".into()))?;
79        let mut sum_ab = 0.0_f64;
80        let mut sum_a = 0.0_f64;
81        let mut sum_b = 0.0_f64;
82        let mut sum_a2 = 0.0_f64;
83        let mut sum_b2 = 0.0_f64;
84        let mut n = 0_usize;
85        for (x, y) in a.into_iter().zip(b.into_iter()) {
86            if let (Some(xv), Some(yv)) = (x, y) {
87                n += 1;
88                sum_a += xv;
89                sum_b += yv;
90                sum_ab += xv * yv;
91                sum_a2 += xv * xv;
92                sum_b2 += yv * yv;
93            }
94        }
95        if n < 2 {
96            return Ok(f64::NAN);
97        }
98        let mean_a = sum_a / n as f64;
99        let mean_b = sum_b / n as f64;
100        let std_a = ((sum_a2 / n as f64 - mean_a * mean_a) * (n as f64 / (n as f64 - 1.0))).sqrt();
101        let std_b = ((sum_b2 / n as f64 - mean_b * mean_b) * (n as f64 / (n as f64 - 1.0))).sqrt();
102        if std_a == 0.0 || std_b == 0.0 {
103            return Ok(f64::NAN);
104        }
105        let cov = (sum_ab - n as f64 * mean_a * mean_b) / (n as f64 - 1.0);
106        Ok(cov / (std_a * std_b))
107    }
108
109    /// Correlation matrix of all numeric columns. PySpark df.corr() returns a DataFrame of pairwise correlations.
110    /// Returns a DataFrame with column names as first column and one column per numeric column with correlation values.
111    pub fn corr_matrix(&self) -> Result<DataFrame, PolarsError> {
112        let pl_df = self.df.df.as_ref();
113        let numeric_cols: Vec<String> = pl_df
114            .get_columns()
115            .iter()
116            .filter(|s| {
117                matches!(
118                    s.dtype(),
119                    DataType::Int8
120                        | DataType::Int16
121                        | DataType::Int32
122                        | DataType::Int64
123                        | DataType::UInt8
124                        | DataType::UInt16
125                        | DataType::UInt32
126                        | DataType::UInt64
127                        | DataType::Float32
128                        | DataType::Float64
129                )
130            })
131            .map(|s| s.name().to_string())
132            .collect();
133        if numeric_cols.is_empty() {
134            return Ok(DataFrame::from_polars_with_options(
135                PlDataFrame::default(),
136                self.df.case_sensitive,
137            ));
138        }
139        let mut columns: Vec<Series> = Vec::with_capacity(numeric_cols.len());
140        for (i, name_i) in numeric_cols.iter().enumerate() {
141            let mut row_vals = Vec::with_capacity(numeric_cols.len());
142            for (j, name_j) in numeric_cols.iter().enumerate() {
143                let r = if i == j {
144                    1.0_f64
145                } else {
146                    self.corr(name_i, name_j)?
147                };
148                row_vals.push(Some(r));
149            }
150            let series = Series::new(name_i.as_str().into(), row_vals);
151            columns.push(series);
152        }
153        let out_pl = PlDataFrame::new(columns.into_iter().map(|s| s.into()).collect())?;
154        Ok(DataFrame::from_polars_with_options(
155            out_pl,
156            self.df.case_sensitive,
157        ))
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use crate::{DataFrame, SparkSession};
164
165    fn test_df() -> DataFrame {
166        let spark = SparkSession::builder()
167            .app_name("stat_tests")
168            .get_or_create();
169        let tuples = vec![
170            (1i64, 25i64, "a".to_string()),
171            (2i64, 30i64, "b".to_string()),
172            (3i64, 35i64, "c".to_string()),
173        ];
174        spark
175            .create_dataframe(tuples, vec!["id", "age", "name"])
176            .unwrap()
177    }
178
179    #[test]
180    fn stat_corr_two_columns() {
181        let df = test_df();
182        let stat = df.stat();
183        let r = stat.corr("id", "age").unwrap();
184        assert!(
185            r.is_nan() || (-1.0 - 1e-10..=1.0 + 1e-10).contains(&r),
186            "corr should be in [-1,1] or NaN, got {r}"
187        );
188    }
189
190    #[test]
191    fn stat_cov_two_columns() {
192        let df = test_df();
193        let stat = df.stat();
194        let c = stat.cov("id", "age").unwrap();
195        assert!(c.is_finite() || c.is_nan());
196    }
197
198    #[test]
199    fn stat_corr_less_than_two_rows_returns_nan() {
200        let spark = SparkSession::builder()
201            .app_name("stat_tests")
202            .get_or_create();
203        let tuples = vec![(1i64, 10i64, "x".to_string())];
204        let df = spark.create_dataframe(tuples, vec!["a", "b", "c"]).unwrap();
205        let stat = df.stat();
206        let r = stat.corr("a", "b").unwrap();
207        assert!(r.is_nan());
208    }
209
210    #[test]
211    fn stat_cov_constant_column() {
212        let spark = SparkSession::builder()
213            .app_name("stat_tests")
214            .get_or_create();
215        let tuples = vec![(1i64, 5i64, "a".to_string()), (1i64, 5i64, "b".to_string())];
216        let df = spark
217            .create_dataframe(tuples, vec!["k", "v", "label"])
218            .unwrap();
219        let stat = df.stat();
220        let c = stat.cov("k", "v").unwrap();
221        assert!(c.is_nan() || c == 0.0);
222    }
223}