1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
//! [Kendall's tau rank correlation](https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient).
//! At this point this is basically a copy-paste
//! from [Apache Commons Math](http://commons.apache.org/proper/commons-math/) library with some
//! additions taken from [scipy](https://github.com/scipy/scipy).
//!
//! Example usage:
//! ```
//! let res = kendalls::tau_b(&[1, 2, 3], &[3, 4, 5]);
//! assert_eq!(res, Ok(1.0));
//! ```
//! If you want to compute correlation, let's say, for `f64` type, then you will have to
//! provide either a custom comparator function or declare `Ord` trait for your custom floating point
//! numbers type (see [float](https://crates.io/crates/float) crate).
//!
//! ```
//! use std::cmp::Ordering;
//!
//! let res = kendalls::tau_b_with_comparator(
//!            &[1.0, 2.0],
//!            &[3.0, 4.0],
//!            |a: &f64, b: &f64| a.partial_cmp(&b).unwrap_or(Ordering::Greater),
//!        );
//! assert_eq!(res, Ok(1.0));
//! ```
//!
//! The function will return an error if you pass empty arrays into it or `x` and `y` arrays'
//! dimensions are not equal.
use std::cmp::Ordering;
use std::error::Error as StdError;
use std::fmt::{Display, Error as FmtError, Formatter};
use std::result::Result;

#[derive(Debug, PartialEq)]
pub enum Error {
    DimensionMismatch { expected: usize, got: usize },
    InsufficientLength,
}

impl Display for Error {
    fn fmt(&self, f: &mut Formatter) -> Result<(), FmtError> {
        match self {
            Error::InsufficientLength => write!(f, "insufficient array length"),
            Error::DimensionMismatch { expected, got } => {
                write!(f, "dimension mismatch: {} != {}", expected, got)
            }
        }
    }
}

impl StdError for Error {}

/// Implementation of Kendall's Tau-b rank correlation between two arrays.
///
/// The definition of Kendall’s tau that is used is:
///
/// `tau = (P - Q) / sqrt((P + Q + T) * (P + Q + U))`
///
/// where P is the number of concordant pairs, Q the number of discordant pairs, T the number of
/// ties only in x, and U the number of ties only in y. If a tie occurs for the same pair in
/// both x and y, it is not added to either T or U.
pub fn tau_b<T>(x: &[T], y: &[T]) -> Result<f64, Error>
where
    T: Ord + Clone + Default,
{
    tau_b_with_comparator(x, y, |a, b| a.cmp(b))
}

/// The same as `tau_b` but also allow to specify custom comparator for numbers for
/// which [Ord] trait is not defined.
pub fn tau_b_with_comparator<T, F>(x: &[T], y: &[T], mut comparator: F) -> Result<f64, Error>
where
    T: PartialOrd + Clone + Default,
    F: FnMut(&T, &T) -> Ordering,
{
    if x.len() != y.len() {
        return Err(Error::DimensionMismatch {
            expected: x.len(),
            got: y.len(),
        });
    }

    if x.is_empty() {
        return Err(Error::InsufficientLength);
    }

    let n = x.len();

    let mut pairs: Vec<(T, T)> = Vec::with_capacity(n);
    for pair in x.iter().cloned().zip(y.iter().cloned()) {
        pairs.push(pair);
    }

    pairs.sort_by(|pair1, pair2| {
        let res = comparator(&pair1.0, &pair2.0);
        if res == Ordering::Equal {
            comparator(&pair1.1, &pair2.1)
        } else {
            res
        }
    });

    let mut tied_x_pairs = 0usize;
    let mut tied_xy_pairs = 0usize;
    let mut consecutive_x_ties = 1usize;
    let mut consecutive_xy_ties = 1usize;

    for i in 1..n {
        let prev = &pairs[i - 1];
        let curr = &pairs[i];
        if curr.0 == prev.0 {
            consecutive_x_ties += 1;
            if curr.1 == prev.1 {
                consecutive_xy_ties += 1;
            } else {
                tied_xy_pairs += sum(consecutive_xy_ties - 1);
            }
        } else {
            tied_x_pairs += sum(consecutive_x_ties - 1);
            consecutive_x_ties = 1;
            tied_xy_pairs += sum(consecutive_xy_ties - 1);
            consecutive_xy_ties = 1;
        }
    }

    tied_x_pairs += sum(consecutive_x_ties - 1);
    tied_xy_pairs += sum(consecutive_xy_ties - 1);

    let mut swaps = 0usize;
    let mut pairs_dest: Vec<(T, T)> = Vec::with_capacity(n);
    for _ in 0..n {
        pairs_dest.push((Default::default(), Default::default()));
    }

    let mut segment_size = 1usize;
    while segment_size < n {
        for offset in (0..n).step_by(2 * segment_size) {
            let mut i = offset;
            let i_end = n.min(i + segment_size);
            let mut j = i_end;
            let j_end = n.min(j + segment_size);

            let mut copy_location = offset;
            while i < i_end || j < j_end {
                if i < i_end {
                    if j < j_end {
                        if comparator(&pairs[i].1, &pairs[j].1) == Ordering::Less {
                            pairs_dest[copy_location] = pairs[i].clone();
                            i += 1;
                        } else {
                            pairs_dest[copy_location] = pairs[j].clone();
                            j += 1;
                            swaps += i_end - i;
                        }
                    } else {
                        pairs_dest[copy_location] = pairs[i].clone();
                        i += 1;
                    }
                } else {
                    pairs_dest[copy_location] = pairs[j].clone();
                    j += 1;
                }
                copy_location += 1;
            }
        }

        std::mem::swap(&mut pairs, &mut pairs_dest);

        segment_size <<= 1;
    }

    let mut tied_y_pairs = 0usize;
    let mut consecutive_y_ties = 1usize;

    for i in 1..n {
        let prev = &pairs[i - 1];
        let curr = &pairs[i];
        if curr.1 == prev.1 {
            consecutive_y_ties += 1;
        } else {
            tied_y_pairs += sum(consecutive_y_ties - 1);
            consecutive_y_ties = 1;
        }
    }

    tied_y_pairs += sum(consecutive_y_ties - 1);

    // to prevent overflow on subtraction
    let num_pairs_f: f64 = ((n * (n - 1)) as f64) / 2.0; // sum(n - 1).as_();
    let tied_x_pairs_f: f64 = tied_x_pairs as f64;
    let tied_y_pairs_f: f64 = tied_y_pairs as f64;
    let tied_xy_pairs_f: f64 = tied_xy_pairs as f64;
    let swaps_f: f64 = (2 * swaps) as f64;

    // Note that tot = con + dis + (xtie - ntie) + (ytie - ntie) + ntie
    //               = con + dis + xtie + ytie - ntie
    //
    //           C-D = tot - xtie - ytie + ntie - 2 * dis
    let concordant_minus_discordant =
        num_pairs_f - tied_x_pairs_f - tied_y_pairs_f + tied_xy_pairs_f - swaps_f;
    let non_tied_pairs_multiplied = (num_pairs_f - tied_x_pairs_f) * (num_pairs_f - tied_y_pairs_f);

    let tau = concordant_minus_discordant / non_tied_pairs_multiplied.sqrt();

    // limit range to fix computational errors
    Ok(tau.max(-1.0).min(1.0))
}

