kendalls/
lib.rs

1//! [Kendall's tau rank correlation](https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient).
2//! Initially the library was based on
3//! [Apache Commons Math](http://commons.apache.org/proper/commons-math/) library with some
4//! additions taken from [scipy](https://github.com/scipy/scipy)
5//! and R [cor.test](https://github.com/SurajGupta/r-source/blob/master/src/library/stats/R/cor.test.R) function.
6//!
7//! Example usage:
8//! ```
9//! let (tau_b, significance) = kendalls::tau_b(&[1, 2, 3], &[3, 4, 5]).unwrap();
10//! assert_eq!(tau_b, 1.0);
11//! assert_eq!(significance, 1.5666989036012806);
12//! ```
13//! If you want to compute correlation, let's say, for `f64` type, then you will have to
14//! provide either a custom comparator function or declare `Ord` trait for your custom floating point
15//! numbers type (see [float](https://crates.io/crates/float) crate).
16//!
17//! ```
18//! use std::cmp::Ordering;
19//!
20//! let (tau_b, _significance) = kendalls::tau_b_with_comparator(
21//!     &[1.0, 2.0],
22//!     &[3.0, 4.0],
23//!     |a: &f64, b: &f64| a.partial_cmp(&b).unwrap_or(Ordering::Greater),
24//! ).unwrap();
25//! assert_eq!(tau_b, 1.0);
26//! ```
27//!
28//! The function will return an error if you pass empty arrays into it or `x` and `y` arrays'
29//! dimensions are not equal.
30use std::cmp::Ordering;
31use std::error::Error as StdError;
32use std::fmt::{Display, Error as FmtError, Formatter};
33use std::result::Result;
34
35#[derive(Debug, PartialEq)]
36pub enum Error {
37    DimensionMismatch { expected: usize, got: usize },
38    InsufficientLength,
39}
40
41impl Display for Error {
42    fn fmt(&self, f: &mut Formatter) -> Result<(), FmtError> {
43        match self {
44            Error::InsufficientLength => write!(f, "insufficient array length"),
45            Error::DimensionMismatch { expected, got } => {
46                write!(f, "dimension mismatch: {expected} != {got}")
47            }
48        }
49    }
50}
51
52impl StdError for Error {}
53
54/// Implementation of Kendall's Tau-b rank correlation between two arrays.
55///
56/// The definition of Kendall’s tau that is used is:
57///
58/// `tau = (P - Q) / sqrt((P + Q + T) * (P + Q + U))`
59///
60/// where P is the number of concordant pairs, Q the number of discordant pairs, T the number of
61/// ties only in x, and U the number of ties only in y. If a tie occurs for the same pair in
62/// both x and y, it is not added to either T or U.
63pub fn tau_b<T>(x: &[T], y: &[T]) -> Result<(f64, f64), Error>
64where
65    T: Ord + Clone + Default,
66{
67    tau_b_with_comparator(x, y, |a, b| a.cmp(b))
68}
69
70/// The same as `tau_b` but also allow to specify custom comparator for numbers for
71/// which [Ord] trait is not defined.
72#[allow(clippy::many_single_char_names)]
73pub fn tau_b_with_comparator<T, F>(x: &[T], y: &[T], mut comparator: F) -> Result<(f64, f64), Error>
74where
75    T: PartialOrd + Clone + Default,
76    F: FnMut(&T, &T) -> Ordering,
77{
78    if x.len() != y.len() {
79        return Err(Error::DimensionMismatch {
80            expected: x.len(),
81            got: y.len(),
82        });
83    }
84
85    if x.is_empty() {
86        return Err(Error::InsufficientLength);
87    }
88
89    let n = x.len();
90
91    let mut pairs: Vec<(T, T)> = x.iter().cloned().zip(y.iter().cloned()).collect();
92
93    pairs.sort_unstable_by(|pair1, pair2| {
94        let res = comparator(&pair1.0, &pair2.0);
95        if res == Ordering::Equal {
96            comparator(&pair1.1, &pair2.1)
97        } else {
98            res
99        }
100    });
101
102    let mut v1_part_1 = 0usize;
103    let mut v2_part_1 = 0isize;
104
105    let mut tied_x_pairs = 0usize;
106    let mut tied_xy_pairs = 0usize;
107    let mut vt = 0usize;
108    let mut consecutive_x_ties = 1usize;
109    let mut consecutive_xy_ties = 1usize;
110
111    for i in 1..n {
112        let prev = &pairs[i - 1];
113        let curr = &pairs[i];
114        if curr.0 == prev.0 {
115            consecutive_x_ties += 1;
116            if curr.1 == prev.1 {
117                consecutive_xy_ties += 1;
118            } else {
119                tied_xy_pairs += sum(consecutive_xy_ties - 1);
120                consecutive_xy_ties = 1;
121            }
122        } else {
123            update_x_group(
124                &mut vt,
125                &mut tied_x_pairs,
126                &mut tied_xy_pairs,
127                &mut v1_part_1,
128                &mut v2_part_1,
129                consecutive_x_ties,
130                consecutive_xy_ties,
131            );
132            consecutive_x_ties = 1;
133            consecutive_xy_ties = 1;
134        }
135    }
136
137    update_x_group(
138        &mut vt,
139        &mut tied_x_pairs,
140        &mut tied_xy_pairs,
141        &mut v1_part_1,
142        &mut v2_part_1,
143        consecutive_x_ties,
144        consecutive_xy_ties,
145    );
146
147    let mut swaps = 0usize;
148    let mut pairs_dest: Vec<(T, T)> = vec![(Default::default(), Default::default()); n];
149
150    let mut segment_size = 1usize;
151    while segment_size < n {
152        for offset in (0..n).step_by(2 * segment_size) {
153            let mut i = offset;
154            let i_end = n.min(i + segment_size);
155            let mut j = i_end;
156            let j_end = n.min(j + segment_size);
157            let mut copy_location = offset;
158
159            while i < i_end && j < j_end {
160                let a = &pairs[i].1;
161                let b = &pairs[j].1;
162
163                if a.partial_cmp(b).unwrap_or(Ordering::Greater) == Ordering::Greater {
164                    pairs_dest[copy_location] = pairs[j].clone();
165                    j += 1;
166                    swaps += i_end - i;
167                } else {
168                    pairs_dest[copy_location] = pairs[i].clone();
169                    i += 1;
170                }
171
172                copy_location += 1;
173            }
174
175            while i < i_end {
176                pairs_dest[copy_location] = pairs[i].clone();
177                i += 1;
178                copy_location += 1
179            }
180
181            while j < j_end {
182                pairs_dest[copy_location] = pairs[j].clone();
183                j += 1;
184                copy_location += 1
185            }
186        }
187        std::mem::swap(&mut pairs, &mut pairs_dest);
188
189        segment_size <<= 1;
190    }
191
192    let mut v1_part_2 = 0usize;
193    let mut v2_part_2 = 0isize;
194    let mut tied_y_pairs = 0usize;
195    let mut consecutive_y_ties = 1usize;
196    let mut vu = 0usize;
197
198    for j in 1..n {
199        let prev = &pairs[j - 1];
200        let curr = &pairs[j];
201        if curr.1 == prev.1 {
202            consecutive_y_ties += 1;
203        } else {
204            update_y_group(
205                &mut vu,
206                &mut tied_y_pairs,
207                &mut v1_part_2,
208                &mut v2_part_2,
209                consecutive_y_ties,
210            );
211            consecutive_y_ties = 1;
212        }
213    }
214
215    update_y_group(
216        &mut vu,
217        &mut tied_y_pairs,
218        &mut v1_part_2,
219        &mut v2_part_2,
220        consecutive_y_ties,
221    );
222
223    // Generates T1 and T2 for significance
224    let v1 = (v1_part_1 * v1_part_2) as f64;
225    let v2 = (v2_part_1 * v2_part_2) as f64;
226
227    // Prevents overflow on subtraction
228    let num_pairs_f: f64 = ((n * (n - 1)) as f64) / 2.0; // sum(n - 1).as_();
229    let tied_x_pairs_f: f64 = tied_x_pairs as f64;
230    let tied_y_pairs_f: f64 = tied_y_pairs as f64;
231    let tied_xy_pairs_f: f64 = tied_xy_pairs as f64;
232    let swaps_f: f64 = (2 * swaps) as f64;
233
234    // Note that tot = con + dis + (xtie - ntie) + (ytie - ntie) + ntie
235    //               = con + dis + xtie + ytie - ntie
236    //
237    //           C-D = tot - xtie - ytie + ntie - 2 * dis
238    let concordant_minus_discordant =
239        num_pairs_f - tied_x_pairs_f - tied_y_pairs_f + tied_xy_pairs_f - swaps_f;
240
241    // non_tied_pairs_multiplied = ((n0 - n1) * (n0 - n2)).sqrt()
242    let non_tied_pairs_multiplied = (num_pairs_f - tied_x_pairs_f) * (num_pairs_f - tied_y_pairs_f);
243
244    let tau_b = concordant_minus_discordant / non_tied_pairs_multiplied.sqrt();
245
246    // Significance
247    let v0 = (n * (n - 1)) * (2 * n + 5);
248    let n_f = n as f64;
249
250    let v0_isize = v0 as isize;
251    let vt_isize = vt as isize;
252    let vu_isize = vu as isize;
253    let var_s = (v0_isize - vt_isize - vu_isize) as f64 / 18.0
254        + v1 / (2.0 * n_f * (n_f - 1.0))
255        + v2 / (9.0 * n_f * (n_f - 1.0) * (n_f - 2.0));
256
257    let s = tau_b * non_tied_pairs_multiplied.sqrt();
258    let z = s / var_s.sqrt();
259
260    // Limit range to fix computational errors
261    Ok((tau_b.clamp(-1.0, 1.0), z))
262}
263
264#[inline]
265fn sum(n: usize) -> usize {
266    n * (n + 1_usize) / 2_usize
267}
268
269/// Updated vt, v1_part_1, v2_part_1, tied_x_pairs, tied_xy_pairs variables with current tied group in X
270fn update_x_group(
271    vt: &mut usize,
272    tied_x_pairs: &mut usize,
273    tied_xy_pairs: &mut usize,
274    v1_part_1: &mut usize,
275    v2_part_1: &mut isize,
276    consecutive_x_ties: usize,
277    consecutive_xy_ties: usize,
278) {
279    *vt += consecutive_x_ties * (consecutive_x_ties - 1) * (2 * consecutive_x_ties + 5);
280    *v1_part_1 += consecutive_x_ties * (consecutive_x_ties - 1);
281
282    let consecutive_x_ties_i = consecutive_x_ties as isize;
283    *v2_part_1 += consecutive_x_ties_i * (consecutive_x_ties_i - 1) * (consecutive_x_ties_i - 2);
284
285    *tied_x_pairs += sum(consecutive_x_ties - 1);
286    *tied_xy_pairs += sum(consecutive_xy_ties - 1);
287}
288
289/// Updated vu, tied_y_pairs, v1_part_2 and v2_part_2 variables with current tied group in Y
290fn update_y_group(
291    vu: &mut usize,
292    tied_y_pairs: &mut usize,
293    v1_part_2: &mut usize,
294    v2_part_2: &mut isize,
295    consecutive_y_ties: usize,
296) {
297    *vu += consecutive_y_ties * (consecutive_y_ties - 1) * (2 * consecutive_y_ties + 5);
298    *v1_part_2 += consecutive_y_ties * (consecutive_y_ties - 1);
299
300    let consecutive_y_ties_i = consecutive_y_ties as isize;
301    *v2_part_2 += consecutive_y_ties_i * (consecutive_y_ties_i - 1) * (consecutive_y_ties_i - 2);
302
303    *tied_y_pairs += sum(consecutive_y_ties - 1);
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use float_cmp::assert_approx_eq;
310
311    #[test]
312    fn xy_consecutive_pair_test() {
313        let x = vec![
314            12.0, 14.0, 14.0, 17.0, 19.0, 19.0, 19.0, 19.0, 19.0, 20.0, 21.0, 21.0, 21.0, 21.0,
315            21.0, 22.0, 23.0, 24.0, 24.0, 24.0, 26.0, 26.0, 27.0,
316        ];
317        let y = vec![
318            11.0, 4.0, 4.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0,
319            4.0, 0.0, 0.0, 0.0, 0.0, 0.0,
320        ];
321
322        let (tau_b, z) = tau_b_with_comparator(&x, &y, |a: &f64, b: &f64| {
323            a.partial_cmp(b).unwrap_or(Ordering::Greater)
324        })
325        .unwrap();
326
327        assert_approx_eq!(f64, tau_b, -0.3762015410475098);
328        assert_approx_eq!(f64, z, -2.09764910068664);
329    }
330
331    #[test]
332    fn shifted_test() {
333        let comparator = |a: &f64, b: &f64| a.partial_cmp(b).unwrap_or(Ordering::Greater);
334
335        let x = &[1.0, 1.0, 2.0, 2.0, 3.0, 3.0];
336        let y = &[1.0, 2.0, 2.0, 3.0, 3.0, 4.0];
337        let (tau_b, z) = tau_b_with_comparator(&x[..], &y[..], comparator).unwrap();
338        assert_approx_eq!(f64, tau_b, 0.8006407690254358);
339        assert_approx_eq!(f64, z, 2.0526, epsilon = 0.0001);
340
341        let x = &[12.0, 2.0, 1.0, 12.0, 2.0];
342        let y = &[1.0, 4.0, 7.0, 1.0, 0.0];
343        let (tau_b, z) = tau_b_with_comparator(&x[..], &y[..], comparator).unwrap();
344        assert_approx_eq!(f64, tau_b, -0.4714045207910316);
345        assert_approx_eq!(f64, z, -1.0742, epsilon = 0.0001);
346    }
347
348    #[test]
349    fn simple_correlated_data() {
350        let (tau_b, z) = tau_b(&[1, 2, 3], &[3, 4, 5]).unwrap();
351        assert_eq!(tau_b, 1.0);
352        assert_approx_eq!(f64, z, 1.5666989036012806);
353    }
354
355    #[test]
356    fn simple_correlated_reversed() {
357        let (tau_b, z) = tau_b(&[1, 2, 3], &[5, 4, 3]).unwrap();
358        assert_eq!(tau_b, -1.0);
359        assert_approx_eq!(f64, z, -1.5666989036012806);
360    }
361
362    #[test]
363    fn simple_jumble() {
364        let x = &[1.0, 2.0, 3.0, 4.0];
365        let y = &[1.0, 3.0, 2.0, 4.0];
366
367        // 6 pairs: (A,B) (A,C) (A,D) (B,C) (B,D) (C,D)
368        // (B,C) is discordant, the other 5 are concordant
369        let expected_tau_b = (5.0 - 1.0) / 6.0;
370        let expected_z = 1.3587324409735149;
371
372        assert_eq!(
373            tau_b_with_comparator(x, y, |a: &f64, b: &f64| a
374                .partial_cmp(b)
375                .unwrap_or(Ordering::Greater)),
376            Ok((expected_tau_b, expected_z))
377        );
378    }
379
380    #[test]
381    fn balanced_jumble() {
382        let x = [1.0, 2.0, 3.0, 4.0];
383        let y = [1.0, 4.0, 3.0, 2.0];
384
385        // 6 pairs: (A,B) (A,C) (A,D) (B,C) (B,D) (C,D)
386        // (A,B) (A,C), (A,D) are concordant, the other 3 are discordant
387
388        assert_eq!(
389            tau_b_with_comparator(&x, &y, |a: &f64, b: &f64| a
390                .partial_cmp(b)
391                .unwrap_or(Ordering::Greater)),
392            Ok((0.0, 0.0))
393        );
394    }
395
396    #[test]
397    fn fails_if_dimentions_does_not_match() {
398        let res = tau_b(&[1, 2, 3], &[5, 4]);
399        assert_eq!(
400            res,
401            Err(Error::DimensionMismatch {
402                expected: 3,
403                got: 2
404            })
405        );
406    }
407
408    #[test]
409    fn fails_if_arrays_are_empty() {
410        let res = tau_b::<i32>(&[], &[]);
411        assert_eq!(res, Err(Error::InsufficientLength));
412    }
413
414    #[test]
415    fn it_format_dimension_mismatch_error() {
416        let error = Error::DimensionMismatch {
417            expected: 2,
418            got: 1,
419        };
420        assert_eq!("dimension mismatch: 2 != 1", format!("{}", error));
421    }
422
423    #[test]
424    fn it_format_insufficient_length_error() {
425        let error = Error::InsufficientLength {};
426        assert_eq!("insufficient array length", format!("{}", error));
427    }
428
429    #[test]
430    /// Checks that lib does not panic subtracting some usize values
431    fn test_subtract_with_overflow() {
432        let x = vec![
433            -0.1309, -0.1309, -0.1309, -0.1309, -0.1309, -0.1309, -0.1309, -0.1309, -0.1309, 6.8901,
434        ];
435        let y = vec![1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
436
437        let result = std::panic::catch_unwind(|| {
438            let (_tau, _significance) = tau_b_with_comparator(&x, &y, |a: &f64, b: &f64| {
439                a.partial_cmp(b).unwrap_or(Ordering::Greater)
440            })
441            .unwrap();
442        });
443        assert!(result.is_ok()); // Should not panic
444    }
445}