Skip to main content

cyanea_stats/
correlation.rs

1//! Correlation analysis.
2//!
3//! Provides Pearson and Spearman correlation coefficients, and a
4//! [`CorrelationMatrix`] for pairwise analysis of multiple variables.
5
6use cyanea_core::{CyaneaError, Result, Summarizable};
7
8use crate::rank::{rank, RankMethod};
9
10/// Pearson product-moment correlation coefficient between `x` and `y`.
11///
12/// Returns 0.0 if either series is constant (zero variance).
13pub fn pearson(x: &[f64], y: &[f64]) -> Result<f64> {
14    validate_paired(x, y)?;
15
16    let n = x.len() as f64;
17    let mean_x: f64 = x.iter().sum::<f64>() / n;
18    let mean_y: f64 = y.iter().sum::<f64>() / n;
19
20    let mut cov = 0.0;
21    let mut var_x = 0.0;
22    let mut var_y = 0.0;
23    for (xi, yi) in x.iter().zip(y.iter()) {
24        let dx = xi - mean_x;
25        let dy = yi - mean_y;
26        cov += dx * dy;
27        var_x += dx * dx;
28        var_y += dy * dy;
29    }
30
31    let denom = (var_x * var_y).sqrt();
32    if denom == 0.0 {
33        return Ok(0.0);
34    }
35    Ok(cov / denom)
36}
37
38/// Spearman rank correlation coefficient between `x` and `y`.
39///
40/// Ranks both series with [`RankMethod::Average`], then computes Pearson
41/// correlation on the ranks.
42pub fn spearman(x: &[f64], y: &[f64]) -> Result<f64> {
43    validate_paired(x, y)?;
44    let rx = rank(x, RankMethod::Average);
45    let ry = rank(y, RankMethod::Average);
46    pearson(&rx, &ry)
47}
48
49fn validate_paired(x: &[f64], y: &[f64]) -> Result<()> {
50    if x.len() != y.len() {
51        return Err(CyaneaError::InvalidInput(format!(
52            "correlation: x and y must have the same length ({} vs {})",
53            x.len(),
54            y.len(),
55        )));
56    }
57    if x.len() < 2 {
58        return Err(CyaneaError::InvalidInput(
59            "correlation: need at least 2 observations".into(),
60        ));
61    }
62    Ok(())
63}
64
65// ── Correlation matrix ─────────────────────────────────────────────────────
66
67/// Pairwise Pearson correlation matrix for a set of variables.
68#[derive(Debug, Clone)]
69pub struct CorrelationMatrix {
70    /// Flat storage (row-major, n×n).
71    data: Vec<f64>,
72    /// Number of variables.
73    size: usize,
74    /// Optional variable labels.
75    labels: Option<Vec<String>>,
76}
77
78impl CorrelationMatrix {
79    /// Build a correlation matrix from rows of observations.
80    ///
81    /// Each inner slice is one variable's observations (all must have the same
82    /// length and at least 2 elements).
83    pub fn from_rows(rows: &[&[f64]]) -> Result<Self> {
84        Self::build(rows, None)
85    }
86
87    /// Build a labeled correlation matrix.
88    pub fn from_rows_labeled(rows: &[&[f64]], labels: &[&str]) -> Result<Self> {
89        if labels.len() != rows.len() {
90            return Err(CyaneaError::InvalidInput(
91                "CorrelationMatrix: labels length must match rows length".into(),
92            ));
93        }
94        let labels: Vec<String> = labels.iter().map(|s| s.to_string()).collect();
95        Self::build(rows, Some(labels))
96    }
97
98    fn build(rows: &[&[f64]], labels: Option<Vec<String>>) -> Result<Self> {
99        if rows.is_empty() {
100            return Err(CyaneaError::InvalidInput(
101                "CorrelationMatrix: need at least one variable".into(),
102            ));
103        }
104        let obs_len = rows[0].len();
105        for (i, row) in rows.iter().enumerate() {
106            if row.len() != obs_len {
107                return Err(CyaneaError::InvalidInput(format!(
108                    "CorrelationMatrix: row {} has {} observations, expected {}",
109                    i,
110                    row.len(),
111                    obs_len,
112                )));
113            }
114        }
115
116        let n = rows.len();
117        #[cfg(feature = "parallel")]
118        let data = {
119            use rayon::prelude::*;
120            let upper: Vec<Vec<(usize, f64)>> = (0..n)
121                .into_par_iter()
122                .map(|i| {
123                    ((i + 1)..n)
124                        .map(|j| {
125                            let r = pearson(rows[i], rows[j]).unwrap();
126                            (j, r)
127                        })
128                        .collect()
129                })
130                .collect();
131            let mut data = vec![0.0; n * n];
132            for i in 0..n {
133                data[i * n + i] = 1.0;
134                for &(j, r) in &upper[i] {
135                    data[i * n + j] = r;
136                    data[j * n + i] = r;
137                }
138            }
139            data
140        };
141        #[cfg(not(feature = "parallel"))]
142        let data = {
143            let mut data = vec![0.0; n * n];
144            for i in 0..n {
145                data[i * n + i] = 1.0;
146                for j in (i + 1)..n {
147                    let r = pearson(rows[i], rows[j])?;
148                    data[i * n + j] = r;
149                    data[j * n + i] = r;
150                }
151            }
152            data
153        };
154
155        Ok(Self {
156            data,
157            size: n,
158            labels,
159        })
160    }
161
162    /// Get the correlation between variable `i` and variable `j`.
163    pub fn get(&self, i: usize, j: usize) -> f64 {
164        self.data[i * self.size + j]
165    }
166
167    /// Number of variables.
168    pub fn n(&self) -> usize {
169        self.size
170    }
171
172    /// Variable labels, if provided.
173    pub fn labels(&self) -> Option<&[String]> {
174        self.labels.as_deref()
175    }
176}
177
178impl Summarizable for CorrelationMatrix {
179    fn summary(&self) -> String {
180        format!("CorrelationMatrix: {}x{}", self.size, self.size)
181    }
182}
183
184// ── Tests ──────────────────────────────────────────────────────────────────
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    const TOL: f64 = 1e-10;
191
192    #[test]
193    fn pearson_perfect_positive() {
194        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
195        let y = [2.0, 4.0, 6.0, 8.0, 10.0];
196        assert!((pearson(&x, &y).unwrap() - 1.0).abs() < TOL);
197    }
198
199    #[test]
200    fn pearson_perfect_negative() {
201        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
202        let y = [10.0, 8.0, 6.0, 4.0, 2.0];
203        assert!((pearson(&x, &y).unwrap() - (-1.0)).abs() < TOL);
204    }
205
206    #[test]
207    fn pearson_zero_correlation() {
208        // Orthogonal pattern
209        let x = [1.0, 0.0, -1.0, 0.0];
210        let y = [0.0, 1.0, 0.0, -1.0];
211        assert!((pearson(&x, &y).unwrap()).abs() < TOL);
212    }
213
214    #[test]
215    fn pearson_constant_series() {
216        let x = [3.0, 3.0, 3.0];
217        let y = [1.0, 2.0, 3.0];
218        assert!((pearson(&x, &y).unwrap()).abs() < TOL);
219    }
220
221    #[test]
222    fn pearson_length_mismatch() {
223        assert!(pearson(&[1.0, 2.0], &[1.0]).is_err());
224    }
225
226    #[test]
227    fn pearson_too_short() {
228        assert!(pearson(&[1.0], &[2.0]).is_err());
229    }
230
231    #[test]
232    fn spearman_monotonic() {
233        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
234        let y = [1.0, 8.0, 27.0, 64.0, 125.0]; // x^3 — monotonically increasing
235        assert!((spearman(&x, &y).unwrap() - 1.0).abs() < TOL);
236    }
237
238    #[test]
239    fn spearman_reverse() {
240        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
241        let y = [5.0, 4.0, 3.0, 2.0, 1.0];
242        assert!((spearman(&x, &y).unwrap() - (-1.0)).abs() < TOL);
243    }
244
245    #[test]
246    fn correlation_matrix_diagonal() {
247        let a = [1.0, 2.0, 3.0];
248        let b = [4.0, 5.0, 6.0];
249        let c = [7.0, 8.0, 9.0];
250        let cm = CorrelationMatrix::from_rows(&[&a[..], &b[..], &c[..]]).unwrap();
251        assert_eq!(cm.n(), 3);
252        assert!((cm.get(0, 0) - 1.0).abs() < TOL);
253        assert!((cm.get(1, 1) - 1.0).abs() < TOL);
254        assert!((cm.get(2, 2) - 1.0).abs() < TOL);
255    }
256
257    #[test]
258    fn correlation_matrix_symmetric() {
259        let a = [1.0, 2.0, 3.0, 4.0];
260        let b = [4.0, 3.0, 2.0, 1.0];
261        let cm = CorrelationMatrix::from_rows(&[&a[..], &b[..]]).unwrap();
262        assert!((cm.get(0, 1) - cm.get(1, 0)).abs() < TOL);
263    }
264
265    #[test]
266    fn correlation_matrix_summary() {
267        let a = [1.0, 2.0, 3.0];
268        let b = [4.0, 5.0, 6.0];
269        let cm = CorrelationMatrix::from_rows(&[&a[..], &b[..]]).unwrap();
270        assert_eq!(cm.summary(), "CorrelationMatrix: 2x2");
271    }
272}