use std::cmp::Ordering;
use std::error::Error as StdError;
use std::fmt::{Display, Error as FmtError, Formatter};
use std::result::Result;
#[derive(Debug, PartialEq)]
pub enum Error {
DimensionMismatch { expected: usize, got: usize },
InsufficientLength,
}
impl Display for Error {
fn fmt(&self, f: &mut Formatter) -> Result<(), FmtError> {
match self {
Error::InsufficientLength => write!(f, "insufficient array length"),
Error::DimensionMismatch { expected, got } => {
write!(f, "dimension mismatch: {} != {}", expected, got)
}
}
}
}
impl StdError for Error {}
pub fn tau_b<T>(x: &[T], y: &[T]) -> Result<f64, Error>
where
T: Ord + Clone + Default,
{
tau_b_with_comparator(x, y, |a, b| a.cmp(b))
}
#[allow(clippy::many_single_char_names)]
pub fn tau_b_with_comparator<T, F>(x: &[T], y: &[T], mut comparator: F) -> Result<f64, Error>
where
T: PartialOrd + Clone + Default,
F: FnMut(&T, &T) -> Ordering,
{
if x.len() != y.len() {
return Err(Error::DimensionMismatch {
expected: x.len(),
got: y.len(),
});
}
if x.is_empty() {
return Err(Error::InsufficientLength);
}
let n = x.len();
let mut pairs: Vec<(T, T)> = Vec::with_capacity(n);
for pair in x.iter().cloned().zip(y.iter().cloned()) {
pairs.push(pair);
}
pairs.sort_by(|pair1, pair2| {
let res = comparator(&pair1.0, &pair2.0);
if res == Ordering::Equal {
comparator(&pair1.1, &pair2.1)
} else {
res
}
});
let mut tied_x_pairs = 0usize;
let mut tied_xy_pairs = 0usize;
let mut consecutive_x_ties = 1usize;
let mut consecutive_xy_ties = 1usize;
for i in 1..n {
let prev = &pairs[i - 1];
let curr = &pairs[i];
if curr.0 == prev.0 {
consecutive_x_ties += 1;
if curr.1 == prev.1 {
consecutive_xy_ties += 1;
} else {
tied_xy_pairs += sum(consecutive_xy_ties - 1);
consecutive_xy_ties = 1;
}
} else {
tied_x_pairs += sum(consecutive_x_ties - 1);
consecutive_x_ties = 1;
tied_xy_pairs += sum(consecutive_xy_ties - 1);
consecutive_xy_ties = 1;
}
}
tied_x_pairs += sum(consecutive_x_ties - 1);
tied_xy_pairs += sum(consecutive_xy_ties - 1);
let mut swaps = 0usize;
let mut pairs_dest: Vec<(T, T)> = vec![(Default::default(), Default::default()); n];
let mut segment_size = 1usize;
while segment_size < n {
for offset in (0..n).step_by(2 * segment_size) {
let mut i = offset;
let i_end = n.min(i + segment_size);
let mut j = i_end;
let j_end = n.min(j + segment_size);
let mut copy_location = offset;
while i < i_end || j < j_end {
if i < i_end {
if j < j_end {
let a = &pairs[i].1;
let b = &pairs[j].1;
if comparator(a, b) == Ordering::Greater {
pairs_dest[copy_location] = pairs[j].clone();
j += 1;
swaps += i_end - i;
} else {
pairs_dest[copy_location] = pairs[i].clone();
i += 1;
}
} else {
pairs_dest[copy_location] = pairs[i].clone();
i += 1;
}
} else {
pairs_dest[copy_location] = pairs[j].clone();
j += 1;
}
copy_location += 1;
}
}
std::mem::swap(&mut pairs, &mut pairs_dest);
segment_size <<= 1;
}
let mut tied_y_pairs = 0usize;
let mut consecutive_y_ties = 1usize;
for i in 1..n {
let prev = &pairs[i - 1];
let curr = &pairs[i];
if curr.1 == prev.1 {
consecutive_y_ties += 1;
} else {
tied_y_pairs += sum(consecutive_y_ties - 1);
consecutive_y_ties = 1;
}
}
tied_y_pairs += sum(consecutive_y_ties - 1);
let num_pairs_f: f64 = ((n * (n - 1)) as f64) / 2.0;
let tied_x_pairs_f: f64 = tied_x_pairs as f64;
let tied_y_pairs_f: f64 = tied_y_pairs as f64;
let tied_xy_pairs_f: f64 = tied_xy_pairs as f64;
let swaps_f: f64 = (2 * swaps) as f64;
let concordant_minus_discordant =
num_pairs_f - tied_x_pairs_f - tied_y_pairs_f + tied_xy_pairs_f - swaps_f;
let non_tied_pairs_multiplied = (num_pairs_f - tied_x_pairs_f) * (num_pairs_f - tied_y_pairs_f);
let tau = concordant_minus_discordant / non_tied_pairs_multiplied.sqrt();
Ok(tau.max(-1.0).min(1.0))
}
pub fn significance(tau: f64, n: usize) -> f64 {
let n_tmp: f64 = (n * (n - 1)) as f64;
let deter: f64 = (2 * (2 * n + 5)) as f64;
(3.0 * tau * n_tmp.sqrt()) / deter.sqrt()
}
#[inline]
fn sum(n: usize) -> usize {
n * (n + 1usize) / 2usize
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn xy_consecutive_pair_test() {
let x = vec![
12.0, 14.0, 14.0, 17.0, 19.0, 19.0, 19.0, 19.0, 19.0, 20.0, 21.0, 21.0, 21.0, 21.0, 21.0,
22.0, 23.0, 24.0, 24.0, 24.0, 26.0, 26.0, 27.0,
];
let y = vec![
11.0, 4.0, 4.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 4.0,
0.0, 0.0, 0.0, 0.0, 0.0,
];
let tau = tau_b_with_comparator(&x, &y, |a: &f64, b: &f64| {
a.partial_cmp(&b).unwrap_or(Ordering::Greater)
}).unwrap();
approx::assert_abs_diff_eq!(tau, -0.3762015410475098);
}
#[test]
fn shifted_test() {
let comparator = |a: &f64, b: &f64| a.partial_cmp(&b).unwrap_or(Ordering::Greater);
let x = &[1.0, 1.0, 2.0, 2.0, 3.0, 3.0];
let y = &[1.0, 2.0, 2.0, 3.0, 3.0, 4.0];
let res = tau_b_with_comparator(&x[..], &y[..], comparator).unwrap();
approx::assert_abs_diff_eq!(res, 0.8006407690254358);
let x = &[12.0, 2.0, 1.0, 12.0, 2.0];
let y = &[1.0, 4.0, 7.0, 1.0, 0.0];
let res = tau_b_with_comparator(&x[..], &y[..], comparator).unwrap();
approx::assert_abs_diff_eq!(res, -0.4714045207910316);
}
#[test]
fn simple_correlated_data() {
let res = tau_b(&[1, 2, 3], &[3, 4, 5]).unwrap();
assert_eq!(res, 1.0);
}
#[test]
fn simple_correlated_reversed() {
let res = tau_b(&[1, 2, 3], &[5, 4, 3]).unwrap();
assert_eq!(res, -1.0);
}
#[test]
fn simple_jumble() {
let x = &[1.0, 2.0, 3.0, 4.0];
let y = &[1.0, 3.0, 2.0, 4.0];
let expected = (5.0 - 1.0) / 6.0;
assert_eq!(
tau_b_with_comparator(x, y, |a: &f64, b: &f64| a
.partial_cmp(&b)
.unwrap_or(Ordering::Greater)),
Ok(expected)
);
}
#[test]
fn balanced_jumble() {
let x = [1.0, 2.0, 3.0, 4.0];
let y = [1.0, 4.0, 3.0, 2.0];
assert_eq!(
tau_b_with_comparator(&x, &y, |a: &f64, b: &f64| a
.partial_cmp(&b)
.unwrap_or(Ordering::Greater)),
Ok(0.0)
);
}
#[test]
fn fails_if_dimentions_does_not_match() {
let res = tau_b(&[1, 2, 3], &[5, 4]);
assert_eq!(
res,
Err(Error::DimensionMismatch {
expected: 3,
got: 2
})
);
}
#[test]
fn fails_if_arrays_are_empty() {
let res = tau_b::<i32>(&[], &[]);
assert_eq!(res, Err(Error::InsufficientLength));
}
#[test]
fn it_format_dimension_mismatch_error() {
let error = Error::DimensionMismatch { expected: 2, got: 1 };
assert_eq!("dimension mismatch: 2 != 1", format!("{}", error));
}
#[test]
fn it_format_insufficient_length_error() {
let error = Error::InsufficientLength {} ;
assert_eq!("insufficient array length", format!("{}", error));
}
#[test]
fn significance_computed_correctly_for_certain_values() {
let res = significance(0.818, 12);
assert!(res > 3.7009);
assert!(res < 3.709);
}
}