Skip to main content

datasynth_eval/statistical/
correlation.rs

1//! Correlation analysis for evaluating cross-field dependencies.
2//!
3//! Validates that generated data maintains expected correlations between
4//! related fields (e.g., amount vs. line items, processing time vs. complexity).
5
6use crate::error::{EvalError, EvalResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Expected correlation between two fields.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ExpectedCorrelation {
13    /// First field name
14    pub field1: String,
15    /// Second field name
16    pub field2: String,
17    /// Expected Pearson correlation coefficient
18    pub expected_r: f64,
19    /// Acceptable deviation tolerance
20    pub tolerance: f64,
21}
22
23impl ExpectedCorrelation {
24    /// Create a new expected correlation.
25    pub fn new(field1: impl Into<String>, field2: impl Into<String>, expected_r: f64) -> Self {
26        Self {
27            field1: field1.into(),
28            field2: field2.into(),
29            expected_r,
30            tolerance: 0.10, // 0.10 default tolerance
31        }
32    }
33
34    /// Set tolerance.
35    pub fn with_tolerance(mut self, tolerance: f64) -> Self {
36        self.tolerance = tolerance;
37        self
38    }
39}
40
41/// Result of correlation check for a pair of fields.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CorrelationCheckResult {
44    /// First field
45    pub field1: String,
46    /// Second field
47    pub field2: String,
48    /// Observed Pearson correlation
49    pub observed_r: f64,
50    /// Expected correlation (if specified)
51    pub expected_r: Option<f64>,
52    /// Deviation from expected
53    pub deviation: Option<f64>,
54    /// Whether within tolerance
55    pub within_tolerance: bool,
56    /// P-value for correlation significance
57    pub p_value: f64,
58    /// Sample size
59    pub sample_size: usize,
60}
61
62/// Full correlation matrix analysis results.
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct CorrelationAnalysis {
65    /// Sample size
66    pub sample_size: usize,
67    /// Field names in order
68    pub fields: Vec<String>,
69    /// Correlation matrix (upper triangular, row-major)
70    pub correlation_matrix: Vec<f64>,
71    /// Individual correlation check results
72    pub correlation_checks: Vec<CorrelationCheckResult>,
73    /// Number of checks that passed
74    pub checks_passed: usize,
75    /// Number of checks that failed
76    pub checks_failed: usize,
77    /// Overall pass status
78    pub passes: bool,
79    /// Issues found
80    pub issues: Vec<String>,
81}
82
83impl CorrelationAnalysis {
84    /// Get correlation between two fields by name.
85    pub fn get_correlation(&self, field1: &str, field2: &str) -> Option<f64> {
86        let idx1 = self.fields.iter().position(|f| f == field1)?;
87        let idx2 = self.fields.iter().position(|f| f == field2)?;
88
89        if idx1 == idx2 {
90            return Some(1.0);
91        }
92
93        let (i, j) = if idx1 < idx2 {
94            (idx1, idx2)
95        } else {
96            (idx2, idx1)
97        };
98
99        // Calculate index in upper triangular matrix
100        let n = self.fields.len();
101        let mut matrix_idx = 0;
102        for row in 0..i {
103            matrix_idx += n - row - 1;
104        }
105        matrix_idx += j - i - 1;
106
107        self.correlation_matrix.get(matrix_idx).copied()
108    }
109}
110
111/// Analyzer for correlation analysis.
112pub struct CorrelationAnalyzer {
113    /// Expected correlations to validate
114    expected_correlations: Vec<ExpectedCorrelation>,
115    /// Significance level for p-value tests
116    significance_level: f64,
117}
118
119impl CorrelationAnalyzer {
120    /// Create a new correlation analyzer.
121    pub fn new() -> Self {
122        Self {
123            expected_correlations: Vec::new(),
124            significance_level: 0.05,
125        }
126    }
127
128    /// Add expected correlations to validate.
129    pub fn with_expected_correlations(mut self, correlations: Vec<ExpectedCorrelation>) -> Self {
130        self.expected_correlations = correlations;
131        self
132    }
133
134    /// Set significance level.
135    pub fn with_significance_level(mut self, level: f64) -> Self {
136        self.significance_level = level;
137        self
138    }
139
140    /// Analyze correlations in the provided data.
141    ///
142    /// `data` is a map from field name to values for that field.
143    /// All value vectors must have the same length.
144    pub fn analyze(&self, data: &HashMap<String, Vec<f64>>) -> EvalResult<CorrelationAnalysis> {
145        if data.is_empty() {
146            return Err(EvalError::MissingData("No data provided".to_string()));
147        }
148
149        // Verify all columns have same length
150        let lengths: Vec<usize> = data.values().map(|v| v.len()).collect();
151        if !lengths.iter().all(|&l| l == lengths[0]) {
152            return Err(EvalError::InvalidParameter(
153                "All fields must have same number of values".to_string(),
154            ));
155        }
156
157        let sample_size = lengths[0];
158        if sample_size < 3 {
159            return Err(EvalError::InsufficientData {
160                required: 3,
161                actual: sample_size,
162            });
163        }
164
165        // Get ordered field names
166        let fields: Vec<String> = data.keys().cloned().collect();
167        let n_fields = fields.len();
168
169        // Calculate full correlation matrix
170        let mut correlation_matrix = Vec::new();
171        for i in 0..n_fields {
172            for j in (i + 1)..n_fields {
173                let field1 = &fields[i];
174                let field2 = &fields[j];
175                let values1 = data.get(field1).unwrap();
176                let values2 = data.get(field2).unwrap();
177                let r = pearson_correlation(values1, values2);
178                correlation_matrix.push(r);
179            }
180        }
181
182        // Check expected correlations
183        let mut correlation_checks = Vec::new();
184        let mut issues = Vec::new();
185
186        for expected in &self.expected_correlations {
187            let values1 = match data.get(&expected.field1) {
188                Some(v) => v,
189                None => {
190                    issues.push(format!("Field '{}' not found in data", expected.field1));
191                    continue;
192                }
193            };
194            let values2 = match data.get(&expected.field2) {
195                Some(v) => v,
196                None => {
197                    issues.push(format!("Field '{}' not found in data", expected.field2));
198                    continue;
199                }
200            };
201
202            let observed_r = pearson_correlation(values1, values2);
203            let p_value = correlation_p_value(observed_r, sample_size);
204            let deviation = (observed_r - expected.expected_r).abs();
205            let within_tolerance = deviation <= expected.tolerance;
206
207            if !within_tolerance {
208                issues.push(format!(
209                    "Correlation between '{}' and '{}': expected {:.3}, got {:.3} (deviation {:.3} > tolerance {:.3})",
210                    expected.field1, expected.field2, expected.expected_r, observed_r, deviation, expected.tolerance
211                ));
212            }
213
214            correlation_checks.push(CorrelationCheckResult {
215                field1: expected.field1.clone(),
216                field2: expected.field2.clone(),
217                observed_r,
218                expected_r: Some(expected.expected_r),
219                deviation: Some(deviation),
220                within_tolerance,
221                p_value,
222                sample_size,
223            });
224        }
225
226        let checks_passed = correlation_checks
227            .iter()
228            .filter(|c| c.within_tolerance)
229            .count();
230        let checks_failed = correlation_checks.len() - checks_passed;
231        let passes = checks_failed == 0;
232
233        Ok(CorrelationAnalysis {
234            sample_size,
235            fields,
236            correlation_matrix,
237            correlation_checks,
238            checks_passed,
239            checks_failed,
240            passes,
241            issues,
242        })
243    }
244
245    /// Analyze correlations from paired data (simpler interface for two fields).
246    pub fn analyze_pair(
247        &self,
248        values1: &[f64],
249        values2: &[f64],
250    ) -> EvalResult<CorrelationCheckResult> {
251        if values1.len() != values2.len() {
252            return Err(EvalError::InvalidParameter(
253                "Value vectors must have same length".to_string(),
254            ));
255        }
256
257        let n = values1.len();
258        if n < 3 {
259            return Err(EvalError::InsufficientData {
260                required: 3,
261                actual: n,
262            });
263        }
264
265        let observed_r = pearson_correlation(values1, values2);
266        let p_value = correlation_p_value(observed_r, n);
267
268        Ok(CorrelationCheckResult {
269            field1: "field1".to_string(),
270            field2: "field2".to_string(),
271            observed_r,
272            expected_r: None,
273            deviation: None,
274            within_tolerance: true,
275            p_value,
276            sample_size: n,
277        })
278    }
279}
280
281impl Default for CorrelationAnalyzer {
282    fn default() -> Self {
283        Self::new()
284    }
285}
286
287/// Calculate Pearson correlation coefficient.
288pub fn pearson_correlation(x: &[f64], y: &[f64]) -> f64 {
289    assert_eq!(x.len(), y.len(), "Vectors must have same length");
290
291    let n = x.len() as f64;
292    if n < 2.0 {
293        return 0.0;
294    }
295
296    let mean_x: f64 = x.iter().sum::<f64>() / n;
297    let mean_y: f64 = y.iter().sum::<f64>() / n;
298
299    let mut cov = 0.0;
300    let mut var_x = 0.0;
301    let mut var_y = 0.0;
302
303    for i in 0..x.len() {
304        let dx = x[i] - mean_x;
305        let dy = y[i] - mean_y;
306        cov += dx * dy;
307        var_x += dx * dx;
308        var_y += dy * dy;
309    }
310
311    if var_x <= 0.0 || var_y <= 0.0 {
312        return 0.0;
313    }
314
315    cov / (var_x.sqrt() * var_y.sqrt())
316}
317
318/// Calculate Spearman rank correlation coefficient.
319pub fn spearman_correlation(x: &[f64], y: &[f64]) -> f64 {
320    assert_eq!(x.len(), y.len(), "Vectors must have same length");
321
322    let n = x.len();
323    if n < 2 {
324        return 0.0;
325    }
326
327    // Calculate ranks
328    let rank_x = calculate_ranks(x);
329    let rank_y = calculate_ranks(y);
330
331    // Pearson correlation of ranks
332    pearson_correlation(&rank_x, &rank_y)
333}
334
335/// Calculate ranks for a vector (handles ties with average rank).
336fn calculate_ranks(values: &[f64]) -> Vec<f64> {
337    let n = values.len();
338    let mut indexed: Vec<(usize, f64)> = values.iter().cloned().enumerate().collect();
339    indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
340
341    let mut ranks = vec![0.0; n];
342    let mut i = 0;
343    while i < n {
344        // Find all ties
345        let mut j = i;
346        while j < n && (indexed[j].1 - indexed[i].1).abs() < 1e-10 {
347            j += 1;
348        }
349
350        // Average rank for ties
351        let avg_rank = (i + j) as f64 / 2.0 + 0.5;
352        for k in i..j {
353            ranks[indexed[k].0] = avg_rank;
354        }
355
356        i = j;
357    }
358
359    ranks
360}
361
362/// Calculate p-value for correlation coefficient using t-distribution approximation.
363fn correlation_p_value(r: f64, n: usize) -> f64 {
364    if n <= 2 {
365        return 1.0;
366    }
367
368    if r.abs() >= 1.0 {
369        return 0.0;
370    }
371
372    // t-statistic: t = r * sqrt((n-2) / (1-r²))
373    let df = n - 2;
374    let t = r * ((df as f64) / (1.0 - r * r)).sqrt();
375
376    // Two-tailed p-value using t-distribution
377    let t_abs = t.abs();
378    2.0 * student_t_cdf(-t_abs, df as f64)
379}
380
381/// Student-t CDF approximation.
382fn student_t_cdf(t: f64, df: f64) -> f64 {
383    // For large df, use normal approximation
384    if df > 30.0 {
385        return normal_cdf(t);
386    }
387
388    // Use beta function approximation
389    let t2 = t * t;
390    let prob = 0.5 * incomplete_beta(df / 2.0, 0.5, df / (df + t2));
391
392    if t > 0.0 {
393        1.0 - prob
394    } else {
395        prob
396    }
397}
398
399/// Standard normal CDF.
400fn normal_cdf(x: f64) -> f64 {
401    0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
402}
403
404/// Error function approximation.
405fn erf(x: f64) -> f64 {
406    let a1 = 0.254829592;
407    let a2 = -0.284496736;
408    let a3 = 1.421413741;
409    let a4 = -1.453152027;
410    let a5 = 1.061405429;
411    let p = 0.3275911;
412
413    let sign = if x < 0.0 { -1.0 } else { 1.0 };
414    let x = x.abs();
415
416    let t = 1.0 / (1.0 + p * x);
417    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
418
419    sign * y
420}
421
422/// Incomplete beta function approximation.
423fn incomplete_beta(a: f64, b: f64, x: f64) -> f64 {
424    if x <= 0.0 {
425        return 0.0;
426    }
427    if x >= 1.0 {
428        return 1.0;
429    }
430
431    let lbeta = ln_gamma(a) + ln_gamma(b) - ln_gamma(a + b);
432    let front = (x.powf(a) * (1.0 - x).powf(b)) / lbeta.exp();
433
434    // Lentz's algorithm for continued fraction
435    let mut c: f64 = 1.0;
436    let mut d: f64 = 1.0 / (1.0 - (a + b) * x / (a + 1.0)).max(1e-30);
437    let mut h = d;
438
439    for m in 1..100 {
440        let m = m as f64;
441        let d1 = m * (b - m) * x / ((a + 2.0 * m - 1.0) * (a + 2.0 * m));
442        let d2 = -(a + m) * (a + b + m) * x / ((a + 2.0 * m) * (a + 2.0 * m + 1.0));
443
444        d = 1.0 / (1.0 + d1 * d).max(1e-30);
445        c = 1.0 + d1 / c.max(1e-30);
446        h *= c * d;
447
448        d = 1.0 / (1.0 + d2 * d).max(1e-30);
449        c = 1.0 + d2 / c.max(1e-30);
450        h *= c * d;
451
452        if ((c * d) - 1.0).abs() < 1e-8 {
453            break;
454        }
455    }
456
457    front * h / a
458}
459
460/// Log gamma function approximation.
461fn ln_gamma(x: f64) -> f64 {
462    if x <= 0.0 {
463        return f64::INFINITY;
464    }
465    0.5 * (2.0 * std::f64::consts::PI / x).ln() + x * ((x + 1.0 / (12.0 * x)).ln() - 1.0)
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471
472    #[test]
473    fn test_pearson_correlation() {
474        // Perfect positive correlation
475        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
476        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
477        let r = pearson_correlation(&x, &y);
478        assert!((r - 1.0).abs() < 0.001);
479
480        // Perfect negative correlation
481        let y_neg = vec![10.0, 8.0, 6.0, 4.0, 2.0];
482        let r_neg = pearson_correlation(&x, &y_neg);
483        assert!((r_neg + 1.0).abs() < 0.001);
484
485        // Low correlation (values chosen to have weak correlation)
486        let x_rand = vec![1.0, 2.0, 3.0, 4.0, 5.0];
487        let y_rand = vec![3.0, 1.0, 4.0, 5.0, 2.0];
488        let r_rand = pearson_correlation(&x_rand, &y_rand);
489        // Verify correlation is weak (not strongly positive or negative)
490        assert!(
491            r_rand.abs() < 0.7,
492            "Expected weak correlation, got {}",
493            r_rand
494        );
495    }
496
497    #[test]
498    fn test_spearman_correlation() {
499        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
500        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
501        let r = spearman_correlation(&x, &y);
502        assert!((r - 1.0).abs() < 0.001);
503    }
504
505    #[test]
506    fn test_correlation_analyzer() {
507        let mut data = HashMap::new();
508        data.insert(
509            "x".to_string(),
510            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
511        );
512        data.insert(
513            "y".to_string(),
514            vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0],
515        );
516        data.insert(
517            "z".to_string(),
518            vec![10.0, 8.0, 6.0, 4.0, 2.0, 1.0, 3.0, 5.0, 7.0, 9.0],
519        );
520
521        let analyzer =
522            CorrelationAnalyzer::new()
523                .with_expected_correlations(vec![
524                    ExpectedCorrelation::new("x", "y", 1.0).with_tolerance(0.01)
525                ]);
526
527        let result = analyzer.analyze(&data).unwrap();
528        assert_eq!(result.sample_size, 10);
529        assert!(result.passes);
530
531        // Check we can retrieve correlation
532        let r_xy = result.get_correlation("x", "y").unwrap();
533        assert!((r_xy - 1.0).abs() < 0.001);
534    }
535
536    #[test]
537    fn test_correlation_failure() {
538        let mut data = HashMap::new();
539        data.insert("x".to_string(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
540        data.insert("y".to_string(), vec![5.0, 4.0, 3.0, 2.0, 1.0]); // Negative correlation
541
542        let analyzer = CorrelationAnalyzer::new().with_expected_correlations(vec![
543            ExpectedCorrelation::new("x", "y", 0.8).with_tolerance(0.1), // Expected positive
544        ]);
545
546        let result = analyzer.analyze(&data).unwrap();
547        assert!(!result.passes);
548        assert_eq!(result.checks_failed, 1);
549    }
550
551    #[test]
552    fn test_correlation_p_value() {
553        // Strong correlation with large sample should have low p-value
554        let x: Vec<f64> = (0..100).map(|i| i as f64).collect();
555        let y: Vec<f64> = x.iter().map(|&v| v * 2.0 + 1.0).collect();
556
557        let r = pearson_correlation(&x, &y);
558        let p = correlation_p_value(r, x.len());
559
560        assert!(r > 0.99);
561        assert!(p < 0.001);
562    }
563
564    #[test]
565    fn test_rank_calculation() {
566        let values = vec![1.0, 3.0, 2.0, 3.0, 5.0]; // Note: ties at 3.0
567        let ranks = calculate_ranks(&values);
568
569        // 1.0 -> rank 1
570        // 2.0 -> rank 2
571        // 3.0, 3.0 -> ranks 3.5, 3.5 (average of 3 and 4)
572        // 5.0 -> rank 5
573        assert!((ranks[0] - 1.0).abs() < 0.001);
574        assert!((ranks[2] - 2.0).abs() < 0.001);
575        assert!((ranks[1] - 3.5).abs() < 0.001);
576        assert!((ranks[3] - 3.5).abs() < 0.001);
577        assert!((ranks[4] - 5.0).abs() < 0.001);
578    }
579}