1use crate::errors::EmptyInput;
2use ndarray::prelude::*;
3use ndarray::Data;
4use num_traits::{Float, FromPrimitive};
5
6pub trait CorrelationExt<A, S>
9where
10 S: Data<Elem = A>,
11{
12 fn cov(&self, ddof: A) -> Result<Array2<A>, EmptyInput>
65 where
66 A: Float + FromPrimitive;
67
68 fn pearson_correlation(&self) -> Result<Array2<A>, EmptyInput>
122 where
123 A: Float + FromPrimitive;
124
125 private_decl! {}
126}
127
128impl<A: 'static, S> CorrelationExt<A, S> for ArrayBase<S, Ix2>
129where
130 S: Data<Elem = A>,
131{
132 fn cov(&self, ddof: A) -> Result<Array2<A>, EmptyInput>
133 where
134 A: Float + FromPrimitive,
135 {
136 let observation_axis = Axis(1);
137 let n_observations = A::from_usize(self.len_of(observation_axis)).unwrap();
138 let dof = if ddof >= n_observations {
139 panic!(
140 "`ddof` needs to be strictly smaller than the \
141 number of observations provided for each \
142 random variable!"
143 )
144 } else {
145 n_observations - ddof
146 };
147 let mean = self.mean_axis(observation_axis);
148 match mean {
149 Some(mean) => {
150 let denoised = self - &mean.insert_axis(observation_axis);
151 let covariance = denoised.dot(&denoised.t());
152 Ok(covariance.mapv_into(|x| x / dof))
153 }
154 None => Err(EmptyInput),
155 }
156 }
157
158 fn pearson_correlation(&self) -> Result<Array2<A>, EmptyInput>
159 where
160 A: Float + FromPrimitive,
161 {
162 match self.dim() {
163 (n, m) if n > 0 && m > 0 => {
164 let observation_axis = Axis(1);
165 let ddof = A::zero();
169 let cov = self.cov(ddof).unwrap();
170 let std = self
171 .std_axis(observation_axis, ddof)
172 .insert_axis(observation_axis);
173 let std_matrix = std.dot(&std.t());
174 Ok(cov / std_matrix)
176 }
177 _ => Err(EmptyInput),
178 }
179 }
180
181 private_impl! {}
182}
183
184#[cfg(test)]
185mod cov_tests {
186 use super::*;
187 use ndarray::array;
188 use ndarray_rand::rand;
189 use ndarray_rand::rand_distr::Uniform;
190 use ndarray_rand::RandomExt;
191 use quickcheck_macros::quickcheck;
192
193 #[quickcheck]
194 fn constant_random_variables_have_zero_covariance_matrix(value: f64) -> bool {
195 let n_random_variables = 3;
196 let n_observations = 4;
197 let a = Array::from_elem((n_random_variables, n_observations), value);
198 abs_diff_eq!(
199 a.cov(1.).unwrap(),
200 &Array::zeros((n_random_variables, n_random_variables)),
201 epsilon = 1e-8,
202 )
203 }
204
205 #[quickcheck]
206 fn covariance_matrix_is_symmetric(bound: f64) -> bool {
207 let n_random_variables = 3;
208 let n_observations = 4;
209 let a = Array::random(
210 (n_random_variables, n_observations),
211 Uniform::new(-bound.abs(), bound.abs()),
212 );
213 let covariance = a.cov(1.).unwrap();
214 abs_diff_eq!(covariance, &covariance.t(), epsilon = 1e-8)
215 }
216
217 #[test]
218 #[should_panic]
219 fn test_invalid_ddof() {
220 let n_random_variables = 3;
221 let n_observations = 4;
222 let a = Array::random((n_random_variables, n_observations), Uniform::new(0., 10.));
223 let invalid_ddof = (n_observations as f64) + rand::random::<f64>().abs();
224 let _ = a.cov(invalid_ddof);
225 }
226
227 #[test]
228 fn test_covariance_zero_variables() {
229 let a = Array2::<f32>::zeros((0, 2));
230 let cov = a.cov(1.);
231 assert!(cov.is_ok());
232 assert_eq!(cov.unwrap().shape(), &[0, 0]);
233 }
234
235 #[test]
236 fn test_covariance_zero_observations() {
237 let a = Array2::<f32>::zeros((2, 0));
238 let cov = a.cov(-1.);
240 assert_eq!(cov, Err(EmptyInput));
241 }
242
243 #[test]
244 fn test_covariance_zero_variables_zero_observations() {
245 let a = Array2::<f32>::zeros((0, 0));
246 let cov = a.cov(-1.);
248 assert_eq!(cov, Err(EmptyInput));
249 }
250
251 #[test]
252 fn test_covariance_for_random_array() {
253 let a = array![
254 [0.72009497, 0.12568055, 0.55705966, 0.5959984, 0.69471457],
255 [0.56717131, 0.47619486, 0.21526298, 0.88915366, 0.91971245],
256 [0.59044195, 0.10720363, 0.76573717, 0.54693675, 0.95923036],
257 [0.24102952, 0.131347, 0.11118028, 0.21451351, 0.30515539],
258 [0.26952473, 0.93079841, 0.8080893, 0.42814155, 0.24642258]
259 ];
260 let numpy_covariance = array![
261 [0.05786248, 0.02614063, 0.06446215, 0.01285105, -0.06443992],
262 [0.02614063, 0.08733569, 0.02436933, 0.01977437, -0.06715555],
263 [0.06446215, 0.02436933, 0.10052129, 0.01393589, -0.06129912],
264 [0.01285105, 0.01977437, 0.01393589, 0.00638795, -0.02355557],
265 [
266 -0.06443992,
267 -0.06715555,
268 -0.06129912,
269 -0.02355557,
270 0.09909855
271 ]
272 ];
273 assert_eq!(a.ndim(), 2);
274 assert_abs_diff_eq!(a.cov(1.).unwrap(), &numpy_covariance, epsilon = 1e-8);
275 }
276
277 #[test]
278 #[should_panic]
279 fn test_covariance_for_badly_conditioned_array() {
281 let a: Array2<f64> = array![[1e12 + 1., 1e12 - 1.], [1e-6 + 1e-12, 1e-6 - 1e-12],];
282 let expected_covariance = array![[2., 2e-12], [2e-12, 2e-24]];
283 assert_abs_diff_eq!(a.cov(1.).unwrap(), &expected_covariance, epsilon = 1e-24);
284 }
285}
286
287#[cfg(test)]
288mod pearson_correlation_tests {
289 use super::*;
290 use ndarray::array;
291 use ndarray::Array;
292 use ndarray_rand::rand_distr::Uniform;
293 use ndarray_rand::RandomExt;
294 use quickcheck_macros::quickcheck;
295
296 #[quickcheck]
297 fn output_matrix_is_symmetric(bound: f64) -> bool {
298 let n_random_variables = 3;
299 let n_observations = 4;
300 let a = Array::random(
301 (n_random_variables, n_observations),
302 Uniform::new(-bound.abs(), bound.abs()),
303 );
304 let pearson_correlation = a.pearson_correlation().unwrap();
305 abs_diff_eq!(
306 pearson_correlation.view(),
307 pearson_correlation.t(),
308 epsilon = 1e-8
309 )
310 }
311
312 #[quickcheck]
313 fn constant_random_variables_have_nan_correlation(value: f64) -> bool {
314 let n_random_variables = 3;
315 let n_observations = 4;
316 let a = Array::from_elem((n_random_variables, n_observations), value);
317 let pearson_correlation = a.pearson_correlation();
318 pearson_correlation
319 .unwrap()
320 .iter()
321 .map(|x| x.is_nan())
322 .fold(true, |acc, flag| acc & flag)
323 }
324
325 #[test]
326 fn test_zero_variables() {
327 let a = Array2::<f32>::zeros((0, 2));
328 let pearson_correlation = a.pearson_correlation();
329 assert_eq!(pearson_correlation, Err(EmptyInput))
330 }
331
332 #[test]
333 fn test_zero_observations() {
334 let a = Array2::<f32>::zeros((2, 0));
335 let pearson = a.pearson_correlation();
336 assert_eq!(pearson, Err(EmptyInput));
337 }
338
339 #[test]
340 fn test_zero_variables_zero_observations() {
341 let a = Array2::<f32>::zeros((0, 0));
342 let pearson = a.pearson_correlation();
343 assert_eq!(pearson, Err(EmptyInput));
344 }
345
346 #[test]
347 fn test_for_random_array() {
348 let a = array![
349 [0.16351516, 0.56863268, 0.16924196, 0.72579120],
350 [0.44342453, 0.19834387, 0.25411802, 0.62462382],
351 [0.97162731, 0.29958849, 0.17338142, 0.80198342],
352 [0.91727132, 0.79817799, 0.62237124, 0.38970998],
353 [0.26979716, 0.20887228, 0.95454999, 0.96290785]
354 ];
355 let numpy_corrcoeff = array![
356 [1., 0.38089376, 0.08122504, -0.59931623, 0.1365648],
357 [0.38089376, 1., 0.80918429, -0.52615195, 0.38954398],
358 [0.08122504, 0.80918429, 1., 0.07134906, -0.17324776],
359 [-0.59931623, -0.52615195, 0.07134906, 1., -0.8743213],
360 [0.1365648, 0.38954398, -0.17324776, -0.8743213, 1.]
361 ];
362 assert_eq!(a.ndim(), 2);
363 assert_abs_diff_eq!(
364 a.pearson_correlation().unwrap(),
365 numpy_corrcoeff,
366 epsilon = 1e-7
367 );
368 }
369}