lattice_embed/simd/
normalize.rs1#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6#[cfg(target_arch = "aarch64")]
7use std::arch::aarch64::*;
8
9use super::simd_config;
10
11#[cfg(target_arch = "x86_64")]
12use super::dot_product::{horizontal_sum_avx2, horizontal_sum_avx512};
13
14#[cfg(target_arch = "aarch64")]
15use super::dot_product::horizontal_sum_neon;
16
17#[inline]
19pub fn normalize(vector: &mut [f32]) {
20 let config = simd_config();
21
22 #[cfg(target_arch = "x86_64")]
23 {
24 if config.avx512f_enabled {
25 return unsafe { normalize_avx512_unrolled(vector) };
29 }
30 if config.avx2_enabled && config.fma_enabled {
31 return unsafe { normalize_avx2_unrolled(vector) };
35 }
36 }
37
38 #[cfg(target_arch = "aarch64")]
39 {
40 if config.neon_enabled {
41 return unsafe { normalize_neon_unrolled(vector) };
45 }
46 }
47
48 normalize_scalar(vector)
49}
50
51pub(crate) fn normalize_scalar(vector: &mut [f32]) {
53 let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
54 if norm > 0.0 {
55 let inv_norm = 1.0 / norm;
56 vector.iter_mut().for_each(|x| *x *= inv_norm);
57 }
58}
59
60#[cfg(target_arch = "x86_64")]
76#[target_feature(enable = "avx512f")]
77unsafe fn normalize_avx512_unrolled(vector: &mut [f32]) {
78 const SIMD_WIDTH: usize = 16;
79 const UNROLL: usize = 4;
80 const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
81
82 let n = vector.len();
83 let chunks = n / CHUNK_SIZE;
84 let main_processed = chunks * CHUNK_SIZE;
85 let remaining = n - main_processed;
86 let remaining_chunks = remaining / SIMD_WIDTH;
87
88 let mut norm0 = _mm512_setzero_ps();
90 let mut norm1 = _mm512_setzero_ps();
91 let mut norm2 = _mm512_setzero_ps();
92 let mut norm3 = _mm512_setzero_ps();
93
94 for i in 0..chunks {
95 let base = i * CHUNK_SIZE;
96
97 let v0 = _mm512_loadu_ps(vector.as_ptr().add(base));
98 norm0 = _mm512_fmadd_ps(v0, v0, norm0);
99
100 let v1 = _mm512_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH));
101 norm1 = _mm512_fmadd_ps(v1, v1, norm1);
102
103 let v2 = _mm512_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 2));
104 norm2 = _mm512_fmadd_ps(v2, v2, norm2);
105
106 let v3 = _mm512_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 3));
107 norm3 = _mm512_fmadd_ps(v3, v3, norm3);
108 }
109
110 let norm_vec = _mm512_add_ps(_mm512_add_ps(norm0, norm1), _mm512_add_ps(norm2, norm3));
111
112 let mut norm_remainder = _mm512_setzero_ps();
114 for i in 0..remaining_chunks {
115 let offset = main_processed + i * SIMD_WIDTH;
116 let v = _mm512_loadu_ps(vector.as_ptr().add(offset));
117 norm_remainder = _mm512_fmadd_ps(v, v, norm_remainder);
118 }
119
120 let mut norm_sq = horizontal_sum_avx512(norm_vec) + horizontal_sum_avx512(norm_remainder);
121
122 for i in (main_processed + remaining_chunks * SIMD_WIDTH)..n {
124 norm_sq += vector[i] * vector[i];
125 }
126
127 let norm = norm_sq.sqrt();
128 if norm == 0.0 {
129 return;
130 }
131
132 let inv_norm = 1.0 / norm;
133 let inv_norm_vec = _mm512_set1_ps(inv_norm);
134
135 for i in 0..chunks {
137 let base = i * CHUNK_SIZE;
138
139 let v0 = _mm512_loadu_ps(vector.as_ptr().add(base));
140 _mm512_storeu_ps(
141 vector.as_mut_ptr().add(base),
142 _mm512_mul_ps(v0, inv_norm_vec),
143 );
144
145 let v1 = _mm512_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH));
146 _mm512_storeu_ps(
147 vector.as_mut_ptr().add(base + SIMD_WIDTH),
148 _mm512_mul_ps(v1, inv_norm_vec),
149 );
150
151 let v2 = _mm512_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 2));
152 _mm512_storeu_ps(
153 vector.as_mut_ptr().add(base + SIMD_WIDTH * 2),
154 _mm512_mul_ps(v2, inv_norm_vec),
155 );
156
157 let v3 = _mm512_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 3));
158 _mm512_storeu_ps(
159 vector.as_mut_ptr().add(base + SIMD_WIDTH * 3),
160 _mm512_mul_ps(v3, inv_norm_vec),
161 );
162 }
163
164 for i in 0..remaining_chunks {
166 let offset = main_processed + i * SIMD_WIDTH;
167 let v = _mm512_loadu_ps(vector.as_ptr().add(offset));
168 _mm512_storeu_ps(
169 vector.as_mut_ptr().add(offset),
170 _mm512_mul_ps(v, inv_norm_vec),
171 );
172 }
173
174 for i in (main_processed + remaining_chunks * SIMD_WIDTH)..n {
176 vector[i] *= inv_norm;
177 }
178}
179
180#[cfg(target_arch = "x86_64")]
192#[target_feature(enable = "avx2", enable = "fma")]
193unsafe fn normalize_avx2_unrolled(vector: &mut [f32]) {
194 const SIMD_WIDTH: usize = 8;
195 const UNROLL: usize = 4;
196 const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
197 let n = vector.len();
198 let chunks = n / CHUNK_SIZE;
199
200 let mut norm0 = _mm256_setzero_ps();
202 let mut norm1 = _mm256_setzero_ps();
203 let mut norm2 = _mm256_setzero_ps();
204 let mut norm3 = _mm256_setzero_ps();
205
206 for i in 0..chunks {
207 let base = i * CHUNK_SIZE;
208
209 let v0 = _mm256_loadu_ps(vector.as_ptr().add(base));
210 norm0 = _mm256_fmadd_ps(v0, v0, norm0);
211
212 let v1 = _mm256_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH));
213 norm1 = _mm256_fmadd_ps(v1, v1, norm1);
214
215 let v2 = _mm256_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 2));
216 norm2 = _mm256_fmadd_ps(v2, v2, norm2);
217
218 let v3 = _mm256_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 3));
219 norm3 = _mm256_fmadd_ps(v3, v3, norm3);
220 }
221
222 let norm_vec = _mm256_add_ps(_mm256_add_ps(norm0, norm1), _mm256_add_ps(norm2, norm3));
223 let mut norm_sq = horizontal_sum_avx2(norm_vec);
224
225 for i in (chunks * CHUNK_SIZE)..n {
227 norm_sq += vector[i] * vector[i];
228 }
229
230 let norm = norm_sq.sqrt();
231 if norm == 0.0 {
232 return;
233 }
234
235 let inv_norm = 1.0 / norm;
236 let inv_norm_vec = _mm256_set1_ps(inv_norm);
237
238 for i in 0..chunks {
240 let base = i * CHUNK_SIZE;
241
242 let v0 = _mm256_loadu_ps(vector.as_ptr().add(base));
243 _mm256_storeu_ps(
244 vector.as_mut_ptr().add(base),
245 _mm256_mul_ps(v0, inv_norm_vec),
246 );
247
248 let v1 = _mm256_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH));
249 _mm256_storeu_ps(
250 vector.as_mut_ptr().add(base + SIMD_WIDTH),
251 _mm256_mul_ps(v1, inv_norm_vec),
252 );
253
254 let v2 = _mm256_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 2));
255 _mm256_storeu_ps(
256 vector.as_mut_ptr().add(base + SIMD_WIDTH * 2),
257 _mm256_mul_ps(v2, inv_norm_vec),
258 );
259
260 let v3 = _mm256_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 3));
261 _mm256_storeu_ps(
262 vector.as_mut_ptr().add(base + SIMD_WIDTH * 3),
263 _mm256_mul_ps(v3, inv_norm_vec),
264 );
265 }
266
267 for i in (chunks * CHUNK_SIZE)..n {
269 vector[i] *= inv_norm;
270 }
271}
272
273#[cfg(target_arch = "aarch64")]
285#[inline]
286unsafe fn normalize_neon_unrolled(vector: &mut [f32]) {
287 const SIMD_WIDTH: usize = 4;
288 const UNROLL: usize = 4;
289 const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
290 let n = vector.len();
291 let chunks = n / CHUNK_SIZE;
292
293 let mut norm0 = vdupq_n_f32(0.0);
295 let mut norm1 = vdupq_n_f32(0.0);
296 let mut norm2 = vdupq_n_f32(0.0);
297 let mut norm3 = vdupq_n_f32(0.0);
298
299 for i in 0..chunks {
300 let base = i * CHUNK_SIZE;
301
302 let v0 = vld1q_f32(vector.as_ptr().add(base));
303 norm0 = vfmaq_f32(norm0, v0, v0);
304
305 let v1 = vld1q_f32(vector.as_ptr().add(base + SIMD_WIDTH));
306 norm1 = vfmaq_f32(norm1, v1, v1);
307
308 let v2 = vld1q_f32(vector.as_ptr().add(base + SIMD_WIDTH * 2));
309 norm2 = vfmaq_f32(norm2, v2, v2);
310
311 let v3 = vld1q_f32(vector.as_ptr().add(base + SIMD_WIDTH * 3));
312 norm3 = vfmaq_f32(norm3, v3, v3);
313 }
314
315 let norm_vec = vaddq_f32(vaddq_f32(norm0, norm1), vaddq_f32(norm2, norm3));
316 let mut norm_sq = horizontal_sum_neon(norm_vec);
317
318 for val in vector.iter().skip(chunks * CHUNK_SIZE) {
320 norm_sq += val * val;
321 }
322
323 let norm = norm_sq.sqrt();
324 if norm == 0.0 {
325 return;
326 }
327
328 let inv_norm = 1.0 / norm;
329 let inv_norm_vec = vdupq_n_f32(inv_norm);
330
331 for i in 0..chunks {
333 let base = i * CHUNK_SIZE;
334
335 let v0 = vld1q_f32(vector.as_ptr().add(base));
336 vst1q_f32(vector.as_mut_ptr().add(base), vmulq_f32(v0, inv_norm_vec));
337
338 let v1 = vld1q_f32(vector.as_ptr().add(base + SIMD_WIDTH));
339 vst1q_f32(
340 vector.as_mut_ptr().add(base + SIMD_WIDTH),
341 vmulq_f32(v1, inv_norm_vec),
342 );
343
344 let v2 = vld1q_f32(vector.as_ptr().add(base + SIMD_WIDTH * 2));
345 vst1q_f32(
346 vector.as_mut_ptr().add(base + SIMD_WIDTH * 2),
347 vmulq_f32(v2, inv_norm_vec),
348 );
349
350 let v3 = vld1q_f32(vector.as_ptr().add(base + SIMD_WIDTH * 3));
351 vst1q_f32(
352 vector.as_mut_ptr().add(base + SIMD_WIDTH * 3),
353 vmulq_f32(v3, inv_norm_vec),
354 );
355 }
356
357 for val in vector.iter_mut().skip(chunks * CHUNK_SIZE) {
359 *val *= inv_norm;
360 }
361}