robin-sparkless 0.11.9

PySpark-like DataFrame API in Rust on Polars; no JVM.
Documentation
//! DataFrame statistical methods: stat (cov, corr), summary.
//! PySpark: df.stat().cov("a", "b"), df.stat().corr("a", "b"), df.corr() (matrix), df.summary(...).

use super::DataFrame;
use polars::datatypes::DataType;
use polars::prelude::{DataFrame as PlDataFrame, NamedFrom, PolarsError, Series};

/// Helper for DataFrame statistical methods (PySpark-style df.stat().cov/corr).
pub struct DataFrameStat<'a> {
    pub(crate) df: &'a DataFrame,
}

impl<'a> DataFrameStat<'a> {
    /// Sample covariance between two columns. PySpark stat.cov. ddof=1 for sample covariance.
    pub fn cov(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
        let c1 = self.df.resolve_column_name(col1)?;
        let c2 = self.df.resolve_column_name(col2)?;
        let pl = self.df.collect_inner()?;
        let s1 = pl
            .column(c1.as_str())?
            .cast(&polars::datatypes::DataType::Float64)?;
        let s2 = pl
            .column(c2.as_str())?
            .cast(&polars::datatypes::DataType::Float64)?;
        let a = s1
            .f64()
            .map_err(|_| PolarsError::ComputeError("cov: need float column".into()))?;
        let b = s2
            .f64()
            .map_err(|_| PolarsError::ComputeError("cov: need float column".into()))?;
        let mut sum_ab = 0.0_f64;
        let mut sum_a = 0.0_f64;
        let mut sum_b = 0.0_f64;
        let mut n = 0_usize;
        for (x, y) in a.into_iter().zip(b.into_iter()) {
            if let (Some(xv), Some(yv)) = (x, y) {
                n += 1;
                sum_a += xv;
                sum_b += yv;
                sum_ab += xv * yv;
            }
        }
        if n < 2 {
            return Ok(f64::NAN);
        }
        let mean_a = sum_a / n as f64;
        let mean_b = sum_b / n as f64;
        let cov = (sum_ab - n as f64 * mean_a * mean_b) / (n as f64 - 1.0);
        Ok(cov)
    }

    /// Pearson correlation between two columns. PySpark stat.corr.
    pub fn corr(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
        let c1 = self.df.resolve_column_name(col1)?;
        let c2 = self.df.resolve_column_name(col2)?;
        let pl = self.df.collect_inner()?;
        let s1 = pl
            .column(c1.as_str())?
            .cast(&polars::datatypes::DataType::Float64)?;
        let s2 = pl
            .column(c2.as_str())?
            .cast(&polars::datatypes::DataType::Float64)?;
        let a = s1
            .f64()
            .map_err(|_| PolarsError::ComputeError("corr: need float column".into()))?;
        let b = s2
            .f64()
            .map_err(|_| PolarsError::ComputeError("corr: need float column".into()))?;
        let mut sum_ab = 0.0_f64;
        let mut sum_a = 0.0_f64;
        let mut sum_b = 0.0_f64;
        let mut sum_a2 = 0.0_f64;
        let mut sum_b2 = 0.0_f64;
        let mut n = 0_usize;
        for (x, y) in a.into_iter().zip(b.into_iter()) {
            if let (Some(xv), Some(yv)) = (x, y) {
                n += 1;
                sum_a += xv;
                sum_b += yv;
                sum_ab += xv * yv;
                sum_a2 += xv * xv;
                sum_b2 += yv * yv;
            }
        }
        if n < 2 {
            return Ok(f64::NAN);
        }
        let mean_a = sum_a / n as f64;
        let mean_b = sum_b / n as f64;
        let std_a = ((sum_a2 / n as f64 - mean_a * mean_a) * (n as f64 / (n as f64 - 1.0))).sqrt();
        let std_b = ((sum_b2 / n as f64 - mean_b * mean_b) * (n as f64 / (n as f64 - 1.0))).sqrt();
        if std_a == 0.0 || std_b == 0.0 {
            return Ok(f64::NAN);
        }
        let cov = (sum_ab - n as f64 * mean_a * mean_b) / (n as f64 - 1.0);
        Ok(cov / (std_a * std_b))
    }

