ndarray_stats/
correlation.rs

1use crate::errors::EmptyInput;
2use ndarray::prelude::*;
3use ndarray::Data;
4use num_traits::{Float, FromPrimitive};
5
6/// Extension trait for `ArrayBase` providing functions
7/// to compute different correlation measures.
8pub trait CorrelationExt<A, S>
9where
10    S: Data<Elem = A>,
11{
12    /// Return the covariance matrix `C` for a 2-dimensional
13    /// array of observations `M`.
14    ///
15    /// Let `(r, o)` be the shape of `M`:
16    /// - `r` is the number of random variables;
17    /// - `o` is the number of observations we have collected
18    /// for each random variable.
19    ///
20    /// Every column in `M` is an experiment: a single observation for each
21    /// random variable.
22    /// Each row in `M` contains all the observations for a certain random variable.
23    ///
24    /// The parameter `ddof` specifies the "delta degrees of freedom". For
25    /// example, to calculate the population covariance, use `ddof = 0`, or to
26    /// calculate the sample covariance (unbiased estimate), use `ddof = 1`.
27    ///
28    /// The covariance of two random variables is defined as:
29    ///
30    /// ```text
31    ///                1       n
32    /// cov(X, Y) = ――――――――   ∑ (xᵢ - x̅)(yᵢ - y̅)
33    ///             n - ddof  i=1
34    /// ```
35    ///
36    /// where
37    ///
38    /// ```text
39    ///     1   n
40    /// x̅ = ―   ∑ xᵢ
41    ///     n  i=1
42    /// ```
43    /// and similarly for ̅y.
44    ///
45    /// If `M` is empty (either zero observations or zero random variables), it returns `Err(EmptyInput)`.
46    ///
47    /// **Panics** if `ddof` is negative or greater than or equal to the number of
48    /// observations, or if the type cast of `n_observations` from `usize` to `A` fails.
49    ///
50    /// # Example
51    ///
52    /// ```
53    /// use ndarray::{aview2, arr2};
54    /// use ndarray_stats::CorrelationExt;
55    ///
56    /// let a = arr2(&[[1., 3., 5.],
57    ///                [2., 4., 6.]]);
58    /// let covariance = a.cov(1.).unwrap();
59    /// assert_eq!(
60    ///    covariance,
61    ///    aview2(&[[4., 4.], [4., 4.]])
62    /// );
63    /// ```
64    fn cov(&self, ddof: A) -> Result<Array2<A>, EmptyInput>
65    where
66        A: Float + FromPrimitive;
67
68    /// Return the [Pearson correlation coefficients](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient)
69    /// for a 2-dimensional array of observations `M`.
70    ///
71    /// Let `(r, o)` be the shape of `M`:
72    /// - `r` is the number of random variables;
73    /// - `o` is the number of observations we have collected
74    /// for each random variable.
75    ///
76    /// Every column in `M` is an experiment: a single observation for each
77    /// random variable.
78    /// Each row in `M` contains all the observations for a certain random variable.
79    ///
80    /// The Pearson correlation coefficient of two random variables is defined as:
81    ///
82    /// ```text
83    ///              cov(X, Y)
84    /// rho(X, Y) = ――――――――――――
85    ///             std(X)std(Y)
86    /// ```
87    ///
88    /// Let `R` be the matrix returned by this function. Then
89    /// ```text
90    /// R_ij = rho(X_i, X_j)
91    /// ```
92    ///
93    /// If `M` is empty (either zero observations or zero random variables), it returns `Err(EmptyInput)`.
94    ///
95    /// **Panics** if the type cast of `n_observations` from `usize` to `A` fails or
96    /// if the standard deviation of one of the random variables is zero and
97    /// division by zero panics for type A.
98    ///
99    /// # Example
100    ///
101    /// ```
102    /// use approx;
103    /// use ndarray::arr2;
104    /// use ndarray_stats::CorrelationExt;
105    /// use approx::AbsDiffEq;
106    ///
107    /// let a = arr2(&[[1., 3., 5.],
108    ///                [2., 4., 6.]]);
109    /// let corr = a.pearson_correlation().unwrap();
110    /// let epsilon = 1e-7;
111    /// assert!(
112    ///     corr.abs_diff_eq(
113    ///         &arr2(&[
114    ///             [1., 1.],
115    ///             [1., 1.],
116    ///         ]),
117    ///         epsilon
118    ///     )
119    /// );
120    /// ```
121    fn pearson_correlation(&self) -> Result<Array2<A>, EmptyInput>
122    where
123        A: Float + FromPrimitive;
124
125    private_decl! {}
126}
127
128impl<A: 'static, S> CorrelationExt<A, S> for ArrayBase<S, Ix2>
129where
130    S: Data<Elem = A>,
131{
132    fn cov(&self, ddof: A) -> Result<Array2<A>, EmptyInput>
133    where
134        A: Float + FromPrimitive,
135    {
136        let observation_axis = Axis(1);
137        let n_observations = A::from_usize(self.len_of(observation_axis)).unwrap();
138        let dof = if ddof >= n_observations {
139            panic!(
140                "`ddof` needs to be strictly smaller than the \
141                 number of observations provided for each \
142                 random variable!"
143            )
144        } else {
145            n_observations - ddof
146        };
147        let mean = self.mean_axis(observation_axis);
148        match mean {
149            Some(mean) => {
150                let denoised = self - &mean.insert_axis(observation_axis);
151                let covariance = denoised.dot(&denoised.t());
152                Ok(covariance.mapv_into(|x| x / dof))
153            }
154            None => Err(EmptyInput),
155        }
156    }
157
158    fn pearson_correlation(&self) -> Result<Array2<A>, EmptyInput>
159    where
160        A: Float + FromPrimitive,
161    {
162        match self.dim() {
163            (n, m) if n > 0 && m > 0 => {
164                let observation_axis = Axis(1);
165                // The ddof value doesn't matter, as long as we use the same one
166                // for computing covariance and standard deviation
167                // We choose 0 as it is the smallest number admitted by std_axis
168                let ddof = A::zero();
169                let cov = self.cov(ddof).unwrap();
170                let std = self
171                    .std_axis(observation_axis, ddof)
172                    .insert_axis(observation_axis);
173                let std_matrix = std.dot(&std.t());
174                // element-wise division
175                Ok(cov / std_matrix)
176            }
177            _ => Err(EmptyInput),
178        }
179    }
180
181    private_impl! {}
182}
183
184#[cfg(test)]
185mod cov_tests {
186    use super::*;
187    use ndarray::array;
188    use ndarray_rand::rand;
189    use ndarray_rand::rand_distr::Uniform;
190    use ndarray_rand::RandomExt;
191    use quickcheck_macros::quickcheck;
192
193    #[quickcheck]
194    fn constant_random_variables_have_zero_covariance_matrix(value: f64) -> bool {
195        let n_random_variables = 3;
196        let n_observations = 4;
197        let a = Array::from_elem((n_random_variables, n_observations), value);
198        abs_diff_eq!(
199            a.cov(1.).unwrap(),
200            &Array::zeros((n_random_variables, n_random_variables)),
201            epsilon = 1e-8,
202        )
203    }
204
205    #[quickcheck]
206    fn covariance_matrix_is_symmetric(bound: f64) -> bool {
207        let n_random_variables = 3;
208        let n_observations = 4;
209        let a = Array::random(
210            (n_random_variables, n_observations),
211            Uniform::new(-bound.abs(), bound.abs()),
212        );
213        let covariance = a.cov(1.).unwrap();
214        abs_diff_eq!(covariance, &covariance.t(), epsilon = 1e-8)
215    }
216
217    #[test]
218    #[should_panic]
219    fn test_invalid_ddof() {
220        let n_random_variables = 3;
221        let n_observations = 4;
222        let a = Array::random((n_random_variables, n_observations), Uniform::new(0., 10.));
223        let invalid_ddof = (n_observations as f64) + rand::random::<f64>().abs();
224        let _ = a.cov(invalid_ddof);
225    }
226
227    #[test]
228    fn test_covariance_zero_variables() {
229        let a = Array2::<f32>::zeros((0, 2));
230        let cov = a.cov(1.);
231        assert!(cov.is_ok());
232        assert_eq!(cov.unwrap().shape(), &[0, 0]);
233    }
234
235    #[test]
236    fn test_covariance_zero_observations() {
237        let a = Array2::<f32>::zeros((2, 0));
238        // Negative ddof (-1 < 0) to avoid invalid-ddof panic
239        let cov = a.cov(-1.);
240        assert_eq!(cov, Err(EmptyInput));
241    }
242
243    #[test]
244    fn test_covariance_zero_variables_zero_observations() {
245        let a = Array2::<f32>::zeros((0, 0));
246        // Negative ddof (-1 < 0) to avoid invalid-ddof panic
247        let cov = a.cov(-1.);
248        assert_eq!(cov, Err(EmptyInput));
249    }
250
251    #[test]
252    fn test_covariance_for_random_array() {
253        let a = array![
254            [0.72009497, 0.12568055, 0.55705966, 0.5959984, 0.69471457],
255            [0.56717131, 0.47619486, 0.21526298, 0.88915366, 0.91971245],
256            [0.59044195, 0.10720363, 0.76573717, 0.54693675, 0.95923036],
257            [0.24102952, 0.131347, 0.11118028, 0.21451351, 0.30515539],
258            [0.26952473, 0.93079841, 0.8080893, 0.42814155, 0.24642258]
259        ];
260        let numpy_covariance = array![
261            [0.05786248, 0.02614063, 0.06446215, 0.01285105, -0.06443992],
262            [0.02614063, 0.08733569, 0.02436933, 0.01977437, -0.06715555],
263            [0.06446215, 0.02436933, 0.10052129, 0.01393589, -0.06129912],
264            [0.01285105, 0.01977437, 0.01393589, 0.00638795, -0.02355557],
265            [
266                -0.06443992,
267                -0.06715555,
268                -0.06129912,
269                -0.02355557,
270                0.09909855
271            ]
272        ];
273        assert_eq!(a.ndim(), 2);
274        assert_abs_diff_eq!(a.cov(1.).unwrap(), &numpy_covariance, epsilon = 1e-8);
275    }
276
277    #[test]
278    #[should_panic]
279    // We lose precision, hence the failing assert
280    fn test_covariance_for_badly_conditioned_array() {
281        let a: Array2<f64> = array![[1e12 + 1., 1e12 - 1.], [1e-6 + 1e-12, 1e-6 - 1e-12],];
282        let expected_covariance = array![[2., 2e-12], [2e-12, 2e-24]];
283        assert_abs_diff_eq!(a.cov(1.).unwrap(), &expected_covariance, epsilon = 1e-24);
284    }
285}
286
287#[cfg(test)]
288mod pearson_correlation_tests {
289    use super::*;
290    use ndarray::array;
291    use ndarray::Array;
292    use ndarray_rand::rand_distr::Uniform;
293    use ndarray_rand::RandomExt;
294    use quickcheck_macros::quickcheck;
295
296    #[quickcheck]
297    fn output_matrix_is_symmetric(bound: f64) -> bool {
298        let n_random_variables = 3;
299        let n_observations = 4;
300        let a = Array::random(
301            (n_random_variables, n_observations),
302            Uniform::new(-bound.abs(), bound.abs()),
303        );
304        let pearson_correlation = a.pearson_correlation().unwrap();
305        abs_diff_eq!(
306            pearson_correlation.view(),
307            pearson_correlation.t(),
308            epsilon = 1e-8
309        )
310    }
311
312    #[quickcheck]
313    fn constant_random_variables_have_nan_correlation(value: f64) -> bool {
314        let n_random_variables = 3;
315        let n_observations = 4;
316        let a = Array::from_elem((n_random_variables, n_observations), value);
317        let pearson_correlation = a.pearson_correlation();
318        pearson_correlation
319            .unwrap()
320            .iter()
321            .map(|x| x.is_nan())
322            .fold(true, |acc, flag| acc & flag)
323    }
324
325    #[test]
326    fn test_zero_variables() {
327        let a = Array2::<f32>::zeros((0, 2));
328        let pearson_correlation = a.pearson_correlation();
329        assert_eq!(pearson_correlation, Err(EmptyInput))
330    }
331
332    #[test]
333    fn test_zero_observations() {
334        let a = Array2::<f32>::zeros((2, 0));
335        let pearson = a.pearson_correlation();
336        assert_eq!(pearson, Err(EmptyInput));
337    }
338
339    #[test]
340    fn test_zero_variables_zero_observations() {
341        let a = Array2::<f32>::zeros((0, 0));
342        let pearson = a.pearson_correlation();
343        assert_eq!(pearson, Err(EmptyInput));
344    }
345
346    #[test]
347    fn test_for_random_array() {
348        let a = array![
349            [0.16351516, 0.56863268, 0.16924196, 0.72579120],
350            [0.44342453, 0.19834387, 0.25411802, 0.62462382],
351            [0.97162731, 0.29958849, 0.17338142, 0.80198342],
352            [0.91727132, 0.79817799, 0.62237124, 0.38970998],
353            [0.26979716, 0.20887228, 0.95454999, 0.96290785]
354        ];
355        let numpy_corrcoeff = array![
356            [1., 0.38089376, 0.08122504, -0.59931623, 0.1365648],
357            [0.38089376, 1., 0.80918429, -0.52615195, 0.38954398],
358            [0.08122504, 0.80918429, 1., 0.07134906, -0.17324776],
359            [-0.59931623, -0.52615195, 0.07134906, 1., -0.8743213],
360            [0.1365648, 0.38954398, -0.17324776, -0.8743213, 1.]
361        ];
362        assert_eq!(a.ndim(), 2);
363        assert_abs_diff_eq!(
364            a.pearson_correlation().unwrap(),
365            numpy_corrcoeff,
366            epsilon = 1e-7
367        );
368    }
369}