proximipy/numerics/
f32vector.rs1use std::simd::{num::SimdFloat, Simd};
2
3const SIMD_LANECOUNT: usize = 8;
4type SimdF32 = Simd<f32, SIMD_LANECOUNT>;
5
6#[derive(Debug, Clone)]
7pub struct F32Vector<'a> {
8 array: &'a [f32],
9}
10
11impl<'a> F32Vector<'a> {
12 pub fn len(&self) -> usize {
13 self.array.len()
14 }
15
16 pub fn is_empty(&self) -> bool {
17 self.array.is_empty()
18 }
19
20 #[inline]
32 pub fn l2_dist_squared(&self, othr: &F32Vector<'a>) -> f32 {
33 debug_assert!(self.len() == othr.len());
34 debug_assert!(self.len() % SIMD_LANECOUNT == 0);
35
36 let mut intermediate_sum_x8 = Simd::<f32, SIMD_LANECOUNT>::splat(0.0);
37
38 let self_chunks = self.array.chunks_exact(SIMD_LANECOUNT);
39 let othr_chunks = othr.array.chunks_exact(SIMD_LANECOUNT);
40
41 for (slice_self, slice_othr) in self_chunks.zip(othr_chunks) {
42 let f32x8_slf = SimdF32::from_slice(slice_self);
43 let f32x8_oth = SimdF32::from_slice(slice_othr);
44 let diff = f32x8_slf - f32x8_oth;
45 intermediate_sum_x8 += diff * diff;
46 }
47
48 intermediate_sum_x8.reduce_sum() }
50
51 #[inline]
59 pub fn l2_dist(&self, other: &F32Vector<'a>) -> f32 {
60 self.l2_dist_squared(other).sqrt()
61 }
62}
63
64impl<'a> From<&'a [f32]> for F32Vector<'a> {
65 fn from(value: &'a [f32]) -> Self {
66 F32Vector { array: value }
67 }
68}
69
70impl PartialEq for F32Vector<'_> {
71 fn eq(&self, other: &Self) -> bool {
72 self.array
73 .iter()
74 .zip(other.array.iter())
75 .all(|(&a, &b)| a == b)
76 }
77}
78
79impl Eq for F32Vector<'_> {}
80
81impl std::hash::Hash for F32Vector<'_> {
82 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
83 for &value in self.array {
85 value.to_bits().hash(state); }
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use quickcheck::{QuickCheck, TestResult};
94
95 const TOLERANCE: f32 = 1e-8;
96
97 fn close(actual: f32, target: f32) -> bool {
98 (target - actual).abs() < TOLERANCE
99 }
100
101 fn is_valid_l2(suspect: f32) -> bool {
102 suspect.is_finite() && suspect >= 0.0
103 }
104
105 fn l2_spec<'a>(v1: F32Vector<'a>, v2: F32Vector<'a>) -> f32 {
106 v1.array
107 .iter()
108 .zip(v2.array.iter())
109 .map(|(&x, &y)| {
110 let diff = x - y;
111 diff * diff
112 })
113 .sum()
114 }
115
116 #[test]
117 fn self_sim_is_zero() {
118 fn qc_self_sim_is_zero(totest: Vec<f32>) -> TestResult {
119 let usable_length = totest.len() / 8 * 8;
120 if totest[0..usable_length].iter().any(|x| !x.is_finite()) {
121 return TestResult::discard();
122 }
123 let testvec = F32Vector::from(&totest[0..usable_length]);
124 let selfsim = testvec.l2_dist(&testvec);
125 let to_check = is_valid_l2(selfsim) && close(selfsim, 0.0);
126 return TestResult::from_bool(to_check);
127 }
128
129 QuickCheck::new()
130 .tests(10_000)
131 .min_tests_passed(1_000)
134 .quickcheck(qc_self_sim_is_zero as fn(Vec<f32>) -> TestResult);
135 }
136
137 #[test]
138 fn squared_invariant() {
141 fn qc_squared_invariant(u: Vec<f32>, v: Vec<f32>, w: Vec<f32>, x: Vec<f32>) -> TestResult {
142 let all_vecs = [u, v, w, x]; let min_length = all_vecs.iter().map(|x| x.len()).min().unwrap() / 8 * 8;
144 let all_vectors: Vec<F32Vector> = all_vecs
145 .iter()
146 .map(|vec| F32Vector::from(&vec[..min_length]))
147 .collect();
148
149 let d1_squared = all_vectors[0].l2_dist_squared(&all_vectors[1]);
150 let d2_squared = all_vectors[2].l2_dist_squared(&all_vectors[3]);
151
152 let d1_root = all_vectors[0].l2_dist(&all_vectors[1]);
153 let d2_root = all_vectors[2].l2_dist(&all_vectors[3]);
154
155 let sanity_check1 = (d1_squared < d2_squared) == (d1_root < d2_root);
156 let sanity_check2 = (d1_squared <= d2_squared) == (d1_root <= d2_root);
157 TestResult::from_bool(sanity_check1 && sanity_check2)
158 }
159
160 QuickCheck::new().tests(10_000).quickcheck(
161 qc_squared_invariant as fn(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) -> TestResult,
162 );
163 }
164
165 #[test]
166 fn simd_matches_spec() {
167 fn qc_simd_matches_spec(u: Vec<f32>, v: Vec<f32>) -> TestResult {
168 let min_length = u.len().min(v.len()) / 8 * 8;
169 let (u_f32v, v_f32v) = (
170 F32Vector::from(&u[0..min_length]),
171 F32Vector::from(&v[0..min_length]),
172 );
173 let simd = u_f32v.l2_dist_squared(&v_f32v);
174 let spec = l2_spec(u_f32v, v_f32v);
175
176 if simd.is_infinite() {
177 TestResult::from_bool(spec.is_infinite())
178 } else if simd.is_nan() {
179 TestResult::from_bool(spec.is_nan())
180 } else {
181 TestResult::from_bool(close(simd, spec))
182 }
183 }
184
185 QuickCheck::new()
186 .tests(10_000)
187 .quickcheck(qc_simd_matches_spec as fn(Vec<f32>, Vec<f32>) -> TestResult);
188 }
189}