kdtree_simd/simd_euclidean/
mod.rs

1//!
2//! ```rust
3//! # use simd_euclidean::*;
4//! # use rand::*;
5//! for &i in [16, 32, 64, 128].into_iter() {
6//!   // Dispatch to F32x4 or F32x8 (above 64 elements)
7//!     let mut rng = rand::thread_rng();
8//!     let a = (0..i).map(|_| rng.gen::<f32>()).collect::<Vec<f32>>();
9//!     let b = (0..i).map(|_| rng.gen::<f32>()).collect::<Vec<f32>>();
10
11//!     let v = Vectorized::distance(&a, &b);
12//!     let n = Naive::distance(&a, &b);
13//!     assert!((n-v).abs() < 0.00001);
14//! }
15//! ```
16
17#[macro_use]
18mod macros;
19
20mod f32x4;
21mod f32x8;
22
23mod f64x2;
24mod f64x4;
25
26pub use self::f32x4::F32x4;
27pub use self::f32x8::F32x8;
28pub use self::f64x2::F64x2;
29pub use self::f64x4::F64x4;
30
31pub trait Naive {
32    type Output;
33    type Ty;
34
35    fn squared_distance(self, other: Self) -> Self::Output;
36    fn distance(self, other: Self) -> Self::Output;
37}
38
39pub trait Vectorized {
40    type Output;
41    fn squared_distance(self, other: Self) -> Self::Output;
42    fn distance(self, other: Self) -> Self::Output;
43}
44
45impl_naive!(f64, f64);
46impl_naive!(f32, f32);
47
48/// Calculate the euclidean distance between two slices of equal length
49///
50/// # Panics
51///
52/// Will panic if the lengths of the slices are not equal
53pub fn scalar_euclidean<T: Naive>(a: T, b: T) -> T::Output {
54    Naive::distance(a, b)
55}
56
57/// SIMD-capable calculation of the euclidean distance between two slices
58/// of equal length
59///
60/// ```rust
61/// # use simd_euclidean::*;
62/// let distance = vector_euclidean(&[0.1, 0.2, 0.3, 0.4f32] as &[f32], &[0.4, 0.3, 0.2, 0.1f32]);
63/// ```
64/// # Panics
65///
66/// Will panic if the lengths of the slices are not equal
67pub fn vector_euclidean<T: Vectorized>(a: T, b: T) -> T::Output {
68    Vectorized::distance(a, b)
69}
70
71impl Vectorized for &[f32] {
72    type Output = f32;
73    fn squared_distance(self, other: Self) -> Self::Output {
74        if self.len() >= 64 {
75            F32x8::squared_distance(self, other)
76        } else {
77            F32x4::squared_distance(self, other)
78        }
79    }
80
81    fn distance(self, other: Self) -> Self::Output {
82        Vectorized::squared_distance(self, other).sqrt()
83    }
84}
85
86impl Vectorized for &Vec<f32> {
87    type Output = f32;
88    fn squared_distance(self, other: Self) -> Self::Output {
89        if self.len() >= 64 {
90            F32x8::squared_distance(self, other)
91        } else {
92            F32x4::squared_distance(self, other)
93        }
94    }
95
96    fn distance(self, other: Self) -> Self::Output {
97        Vectorized::squared_distance(self, other).sqrt()
98    }
99}
100
101impl Vectorized for &[f64] {
102    type Output = f64;
103    fn squared_distance(self, other: Self) -> Self::Output {
104        if self.len() >= 16 {
105            F64x4::squared_distance(self, other)
106        } else {
107            F64x2::squared_distance(self, other)
108        }
109    }
110
111    fn distance(self, other: Self) -> Self::Output {
112        Vectorized::squared_distance(self, other).sqrt()
113    }
114}
115
116impl Vectorized for &Vec<f64> {
117    type Output = f64;
118    fn squared_distance(self, other: Self) -> Self::Output {
119        if self.len() >= 16 {
120            F64x4::squared_distance(self, other)
121        } else {
122            F64x2::squared_distance(self, other)
123        }
124    }
125
126    fn distance(self, other: Self) -> Self::Output {
127        Vectorized::squared_distance(self, other).sqrt()
128    }
129}
130
131#[cfg(test)]
132mod test {
133    use super::*;
134
135    pub const XS: [f32; 72] = [
136        6.1125, 10.795, 20.0, 0.0, 10.55, 10.63, 20.0, 0.0, 0.0, 0.0, 0.0, 0.0, 10.26, 10.73, 0.0,
137        0.0, 20.0, 0.0, 10.4975, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 20.0, 20.0, 20.0,
138        0.0, 0.0, 0.0, 0.0, 10.475, 6.0905, 20.0, 0.0, 20.0, 20.0, 0.0, 10.5375, 10.54, 10.575,
139        0.0, 0.0, 0.0, 10.76, 10.755, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 20.0,
140        20.0, 0.0, 20.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 20.0,
141    ];
142    pub const YS: [f32; 72] = [
143        6.0905, 20.0, 0.0, 20.0, 20.0, 0.0, 10.5375, 10.54, 10.575, 0.0, 0.0, 0.0, 10.76, 10.755,
144        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 20.0, 20.0, 0.0, 20.0, 0.0, 0.0,
145        20.0, 0.0, 0.0, 0.0, 20.0, 6.1125, 10.795, 20.0, 0.0, 10.55, 10.63, 20.0, 0.0, 0.0, 0.0,
146        0.0, 0.0, 10.26, 10.73, 0.0, 0.0, 20.0, 0.0, 10.4975, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
147        20.0, 0.0, 20.0, 20.0, 20.0, 0.0, 0.0, 0.0, 0.0, 10.475,
148    ];
149
150    #[test]
151    fn verify() {
152        for i in 0..XS.len() {
153            let x = &XS[..i];
154            let y = &YS[..i];
155            let res = scalar_euclidean(x, y);
156            assert!(
157                (Vectorized::distance(x, y) - res).abs() < 0.0001,
158                "iter {}, {} != {}",
159                i,
160                Vectorized::distance(x, y),
161                res
162            );
163            assert!(
164                (F32x8::distance(x, y) - res).abs() < 0.0001,
165                "iter {}, {} != {}",
166                i,
167                F32x8::distance(x, y),
168                res
169            );
170            assert!(
171                (F32x4::distance(x, y) - res).abs() < 0.0001,
172                "iter {}, {} != {}",
173                i,
174                F32x4::distance(x, y),
175                res
176            );
177        }
178    }
179
180    #[test]
181    fn verify_random() {
182        use rand::Rng;
183        let mut rng = rand::thread_rng();
184        let input_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024];
185
186        for &i in input_sizes.iter() {
187            let len = i + rng.gen_range(0, 16) as usize;
188            let mut a = Vec::with_capacity(len);
189            let mut b = Vec::with_capacity(len);
190
191            for _ in 0..len {
192                a.push(rng.gen::<f32>());
193                b.push(rng.gen::<f32>());
194            }
195
196            let diff = (vector_euclidean(&a, &b) - scalar_euclidean(&a, &b)).abs();
197            assert!(diff <= 0.0001, "diff = {}, len = {}", diff, i);
198        }
199    }
200
201    #[test]
202    fn smoke_mul() {
203        let a = F32x4::from_slice(&[1.0, 2.0, 3.0, 4.0]);
204        let b = F32x4::from_slice(&[4.0, 3.0, 2.0, 1.0]);
205        let c = a * b;
206        assert_eq!(c.horizontal_add(), 4.0 + 6.0 + 6.0 + 4.0);
207    }
208
209    #[test]
210    fn smoke_mul_assign() {
211        let mut a = F32x4::from_slice(&[1.0, 2.0, 3.0, 4.0]);
212        let b = F32x4::from_slice(&[4.0, 3.0, 2.0, 1.0]);
213        a *= b;
214        assert_eq!(a.horizontal_add(), 4.0 + 6.0 + 6.0 + 4.0);
215    }
216
217    #[test]
218    fn smoke_add() {
219        let a = F32x4::from_slice(&[1.0, 2.0, 3.0, 4.0]);
220        let b = F32x4::from_slice(&[4.0, 3.0, 2.0, 1.0]);
221        let c = a + b;
222        assert_eq!(c, F32x4::new(5.0, 5.0, 5.0, 5.0));
223    }
224
225    #[test]
226    fn smoke_sub() {
227        let a = F32x4::from_slice(&[1.0, 2.0, 3.0, 4.0]);
228        let b = F32x4::from_slice(&[4.0, 3.0, 2.0, 1.0]);
229        let c = a - b;
230        assert_eq!(c, F32x4::new(-3.0, -1.0, 1.0, 3.0));
231    }
232}