polars_plan/dsl/function_expr/
correlation.rs

1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3
4use super::*;
5
6#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7#[derive(Copy, Clone, PartialEq, Debug, Hash)]
8pub enum CorrelationMethod {
9    Pearson,
10    #[cfg(all(feature = "rank", feature = "propagate_nans"))]
11    SpearmanRank(bool),
12    Covariance(u8),
13}
14
15impl Display for CorrelationMethod {
16    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
17        use CorrelationMethod::*;
18        let s = match self {
19            Pearson => "pearson",
20            #[cfg(all(feature = "rank", feature = "propagate_nans"))]
21            SpearmanRank(_) => "spearman_rank",
22            Covariance(_) => return write!(f, "covariance"),
23        };
24        write!(f, "{}_correlation", s)
25    }
26}
27
28pub(super) fn corr(s: &[Column], method: CorrelationMethod) -> PolarsResult<Column> {
29    polars_ensure!(
30        s[0].len() == s[1].len() || s[0].len() == 1 || s[1].len() == 1,
31        length_mismatch = "corr",
32        s[0].len(),
33        s[1].len()
34    );
35
36    match method {
37        CorrelationMethod::Pearson => pearson_corr(s),
38        #[cfg(all(feature = "rank", feature = "propagate_nans"))]
39        CorrelationMethod::SpearmanRank(propagate_nans) => spearman_rank_corr(s, propagate_nans),
40        CorrelationMethod::Covariance(ddof) => covariance(s, ddof),
41    }
42}
43
44fn covariance(s: &[Column], ddof: u8) -> PolarsResult<Column> {
45    let a = &s[0];
46    let b = &s[1];
47    let name = PlSmallStr::from_static("cov");
48
49    use polars_ops::chunked_array::cov::cov;
50    let ret = match a.dtype() {
51        DataType::Float32 => {
52            let ret = cov(a.f32().unwrap(), b.f32().unwrap(), ddof).map(|v| v as f32);
53            return Ok(Column::new(name, &[ret]));
54        },
55        DataType::Float64 => cov(a.f64().unwrap(), b.f64().unwrap(), ddof),
56        DataType::Int32 => cov(a.i32().unwrap(), b.i32().unwrap(), ddof),
57        DataType::Int64 => cov(a.i64().unwrap(), b.i64().unwrap(), ddof),
58        DataType::UInt32 => cov(a.u32().unwrap(), b.u32().unwrap(), ddof),
59        DataType::UInt64 => cov(a.u64().unwrap(), b.u64().unwrap(), ddof),
60        _ => {
61            let a = a.cast(&DataType::Float64)?;
62            let b = b.cast(&DataType::Float64)?;
63            cov(a.f64().unwrap(), b.f64().unwrap(), ddof)
64        },
65    };
66    Ok(Column::new(name, &[ret]))
67}
68
69fn pearson_corr(s: &[Column]) -> PolarsResult<Column> {
70    let a = &s[0];
71    let b = &s[1];
72    let name = PlSmallStr::from_static("pearson_corr");
73
74    use polars_ops::chunked_array::cov::pearson_corr;
75    let ret = match a.dtype() {
76        DataType::Float32 => {
77            let ret = pearson_corr(a.f32().unwrap(), b.f32().unwrap()).map(|v| v as f32);
78            return Ok(Column::new(name.clone(), &[ret]));
79        },
80        DataType::Float64 => pearson_corr(a.f64().unwrap(), b.f64().unwrap()),
81        DataType::Int32 => pearson_corr(a.i32().unwrap(), b.i32().unwrap()),
82        DataType::Int64 => pearson_corr(a.i64().unwrap(), b.i64().unwrap()),
83        DataType::UInt32 => pearson_corr(a.u32().unwrap(), b.u32().unwrap()),
84        _ => {
85            let a = a.cast(&DataType::Float64)?;
86            let b = b.cast(&DataType::Float64)?;
87            pearson_corr(a.f64().unwrap(), b.f64().unwrap())
88        },
89    };
90    Ok(Column::new(name, &[ret]))
91}
92
93#[cfg(all(feature = "rank", feature = "propagate_nans"))]
94fn spearman_rank_corr(s: &[Column], propagate_nans: bool) -> PolarsResult<Column> {
95    use polars_core::utils::coalesce_nulls_columns;
96    use polars_ops::chunked_array::nan_propagating_aggregate::nan_max_s;
97    let a = &s[0];
98    let b = &s[1];
99
100    let (a, b) = coalesce_nulls_columns(a, b);
101
102    let name = PlSmallStr::from_static("spearman_rank_correlation");
103    if propagate_nans && a.dtype().is_float() {
104        for s in [&a, &b] {
105            if nan_max_s(s.as_materialized_series(), PlSmallStr::EMPTY)
106                .get(0)
107                .unwrap()
108                .extract::<f64>()
109                .unwrap()
110                .is_nan()
111            {
112                return Ok(Column::new(name, &[f64::NAN]));
113            }
114        }
115    }
116
117    // drop nulls so that they are excluded
118    let a = a.drop_nulls();
119    let b = b.drop_nulls();
120
121    let a_rank = a
122        .as_materialized_series()
123        .rank(
124            RankOptions {
125                method: RankMethod::Average,
126                ..Default::default()
127            },
128            None,
129        )
130        .into();
131    let b_rank = b
132        .as_materialized_series()
133        .rank(
134            RankOptions {
135                method: RankMethod::Average,
136                ..Default::default()
137            },
138            None,
139        )
140        .into();
141
142    pearson_corr(&[a_rank, b_rank])
143}