robin_sparkless/dataframe/
stats.rs1use super::DataFrame;
5use polars::datatypes::DataType;
6use polars::prelude::{DataFrame as PlDataFrame, NamedFrom, PolarsError, Series};
7
8pub struct DataFrameStat<'a> {
10 pub(crate) df: &'a DataFrame,
11}
12
13impl<'a> DataFrameStat<'a> {
14 pub fn cov(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
16 let c1 = self.df.resolve_column_name(col1)?;
17 let c2 = self.df.resolve_column_name(col2)?;
18 let pl = self.df.collect_inner()?;
19 let s1 = pl
20 .column(c1.as_str())?
21 .cast(&polars::datatypes::DataType::Float64)?;
22 let s2 = pl
23 .column(c2.as_str())?
24 .cast(&polars::datatypes::DataType::Float64)?;
25 let a = s1
26 .f64()
27 .map_err(|_| PolarsError::ComputeError("cov: need float column".into()))?;
28 let b = s2
29 .f64()
30 .map_err(|_| PolarsError::ComputeError("cov: need float column".into()))?;
31 let mut sum_ab = 0.0_f64;
32 let mut sum_a = 0.0_f64;
33 let mut sum_b = 0.0_f64;
34 let mut n = 0_usize;
35 for (x, y) in a.into_iter().zip(b.into_iter()) {
36 if let (Some(xv), Some(yv)) = (x, y) {
37 n += 1;
38 sum_a += xv;
39 sum_b += yv;
40 sum_ab += xv * yv;
41 }
42 }
43 if n < 2 {
44 return Ok(f64::NAN);
45 }
46 let mean_a = sum_a / n as f64;
47 let mean_b = sum_b / n as f64;
48 let cov = (sum_ab - n as f64 * mean_a * mean_b) / (n as f64 - 1.0);
49 Ok(cov)
50 }
51
52 pub fn corr(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
54 let c1 = self.df.resolve_column_name(col1)?;
55 let c2 = self.df.resolve_column_name(col2)?;
56 let pl = self.df.collect_inner()?;
57 let s1 = pl
58 .column(c1.as_str())?
59 .cast(&polars::datatypes::DataType::Float64)?;
60 let s2 = pl
61 .column(c2.as_str())?
62 .cast(&polars::datatypes::DataType::Float64)?;
63 let a = s1
64 .f64()
65 .map_err(|_| PolarsError::ComputeError("corr: need float column".into()))?;
66 let b = s2
67 .f64()
68 .map_err(|_| PolarsError::ComputeError("corr: need float column".into()))?;
69 let mut sum_ab = 0.0_f64;
70 let mut sum_a = 0.0_f64;
71 let mut sum_b = 0.0_f64;
72 let mut sum_a2 = 0.0_f64;
73 let mut sum_b2 = 0.0_f64;
74 let mut n = 0_usize;
75 for (x, y) in a.into_iter().zip(b.into_iter()) {
76 if let (Some(xv), Some(yv)) = (x, y) {
77 n += 1;
78 sum_a += xv;
79 sum_b += yv;
80 sum_ab += xv * yv;
81 sum_a2 += xv * xv;
82 sum_b2 += yv * yv;
83 }
84 }
85 if n < 2 {
86 return Ok(f64::NAN);
87 }
88 let mean_a = sum_a / n as f64;
89 let mean_b = sum_b / n as f64;
90 let std_a = ((sum_a2 / n as f64 - mean_a * mean_a) * (n as f64 / (n as f64 - 1.0))).sqrt();
91 let std_b = ((sum_b2 / n as f64 - mean_b * mean_b) * (n as f64 / (n as f64 - 1.0))).sqrt();
92 if std_a == 0.0 || std_b == 0.0 {
93 return Ok(f64::NAN);
94 }
95 let cov = (sum_ab - n as f64 * mean_a * mean_b) / (n as f64 - 1.0);
96 Ok(cov / (std_a * std_b))
97 }
98
99 pub fn corr_matrix(&self) -> Result<DataFrame, PolarsError> {
102 let collected = self.df.collect_inner()?;
103 let pl_df = collected.as_ref();
104 let numeric_cols: Vec<String> = pl_df
105 .columns()
106 .iter()
107 .filter(|s| {
108 matches!(
109 s.dtype(),
110 DataType::Int8
111 | DataType::Int16
112 | DataType::Int32
113 | DataType::Int64
114 | DataType::UInt8
115 | DataType::UInt16
116 | DataType::UInt32
117 | DataType::UInt64
118 | DataType::Float32
119 | DataType::Float64
120 )
121 })
122 .map(|s| s.name().to_string())
123 .collect();
124 if numeric_cols.is_empty() {
125 return Ok(DataFrame::from_polars_with_options(
126 PlDataFrame::default(),
127 self.df.case_sensitive,
128 ));
129 }
130 let mut columns: Vec<Series> = Vec::with_capacity(numeric_cols.len());
131 for (i, name_i) in numeric_cols.iter().enumerate() {
132 let mut row_vals = Vec::with_capacity(numeric_cols.len());
133 for (j, name_j) in numeric_cols.iter().enumerate() {
134 let r = if i == j {
135 1.0_f64
136 } else {
137 self.corr(name_i, name_j)?
138 };
139 row_vals.push(Some(r));
140 }
141 let series = Series::new(name_i.as_str().into(), row_vals);
142 columns.push(series);
143 }
144 let out_pl =
145 PlDataFrame::new_infer_height(columns.into_iter().map(|s| s.into()).collect())?;
146 Ok(DataFrame::from_polars_with_options(
147 out_pl,
148 self.df.case_sensitive,
149 ))
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use crate::{DataFrame, SparkSession};
156
157 fn test_df() -> DataFrame {
158 let spark = SparkSession::builder()
159 .app_name("stat_tests")
160 .get_or_create();
161 let tuples = vec![
162 (1i64, 25i64, "a".to_string()),
163 (2i64, 30i64, "b".to_string()),
164 (3i64, 35i64, "c".to_string()),
165 ];
166 spark
167 .create_dataframe(tuples, vec!["id", "age", "name"])
168 .unwrap()
169 }
170
171 #[test]
172 fn stat_corr_two_columns() {
173 let df = test_df();
174 let stat = df.stat();
175 let r = stat.corr("id", "age").unwrap();
176 assert!(
177 r.is_nan() || (-1.0 - 1e-10..=1.0 + 1e-10).contains(&r),
178 "corr should be in [-1,1] or NaN, got {r}"
179 );
180 }
181
182 #[test]
183 fn stat_cov_two_columns() {
184 let df = test_df();
185 let stat = df.stat();
186 let c = stat.cov("id", "age").unwrap();
187 assert!(c.is_finite() || c.is_nan());
188 }
189
190 #[test]
191 fn stat_corr_less_than_two_rows_returns_nan() {
192 let spark = SparkSession::builder()
193 .app_name("stat_tests")
194 .get_or_create();
195 let tuples = vec![(1i64, 10i64, "x".to_string())];
196 let df = spark.create_dataframe(tuples, vec!["a", "b", "c"]).unwrap();
197 let stat = df.stat();
198 let r = stat.corr("a", "b").unwrap();
199 assert!(r.is_nan());
200 }
201
202 #[test]
203 fn stat_cov_constant_column() {
204 let spark = SparkSession::builder()
205 .app_name("stat_tests")
206 .get_or_create();
207 let tuples = vec![(1i64, 5i64, "a".to_string()), (1i64, 5i64, "b".to_string())];
208 let df = spark
209 .create_dataframe(tuples, vec!["k", "v", "label"])
210 .unwrap();
211 let stat = df.stat();
212 let c = stat.cov("k", "v").unwrap();
213 assert!(c.is_nan() || c == 0.0);
214 }
215}