/// Calculate statistical significance: Z.
///
/// Typically any value greater than 1.96 is going to be statistically significant
/// against the Z-table with alpha set at 0.05.
pub fn significance(tau: f64, n: usize) -> f64 {

    let n_tmp: f64 = (n * (n - 1)) as f64;
    let deter: f64 = (2 * (2 * n + 5)) as f64;
    (3.0 * tau * n_tmp.sqrt()) / deter.sqrt()
}

#[inline]
fn sum(n: usize) -> usize {
    n * (n + 1usize) / 2usize
}

#[cfg(test)]
mod tests {

    use super::*;

    #[test]
    fn simple_correlated_data() {
        let res = tau_b(&[1, 2, 3], &[3, 4, 5]).unwrap();
        assert_eq!(res, 1.0);
    }

    #[test]
    fn simple_correlated_reversed() {
        let res = tau_b(&[1, 2, 3], &[5, 4, 3]).unwrap();
        assert_eq!(res, -1.0);
    }

    #[test]
    fn simple_jumble() {
        let x = &[1.0, 2.0, 3.0, 4.0];
        let y = &[1.0, 3.0, 2.0, 4.0];

        // 6 pairs: (A,B) (A,C) (A,D) (B,C) (B,D) (C,D)
        // (B,C) is discordant, the other 5 are concordant
        let expected = (5.0 - 1.0) / 6.0;

        assert_eq!(
            tau_b_with_comparator(x, y, |a: &f64, b: &f64| a
                .partial_cmp(&b)
                .unwrap_or(Ordering::Greater)),
            Ok(expected)
        );
    }

    #[test]
    fn balanced_jumble() {
        let x = [1.0, 2.0, 3.0, 4.0];
        let y = [1.0, 4.0, 3.0, 2.0];

        // 6 pairs: (A,B) (A,C) (A,D) (B,C) (B,D) (C,D)
        // (A,B) (A,C), (A,D) are concordant, the other 3 are discordant

        assert_eq!(
            tau_b_with_comparator(&x, &y, |a: &f64, b: &f64| a
                .partial_cmp(&b)
                .unwrap_or(Ordering::Greater)),
            Ok(0.0)
        );
    }

    #[test]
    fn fails_if_dimentions_does_not_match() {
        let res = tau_b(&[1, 2, 3], &[5, 4]);
        assert_eq!(
            res,
            Err(Error::DimensionMismatch {
                expected: 3,
                got: 2
            })
        );
    }

    #[test]
    fn fails_if_arrays_are_empty() {
        let res = tau_b::<i32>(&[], &[]);
        assert_eq!(res, Err(Error::InsufficientLength));
    }

    #[test]
    fn it_format_dimension_mismatch_error() {
        let error = Error::DimensionMismatch { expected: 2, got: 1 };
        assert_eq!("dimension mismatch: 2 != 1", format!("{}", error));
    }

    #[test]
    fn it_format_insufficient_length_error() {
        let error = Error::InsufficientLength {} ;
        assert_eq!("insufficient array length", format!("{}", error));
    }

    #[test]
    fn significance_computed_correctly_for_certain_values() {
        let res = significance(0.818, 12);
        assert!(res > 3.7009);
        assert!(res < 3.709);
    }
}