1use cyanea_core::{CyaneaError, Result, Summarizable};
7
8use crate::rank::{rank, RankMethod};
9
10pub fn pearson(x: &[f64], y: &[f64]) -> Result<f64> {
14 validate_paired(x, y)?;
15
16 let n = x.len() as f64;
17 let mean_x: f64 = x.iter().sum::<f64>() / n;
18 let mean_y: f64 = y.iter().sum::<f64>() / n;
19
20 let mut cov = 0.0;
21 let mut var_x = 0.0;
22 let mut var_y = 0.0;
23 for (xi, yi) in x.iter().zip(y.iter()) {
24 let dx = xi - mean_x;
25 let dy = yi - mean_y;
26 cov += dx * dy;
27 var_x += dx * dx;
28 var_y += dy * dy;
29 }
30
31 let denom = (var_x * var_y).sqrt();
32 if denom == 0.0 {
33 return Ok(0.0);
34 }
35 Ok(cov / denom)
36}
37
38pub fn spearman(x: &[f64], y: &[f64]) -> Result<f64> {
43 validate_paired(x, y)?;
44 let rx = rank(x, RankMethod::Average);
45 let ry = rank(y, RankMethod::Average);
46 pearson(&rx, &ry)
47}
48
49fn validate_paired(x: &[f64], y: &[f64]) -> Result<()> {
50 if x.len() != y.len() {
51 return Err(CyaneaError::InvalidInput(format!(
52 "correlation: x and y must have the same length ({} vs {})",
53 x.len(),
54 y.len(),
55 )));
56 }
57 if x.len() < 2 {
58 return Err(CyaneaError::InvalidInput(
59 "correlation: need at least 2 observations".into(),
60 ));
61 }
62 Ok(())
63}
64
65#[derive(Debug, Clone)]
69pub struct CorrelationMatrix {
70 data: Vec<f64>,
72 size: usize,
74 labels: Option<Vec<String>>,
76}
77
78impl CorrelationMatrix {
79 pub fn from_rows(rows: &[&[f64]]) -> Result<Self> {
84 Self::build(rows, None)
85 }
86
87 pub fn from_rows_labeled(rows: &[&[f64]], labels: &[&str]) -> Result<Self> {
89 if labels.len() != rows.len() {
90 return Err(CyaneaError::InvalidInput(
91 "CorrelationMatrix: labels length must match rows length".into(),
92 ));
93 }
94 let labels: Vec<String> = labels.iter().map(|s| s.to_string()).collect();
95 Self::build(rows, Some(labels))
96 }
97
98 fn build(rows: &[&[f64]], labels: Option<Vec<String>>) -> Result<Self> {
99 if rows.is_empty() {
100 return Err(CyaneaError::InvalidInput(
101 "CorrelationMatrix: need at least one variable".into(),
102 ));
103 }
104 let obs_len = rows[0].len();
105 for (i, row) in rows.iter().enumerate() {
106 if row.len() != obs_len {
107 return Err(CyaneaError::InvalidInput(format!(
108 "CorrelationMatrix: row {} has {} observations, expected {}",
109 i,
110 row.len(),
111 obs_len,
112 )));
113 }
114 }
115
116 let n = rows.len();
117 #[cfg(feature = "parallel")]
118 let data = {
119 use rayon::prelude::*;
120 let upper: Vec<Vec<(usize, f64)>> = (0..n)
121 .into_par_iter()
122 .map(|i| {
123 ((i + 1)..n)
124 .map(|j| {
125 let r = pearson(rows[i], rows[j]).unwrap();
126 (j, r)
127 })
128 .collect()
129 })
130 .collect();
131 let mut data = vec![0.0; n * n];
132 for i in 0..n {
133 data[i * n + i] = 1.0;
134 for &(j, r) in &upper[i] {
135 data[i * n + j] = r;
136 data[j * n + i] = r;
137 }
138 }
139 data
140 };
141 #[cfg(not(feature = "parallel"))]
142 let data = {
143 let mut data = vec![0.0; n * n];
144 for i in 0..n {
145 data[i * n + i] = 1.0;
146 for j in (i + 1)..n {
147 let r = pearson(rows[i], rows[j])?;
148 data[i * n + j] = r;
149 data[j * n + i] = r;
150 }
151 }
152 data
153 };
154
155 Ok(Self {
156 data,
157 size: n,
158 labels,
159 })
160 }
161
162 pub fn get(&self, i: usize, j: usize) -> f64 {
164 self.data[i * self.size + j]
165 }
166
167 pub fn n(&self) -> usize {
169 self.size
170 }
171
172 pub fn labels(&self) -> Option<&[String]> {
174 self.labels.as_deref()
175 }
176}
177
178impl Summarizable for CorrelationMatrix {
179 fn summary(&self) -> String {
180 format!("CorrelationMatrix: {}x{}", self.size, self.size)
181 }
182}
183
184#[cfg(test)]
187mod tests {
188 use super::*;
189
190 const TOL: f64 = 1e-10;
191
192 #[test]
193 fn pearson_perfect_positive() {
194 let x = [1.0, 2.0, 3.0, 4.0, 5.0];
195 let y = [2.0, 4.0, 6.0, 8.0, 10.0];
196 assert!((pearson(&x, &y).unwrap() - 1.0).abs() < TOL);
197 }
198
199 #[test]
200 fn pearson_perfect_negative() {
201 let x = [1.0, 2.0, 3.0, 4.0, 5.0];
202 let y = [10.0, 8.0, 6.0, 4.0, 2.0];
203 assert!((pearson(&x, &y).unwrap() - (-1.0)).abs() < TOL);
204 }
205
206 #[test]
207 fn pearson_zero_correlation() {
208 let x = [1.0, 0.0, -1.0, 0.0];
210 let y = [0.0, 1.0, 0.0, -1.0];
211 assert!((pearson(&x, &y).unwrap()).abs() < TOL);
212 }
213
214 #[test]
215 fn pearson_constant_series() {
216 let x = [3.0, 3.0, 3.0];
217 let y = [1.0, 2.0, 3.0];
218 assert!((pearson(&x, &y).unwrap()).abs() < TOL);
219 }
220
221 #[test]
222 fn pearson_length_mismatch() {
223 assert!(pearson(&[1.0, 2.0], &[1.0]).is_err());
224 }
225
226 #[test]
227 fn pearson_too_short() {
228 assert!(pearson(&[1.0], &[2.0]).is_err());
229 }
230
231 #[test]
232 fn spearman_monotonic() {
233 let x = [1.0, 2.0, 3.0, 4.0, 5.0];
234 let y = [1.0, 8.0, 27.0, 64.0, 125.0]; assert!((spearman(&x, &y).unwrap() - 1.0).abs() < TOL);
236 }
237
238 #[test]
239 fn spearman_reverse() {
240 let x = [1.0, 2.0, 3.0, 4.0, 5.0];
241 let y = [5.0, 4.0, 3.0, 2.0, 1.0];
242 assert!((spearman(&x, &y).unwrap() - (-1.0)).abs() < TOL);
243 }
244
245 #[test]
246 fn correlation_matrix_diagonal() {
247 let a = [1.0, 2.0, 3.0];
248 let b = [4.0, 5.0, 6.0];
249 let c = [7.0, 8.0, 9.0];
250 let cm = CorrelationMatrix::from_rows(&[&a[..], &b[..], &c[..]]).unwrap();
251 assert_eq!(cm.n(), 3);
252 assert!((cm.get(0, 0) - 1.0).abs() < TOL);
253 assert!((cm.get(1, 1) - 1.0).abs() < TOL);
254 assert!((cm.get(2, 2) - 1.0).abs() < TOL);
255 }
256
257 #[test]
258 fn correlation_matrix_symmetric() {
259 let a = [1.0, 2.0, 3.0, 4.0];
260 let b = [4.0, 3.0, 2.0, 1.0];
261 let cm = CorrelationMatrix::from_rows(&[&a[..], &b[..]]).unwrap();
262 assert!((cm.get(0, 1) - cm.get(1, 0)).abs() < TOL);
263 }
264
265 #[test]
266 fn correlation_matrix_summary() {
267 let a = [1.0, 2.0, 3.0];
268 let b = [4.0, 5.0, 6.0];
269 let cm = CorrelationMatrix::from_rows(&[&a[..], &b[..]]).unwrap();
270 assert_eq!(cm.summary(), "CorrelationMatrix: 2x2");
271 }
272}