use std::cmp::Ordering;
use ndarray::{ArrayBase, ArrayView1, AsArray, Ix1, ViewRepr};
use crate::error::ImgalError;
use crate::statistics::weighted_merge_sort_mut;
use crate::traits::numeric::AsNumeric;
pub fn weighted_kendall_tau_b<'a, T, A>(
data_a: A,
data_b: A,
weights: &[f64],
) -> Result<f64, ImgalError>
where
A: AsArray<'a, T, Ix1>,
T: 'a + AsNumeric,
{
let view_a: ArrayBase<ViewRepr<&'a T>, Ix1> = data_a.into();
let view_b: ArrayBase<ViewRepr<&'a T>, Ix1> = data_b.into();
let dl = view_a.len();
if dl != view_b.len() || dl != weights.len() {
return Err(ImgalError::MismatchedArrayLengths {
a_arr_name: "data_a",
a_arr_len: dl,
b_arr_name: "data_b",
b_arr_len: view_b.len().min(weights.len()),
});
}
if dl < 2 {
return Ok(0.0);
}
let data_a_uniform = view_a.iter().all(|&v| v == view_a[0]);
let data_b_uniform = view_b.iter().all(|&v| v == view_b[1]);
if data_a_uniform || data_b_uniform {
return Ok(f64::NAN);
}
let (a_ranks, a_tie_corr) = rank_with_weights(view_a, weights);
let (b_ranks, b_tie_corr) = rank_with_weights(view_b, weights);
let mut rank_pairs: Vec<(i32, i32, usize)> = a_ranks
.iter()
.zip(b_ranks.iter())
.enumerate()
.map(|(i, (&a, &b))| (a, b, i))
.collect();
rank_pairs.sort_by_key(|&(a, _, _)| a);
let mut b_sorted: Vec<i32> = Vec::with_capacity(dl);
let mut w_sorted: Vec<f64> = Vec::with_capacity(dl);
rank_pairs.iter().for_each(|&(_, b, i)| {
b_sorted.push(b);
w_sorted.push(weights[i]);
});
let swaps = weighted_merge_sort_mut(&mut b_sorted, &mut w_sorted).unwrap();
let total_w: f64 = weights.iter().sum();
let sum_w_sqr: f64 = weights.iter().map(|w| w.powi(2)).sum();
let total_w_pairs = (total_w.powi(2) - sum_w_sqr) / 2.0;
let c_pairs = total_w_pairs - swaps;
let numer = c_pairs - swaps;
let denom = ((total_w_pairs - a_tie_corr) * (total_w_pairs - b_tie_corr)).sqrt();
if denom != 0.0 && !denom.is_nan() {
let tau = numer / denom;
if tau >= 1.0 {
Ok(1.0)
} else if tau <= -1.0 {
Ok(-1.0)
} else {
Ok(tau)
}
} else {
Ok(0.0)
}
}
fn rank_with_weights<T>(data: ArrayView1<T>, weights: &[f64]) -> (Vec<i32>, f64)
where
T: AsNumeric,
{
let dl = data.len();
let mut indices: Vec<usize> = (0..dl).collect();
indices.sort_by(|&a, &b| data[a].partial_cmp(&data[b]).unwrap_or(Ordering::Equal));
let mut ranks: Vec<i32> = vec![0; dl];
let mut tie_corr = 0.0;
let mut cur_rank = 1;
let mut i = 0;
let mut tied_indices: Vec<usize> = Vec::new();
while i < dl {
let cur_val = data[indices[i]];
let mut j = i;
tied_indices.clear();
while j < dl && data[indices[j]].partial_cmp(&cur_val) == Some(Ordering::Equal) {
tied_indices.push(indices[j]);
j += 1;
}
let group_size = (j - i) as i32;
let avg_rank = cur_rank + (group_size - 1) / 2;
tied_indices.iter().for_each(|&ti| {
ranks[ti] = avg_rank;
});
if group_size > 1 {
let mut tie_group_corr = 0.0;
for k in 0..tied_indices.len() {
for l in (k + 1)..tied_indices.len() {
tie_group_corr += weights[tied_indices[k]] * weights[tied_indices[l]];
}
}
tie_corr += tie_group_corr
}
cur_rank += group_size;
i = j;
}
(ranks, tie_corr)
}