robin_sparkless/dataframe/
stats.rs1use super::DataFrame;
5use polars::datatypes::DataType;
6use polars::prelude::{DataFrame as PlDataFrame, NamedFrom, PolarsError, Series};
7
8pub struct DataFrameStat<'a> {
10 pub(crate) df: &'a DataFrame,
11}
12
13impl<'a> DataFrameStat<'a> {
14 pub fn cov(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
16 let c1 = self.df.resolve_column_name(col1)?;
17 let c2 = self.df.resolve_column_name(col2)?;
18 let s1 = self
19 .df
20 .df
21 .as_ref()
22 .column(c1.as_str())?
23 .cast(&polars::datatypes::DataType::Float64)?;
24 let s2 = self
25 .df
26 .df
27 .as_ref()
28 .column(c2.as_str())?
29 .cast(&polars::datatypes::DataType::Float64)?;
30 let a = s1
31 .f64()
32 .map_err(|_| PolarsError::ComputeError("cov: need float column".into()))?;
33 let b = s2
34 .f64()
35 .map_err(|_| PolarsError::ComputeError("cov: need float column".into()))?;
36 let mut sum_ab = 0.0_f64;
37 let mut sum_a = 0.0_f64;
38 let mut sum_b = 0.0_f64;
39 let mut n = 0_usize;
40 for (x, y) in a.into_iter().zip(b.into_iter()) {
41 if let (Some(xv), Some(yv)) = (x, y) {
42 n += 1;
43 sum_a += xv;
44 sum_b += yv;
45 sum_ab += xv * yv;
46 }
47 }
48 if n < 2 {
49 return Ok(f64::NAN);
50 }
51 let mean_a = sum_a / n as f64;
52 let mean_b = sum_b / n as f64;
53 let cov = (sum_ab - n as f64 * mean_a * mean_b) / (n as f64 - 1.0);
54 Ok(cov)
55 }
56
57 pub fn corr(&self, col1: &str, col2: &str) -> Result<f64, PolarsError> {
59 let c1 = self.df.resolve_column_name(col1)?;
60 let c2 = self.df.resolve_column_name(col2)?;
61 let s1 = self
62 .df
63 .df
64 .as_ref()
65 .column(c1.as_str())?
66 .cast(&polars::datatypes::DataType::Float64)?;
67 let s2 = self
68 .df
69 .df
70 .as_ref()
71 .column(c2.as_str())?
72 .cast(&polars::datatypes::DataType::Float64)?;
73 let a = s1
74 .f64()
75 .map_err(|_| PolarsError::ComputeError("corr: need float column".into()))?;
76 let b = s2
77 .f64()
78 .map_err(|_| PolarsError::ComputeError("corr: need float column".into()))?;
79 let mut sum_ab = 0.0_f64;
80 let mut sum_a = 0.0_f64;
81 let mut sum_b = 0.0_f64;
82 let mut sum_a2 = 0.0_f64;
83 let mut sum_b2 = 0.0_f64;
84 let mut n = 0_usize;
85 for (x, y) in a.into_iter().zip(b.into_iter()) {
86 if let (Some(xv), Some(yv)) = (x, y) {
87 n += 1;
88 sum_a += xv;
89 sum_b += yv;
90 sum_ab += xv * yv;
91 sum_a2 += xv * xv;
92 sum_b2 += yv * yv;
93 }
94 }
95 if n < 2 {
96 return Ok(f64::NAN);
97 }
98 let mean_a = sum_a / n as f64;
99 let mean_b = sum_b / n as f64;
100 let std_a = ((sum_a2 / n as f64 - mean_a * mean_a) * (n as f64 / (n as f64 - 1.0))).sqrt();
101 let std_b = ((sum_b2 / n as f64 - mean_b * mean_b) * (n as f64 / (n as f64 - 1.0))).sqrt();
102 if std_a == 0.0 || std_b == 0.0 {
103 return Ok(f64::NAN);
104 }
105 let cov = (sum_ab - n as f64 * mean_a * mean_b) / (n as f64 - 1.0);
106 Ok(cov / (std_a * std_b))
107 }
108
109 pub fn corr_matrix(&self) -> Result<DataFrame, PolarsError> {
112 let pl_df = self.df.df.as_ref();
113 let numeric_cols: Vec<String> = pl_df
114 .get_columns()
115 .iter()
116 .filter(|s| {
117 matches!(
118 s.dtype(),
119 DataType::Int8
120 | DataType::Int16
121 | DataType::Int32
122 | DataType::Int64
123 | DataType::UInt8
124 | DataType::UInt16
125 | DataType::UInt32
126 | DataType::UInt64
127 | DataType::Float32
128 | DataType::Float64
129 )
130 })
131 .map(|s| s.name().to_string())
132 .collect();
133 if numeric_cols.is_empty() {
134 return Ok(DataFrame::from_polars_with_options(
135 PlDataFrame::default(),
136 self.df.case_sensitive,
137 ));
138 }
139 let mut columns: Vec<Series> = Vec::with_capacity(numeric_cols.len());
140 for (i, name_i) in numeric_cols.iter().enumerate() {
141 let mut row_vals = Vec::with_capacity(numeric_cols.len());
142 for (j, name_j) in numeric_cols.iter().enumerate() {
143 let r = if i == j {
144 1.0_f64
145 } else {
146 self.corr(name_i, name_j)?
147 };
148 row_vals.push(Some(r));
149 }
150 let series = Series::new(name_i.as_str().into(), row_vals);
151 columns.push(series);
152 }
153 let out_pl = PlDataFrame::new(columns.into_iter().map(|s| s.into()).collect())?;
154 Ok(DataFrame::from_polars_with_options(
155 out_pl,
156 self.df.case_sensitive,
157 ))
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use crate::{DataFrame, SparkSession};
164
165 fn test_df() -> DataFrame {
166 let spark = SparkSession::builder()
167 .app_name("stat_tests")
168 .get_or_create();
169 let tuples = vec![
170 (1i64, 25i64, "a".to_string()),
171 (2i64, 30i64, "b".to_string()),
172 (3i64, 35i64, "c".to_string()),
173 ];
174 spark
175 .create_dataframe(tuples, vec!["id", "age", "name"])
176 .unwrap()
177 }
178
179 #[test]
180 fn stat_corr_two_columns() {
181 let df = test_df();
182 let stat = df.stat();
183 let r = stat.corr("id", "age").unwrap();
184 assert!(
185 r.is_nan() || (-1.0 - 1e-10..=1.0 + 1e-10).contains(&r),
186 "corr should be in [-1,1] or NaN, got {r}"
187 );
188 }
189
190 #[test]
191 fn stat_cov_two_columns() {
192 let df = test_df();
193 let stat = df.stat();
194 let c = stat.cov("id", "age").unwrap();
195 assert!(c.is_finite() || c.is_nan());
196 }
197
198 #[test]
199 fn stat_corr_less_than_two_rows_returns_nan() {
200 let spark = SparkSession::builder()
201 .app_name("stat_tests")
202 .get_or_create();
203 let tuples = vec![(1i64, 10i64, "x".to_string())];
204 let df = spark.create_dataframe(tuples, vec!["a", "b", "c"]).unwrap();
205 let stat = df.stat();
206 let r = stat.corr("a", "b").unwrap();
207 assert!(r.is_nan());
208 }
209
210 #[test]
211 fn stat_cov_constant_column() {
212 let spark = SparkSession::builder()
213 .app_name("stat_tests")
214 .get_or_create();
215 let tuples = vec![(1i64, 5i64, "a".to_string()), (1i64, 5i64, "b".to_string())];
216 let df = spark
217 .create_dataframe(tuples, vec!["k", "v", "label"])
218 .unwrap();
219 let stat = df.stat();
220 let c = stat.cov("k", "v").unwrap();
221 assert!(c.is_nan() || c == 0.0);
222 }
223}