memega/ops/
distance.rs

1use std::mem::swap;
2
3use eyre::{eyre, Result};
4use num_traits::{Num, NumAssign};
5
6// Generalised distance - add missing * difference in lengths distance if the
7// arrays are different distances.
8pub fn dist_fn<T>(s1: &[T], s2: &[T], missing: f64, mut f: impl FnMut(&T, &T) -> f64) -> f64 {
9    let min = s1.len().min(s2.len());
10    let mut dist = (s1.len() as f64 - s2.len() as f64).abs() * missing;
11    for i in 0..min {
12        dist += f(&s1[i], &s2[i]);
13    }
14    dist
15}
16
17// Norm 1 distance
18pub fn dist_abs<T: Num + NumAssign + Copy + PartialOrd>(mut a: T, mut b: T) -> T {
19    if a < b {
20        swap(&mut a, &mut b);
21    }
22    a - b
23}
24
25// Norm 1 distance - manhattan distance.
26pub fn dist1<T: Num + NumAssign + Copy + PartialOrd>(s1: &[T], s2: &[T]) -> T {
27    let max = s1.len().max(s2.len());
28    let mut dist = T::zero();
29    for i in 0..max {
30        let zero = T::zero();
31        let a = s1.get(i).unwrap_or(&zero);
32        let b = s2.get(i).unwrap_or(&zero);
33        dist += dist_abs(*a, *b);
34    }
35    dist
36}
37
38// Norm 2 distance - euclidean distance.
39#[must_use]
40pub fn dist2(s1: &[f64], s2: &[f64]) -> f64 {
41    let max = s1.len().max(s2.len());
42    let mut dist = 0.0;
43    for i in 0..max {
44        let a = s1.get(i).unwrap_or(&0.0);
45        let b = s2.get(i).unwrap_or(&0.0);
46        dist += (a - b) * (a - b);
47    }
48    dist.sqrt()
49}
50
51// Number of different pairs
52pub fn count_different<T: PartialEq>(s1: &[T], s2: &[T]) -> usize {
53    let min = s1.len().min(s2.len());
54    let max = s1.len().max(s2.len());
55    let mut count = 0;
56    for i in 0..min {
57        if s1[i] != s2[i] {
58            count += 1;
59        }
60    }
61    count + max - min
62}
63
64// Kendall tau distance: https://en.wikipedia.org/wiki/Kendall_tau_distance
65pub fn kendall_tau<T: PartialOrd>(s1: &[T], s2: &[T]) -> Result<usize> {
66    if s1.len() != s2.len() {
67        return Err(eyre!("must be same length"));
68    }
69    let mut count = 0;
70    for i in 0..s1.len() {
71        for j in (i + 1)..s2.len() {
72            if (s1[i] < s1[j]) != (s2[i] < s2[j]) {
73                count += 1;
74            }
75        }
76    }
77    Ok(count)
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    #[test]
85    fn test_count_different() {
86        assert_eq!(count_different(&[1], &[1]), 0);
87        assert_eq!(count_different(&[1], &[2]), 1);
88        assert_eq!(count_different(&[1], &[1, 2]), 1);
89        assert_eq!(count_different(&[1, 2], &[1]), 1);
90    }
91
92    #[test]
93    fn test_kendall_tau() {
94        assert_eq!(kendall_tau(&[1], &[1]).unwrap(), 0);
95        assert_eq!(kendall_tau(&[1], &[2]).unwrap(), 0);
96        assert_eq!(kendall_tau(&[1, 2], &[1, 2]).unwrap(), 0);
97        assert_eq!(kendall_tau(&[1, 2], &[2, 1]).unwrap(), 1);
98        assert_eq!(kendall_tau(&[1, 2, 3, 4, 5], &[3, 4, 1, 2, 5]).unwrap(), 4);
99    }
100}