fast_distances/distances/
euclidean_grad.rs

1use ndarray::Array1;
2use num::Float;
3
4/// Computes the Euclidean distance and its gradient between two vectors.
5///
6/// The function calculates the Euclidean distance between two input vectors `x` and `y`
7/// and returns both the distance and the gradient. The gradient indicates how much each
8/// element in the input vectors contributes to the distance.
9///
10/// # Parameters
11///
12/// - **`x`:** An `Array1<T>` representing the first vector.
13/// - **`y`:** An `Array1<T>` representing the second vector.
14///
15/// # Type Parameter
16///
17/// - **`T`:** A generic type that must implement the `Float` trait. This ensures
18///   that the elements in vectors `x` and `y` can be used for arithmetic operations
19///   involving floating-point numbers.
20///
21/// # Returns
22///
23/// A tuple containing:
24/// 1. The Euclidean distance between the two input vectors, of type `T`.
25/// 2. An `Array1<T>` representing the gradient. Each element in this array corresponds to the contribution of each element in the input vectors towards the Euclidean distance.
26///
27/// # Panics
28///
29/// - If the input arrays do not have the same length, the function will panic with an appropriate error message.
30pub fn euclidean_grad<T>(x: &Array1<T>, y: &Array1<T>) -> (T, Vec<T>)
31where
32    T: Float,
33{
34    assert_eq!(x.len(), y.len(), "Input arrays must have the same length.");
35
36    let mut result = T::zero();
37    for i in 0..x.len() {
38        let diff = x[i] - y[i];
39        result = result + diff * diff;
40    }
41
42    let distance = result.sqrt();
43    let mut gradient = Vec::with_capacity(x.len());
44
45    // Calculate the gradient
46    for i in 0..x.len() {
47        let grad = (x[i] - y[i]) / (T::from(1e-6).unwrap() + distance);
48        gradient.push(grad);
49    }
50
51    (distance, gradient)
52}
53
54#[cfg(test)]
55mod tests {
56    use ndarray::arr1;
57
58    use super::*; // Import the function to be tested
59
60    #[test]
61    fn test_euclidean_grad_f64() {
62        let x = arr1(&[1.0f64, 2.0, 3.0]);
63        let y = arr1(&[4.0f64, 5.0, 6.0]);
64
65        let (dist, grad) = euclidean_grad(&x, &y);
66        assert!(
67            (dist - 5.196152422706632).abs() < 1e-6,
68            "Distance is incorrect for f64."
69        );
70        assert!(
71            (grad[0] - -0.5773502691896257).abs() < 1e-6,
72            "Gradient[0] is incorrect for f64."
73        );
74        assert!(
75            (grad[1] - -0.5773502691896257).abs() < 1e-6,
76            "Gradient[1] is incorrect for f64."
77        );
78        assert!(
79            (grad[2] - -0.5773502691896257).abs() < 1e-6,
80            "Gradient[2] is incorrect for f64."
81        );
82    }
83
84    #[test]
85    fn test_euclidean_grad_f32() {
86        let x = arr1(&[1.0f32, 2.0, 3.0]);
87        let y = arr1(&[4.0f32, 5.0, 6.0]);
88
89        let (dist, grad) = euclidean_grad(&x, &y);
90        assert!(
91            (dist - 5.1961524).abs() < 1e-6,
92            "Distance is incorrect for f32."
93        );
94        assert!(
95            (grad[0] - -0.57735026).abs() < 1e-6,
96            "Gradient[0] is incorrect for f32."
97        );
98        assert!(
99            (grad[1] - -0.57735026).abs() < 1e-6,
100            "Gradient[1] is incorrect for f32."
101        );
102        assert!(
103            (grad[2] - -0.57735026).abs() < 1e-6,
104            "Gradient[2] is incorrect for f32."
105        );
106    }
107
108    #[test]
109    fn test_euclidean_grad_zero_distance() {
110        let x = arr1(&[1.0f64, 2.0, 3.0]);
111        let y = arr1(&[1.0f64, 2.0, 3.0]);
112
113        let (dist, grad) = euclidean_grad(&x, &y);
114        assert!(
115            (dist - 0.0).abs() < 1e-6,
116            "Distance should be 0 for identical vectors."
117        );
118        for &g in grad.iter() {
119            assert!(
120                (g - 0.0).abs() < 1e-6,
121                "Gradient should be 0 for identical vectors."
122            );
123        }
124    }
125
126    #[test]
127    #[should_panic(expected = "Input arrays must have the same length.")]
128    fn test_euclidean_grad_different_lengths() {
129        let x = arr1(&[1.0f64, 2.0]);
130        let y = arr1(&[4.0f64, 5.0, 6.0]);
131        euclidean_grad(&x, &y); // This should panic
132    }
133}