fast_distances/distances/
correlation.rs1use ndarray::ArrayView1;
2use num::Float;
3
4pub fn correlation<T>(x: &ArrayView1<T>, y: &ArrayView1<T>) -> T
20where
21 T: Float,
22{
23 let mut mu_x = T::zero();
24 let mut mu_y = T::zero();
25 let mut norm_x = T::zero();
26 let mut norm_y = T::zero();
27 let mut dot_product = T::zero();
28
29 for i in 0..x.len() {
31 mu_x = mu_x + x[i];
32 mu_y = mu_y + y[i];
33 }
34
35 mu_x = mu_x / T::from(x.len()).unwrap();
36 mu_y = mu_y / T::from(y.len()).unwrap();
37
38 for i in 0..x.len() {
40 let shifted_x = x[i] - mu_x;
41 let shifted_y = y[i] - mu_y;
42 norm_x = norm_x + shifted_x * shifted_x;
43 norm_y = norm_y + shifted_y * shifted_y;
44 dot_product = dot_product + shifted_x * shifted_y;
45 }
46
47 if norm_x.is_zero() && norm_y.is_zero() {
48 T::zero()
49 } else if dot_product.is_zero() {
50 T::one()
51 } else {
52 T::one() - (dot_product / (norm_x.sqrt() * norm_y.sqrt()))
53 }
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59 use ndarray::arr1;
60
61 #[test]
62 fn test_correlation_basic_f32() {
63 let x = arr1(&[1.0_f32, 2.0, 3.0]);
65 let y = arr1(&[4.0_f32, 5.0, 6.0]);
66 let result = correlation(&x.view(), &y.view());
67 let expected_result = -1.1920929e-7;
68 assert_eq!(result, expected_result);
69 }
70
71 #[test]
72 fn test_correlation_zero_norm_f32() {
73 let x = arr1(&[0.0_f32, 0.0, 0.0]);
75 let y = arr1(&[1.0_f32, 2.0, 3.0]);
76 let result = correlation(&x.view(), &y.view());
77 assert_eq!(result, 1.0_f32);
78 }
79
80 #[test]
81 fn test_correlation_zero_both_norm_f32() {
82 let x = arr1(&[0.0_f32, 0.0, 0.0]);
84 let y = arr1(&[0.0_f32, 0.0, 0.0]);
85 let result = correlation(&x.view(), &y.view());
86 assert_eq!(result, 0.0_f32);
87 }
88
89 #[test]
90 fn test_correlation_basic_f64() {
91 let x = arr1(&[1.0_f64, 2.0, 3.0]);
93 let y = arr1(&[4.0_f64, 5.0, 6.0]);
94 let result = correlation(&x.view(), &y.view());
95 let expected_result = 2.220446049250313e-16;
96 assert_eq!(result, expected_result);
97 }
98
99 #[test]
100 fn test_correlation_zero_norm_f64() {
101 let x = arr1(&[0.0_f64, 0.0, 0.0]);
103 let y = arr1(&[1.0_f64, 2.0, 3.0]);
104 let result = correlation(&x.view(), &y.view());
105 assert_eq!(result, 1.0_f64);
106 }
107
108 #[test]
109 fn test_correlation_zero_both_norm_f64() {
110 let x = arr1(&[0.0_f64, 0.0, 0.0]);
112 let y = arr1(&[0.0_f64, 0.0, 0.0]);
113 let result = correlation(&x.view(), &y.view());
114 assert_eq!(result, 0.0_f64);
115 }
116}