ndarray_stats/
correlation.rs

1use crate::errors::EmptyInput;
2use ndarray::prelude::*;
3use num_traits::{Float, FromPrimitive};
4
5/// Extension trait for `ndarray` providing functions
6/// to compute different correlation measures.
7pub trait CorrelationExt<A> {
8    /// Return the covariance matrix `C` for a 2-dimensional
9    /// array of observations `M`.
10    ///
11    /// Let `(r, o)` be the shape of `M`:
12    /// - `r` is the number of random variables;
13    /// - `o` is the number of observations we have collected
14    /// for each random variable.
15    ///
16    /// Every column in `M` is an experiment: a single observation for each
17    /// random variable.
18    /// Each row in `M` contains all the observations for a certain random variable.
19    ///
20    /// The parameter `ddof` specifies the "delta degrees of freedom". For
21    /// example, to calculate the population covariance, use `ddof = 0`, or to
22    /// calculate the sample covariance (unbiased estimate), use `ddof = 1`.
23    ///
24    /// The covariance of two random variables is defined as:
25    ///
26    /// ```text
27    ///                1       n
28    /// cov(X, Y) = ――――――――   ∑ (xᵢ - x̅)(yᵢ - y̅)
29    ///             n - ddof  i=1
30    /// ```
31    ///
32    /// where
33    ///
34    /// ```text
35    ///     1   n
36    /// x̅ = ―   ∑ xᵢ
37    ///     n  i=1
38    /// ```
39    /// and similarly for ̅y.
40    ///
41    /// If `M` is empty (either zero observations or zero random variables), it returns `Err(EmptyInput)`.
42    ///
43    /// **Panics** if `ddof` is negative or greater than or equal to the number of
44    /// observations, or if the type cast of `n_observations` from `usize` to `A` fails.
45    ///
46    /// # Example
47    ///
48    /// ```
49    /// use ndarray::{aview2, arr2};
50    /// use ndarray_stats::CorrelationExt;
51    ///
52    /// let a = arr2(&[[1., 3., 5.],
53    ///                [2., 4., 6.]]);
54    /// let covariance = a.cov(1.).unwrap();
55    /// assert_eq!(
56    ///    covariance,
57    ///    aview2(&[[4., 4.], [4., 4.]])
58    /// );
59    /// ```
60    fn cov(&self, ddof: A) -> Result<Array2<A>, EmptyInput>
61    where
62        A: Float + FromPrimitive;
63
64    /// Return the [Pearson correlation coefficients](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient)
65    /// for a 2-dimensional array of observations `M`.
66    ///
67    /// Let `(r, o)` be the shape of `M`:
68    /// - `r` is the number of random variables;
69    /// - `o` is the number of observations we have collected
70    /// for each random variable.
71    ///
72    /// Every column in `M` is an experiment: a single observation for each
73    /// random variable.
74    /// Each row in `M` contains all the observations for a certain random variable.
75    ///
76    /// The Pearson correlation coefficient of two random variables is defined as:
77    ///
78    /// ```text
79    ///              cov(X, Y)
80    /// rho(X, Y) = ――――――――――――
81    ///             std(X)std(Y)
82    /// ```
83    ///
84    /// Let `R` be the matrix returned by this function. Then
85    /// ```text
86    /// R_ij = rho(X_i, X_j)
87    /// ```
88    ///
89    /// If `M` is empty (either zero observations or zero random variables), it returns `Err(EmptyInput)`.
90    ///
91    /// **Panics** if the type cast of `n_observations` from `usize` to `A` fails or
92    /// if the standard deviation of one of the random variables is zero and
93    /// division by zero panics for type A.
94    ///
95    /// # Example
96    ///
97    /// ```
98    /// use approx;
99    /// use ndarray::arr2;
100    /// use ndarray_stats::CorrelationExt;
101    /// use approx::AbsDiffEq;
102    ///
103    /// let a = arr2(&[[1., 3., 5.],
104    ///                [2., 4., 6.]]);
105    /// let corr = a.pearson_correlation().unwrap();
106    /// let epsilon = 1e-7;
107    /// assert!(
108    ///     corr.abs_diff_eq(
109    ///         &arr2(&[
110    ///             [1., 1.],
111    ///             [1., 1.],
112    ///         ]),
113    ///         epsilon
114    ///     )
115    /// );
116    /// ```
117    fn pearson_correlation(&self) -> Result<Array2<A>, EmptyInput>
118    where
119        A: Float + FromPrimitive;
120
121    private_decl! {}
122}
123
124impl<A: 'static> CorrelationExt<A> for ArrayRef2<A> {
125    fn cov(&self, ddof: A) -> Result<Array2<A>, EmptyInput>
126    where
127        A: Float + FromPrimitive,
128    {
129        let observation_axis = Axis(1);
130        let n_observations = A::from_usize(self.len_of(observation_axis)).unwrap();
131        let dof = if ddof >= n_observations {
132            panic!(
133                "`ddof` needs to be strictly smaller than the \
134                 number of observations provided for each \
135                 random variable!"
136            )
137        } else {
138            n_observations - ddof
139        };
140        let mean = self.mean_axis(observation_axis);
141        match mean {
142            Some(mean) => {
143                let denoised = self - mean.insert_axis(observation_axis);
144                let covariance = denoised.dot(&denoised.t());
145                Ok(covariance.mapv_into(|x| x / dof))
146            }
147            None => Err(EmptyInput),
148        }
149    }
150
151    fn pearson_correlation(&self) -> Result<Array2<A>, EmptyInput>
152    where
153        A: Float + FromPrimitive,
154    {
155        match self.dim() {
156            (n, m) if n > 0 && m > 0 => {
157                let observation_axis = Axis(1);
158                // The ddof value doesn't matter, as long as we use the same one
159                // for computing covariance and standard deviation
160                // We choose 0 as it is the smallest number admitted by std_axis
161                let ddof = A::zero();
162                let cov = self.cov(ddof).unwrap();
163                let std = self
164                    .std_axis(observation_axis, ddof)
165                    .insert_axis(observation_axis);
166                let std_matrix = std.dot(&std.t());
167                // element-wise division
168                Ok(cov / std_matrix)
169            }
170            _ => Err(EmptyInput),
171        }
172    }
173
174    private_impl! {}
175}
176
177#[cfg(test)]
178mod cov_tests {
179    use super::*;
180    use ndarray::array;
181    use ndarray_rand::rand;
182    use ndarray_rand::rand_distr::Uniform;
183    use ndarray_rand::RandomExt;
184    use quickcheck_macros::quickcheck;
185
186    #[quickcheck]
187    fn constant_random_variables_have_zero_covariance_matrix(value: f64) -> bool {
188        let n_random_variables = 3;
189        let n_observations = 4;
190        let a = Array::from_elem((n_random_variables, n_observations), value);
191        abs_diff_eq!(
192            a.cov(1.).unwrap(),
193            &Array::zeros((n_random_variables, n_random_variables)),
194            epsilon = 1e-8,
195        )
196    }
197
198    #[quickcheck]
199    fn covariance_matrix_is_symmetric(bound: f64) -> bool {
200        let n_random_variables = 3;
201        let n_observations = 4;
202        let a = Array::random(
203            (n_random_variables, n_observations),
204            Uniform::new(-bound.abs(), bound.abs()).unwrap(),
205        );
206        let covariance = a.cov(1.).unwrap();
207        abs_diff_eq!(covariance, &covariance.t(), epsilon = 1e-8)
208    }
209
210    #[test]
211    #[should_panic]
212    fn test_invalid_ddof() {
213        let n_random_variables = 3;
214        let n_observations = 4;
215        let a = Array::random(
216            (n_random_variables, n_observations),
217            Uniform::new(0., 10.).unwrap(),
218        );
219        let invalid_ddof = (n_observations as f64) + rand::random::<f64>().abs();
220        let _ = a.cov(invalid_ddof);
221    }
222
223    #[test]
224    fn test_covariance_zero_variables() {
225        let a = Array2::<f32>::zeros((0, 2));
226        let cov = a.cov(1.);
227        assert!(cov.is_ok());
228        assert_eq!(cov.unwrap().shape(), &[0, 0]);
229    }
230
231    #[test]
232    fn test_covariance_zero_observations() {
233        let a = Array2::<f32>::zeros((2, 0));
234        // Negative ddof (-1 < 0) to avoid invalid-ddof panic
235        let cov = a.cov(-1.);
236        assert_eq!(cov, Err(EmptyInput));
237    }
238
239    #[test]
240    fn test_covariance_zero_variables_zero_observations() {
241        let a = Array2::<f32>::zeros((0, 0));
242        // Negative ddof (-1 < 0) to avoid invalid-ddof panic
243        let cov = a.cov(-1.);
244        assert_eq!(cov, Err(EmptyInput));
245    }
246
247    #[test]
248    fn test_covariance_for_random_array() {
249        let a = array![
250            [0.72009497, 0.12568055, 0.55705966, 0.5959984, 0.69471457],
251            [0.56717131, 0.47619486, 0.21526298, 0.88915366, 0.91971245],
252            [0.59044195, 0.10720363, 0.76573717, 0.54693675, 0.95923036],
253            [0.24102952, 0.131347, 0.11118028, 0.21451351, 0.30515539],
254            [0.26952473, 0.93079841, 0.8080893, 0.42814155, 0.24642258]
255        ];
256        let numpy_covariance = array![
257            [0.05786248, 0.02614063, 0.06446215, 0.01285105, -0.06443992],
258            [0.02614063, 0.08733569, 0.02436933, 0.01977437, -0.06715555],
259            [0.06446215, 0.02436933, 0.10052129, 0.01393589, -0.06129912],
260            [0.01285105, 0.01977437, 0.01393589, 0.00638795, -0.02355557],
261            [
262                -0.06443992,
263                -0.06715555,
264                -0.06129912,
265                -0.02355557,
266                0.09909855
267            ]
268        ];
269        assert_eq!(a.ndim(), 2);
270        assert_abs_diff_eq!(a.cov(1.).unwrap(), &numpy_covariance, epsilon = 1e-8);
271    }
272
273    #[test]
274    #[should_panic]
275    // We lose precision, hence the failing assert
276    fn test_covariance_for_badly_conditioned_array() {
277        let a: Array2<f64> = array![[1e12 + 1., 1e12 - 1.], [1e-6 + 1e-12, 1e-6 - 1e-12],];
278        let expected_covariance = array![[2., 2e-12], [2e-12, 2e-24]];
279        assert_abs_diff_eq!(a.cov(1.).unwrap(), &expected_covariance, epsilon = 1e-24);
280    }
281}
282
283#[cfg(test)]
284mod pearson_correlation_tests {
285    use super::*;
286    use ndarray::array;
287    use ndarray::Array;
288    use ndarray_rand::rand_distr::Uniform;
289    use ndarray_rand::RandomExt;
290    use quickcheck_macros::quickcheck;
291
292    #[quickcheck]
293    fn output_matrix_is_symmetric(bound: f64) -> bool {
294        let n_random_variables = 3;
295        let n_observations = 4;
296        let a = Array::random(
297            (n_random_variables, n_observations),
298            Uniform::new(-bound.abs(), bound.abs()).unwrap(),
299        );
300        let pearson_correlation = a.pearson_correlation().unwrap();
301        abs_diff_eq!(
302            pearson_correlation.view(),
303            pearson_correlation.t(),
304            epsilon = 1e-8
305        )
306    }
307
308    #[quickcheck]
309    fn constant_random_variables_have_nan_correlation(value: f64) -> bool {
310        let n_random_variables = 3;
311        let n_observations = 4;
312        let a = Array::from_elem((n_random_variables, n_observations), value);
313        let pearson_correlation = a.pearson_correlation();
314        pearson_correlation
315            .unwrap()
316            .iter()
317            .map(|x| x.is_nan())
318            .fold(true, |acc, flag| acc & flag)
319    }
320
321    #[test]
322    fn test_zero_variables() {
323        let a = Array2::<f32>::zeros((0, 2));
324        let pearson_correlation = a.pearson_correlation();
325        assert_eq!(pearson_correlation, Err(EmptyInput))
326    }
327
328    #[test]
329    fn test_zero_observations() {
330        let a = Array2::<f32>::zeros((2, 0));
331        let pearson = a.pearson_correlation();
332        assert_eq!(pearson, Err(EmptyInput));
333    }
334
335    #[test]
336    fn test_zero_variables_zero_observations() {
337        let a = Array2::<f32>::zeros((0, 0));
338        let pearson = a.pearson_correlation();
339        assert_eq!(pearson, Err(EmptyInput));
340    }
341
342    #[test]
343    fn test_for_random_array() {
344        let a = array![
345            [0.16351516, 0.56863268, 0.16924196, 0.72579120],
346            [0.44342453, 0.19834387, 0.25411802, 0.62462382],
347            [0.97162731, 0.29958849, 0.17338142, 0.80198342],
348            [0.91727132, 0.79817799, 0.62237124, 0.38970998],
349            [0.26979716, 0.20887228, 0.95454999, 0.96290785]
350        ];
351        let numpy_corrcoeff = array![
352            [1., 0.38089376, 0.08122504, -0.59931623, 0.1365648],
353            [0.38089376, 1., 0.80918429, -0.52615195, 0.38954398],
354            [0.08122504, 0.80918429, 1., 0.07134906, -0.17324776],
355            [-0.59931623, -0.52615195, 0.07134906, 1., -0.8743213],
356            [0.1365648, 0.38954398, -0.17324776, -0.8743213, 1.]
357        ];
358        assert_eq!(a.ndim(), 2);
359        assert_abs_diff_eq!(
360            a.pearson_correlation().unwrap(),
361            numpy_corrcoeff,
362            epsilon = 1e-7
363        );
364    }
365}