proximipy/numerics/
f32vector.rs

1use std::simd::{num::SimdFloat, Simd};
2
3const SIMD_LANECOUNT: usize = 8;
4type SimdF32 = Simd<f32, SIMD_LANECOUNT>;
5
6#[derive(Debug, Clone)]
7pub struct F32Vector<'a> {
8    array: &'a [f32],
9}
10
11impl<'a> F32Vector<'a> {
12    pub fn len(&self) -> usize {
13        self.array.len()
14    }
15
16    pub fn is_empty(&self) -> bool {
17        self.array.is_empty()
18    }
19
20    /// # Usage
21    /// Computes the **SQUARED** L2 distance between two vectors.
22    /// This is cheaper to compute than the regular L2 distance.
23    /// This is typically useful when comparing two distances :
24    ///
25    /// dist(u,v) < dist(w, x) ⇔ dist(u,v) ** 2 < dist(w,x) ** 2
26    ///
27    /// # Panics
28    ///
29    /// Panics in debug mode if the two vectors have different lengths.
30    /// In release mode, the longest vector will be silently truncated.
31    #[inline]
32    pub fn l2_dist_squared(&self, othr: &F32Vector<'a>) -> f32 {
33        debug_assert!(self.len() == othr.len());
34        debug_assert!(self.len() % SIMD_LANECOUNT == 0);
35
36        let mut intermediate_sum_x8 = Simd::<f32, SIMD_LANECOUNT>::splat(0.0);
37
38        let self_chunks = self.array.chunks_exact(SIMD_LANECOUNT);
39        let othr_chunks = othr.array.chunks_exact(SIMD_LANECOUNT);
40
41        for (slice_self, slice_othr) in self_chunks.zip(othr_chunks) {
42            let f32x8_slf = SimdF32::from_slice(slice_self);
43            let f32x8_oth = SimdF32::from_slice(slice_othr);
44            let diff = f32x8_slf - f32x8_oth;
45            intermediate_sum_x8 += diff * diff;
46        }
47
48        intermediate_sum_x8.reduce_sum() // 8-to-1 sum
49    }
50
51    /// # Usage
52    /// Computes the L2 distance between two vectors.
53    ///
54    /// # Panics
55    ///
56    /// Panics in debug mode if the two vectors have different lengths.
57    /// In release mode, the longest vector will be silently truncated.
58    #[inline]
59    pub fn l2_dist(&self, other: &F32Vector<'a>) -> f32 {
60        self.l2_dist_squared(other).sqrt()
61    }
62}
63
64impl<'a> From<&'a [f32]> for F32Vector<'a> {
65    fn from(value: &'a [f32]) -> Self {
66        F32Vector { array: value }
67    }
68}
69
70impl PartialEq for F32Vector<'_> {
71    fn eq(&self, other: &Self) -> bool {
72        self.array
73            .iter()
74            .zip(other.array.iter())
75            .all(|(&a, &b)| a == b)
76    }
77}
78
79impl Eq for F32Vector<'_> {}
80
81impl std::hash::Hash for F32Vector<'_> {
82    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
83        // Iterate through each element of the slice and hash it
84        for &value in self.array {
85            value.to_bits().hash(state); // Convert `f32` to its bit representation for consistent hashing
86        }
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use quickcheck::{QuickCheck, TestResult};
94
95    const TOLERANCE: f32 = 1e-8;
96
97    fn close(actual: f32, target: f32) -> bool {
98        (target - actual).abs() < TOLERANCE
99    }
100
101    fn is_valid_l2(suspect: f32) -> bool {
102        suspect.is_finite() && suspect >= 0.0
103    }
104
105    fn l2_spec<'a>(v1: F32Vector<'a>, v2: F32Vector<'a>) -> f32 {
106        v1.array
107            .iter()
108            .zip(v2.array.iter())
109            .map(|(&x, &y)| {
110                let diff = x - y;
111                diff * diff
112            })
113            .sum()
114    }
115
116    #[test]
117    fn self_sim_is_zero() {
118        fn qc_self_sim_is_zero(totest: Vec<f32>) -> TestResult {
119            let usable_length = totest.len() / 8 * 8;
120            if totest[0..usable_length].iter().any(|x| !x.is_finite()) {
121                return TestResult::discard();
122            }
123            let testvec = F32Vector::from(&totest[0..usable_length]);
124            let selfsim = testvec.l2_dist(&testvec);
125            let to_check = is_valid_l2(selfsim) && close(selfsim, 0.0);
126            return TestResult::from_bool(to_check);
127        }
128
129        QuickCheck::new()
130            .tests(10_000)
131            // force that less than 90% of tests are discarded due to precondition violations
132            // i.e. at least 10% of inputs should be valid so that we cover a good range
133            .min_tests_passed(1_000)
134            .quickcheck(qc_self_sim_is_zero as fn(Vec<f32>) -> TestResult);
135    }
136
137    #[test]
138    // verifies the claim in the documentation of l2_dist_squared
139    // i.e. dist(u,v) < dist(w, x) ⇔ dist(u,v) ** 2 < dist(w,x) ** 2
140    fn squared_invariant() {
141        fn qc_squared_invariant(u: Vec<f32>, v: Vec<f32>, w: Vec<f32>, x: Vec<f32>) -> TestResult {
142            let all_vecs = [u, v, w, x]; //no need to check for NaNs in this case
143            let min_length = all_vecs.iter().map(|x| x.len()).min().unwrap() / 8 * 8;
144            let all_vectors: Vec<F32Vector> = all_vecs
145                .iter()
146                .map(|vec| F32Vector::from(&vec[..min_length]))
147                .collect();
148
149            let d1_squared = all_vectors[0].l2_dist_squared(&all_vectors[1]);
150            let d2_squared = all_vectors[2].l2_dist_squared(&all_vectors[3]);
151
152            let d1_root = all_vectors[0].l2_dist(&all_vectors[1]);
153            let d2_root = all_vectors[2].l2_dist(&all_vectors[3]);
154
155            let sanity_check1 = (d1_squared < d2_squared) == (d1_root < d2_root);
156            let sanity_check2 = (d1_squared <= d2_squared) == (d1_root <= d2_root);
157            TestResult::from_bool(sanity_check1 && sanity_check2)
158        }
159
160        QuickCheck::new().tests(10_000).quickcheck(
161            qc_squared_invariant as fn(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) -> TestResult,
162        );
163    }
164
165    #[test]
166    fn simd_matches_spec() {
167        fn qc_simd_matches_spec(u: Vec<f32>, v: Vec<f32>) -> TestResult {
168            let min_length = u.len().min(v.len()) / 8 * 8;
169            let (u_f32v, v_f32v) = (
170                F32Vector::from(&u[0..min_length]),
171                F32Vector::from(&v[0..min_length]),
172            );
173            let simd = u_f32v.l2_dist_squared(&v_f32v);
174            let spec = l2_spec(u_f32v, v_f32v);
175
176            if simd.is_infinite() {
177                TestResult::from_bool(spec.is_infinite())
178            } else if simd.is_nan() {
179                TestResult::from_bool(spec.is_nan())
180            } else {
181                TestResult::from_bool(close(simd, spec))
182            }
183        }
184
185        QuickCheck::new()
186            .tests(10_000)
187            .quickcheck(qc_simd_matches_spec as fn(Vec<f32>, Vec<f32>) -> TestResult);
188    }
189}