    /// Correlation matrix of all numeric columns. PySpark df.corr() returns a DataFrame of pairwise correlations.
    /// Returns a DataFrame with column names as first column and one column per numeric column with correlation values.
    pub fn corr_matrix(&self) -> Result<DataFrame, PolarsError> {
        let collected = self.df.collect_inner()?;
        let pl_df = collected.as_ref();
        let numeric_cols: Vec<String> = pl_df
            .get_columns()
            .iter()
            .filter(|s| {
                matches!(
                    s.dtype(),
                    DataType::Int8
                        | DataType::Int16
                        | DataType::Int32
                        | DataType::Int64
                        | DataType::UInt8
                        | DataType::UInt16
                        | DataType::UInt32
                        | DataType::UInt64
                        | DataType::Float32
                        | DataType::Float64
                )
            })
            .map(|s| s.name().to_string())
            .collect();
        if numeric_cols.is_empty() {
            return Ok(DataFrame::from_polars_with_options(
                PlDataFrame::default(),
                self.df.case_sensitive,
            ));
        }
        let mut columns: Vec<Series> = Vec::with_capacity(numeric_cols.len());
        for (i, name_i) in numeric_cols.iter().enumerate() {
            let mut row_vals = Vec::with_capacity(numeric_cols.len());
            for (j, name_j) in numeric_cols.iter().enumerate() {
                let r = if i == j {
                    1.0_f64
                } else {
                    self.corr(name_i, name_j)?
                };
                row_vals.push(Some(r));
            }
            let series = Series::new(name_i.as_str().into(), row_vals);
            columns.push(series);
        }
        let out_pl = PlDataFrame::new(columns.into_iter().map(|s| s.into()).collect())?;
        Ok(DataFrame::from_polars_with_options(
            out_pl,
            self.df.case_sensitive,
        ))
    }
}

#[cfg(test)]
mod tests {
    use crate::{DataFrame, SparkSession};

    fn test_df() -> DataFrame {
        let spark = SparkSession::builder()
            .app_name("stat_tests")
            .get_or_create();
        let tuples = vec![
            (1i64, 25i64, "a".to_string()),
            (2i64, 30i64, "b".to_string()),
            (3i64, 35i64, "c".to_string()),
        ];
        spark
            .create_dataframe(tuples, vec!["id", "age", "name"])
            .unwrap()
    }

    #[test]
    fn stat_corr_two_columns() {
        let df = test_df();
        let stat = df.stat();
        let r = stat.corr("id", "age").unwrap();
        assert!(
            r.is_nan() || (-1.0 - 1e-10..=1.0 + 1e-10).contains(&r),
            "corr should be in [-1,1] or NaN, got {r}"
        );
    }

    #[test]
    fn stat_cov_two_columns() {
        let df = test_df();
        let stat = df.stat();
        let c = stat.cov("id", "age").unwrap();
        assert!(c.is_finite() || c.is_nan());
    }

    #[test]
    fn stat_corr_less_than_two_rows_returns_nan() {
        let spark = SparkSession::builder()
            .app_name("stat_tests")
            .get_or_create();
        let tuples = vec![(1i64, 10i64, "x".to_string())];
        let df = spark.create_dataframe(tuples, vec!["a", "b", "c"]).unwrap();
        let stat = df.stat();
        let r = stat.corr("a", "b").unwrap();
        assert!(r.is_nan());
    }

    #[test]
    fn stat_cov_constant_column() {
        let spark = SparkSession::builder()
            .app_name("stat_tests")
            .get_or_create();
        let tuples = vec![(1i64, 5i64, "a".to_string()), (1i64, 5i64, "b".to_string())];
        let df = spark
            .create_dataframe(tuples, vec!["k", "v", "label"])
            .unwrap();
        let stat = df.stat();
        let c = stat.cov("k", "v").unwrap();
        assert!(c.is_nan() || c == 0.0);
    }
}