1use crate::error::{EvalError, EvalResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ExpectedCorrelation {
13 pub field1: String,
15 pub field2: String,
17 pub expected_r: f64,
19 pub tolerance: f64,
21}
22
23impl ExpectedCorrelation {
24 pub fn new(field1: impl Into<String>, field2: impl Into<String>, expected_r: f64) -> Self {
26 Self {
27 field1: field1.into(),
28 field2: field2.into(),
29 expected_r,
30 tolerance: 0.10, }
32 }
33
34 pub fn with_tolerance(mut self, tolerance: f64) -> Self {
36 self.tolerance = tolerance;
37 self
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CorrelationCheckResult {
44 pub field1: String,
46 pub field2: String,
48 pub observed_r: f64,
50 pub expected_r: Option<f64>,
52 pub deviation: Option<f64>,
54 pub within_tolerance: bool,
56 pub p_value: f64,
58 pub sample_size: usize,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct CorrelationAnalysis {
65 pub sample_size: usize,
67 pub fields: Vec<String>,
69 pub correlation_matrix: Vec<f64>,
71 pub correlation_checks: Vec<CorrelationCheckResult>,
73 pub checks_passed: usize,
75 pub checks_failed: usize,
77 pub passes: bool,
79 pub issues: Vec<String>,
81}
82
83impl CorrelationAnalysis {
84 pub fn get_correlation(&self, field1: &str, field2: &str) -> Option<f64> {
86 let idx1 = self.fields.iter().position(|f| f == field1)?;
87 let idx2 = self.fields.iter().position(|f| f == field2)?;
88
89 if idx1 == idx2 {
90 return Some(1.0);
91 }
92
93 let (i, j) = if idx1 < idx2 {
94 (idx1, idx2)
95 } else {
96 (idx2, idx1)
97 };
98
99 let n = self.fields.len();
101 let mut matrix_idx = 0;
102 for row in 0..i {
103 matrix_idx += n - row - 1;
104 }
105 matrix_idx += j - i - 1;
106
107 self.correlation_matrix.get(matrix_idx).copied()
108 }
109}
110
111pub struct CorrelationAnalyzer {
113 expected_correlations: Vec<ExpectedCorrelation>,
115 significance_level: f64,
117}
118
119impl CorrelationAnalyzer {
120 pub fn new() -> Self {
122 Self {
123 expected_correlations: Vec::new(),
124 significance_level: 0.05,
125 }
126 }
127
128 pub fn with_expected_correlations(mut self, correlations: Vec<ExpectedCorrelation>) -> Self {
130 self.expected_correlations = correlations;
131 self
132 }
133
134 pub fn with_significance_level(mut self, level: f64) -> Self {
136 self.significance_level = level;
137 self
138 }
139
140 pub fn analyze(&self, data: &HashMap<String, Vec<f64>>) -> EvalResult<CorrelationAnalysis> {
145 if data.is_empty() {
146 return Err(EvalError::MissingData("No data provided".to_string()));
147 }
148
149 let lengths: Vec<usize> = data.values().map(|v| v.len()).collect();
151 if !lengths.iter().all(|&l| l == lengths[0]) {
152 return Err(EvalError::InvalidParameter(
153 "All fields must have same number of values".to_string(),
154 ));
155 }
156
157 let sample_size = lengths[0];
158 if sample_size < 3 {
159 return Err(EvalError::InsufficientData {
160 required: 3,
161 actual: sample_size,
162 });
163 }
164
165 let fields: Vec<String> = data.keys().cloned().collect();
167 let n_fields = fields.len();
168
169 let mut correlation_matrix = Vec::new();
171 for i in 0..n_fields {
172 for j in (i + 1)..n_fields {
173 let field1 = &fields[i];
174 let field2 = &fields[j];
175 let values1 = data.get(field1).unwrap();
176 let values2 = data.get(field2).unwrap();
177 let r = pearson_correlation(values1, values2);
178 correlation_matrix.push(r);
179 }
180 }
181
182 let mut correlation_checks = Vec::new();
184 let mut issues = Vec::new();
185
186 for expected in &self.expected_correlations {
187 let values1 = match data.get(&expected.field1) {
188 Some(v) => v,
189 None => {
190 issues.push(format!("Field '{}' not found in data", expected.field1));
191 continue;
192 }
193 };
194 let values2 = match data.get(&expected.field2) {
195 Some(v) => v,
196 None => {
197 issues.push(format!("Field '{}' not found in data", expected.field2));
198 continue;
199 }
200 };
201
202 let observed_r = pearson_correlation(values1, values2);
203 let p_value = correlation_p_value(observed_r, sample_size);
204 let deviation = (observed_r - expected.expected_r).abs();
205 let within_tolerance = deviation <= expected.tolerance;
206
207 if !within_tolerance {
208 issues.push(format!(
209 "Correlation between '{}' and '{}': expected {:.3}, got {:.3} (deviation {:.3} > tolerance {:.3})",
210 expected.field1, expected.field2, expected.expected_r, observed_r, deviation, expected.tolerance
211 ));
212 }
213
214 correlation_checks.push(CorrelationCheckResult {
215 field1: expected.field1.clone(),
216 field2: expected.field2.clone(),
217 observed_r,
218 expected_r: Some(expected.expected_r),
219 deviation: Some(deviation),
220 within_tolerance,
221 p_value,
222 sample_size,
223 });
224 }
225
226 let checks_passed = correlation_checks
227 .iter()
228 .filter(|c| c.within_tolerance)
229 .count();
230 let checks_failed = correlation_checks.len() - checks_passed;
231 let passes = checks_failed == 0;
232
233 Ok(CorrelationAnalysis {
234 sample_size,
235 fields,
236 correlation_matrix,
237 correlation_checks,
238 checks_passed,
239 checks_failed,
240 passes,
241 issues,
242 })
243 }
244
245 pub fn analyze_pair(
247 &self,
248 values1: &[f64],
249 values2: &[f64],
250 ) -> EvalResult<CorrelationCheckResult> {
251 if values1.len() != values2.len() {
252 return Err(EvalError::InvalidParameter(
253 "Value vectors must have same length".to_string(),
254 ));
255 }
256
257 let n = values1.len();
258 if n < 3 {
259 return Err(EvalError::InsufficientData {
260 required: 3,
261 actual: n,
262 });
263 }
264
265 let observed_r = pearson_correlation(values1, values2);
266 let p_value = correlation_p_value(observed_r, n);
267
268 Ok(CorrelationCheckResult {
269 field1: "field1".to_string(),
270 field2: "field2".to_string(),
271 observed_r,
272 expected_r: None,
273 deviation: None,
274 within_tolerance: true,
275 p_value,
276 sample_size: n,
277 })
278 }
279}
280
281impl Default for CorrelationAnalyzer {
282 fn default() -> Self {
283 Self::new()
284 }
285}
286
287pub fn pearson_correlation(x: &[f64], y: &[f64]) -> f64 {
289 assert_eq!(x.len(), y.len(), "Vectors must have same length");
290
291 let n = x.len() as f64;
292 if n < 2.0 {
293 return 0.0;
294 }
295
296 let mean_x: f64 = x.iter().sum::<f64>() / n;
297 let mean_y: f64 = y.iter().sum::<f64>() / n;
298
299 let mut cov = 0.0;
300 let mut var_x = 0.0;
301 let mut var_y = 0.0;
302
303 for i in 0..x.len() {
304 let dx = x[i] - mean_x;
305 let dy = y[i] - mean_y;
306 cov += dx * dy;
307 var_x += dx * dx;
308 var_y += dy * dy;
309 }
310
311 if var_x <= 0.0 || var_y <= 0.0 {
312 return 0.0;
313 }
314
315 cov / (var_x.sqrt() * var_y.sqrt())
316}
317
318pub fn spearman_correlation(x: &[f64], y: &[f64]) -> f64 {
320 assert_eq!(x.len(), y.len(), "Vectors must have same length");
321
322 let n = x.len();
323 if n < 2 {
324 return 0.0;
325 }
326
327 let rank_x = calculate_ranks(x);
329 let rank_y = calculate_ranks(y);
330
331 pearson_correlation(&rank_x, &rank_y)
333}
334
335fn calculate_ranks(values: &[f64]) -> Vec<f64> {
337 let n = values.len();
338 let mut indexed: Vec<(usize, f64)> = values.iter().cloned().enumerate().collect();
339 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
340
341 let mut ranks = vec![0.0; n];
342 let mut i = 0;
343 while i < n {
344 let mut j = i;
346 while j < n && (indexed[j].1 - indexed[i].1).abs() < 1e-10 {
347 j += 1;
348 }
349
350 let avg_rank = (i + j) as f64 / 2.0 + 0.5;
352 for k in i..j {
353 ranks[indexed[k].0] = avg_rank;
354 }
355
356 i = j;
357 }
358
359 ranks
360}
361
362fn correlation_p_value(r: f64, n: usize) -> f64 {
364 if n <= 2 {
365 return 1.0;
366 }
367
368 if r.abs() >= 1.0 {
369 return 0.0;
370 }
371
372 let df = n - 2;
374 let t = r * ((df as f64) / (1.0 - r * r)).sqrt();
375
376 let t_abs = t.abs();
378 2.0 * student_t_cdf(-t_abs, df as f64)
379}
380
381fn student_t_cdf(t: f64, df: f64) -> f64 {
383 if df > 30.0 {
385 return normal_cdf(t);
386 }
387
388 let t2 = t * t;
390 let prob = 0.5 * incomplete_beta(df / 2.0, 0.5, df / (df + t2));
391
392 if t > 0.0 {
393 1.0 - prob
394 } else {
395 prob
396 }
397}
398
399fn normal_cdf(x: f64) -> f64 {
401 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
402}
403
404fn erf(x: f64) -> f64 {
406 let a1 = 0.254829592;
407 let a2 = -0.284496736;
408 let a3 = 1.421413741;
409 let a4 = -1.453152027;
410 let a5 = 1.061405429;
411 let p = 0.3275911;
412
413 let sign = if x < 0.0 { -1.0 } else { 1.0 };
414 let x = x.abs();
415
416 let t = 1.0 / (1.0 + p * x);
417 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
418
419 sign * y
420}
421
422fn incomplete_beta(a: f64, b: f64, x: f64) -> f64 {
424 if x <= 0.0 {
425 return 0.0;
426 }
427 if x >= 1.0 {
428 return 1.0;
429 }
430
431 let lbeta = ln_gamma(a) + ln_gamma(b) - ln_gamma(a + b);
432 let front = (x.powf(a) * (1.0 - x).powf(b)) / lbeta.exp();
433
434 let mut c: f64 = 1.0;
436 let mut d: f64 = 1.0 / (1.0 - (a + b) * x / (a + 1.0)).max(1e-30);
437 let mut h = d;
438
439 for m in 1..100 {
440 let m = m as f64;
441 let d1 = m * (b - m) * x / ((a + 2.0 * m - 1.0) * (a + 2.0 * m));
442 let d2 = -(a + m) * (a + b + m) * x / ((a + 2.0 * m) * (a + 2.0 * m + 1.0));
443
444 d = 1.0 / (1.0 + d1 * d).max(1e-30);
445 c = 1.0 + d1 / c.max(1e-30);
446 h *= c * d;
447
448 d = 1.0 / (1.0 + d2 * d).max(1e-30);
449 c = 1.0 + d2 / c.max(1e-30);
450 h *= c * d;
451
452 if ((c * d) - 1.0).abs() < 1e-8 {
453 break;
454 }
455 }
456
457 front * h / a
458}
459
460fn ln_gamma(x: f64) -> f64 {
462 if x <= 0.0 {
463 return f64::INFINITY;
464 }
465 0.5 * (2.0 * std::f64::consts::PI / x).ln() + x * ((x + 1.0 / (12.0 * x)).ln() - 1.0)
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471
472 #[test]
473 fn test_pearson_correlation() {
474 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
476 let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
477 let r = pearson_correlation(&x, &y);
478 assert!((r - 1.0).abs() < 0.001);
479
480 let y_neg = vec![10.0, 8.0, 6.0, 4.0, 2.0];
482 let r_neg = pearson_correlation(&x, &y_neg);
483 assert!((r_neg + 1.0).abs() < 0.001);
484
485 let x_rand = vec![1.0, 2.0, 3.0, 4.0, 5.0];
487 let y_rand = vec![3.0, 1.0, 4.0, 5.0, 2.0];
488 let r_rand = pearson_correlation(&x_rand, &y_rand);
489 assert!(
491 r_rand.abs() < 0.7,
492 "Expected weak correlation, got {}",
493 r_rand
494 );
495 }
496
497 #[test]
498 fn test_spearman_correlation() {
499 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
500 let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
501 let r = spearman_correlation(&x, &y);
502 assert!((r - 1.0).abs() < 0.001);
503 }
504
505 #[test]
506 fn test_correlation_analyzer() {
507 let mut data = HashMap::new();
508 data.insert(
509 "x".to_string(),
510 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
511 );
512 data.insert(
513 "y".to_string(),
514 vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0],
515 );
516 data.insert(
517 "z".to_string(),
518 vec![10.0, 8.0, 6.0, 4.0, 2.0, 1.0, 3.0, 5.0, 7.0, 9.0],
519 );
520
521 let analyzer =
522 CorrelationAnalyzer::new()
523 .with_expected_correlations(vec![
524 ExpectedCorrelation::new("x", "y", 1.0).with_tolerance(0.01)
525 ]);
526
527 let result = analyzer.analyze(&data).unwrap();
528 assert_eq!(result.sample_size, 10);
529 assert!(result.passes);
530
531 let r_xy = result.get_correlation("x", "y").unwrap();
533 assert!((r_xy - 1.0).abs() < 0.001);
534 }
535
536 #[test]
537 fn test_correlation_failure() {
538 let mut data = HashMap::new();
539 data.insert("x".to_string(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
540 data.insert("y".to_string(), vec![5.0, 4.0, 3.0, 2.0, 1.0]); let analyzer = CorrelationAnalyzer::new().with_expected_correlations(vec![
543 ExpectedCorrelation::new("x", "y", 0.8).with_tolerance(0.1), ]);
545
546 let result = analyzer.analyze(&data).unwrap();
547 assert!(!result.passes);
548 assert_eq!(result.checks_failed, 1);
549 }
550
551 #[test]
552 fn test_correlation_p_value() {
553 let x: Vec<f64> = (0..100).map(|i| i as f64).collect();
555 let y: Vec<f64> = x.iter().map(|&v| v * 2.0 + 1.0).collect();
556
557 let r = pearson_correlation(&x, &y);
558 let p = correlation_p_value(r, x.len());
559
560 assert!(r > 0.99);
561 assert!(p < 0.001);
562 }
563
564 #[test]
565 fn test_rank_calculation() {
566 let values = vec![1.0, 3.0, 2.0, 3.0, 5.0]; let ranks = calculate_ranks(&values);
568
569 assert!((ranks[0] - 1.0).abs() < 0.001);
574 assert!((ranks[2] - 2.0).abs() < 0.001);
575 assert!((ranks[1] - 3.5).abs() < 0.001);
576 assert!((ranks[3] - 3.5).abs() < 0.001);
577 assert!((ranks[4] - 5.0).abs() < 0.001);
578 }
579}