1use crate::errors::EmptyInput;
2use ndarray::prelude::*;
3use num_traits::{Float, FromPrimitive};
4
5pub trait CorrelationExt<A> {
8 fn cov(&self, ddof: A) -> Result<Array2<A>, EmptyInput>
61 where
62 A: Float + FromPrimitive;
63
64 fn pearson_correlation(&self) -> Result<Array2<A>, EmptyInput>
118 where
119 A: Float + FromPrimitive;
120
121 private_decl! {}
122}
123
124impl<A: 'static> CorrelationExt<A> for ArrayRef2<A> {
125 fn cov(&self, ddof: A) -> Result<Array2<A>, EmptyInput>
126 where
127 A: Float + FromPrimitive,
128 {
129 let observation_axis = Axis(1);
130 let n_observations = A::from_usize(self.len_of(observation_axis)).unwrap();
131 let dof = if ddof >= n_observations {
132 panic!(
133 "`ddof` needs to be strictly smaller than the \
134 number of observations provided for each \
135 random variable!"
136 )
137 } else {
138 n_observations - ddof
139 };
140 let mean = self.mean_axis(observation_axis);
141 match mean {
142 Some(mean) => {
143 let denoised = self - mean.insert_axis(observation_axis);
144 let covariance = denoised.dot(&denoised.t());
145 Ok(covariance.mapv_into(|x| x / dof))
146 }
147 None => Err(EmptyInput),
148 }
149 }
150
151 fn pearson_correlation(&self) -> Result<Array2<A>, EmptyInput>
152 where
153 A: Float + FromPrimitive,
154 {
155 match self.dim() {
156 (n, m) if n > 0 && m > 0 => {
157 let observation_axis = Axis(1);
158 let ddof = A::zero();
162 let cov = self.cov(ddof).unwrap();
163 let std = self
164 .std_axis(observation_axis, ddof)
165 .insert_axis(observation_axis);
166 let std_matrix = std.dot(&std.t());
167 Ok(cov / std_matrix)
169 }
170 _ => Err(EmptyInput),
171 }
172 }
173
174 private_impl! {}
175}
176
177#[cfg(test)]
178mod cov_tests {
179 use super::*;
180 use ndarray::array;
181 use ndarray_rand::rand;
182 use ndarray_rand::rand_distr::Uniform;
183 use ndarray_rand::RandomExt;
184 use quickcheck_macros::quickcheck;
185
186 #[quickcheck]
187 fn constant_random_variables_have_zero_covariance_matrix(value: f64) -> bool {
188 let n_random_variables = 3;
189 let n_observations = 4;
190 let a = Array::from_elem((n_random_variables, n_observations), value);
191 abs_diff_eq!(
192 a.cov(1.).unwrap(),
193 &Array::zeros((n_random_variables, n_random_variables)),
194 epsilon = 1e-8,
195 )
196 }
197
198 #[quickcheck]
199 fn covariance_matrix_is_symmetric(bound: f64) -> bool {
200 let n_random_variables = 3;
201 let n_observations = 4;
202 let a = Array::random(
203 (n_random_variables, n_observations),
204 Uniform::new(-bound.abs(), bound.abs()).unwrap(),
205 );
206 let covariance = a.cov(1.).unwrap();
207 abs_diff_eq!(covariance, &covariance.t(), epsilon = 1e-8)
208 }
209
210 #[test]
211 #[should_panic]
212 fn test_invalid_ddof() {
213 let n_random_variables = 3;
214 let n_observations = 4;
215 let a = Array::random(
216 (n_random_variables, n_observations),
217 Uniform::new(0., 10.).unwrap(),
218 );
219 let invalid_ddof = (n_observations as f64) + rand::random::<f64>().abs();
220 let _ = a.cov(invalid_ddof);
221 }
222
223 #[test]
224 fn test_covariance_zero_variables() {
225 let a = Array2::<f32>::zeros((0, 2));
226 let cov = a.cov(1.);
227 assert!(cov.is_ok());
228 assert_eq!(cov.unwrap().shape(), &[0, 0]);
229 }
230
231 #[test]
232 fn test_covariance_zero_observations() {
233 let a = Array2::<f32>::zeros((2, 0));
234 let cov = a.cov(-1.);
236 assert_eq!(cov, Err(EmptyInput));
237 }
238
239 #[test]
240 fn test_covariance_zero_variables_zero_observations() {
241 let a = Array2::<f32>::zeros((0, 0));
242 let cov = a.cov(-1.);
244 assert_eq!(cov, Err(EmptyInput));
245 }
246
247 #[test]
248 fn test_covariance_for_random_array() {
249 let a = array![
250 [0.72009497, 0.12568055, 0.55705966, 0.5959984, 0.69471457],
251 [0.56717131, 0.47619486, 0.21526298, 0.88915366, 0.91971245],
252 [0.59044195, 0.10720363, 0.76573717, 0.54693675, 0.95923036],
253 [0.24102952, 0.131347, 0.11118028, 0.21451351, 0.30515539],
254 [0.26952473, 0.93079841, 0.8080893, 0.42814155, 0.24642258]
255 ];
256 let numpy_covariance = array![
257 [0.05786248, 0.02614063, 0.06446215, 0.01285105, -0.06443992],
258 [0.02614063, 0.08733569, 0.02436933, 0.01977437, -0.06715555],
259 [0.06446215, 0.02436933, 0.10052129, 0.01393589, -0.06129912],
260 [0.01285105, 0.01977437, 0.01393589, 0.00638795, -0.02355557],
261 [
262 -0.06443992,
263 -0.06715555,
264 -0.06129912,
265 -0.02355557,
266 0.09909855
267 ]
268 ];
269 assert_eq!(a.ndim(), 2);
270 assert_abs_diff_eq!(a.cov(1.).unwrap(), &numpy_covariance, epsilon = 1e-8);
271 }
272
273 #[test]
274 #[should_panic]
275 fn test_covariance_for_badly_conditioned_array() {
277 let a: Array2<f64> = array![[1e12 + 1., 1e12 - 1.], [1e-6 + 1e-12, 1e-6 - 1e-12],];
278 let expected_covariance = array![[2., 2e-12], [2e-12, 2e-24]];
279 assert_abs_diff_eq!(a.cov(1.).unwrap(), &expected_covariance, epsilon = 1e-24);
280 }
281}
282
283#[cfg(test)]
284mod pearson_correlation_tests {
285 use super::*;
286 use ndarray::array;
287 use ndarray::Array;
288 use ndarray_rand::rand_distr::Uniform;
289 use ndarray_rand::RandomExt;
290 use quickcheck_macros::quickcheck;
291
292 #[quickcheck]
293 fn output_matrix_is_symmetric(bound: f64) -> bool {
294 let n_random_variables = 3;
295 let n_observations = 4;
296 let a = Array::random(
297 (n_random_variables, n_observations),
298 Uniform::new(-bound.abs(), bound.abs()).unwrap(),
299 );
300 let pearson_correlation = a.pearson_correlation().unwrap();
301 abs_diff_eq!(
302 pearson_correlation.view(),
303 pearson_correlation.t(),
304 epsilon = 1e-8
305 )
306 }
307
308 #[quickcheck]
309 fn constant_random_variables_have_nan_correlation(value: f64) -> bool {
310 let n_random_variables = 3;
311 let n_observations = 4;
312 let a = Array::from_elem((n_random_variables, n_observations), value);
313 let pearson_correlation = a.pearson_correlation();
314 pearson_correlation
315 .unwrap()
316 .iter()
317 .map(|x| x.is_nan())
318 .fold(true, |acc, flag| acc & flag)
319 }
320
321 #[test]
322 fn test_zero_variables() {
323 let a = Array2::<f32>::zeros((0, 2));
324 let pearson_correlation = a.pearson_correlation();
325 assert_eq!(pearson_correlation, Err(EmptyInput))
326 }
327
328 #[test]
329 fn test_zero_observations() {
330 let a = Array2::<f32>::zeros((2, 0));
331 let pearson = a.pearson_correlation();
332 assert_eq!(pearson, Err(EmptyInput));
333 }
334
335 #[test]
336 fn test_zero_variables_zero_observations() {
337 let a = Array2::<f32>::zeros((0, 0));
338 let pearson = a.pearson_correlation();
339 assert_eq!(pearson, Err(EmptyInput));
340 }
341
342 #[test]
343 fn test_for_random_array() {
344 let a = array![
345 [0.16351516, 0.56863268, 0.16924196, 0.72579120],
346 [0.44342453, 0.19834387, 0.25411802, 0.62462382],
347 [0.97162731, 0.29958849, 0.17338142, 0.80198342],
348 [0.91727132, 0.79817799, 0.62237124, 0.38970998],
349 [0.26979716, 0.20887228, 0.95454999, 0.96290785]
350 ];
351 let numpy_corrcoeff = array![
352 [1., 0.38089376, 0.08122504, -0.59931623, 0.1365648],
353 [0.38089376, 1., 0.80918429, -0.52615195, 0.38954398],
354 [0.08122504, 0.80918429, 1., 0.07134906, -0.17324776],
355 [-0.59931623, -0.52615195, 0.07134906, 1., -0.8743213],
356 [0.1365648, 0.38954398, -0.17324776, -0.8743213, 1.]
357 ];
358 assert_eq!(a.ndim(), 2);
359 assert_abs_diff_eq!(
360 a.pearson_correlation().unwrap(),
361 numpy_corrcoeff,
362 epsilon = 1e-7
363 );
364 }
365}