rustml/
distance.rs

1//! Functions to compute the distance between vectors.
2
3extern crate libc;
4
5use self::libc::{c_int, c_double, c_float};
6use matrix::*;
7use norm::{L2Norm, Norm};
8use blas::{cblas_daxpy, cblas_saxpy};
9use geometry::Point2D;
10
11pub trait DistancePoint2D<T> {
12    fn euclid(&self, other: &Point2D<T>) -> T;
13}
14
15macro_rules! distance_point2d_impl {
16    ($($t:ty)*) => ($(
17
18        impl DistancePoint2D<$t> for Point2D<$t> {
19            fn euclid(&self, other: &Point2D<$t>) -> $t {
20                ((self.x - other.x).powi(2) + (self.y - other.y).powi(2)).sqrt()
21            }
22        }
23    )*)
24}
25
26distance_point2d_impl!{ f32 f64 }
27
28
29// ----------------------------------------------------------------------------
30
31/// Computes the distance between two vectors.
32pub trait Distance<T> {
33    /// Computes the distance between vector `a` and `b` and returns `None`
34    /// on failure.
35    fn compute(a: &[T], b: &[T]) -> Option<T>;
36}
37
38pub struct Euclid;
39
40impl Distance<f64> for Euclid {
41
42    /// Computes the euclidean distance between the vector `a` and `b`.
43    ///
44    /// Returns `None` if the two vectors have a different length.
45    ///
46    /// # Implementation details
47    ///
48    /// First the BLAS function `cblas_daxpy` is used to compute the
49    /// difference between the vectors. This requires O(n) additional space
50    /// if `n` is the number of elements of each vector. Then, the result
51    /// of the L2 norm of the difference is returned.
52    fn compute(a: &[f64], b: &[f64]) -> Option<f64> {
53
54        // TODO handling of NaN and stuff like this
55        if a.len() != b.len() {
56            return None;
57        }
58
59        // c = b.clone() does not work here because cblas_daxpy
60        // modifies the content of c and cloned() on a slice does
61        // not create a copy.
62        let c: Vec<f64> = b.to_vec();
63
64        unsafe {
65            cblas_daxpy(
66                a.len()     as c_int,
67                -1.0        as c_double,
68                a.as_ptr()  as *const c_double,
69                1           as c_int,
70                c.as_ptr()  as *mut c_double,
71                1           as c_int
72            );
73        }
74        Some(L2Norm::compute(&c))
75    }
76}
77
78impl Distance<f32> for Euclid {
79
80    /// Computes the euclidean distance between the vector `a` and `b`.
81    ///
82    /// Returns `None` if the two vectors have a different length.
83    ///
84    /// # Implementation details
85    ///
86    /// First the BLAS function `cblas_daxpy` is used to compute the
87    /// difference between the vectors. This requires O(n) additional space
88    /// if `n` is the number of elements of each vector. Then, the result
89    /// of the L2 norm of the difference is returned.
90    fn compute(a: &[f32], b: &[f32]) -> Option<f32> {
91
92        // TODO handling of NaN and stuff like this
93        if a.len() != b.len() {
94            return None;
95        }
96
97        // c = b.clone() does not work here because cblas_daxpy
98        // modifies the content of c and cloned() on a slice does
99        // not create a copy.
100        let c: Vec<f32> = b.to_vec();
101
102        unsafe {
103            cblas_saxpy(
104                a.len()     as c_int,
105                -1.0        as c_float,
106                a.as_ptr()  as *const c_float,
107                1           as c_int,
108                c.as_ptr()  as *mut c_float,
109                1           as c_int
110            );
111        }
112        Some(L2Norm::compute(&c))
113    }
114}
115
116pub fn all_pair_distances(m: &Matrix<f64>) -> Matrix<f64> {
117
118    let mut r = Matrix::fill(0.0, m.rows(), m.rows());
119
120    // TODO handling of NaN and stuff like this
121    for (i, row1) in m.row_iter().enumerate() {
122        for (j, row2) in m.row_iter_at(i + 1).enumerate() {
123            let p = j + i + 1;
124            let d = Euclid::compute(row1, row2).unwrap();
125            r.set(i, p, d);
126            r.set(p, i, d);
127        }
128    }
129    r
130}
131
132#[cfg(test)]
133mod tests {
134    use matrix::*;
135    use super::*;
136    use geometry::Point2D;
137
138    #[test]
139    fn test_euclid() {
140
141        let a = vec![1.0, 2.0, 3.0];
142        let b = vec![2.0, 5.0, 13.0];
143        let c = vec![2.0, 5.0, 13.0];
144        let d = vec![1.0, 2.0, 3.0];
145        assert!(Euclid::compute(&a, &b).unwrap() - 10.488088 <= 0.000001);
146        assert_eq!(b, c);
147        assert_eq!(a, d);
148    }
149
150    #[test]
151    fn test_all_pair_distances() {
152
153        let m = mat![1.0, 2.0; 5.0, 12.0; 13.0, 27.0];
154        let r = all_pair_distances(&m);
155
156        assert_eq!(r.rows(), m.rows());
157        assert_eq!(r.cols(), m.rows());
158        assert_eq!(*r.get(0, 0).unwrap(), 0.0);
159        assert_eq!(*r.get(1, 1).unwrap(), 0.0);
160        assert_eq!(*r.get(2, 2).unwrap(), 0.0);
161
162        assert!(*r.get(0, 1).unwrap() - 10.770 <= 0.001);
163        assert!(*r.get(0, 2).unwrap() - 27.731 <= 0.001);
164        assert!(*r.get(1, 0).unwrap() - 10.770 <= 0.001);
165        assert!(*r.get(2, 0).unwrap() - 27.731 <= 0.001);
166
167        assert!(*r.get(1, 2).unwrap() - 17.0 <= 0.001);
168        assert!(*r.get(2, 1).unwrap() - 17.0 <= 0.001);
169    }
170
171    #[test]
172    fn test_euclid_point2d() {
173
174        let a = Point2D::new(2.0, 8.0);
175        let b = Point2D::new(5.0, 12.0);
176        assert!(a.euclid(&b) - 5.0 < 0.00001);
177
178        let d = Point2D::new(2.0, 8.0);
179        assert_eq!(a.euclid(&d), 0.0);
180    }
181}
182