1use std::cmp::Ordering;
31use std::error::Error as StdError;
32use std::fmt::{Display, Error as FmtError, Formatter};
33use std::result::Result;
34
35#[derive(Debug, PartialEq)]
36pub enum Error {
37 DimensionMismatch { expected: usize, got: usize },
38 InsufficientLength,
39}
40
41impl Display for Error {
42 fn fmt(&self, f: &mut Formatter) -> Result<(), FmtError> {
43 match self {
44 Error::InsufficientLength => write!(f, "insufficient array length"),
45 Error::DimensionMismatch { expected, got } => {
46 write!(f, "dimension mismatch: {expected} != {got}")
47 }
48 }
49 }
50}
51
52impl StdError for Error {}
53
54pub fn tau_b<T>(x: &[T], y: &[T]) -> Result<(f64, f64), Error>
64where
65 T: Ord + Clone + Default,
66{
67 tau_b_with_comparator(x, y, |a, b| a.cmp(b))
68}
69
70#[allow(clippy::many_single_char_names)]
73pub fn tau_b_with_comparator<T, F>(x: &[T], y: &[T], mut comparator: F) -> Result<(f64, f64), Error>
74where
75 T: PartialOrd + Clone + Default,
76 F: FnMut(&T, &T) -> Ordering,
77{
78 if x.len() != y.len() {
79 return Err(Error::DimensionMismatch {
80 expected: x.len(),
81 got: y.len(),
82 });
83 }
84
85 if x.is_empty() {
86 return Err(Error::InsufficientLength);
87 }
88
89 let n = x.len();
90
91 let mut pairs: Vec<(T, T)> = x.iter().cloned().zip(y.iter().cloned()).collect();
92
93 pairs.sort_unstable_by(|pair1, pair2| {
94 let res = comparator(&pair1.0, &pair2.0);
95 if res == Ordering::Equal {
96 comparator(&pair1.1, &pair2.1)
97 } else {
98 res
99 }
100 });
101
102 let mut v1_part_1 = 0usize;
103 let mut v2_part_1 = 0isize;
104
105 let mut tied_x_pairs = 0usize;
106 let mut tied_xy_pairs = 0usize;
107 let mut vt = 0usize;
108 let mut consecutive_x_ties = 1usize;
109 let mut consecutive_xy_ties = 1usize;
110
111 for i in 1..n {
112 let prev = &pairs[i - 1];
113 let curr = &pairs[i];
114 if curr.0 == prev.0 {
115 consecutive_x_ties += 1;
116 if curr.1 == prev.1 {
117 consecutive_xy_ties += 1;
118 } else {
119 tied_xy_pairs += sum(consecutive_xy_ties - 1);
120 consecutive_xy_ties = 1;
121 }
122 } else {
123 update_x_group(
124 &mut vt,
125 &mut tied_x_pairs,
126 &mut tied_xy_pairs,
127 &mut v1_part_1,
128 &mut v2_part_1,
129 consecutive_x_ties,
130 consecutive_xy_ties,
131 );
132 consecutive_x_ties = 1;
133 consecutive_xy_ties = 1;
134 }
135 }
136
137 update_x_group(
138 &mut vt,
139 &mut tied_x_pairs,
140 &mut tied_xy_pairs,
141 &mut v1_part_1,
142 &mut v2_part_1,
143 consecutive_x_ties,
144 consecutive_xy_ties,
145 );
146
147 let mut swaps = 0usize;
148 let mut pairs_dest: Vec<(T, T)> = vec![(Default::default(), Default::default()); n];
149
150 let mut segment_size = 1usize;
151 while segment_size < n {
152 for offset in (0..n).step_by(2 * segment_size) {
153 let mut i = offset;
154 let i_end = n.min(i + segment_size);
155 let mut j = i_end;
156 let j_end = n.min(j + segment_size);
157 let mut copy_location = offset;
158
159 while i < i_end && j < j_end {
160 let a = &pairs[i].1;
161 let b = &pairs[j].1;
162
163 if a.partial_cmp(b).unwrap_or(Ordering::Greater) == Ordering::Greater {
164 pairs_dest[copy_location] = pairs[j].clone();
165 j += 1;
166 swaps += i_end - i;
167 } else {
168 pairs_dest[copy_location] = pairs[i].clone();
169 i += 1;
170 }
171
172 copy_location += 1;
173 }
174
175 while i < i_end {
176 pairs_dest[copy_location] = pairs[i].clone();
177 i += 1;
178 copy_location += 1
179 }
180
181 while j < j_end {
182 pairs_dest[copy_location] = pairs[j].clone();
183 j += 1;
184 copy_location += 1
185 }
186 }
187 std::mem::swap(&mut pairs, &mut pairs_dest);
188
189 segment_size <<= 1;
190 }
191
192 let mut v1_part_2 = 0usize;
193 let mut v2_part_2 = 0isize;
194 let mut tied_y_pairs = 0usize;
195 let mut consecutive_y_ties = 1usize;
196 let mut vu = 0usize;
197
198 for j in 1..n {
199 let prev = &pairs[j - 1];
200 let curr = &pairs[j];
201 if curr.1 == prev.1 {
202 consecutive_y_ties += 1;
203 } else {
204 update_y_group(
205 &mut vu,
206 &mut tied_y_pairs,
207 &mut v1_part_2,
208 &mut v2_part_2,
209 consecutive_y_ties,
210 );
211 consecutive_y_ties = 1;
212 }
213 }
214
215 update_y_group(
216 &mut vu,
217 &mut tied_y_pairs,
218 &mut v1_part_2,
219 &mut v2_part_2,
220 consecutive_y_ties,
221 );
222
223 let v1 = (v1_part_1 * v1_part_2) as f64;
225 let v2 = (v2_part_1 * v2_part_2) as f64;
226
227 let num_pairs_f: f64 = ((n * (n - 1)) as f64) / 2.0; let tied_x_pairs_f: f64 = tied_x_pairs as f64;
230 let tied_y_pairs_f: f64 = tied_y_pairs as f64;
231 let tied_xy_pairs_f: f64 = tied_xy_pairs as f64;
232 let swaps_f: f64 = (2 * swaps) as f64;
233
234 let concordant_minus_discordant =
239 num_pairs_f - tied_x_pairs_f - tied_y_pairs_f + tied_xy_pairs_f - swaps_f;
240
241 let non_tied_pairs_multiplied = (num_pairs_f - tied_x_pairs_f) * (num_pairs_f - tied_y_pairs_f);
243
244 let tau_b = concordant_minus_discordant / non_tied_pairs_multiplied.sqrt();
245
246 let v0 = (n * (n - 1)) * (2 * n + 5);
248 let n_f = n as f64;
249
250 let v0_isize = v0 as isize;
251 let vt_isize = vt as isize;
252 let vu_isize = vu as isize;
253 let var_s = (v0_isize - vt_isize - vu_isize) as f64 / 18.0
254 + v1 / (2.0 * n_f * (n_f - 1.0))
255 + v2 / (9.0 * n_f * (n_f - 1.0) * (n_f - 2.0));
256
257 let s = tau_b * non_tied_pairs_multiplied.sqrt();
258 let z = s / var_s.sqrt();
259
260 Ok((tau_b.clamp(-1.0, 1.0), z))
262}
263
264#[inline]
265fn sum(n: usize) -> usize {
266 n * (n + 1_usize) / 2_usize
267}
268
269fn update_x_group(
271 vt: &mut usize,
272 tied_x_pairs: &mut usize,
273 tied_xy_pairs: &mut usize,
274 v1_part_1: &mut usize,
275 v2_part_1: &mut isize,
276 consecutive_x_ties: usize,
277 consecutive_xy_ties: usize,
278) {
279 *vt += consecutive_x_ties * (consecutive_x_ties - 1) * (2 * consecutive_x_ties + 5);
280 *v1_part_1 += consecutive_x_ties * (consecutive_x_ties - 1);
281
282 let consecutive_x_ties_i = consecutive_x_ties as isize;
283 *v2_part_1 += consecutive_x_ties_i * (consecutive_x_ties_i - 1) * (consecutive_x_ties_i - 2);
284
285 *tied_x_pairs += sum(consecutive_x_ties - 1);
286 *tied_xy_pairs += sum(consecutive_xy_ties - 1);
287}
288
289fn update_y_group(
291 vu: &mut usize,
292 tied_y_pairs: &mut usize,
293 v1_part_2: &mut usize,
294 v2_part_2: &mut isize,
295 consecutive_y_ties: usize,
296) {
297 *vu += consecutive_y_ties * (consecutive_y_ties - 1) * (2 * consecutive_y_ties + 5);
298 *v1_part_2 += consecutive_y_ties * (consecutive_y_ties - 1);
299
300 let consecutive_y_ties_i = consecutive_y_ties as isize;
301 *v2_part_2 += consecutive_y_ties_i * (consecutive_y_ties_i - 1) * (consecutive_y_ties_i - 2);
302
303 *tied_y_pairs += sum(consecutive_y_ties - 1);
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use float_cmp::assert_approx_eq;
310
311 #[test]
312 fn xy_consecutive_pair_test() {
313 let x = vec![
314 12.0, 14.0, 14.0, 17.0, 19.0, 19.0, 19.0, 19.0, 19.0, 20.0, 21.0, 21.0, 21.0, 21.0,
315 21.0, 22.0, 23.0, 24.0, 24.0, 24.0, 26.0, 26.0, 27.0,
316 ];
317 let y = vec![
318 11.0, 4.0, 4.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0,
319 4.0, 0.0, 0.0, 0.0, 0.0, 0.0,
320 ];
321
322 let (tau_b, z) = tau_b_with_comparator(&x, &y, |a: &f64, b: &f64| {
323 a.partial_cmp(b).unwrap_or(Ordering::Greater)
324 })
325 .unwrap();
326
327 assert_approx_eq!(f64, tau_b, -0.3762015410475098);
328 assert_approx_eq!(f64, z, -2.09764910068664);
329 }
330
331 #[test]
332 fn shifted_test() {
333 let comparator = |a: &f64, b: &f64| a.partial_cmp(b).unwrap_or(Ordering::Greater);
334
335 let x = &[1.0, 1.0, 2.0, 2.0, 3.0, 3.0];
336 let y = &[1.0, 2.0, 2.0, 3.0, 3.0, 4.0];
337 let (tau_b, z) = tau_b_with_comparator(&x[..], &y[..], comparator).unwrap();
338 assert_approx_eq!(f64, tau_b, 0.8006407690254358);
339 assert_approx_eq!(f64, z, 2.0526, epsilon = 0.0001);
340
341 let x = &[12.0, 2.0, 1.0, 12.0, 2.0];
342 let y = &[1.0, 4.0, 7.0, 1.0, 0.0];
343 let (tau_b, z) = tau_b_with_comparator(&x[..], &y[..], comparator).unwrap();
344 assert_approx_eq!(f64, tau_b, -0.4714045207910316);
345 assert_approx_eq!(f64, z, -1.0742, epsilon = 0.0001);
346 }
347
348 #[test]
349 fn simple_correlated_data() {
350 let (tau_b, z) = tau_b(&[1, 2, 3], &[3, 4, 5]).unwrap();
351 assert_eq!(tau_b, 1.0);
352 assert_approx_eq!(f64, z, 1.5666989036012806);
353 }
354
355 #[test]
356 fn simple_correlated_reversed() {
357 let (tau_b, z) = tau_b(&[1, 2, 3], &[5, 4, 3]).unwrap();
358 assert_eq!(tau_b, -1.0);
359 assert_approx_eq!(f64, z, -1.5666989036012806);
360 }
361
362 #[test]
363 fn simple_jumble() {
364 let x = &[1.0, 2.0, 3.0, 4.0];
365 let y = &[1.0, 3.0, 2.0, 4.0];
366
367 let expected_tau_b = (5.0 - 1.0) / 6.0;
370 let expected_z = 1.3587324409735149;
371
372 assert_eq!(
373 tau_b_with_comparator(x, y, |a: &f64, b: &f64| a
374 .partial_cmp(b)
375 .unwrap_or(Ordering::Greater)),
376 Ok((expected_tau_b, expected_z))
377 );
378 }
379
380 #[test]
381 fn balanced_jumble() {
382 let x = [1.0, 2.0, 3.0, 4.0];
383 let y = [1.0, 4.0, 3.0, 2.0];
384
385 assert_eq!(
389 tau_b_with_comparator(&x, &y, |a: &f64, b: &f64| a
390 .partial_cmp(b)
391 .unwrap_or(Ordering::Greater)),
392 Ok((0.0, 0.0))
393 );
394 }
395
396 #[test]
397 fn fails_if_dimentions_does_not_match() {
398 let res = tau_b(&[1, 2, 3], &[5, 4]);
399 assert_eq!(
400 res,
401 Err(Error::DimensionMismatch {
402 expected: 3,
403 got: 2
404 })
405 );
406 }
407
408 #[test]
409 fn fails_if_arrays_are_empty() {
410 let res = tau_b::<i32>(&[], &[]);
411 assert_eq!(res, Err(Error::InsufficientLength));
412 }
413
414 #[test]
415 fn it_format_dimension_mismatch_error() {
416 let error = Error::DimensionMismatch {
417 expected: 2,
418 got: 1,
419 };
420 assert_eq!("dimension mismatch: 2 != 1", format!("{}", error));
421 }
422
423 #[test]
424 fn it_format_insufficient_length_error() {
425 let error = Error::InsufficientLength {};
426 assert_eq!("insufficient array length", format!("{}", error));
427 }
428
429 #[test]
430 fn test_subtract_with_overflow() {
432 let x = vec![
433 -0.1309, -0.1309, -0.1309, -0.1309, -0.1309, -0.1309, -0.1309, -0.1309, -0.1309, 6.8901,
434 ];
435 let y = vec![1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0];
436
437 let result = std::panic::catch_unwind(|| {
438 let (_tau, _significance) = tau_b_with_comparator(&x, &y, |a: &f64, b: &f64| {
439 a.partial_cmp(b).unwrap_or(Ordering::Greater)
440 })
441 .unwrap();
442 });
443 assert!(result.is_ok()); }
445}