polars_plan/dsl/function_expr/
correlation.rs1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3
4use super::*;
5
6#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7#[derive(Copy, Clone, PartialEq, Debug, Hash)]
8pub enum CorrelationMethod {
9 Pearson,
10 #[cfg(all(feature = "rank", feature = "propagate_nans"))]
11 SpearmanRank(bool),
12 Covariance(u8),
13}
14
15impl Display for CorrelationMethod {
16 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
17 use CorrelationMethod::*;
18 let s = match self {
19 Pearson => "pearson",
20 #[cfg(all(feature = "rank", feature = "propagate_nans"))]
21 SpearmanRank(_) => "spearman_rank",
22 Covariance(_) => return write!(f, "covariance"),
23 };
24 write!(f, "{}_correlation", s)
25 }
26}
27
28pub(super) fn corr(s: &[Column], method: CorrelationMethod) -> PolarsResult<Column> {
29 polars_ensure!(
30 s[0].len() == s[1].len() || s[0].len() == 1 || s[1].len() == 1,
31 length_mismatch = "corr",
32 s[0].len(),
33 s[1].len()
34 );
35
36 match method {
37 CorrelationMethod::Pearson => pearson_corr(s),
38 #[cfg(all(feature = "rank", feature = "propagate_nans"))]
39 CorrelationMethod::SpearmanRank(propagate_nans) => spearman_rank_corr(s, propagate_nans),
40 CorrelationMethod::Covariance(ddof) => covariance(s, ddof),
41 }
42}
43
44fn covariance(s: &[Column], ddof: u8) -> PolarsResult<Column> {
45 let a = &s[0];
46 let b = &s[1];
47 let name = PlSmallStr::from_static("cov");
48
49 use polars_ops::chunked_array::cov::cov;
50 let ret = match a.dtype() {
51 DataType::Float32 => {
52 let ret = cov(a.f32().unwrap(), b.f32().unwrap(), ddof).map(|v| v as f32);
53 return Ok(Column::new(name, &[ret]));
54 },
55 DataType::Float64 => cov(a.f64().unwrap(), b.f64().unwrap(), ddof),
56 DataType::Int32 => cov(a.i32().unwrap(), b.i32().unwrap(), ddof),
57 DataType::Int64 => cov(a.i64().unwrap(), b.i64().unwrap(), ddof),
58 DataType::UInt32 => cov(a.u32().unwrap(), b.u32().unwrap(), ddof),
59 DataType::UInt64 => cov(a.u64().unwrap(), b.u64().unwrap(), ddof),
60 _ => {
61 let a = a.cast(&DataType::Float64)?;
62 let b = b.cast(&DataType::Float64)?;
63 cov(a.f64().unwrap(), b.f64().unwrap(), ddof)
64 },
65 };
66 Ok(Column::new(name, &[ret]))
67}
68
69fn pearson_corr(s: &[Column]) -> PolarsResult<Column> {
70 let a = &s[0];
71 let b = &s[1];
72 let name = PlSmallStr::from_static("pearson_corr");
73
74 use polars_ops::chunked_array::cov::pearson_corr;
75 let ret = match a.dtype() {
76 DataType::Float32 => {
77 let ret = pearson_corr(a.f32().unwrap(), b.f32().unwrap()).map(|v| v as f32);
78 return Ok(Column::new(name.clone(), &[ret]));
79 },
80 DataType::Float64 => pearson_corr(a.f64().unwrap(), b.f64().unwrap()),
81 DataType::Int32 => pearson_corr(a.i32().unwrap(), b.i32().unwrap()),
82 DataType::Int64 => pearson_corr(a.i64().unwrap(), b.i64().unwrap()),
83 DataType::UInt32 => pearson_corr(a.u32().unwrap(), b.u32().unwrap()),
84 _ => {
85 let a = a.cast(&DataType::Float64)?;
86 let b = b.cast(&DataType::Float64)?;
87 pearson_corr(a.f64().unwrap(), b.f64().unwrap())
88 },
89 };
90 Ok(Column::new(name, &[ret]))
91}
92
93#[cfg(all(feature = "rank", feature = "propagate_nans"))]
94fn spearman_rank_corr(s: &[Column], propagate_nans: bool) -> PolarsResult<Column> {
95 use polars_core::utils::coalesce_nulls_columns;
96 use polars_ops::chunked_array::nan_propagating_aggregate::nan_max_s;
97 let a = &s[0];
98 let b = &s[1];
99
100 let (a, b) = coalesce_nulls_columns(a, b);
101
102 let name = PlSmallStr::from_static("spearman_rank_correlation");
103 if propagate_nans && a.dtype().is_float() {
104 for s in [&a, &b] {
105 if nan_max_s(s.as_materialized_series(), PlSmallStr::EMPTY)
106 .get(0)
107 .unwrap()
108 .extract::<f64>()
109 .unwrap()
110 .is_nan()
111 {
112 return Ok(Column::new(name, &[f64::NAN]));
113 }
114 }
115 }
116
117 let a = a.drop_nulls();
119 let b = b.drop_nulls();
120
121 let a_rank = a
122 .as_materialized_series()
123 .rank(
124 RankOptions {
125 method: RankMethod::Average,
126 ..Default::default()
127 },
128 None,
129 )
130 .into();
131 let b_rank = b
132 .as_materialized_series()
133 .rank(
134 RankOptions {
135 method: RankMethod::Average,
136 ..Default::default()
137 },
138 None,
139 )
140 .into();
141
142 pearson_corr(&[a_rank, b_rank])
143}