Skip to main content

lattice_embed/simd/
normalize.rs

1//! SIMD-accelerated vector normalization.
2
3#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6#[cfg(target_arch = "aarch64")]
7use std::arch::aarch64::*;
8
9use super::simd_config;
10
11#[cfg(target_arch = "x86_64")]
12use super::dot_product::{horizontal_sum_avx2, horizontal_sum_avx512};
13
14#[cfg(target_arch = "aarch64")]
15use super::dot_product::horizontal_sum_neon;
16
17/// **Unstable**: SIMD dispatch layer; use `lattice_embed::utils::normalize` for the stable wrapper.
18#[inline]
19pub fn normalize(vector: &mut [f32]) {
20    let config = simd_config();
21
22    #[cfg(target_arch = "x86_64")]
23    {
24        if config.avx512f_enabled {
25            // SAFETY: Runtime feature detection verified AVX-512F. The mutable
26            // slice is valid for the call lifetime; the callee uses unaligned
27            // loads/stores and chunk/remainder bounds that stay inside the slice.
28            return unsafe { normalize_avx512_unrolled(vector) };
29        }
30        if config.avx2_enabled && config.fma_enabled {
31            // SAFETY: Runtime feature detection verified AVX2+FMA. The mutable
32            // slice is valid for the call lifetime; the callee uses unaligned
33            // loads/stores and chunk/remainder bounds that stay inside the slice.
34            return unsafe { normalize_avx2_unrolled(vector) };
35        }
36    }
37
38    #[cfg(target_arch = "aarch64")]
39    {
40        if config.neon_enabled {
41            // SAFETY: NEON is available on aarch64. The mutable slice is valid
42            // for the call lifetime; the callee uses unaligned loads/stores and
43            // bounded chunk/remainder loops that stay inside the slice.
44            return unsafe { normalize_neon_unrolled(vector) };
45        }
46    }
47
48    normalize_scalar(vector)
49}
50
51/// Scalar normalization.
52pub(crate) fn normalize_scalar(vector: &mut [f32]) {
53    let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
54    if norm > 0.0 {
55        let inv_norm = 1.0 / norm;
56        vector.iter_mut().for_each(|x| *x *= inv_norm);
57    }
58}
59
60/// AVX-512F-accelerated normalization with 4x unrolling.
61///
62/// Performs two passes:
63/// 1. Compute L2 norm
64/// 2. Scale each element by 1 / norm
65///
66/// # Safety
67///
68/// Caller must ensure:
69/// - CPU supports AVX-512F instructions (verified via `simd_config()`)
70///
71/// Memory safety:
72/// - Uses `_mm512_loadu_ps`/`_mm512_storeu_ps` for unaligned access
73/// - Pointer arithmetic stays within slice bounds via chunk calculation
74/// - Remainder loops use safe indexing
75#[cfg(target_arch = "x86_64")]
76#[target_feature(enable = "avx512f")]
77unsafe fn normalize_avx512_unrolled(vector: &mut [f32]) {
78    const SIMD_WIDTH: usize = 16;
79    const UNROLL: usize = 4;
80    const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
81
82    let n = vector.len();
83    let chunks = n / CHUNK_SIZE;
84    let main_processed = chunks * CHUNK_SIZE;
85    let remaining = n - main_processed;
86    let remaining_chunks = remaining / SIMD_WIDTH;
87
88    // First pass: compute L2 norm with 4 accumulators
89    let mut norm0 = _mm512_setzero_ps();
90    let mut norm1 = _mm512_setzero_ps();
91    let mut norm2 = _mm512_setzero_ps();
92    let mut norm3 = _mm512_setzero_ps();
93
94    for i in 0..chunks {
95        let base = i * CHUNK_SIZE;
96
97        let v0 = _mm512_loadu_ps(vector.as_ptr().add(base));
98        norm0 = _mm512_fmadd_ps(v0, v0, norm0);
99
100        let v1 = _mm512_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH));
101        norm1 = _mm512_fmadd_ps(v1, v1, norm1);
102
103        let v2 = _mm512_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 2));
104        norm2 = _mm512_fmadd_ps(v2, v2, norm2);
105
106        let v3 = _mm512_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 3));
107        norm3 = _mm512_fmadd_ps(v3, v3, norm3);
108    }
109
110    let norm_vec = _mm512_add_ps(_mm512_add_ps(norm0, norm1), _mm512_add_ps(norm2, norm3));
111
112    // Remainder for norm calculation with single-register AVX-512F loop
113    let mut norm_remainder = _mm512_setzero_ps();
114    for i in 0..remaining_chunks {
115        let offset = main_processed + i * SIMD_WIDTH;
116        let v = _mm512_loadu_ps(vector.as_ptr().add(offset));
117        norm_remainder = _mm512_fmadd_ps(v, v, norm_remainder);
118    }
119
120    let mut norm_sq = horizontal_sum_avx512(norm_vec) + horizontal_sum_avx512(norm_remainder);
121
122    // Scalar tail for norm (recomputed inline to avoid cross-pass variable dependency)
123    for i in (main_processed + remaining_chunks * SIMD_WIDTH)..n {
124        norm_sq += vector[i] * vector[i];
125    }
126
127    let norm = norm_sq.sqrt();
128    if norm == 0.0 {
129        return;
130    }
131
132    let inv_norm = 1.0 / norm;
133    let inv_norm_vec = _mm512_set1_ps(inv_norm);
134
135    // Second pass: scale by inverse norm with 4x unrolling
136    for i in 0..chunks {
137        let base = i * CHUNK_SIZE;
138
139        let v0 = _mm512_loadu_ps(vector.as_ptr().add(base));
140        _mm512_storeu_ps(
141            vector.as_mut_ptr().add(base),
142            _mm512_mul_ps(v0, inv_norm_vec),
143        );
144
145        let v1 = _mm512_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH));
146        _mm512_storeu_ps(
147            vector.as_mut_ptr().add(base + SIMD_WIDTH),
148            _mm512_mul_ps(v1, inv_norm_vec),
149        );
150
151        let v2 = _mm512_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 2));
152        _mm512_storeu_ps(
153            vector.as_mut_ptr().add(base + SIMD_WIDTH * 2),
154            _mm512_mul_ps(v2, inv_norm_vec),
155        );
156
157        let v3 = _mm512_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 3));
158        _mm512_storeu_ps(
159            vector.as_mut_ptr().add(base + SIMD_WIDTH * 3),
160            _mm512_mul_ps(v3, inv_norm_vec),
161        );
162    }
163
164    // Remainder for scaling with single-register AVX-512F loop
165    for i in 0..remaining_chunks {
166        let offset = main_processed + i * SIMD_WIDTH;
167        let v = _mm512_loadu_ps(vector.as_ptr().add(offset));
168        _mm512_storeu_ps(
169            vector.as_mut_ptr().add(offset),
170            _mm512_mul_ps(v, inv_norm_vec),
171        );
172    }
173
174    // Final scalar remainder (recomputed inline to avoid cross-pass variable dependency)
175    for i in (main_processed + remaining_chunks * SIMD_WIDTH)..n {
176        vector[i] *= inv_norm;
177    }
178}
179
180/// AVX2-accelerated normalization with 4x unrolling.
181///
182/// # Safety
183///
184/// Caller must ensure:
185/// - CPU supports AVX2 and FMA instructions (verified via `simd_config()`)
186///
187/// Memory safety:
188/// - Uses `_mm256_loadu_ps`/`_mm256_storeu_ps` for unaligned access (safe for any alignment)
189/// - Pointer arithmetic stays within slice bounds via chunk calculation
190/// - Remainder loop uses safe indexing
191#[cfg(target_arch = "x86_64")]
192#[target_feature(enable = "avx2", enable = "fma")]
193unsafe fn normalize_avx2_unrolled(vector: &mut [f32]) {
194    const SIMD_WIDTH: usize = 8;
195    const UNROLL: usize = 4;
196    const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
197    let n = vector.len();
198    let chunks = n / CHUNK_SIZE;
199
200    // First pass: compute L2 norm with 4 accumulators
201    let mut norm0 = _mm256_setzero_ps();
202    let mut norm1 = _mm256_setzero_ps();
203    let mut norm2 = _mm256_setzero_ps();
204    let mut norm3 = _mm256_setzero_ps();
205
206    for i in 0..chunks {
207        let base = i * CHUNK_SIZE;
208
209        let v0 = _mm256_loadu_ps(vector.as_ptr().add(base));
210        norm0 = _mm256_fmadd_ps(v0, v0, norm0);
211
212        let v1 = _mm256_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH));
213        norm1 = _mm256_fmadd_ps(v1, v1, norm1);
214
215        let v2 = _mm256_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 2));
216        norm2 = _mm256_fmadd_ps(v2, v2, norm2);
217
218        let v3 = _mm256_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 3));
219        norm3 = _mm256_fmadd_ps(v3, v3, norm3);
220    }
221
222    let norm_vec = _mm256_add_ps(_mm256_add_ps(norm0, norm1), _mm256_add_ps(norm2, norm3));
223    let mut norm_sq = horizontal_sum_avx2(norm_vec);
224
225    // Remainder for norm calculation
226    for i in (chunks * CHUNK_SIZE)..n {
227        norm_sq += vector[i] * vector[i];
228    }
229
230    let norm = norm_sq.sqrt();
231    if norm == 0.0 {
232        return;
233    }
234
235    let inv_norm = 1.0 / norm;
236    let inv_norm_vec = _mm256_set1_ps(inv_norm);
237
238    // Second pass: divide by norm with 4x unrolling
239    for i in 0..chunks {
240        let base = i * CHUNK_SIZE;
241
242        let v0 = _mm256_loadu_ps(vector.as_ptr().add(base));
243        _mm256_storeu_ps(
244            vector.as_mut_ptr().add(base),
245            _mm256_mul_ps(v0, inv_norm_vec),
246        );
247
248        let v1 = _mm256_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH));
249        _mm256_storeu_ps(
250            vector.as_mut_ptr().add(base + SIMD_WIDTH),
251            _mm256_mul_ps(v1, inv_norm_vec),
252        );
253
254        let v2 = _mm256_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 2));
255        _mm256_storeu_ps(
256            vector.as_mut_ptr().add(base + SIMD_WIDTH * 2),
257            _mm256_mul_ps(v2, inv_norm_vec),
258        );
259
260        let v3 = _mm256_loadu_ps(vector.as_ptr().add(base + SIMD_WIDTH * 3));
261        _mm256_storeu_ps(
262            vector.as_mut_ptr().add(base + SIMD_WIDTH * 3),
263            _mm256_mul_ps(v3, inv_norm_vec),
264        );
265    }
266
267    // Remainder for scaling
268    for i in (chunks * CHUNK_SIZE)..n {
269        vector[i] *= inv_norm;
270    }
271}
272
273/// NEON-accelerated normalization with 4x unrolling.
274///
275/// # Safety
276///
277/// Caller must ensure:
278/// - Running on aarch64 (NEON is mandatory, always available)
279///
280/// Memory safety:
281/// - Uses `vld1q_f32`/`vst1q_f32` for loads/stores (handles any alignment)
282/// - Pointer arithmetic stays within slice bounds via chunk calculation
283/// - Remainder loop uses safe iteration
284#[cfg(target_arch = "aarch64")]
285#[inline]
286unsafe fn normalize_neon_unrolled(vector: &mut [f32]) {
287    const SIMD_WIDTH: usize = 4;
288    const UNROLL: usize = 4;
289    const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
290    let n = vector.len();
291    let chunks = n / CHUNK_SIZE;
292
293    // First pass: compute L2 norm with 4 accumulators
294    let mut norm0 = vdupq_n_f32(0.0);
295    let mut norm1 = vdupq_n_f32(0.0);
296    let mut norm2 = vdupq_n_f32(0.0);
297    let mut norm3 = vdupq_n_f32(0.0);
298
299    for i in 0..chunks {
300        let base = i * CHUNK_SIZE;
301
302        let v0 = vld1q_f32(vector.as_ptr().add(base));
303        norm0 = vfmaq_f32(norm0, v0, v0);
304
305        let v1 = vld1q_f32(vector.as_ptr().add(base + SIMD_WIDTH));
306        norm1 = vfmaq_f32(norm1, v1, v1);
307
308        let v2 = vld1q_f32(vector.as_ptr().add(base + SIMD_WIDTH * 2));
309        norm2 = vfmaq_f32(norm2, v2, v2);
310
311        let v3 = vld1q_f32(vector.as_ptr().add(base + SIMD_WIDTH * 3));
312        norm3 = vfmaq_f32(norm3, v3, v3);
313    }
314
315    let norm_vec = vaddq_f32(vaddq_f32(norm0, norm1), vaddq_f32(norm2, norm3));
316    let mut norm_sq = horizontal_sum_neon(norm_vec);
317
318    // Remainder for norm calculation
319    for val in vector.iter().skip(chunks * CHUNK_SIZE) {
320        norm_sq += val * val;
321    }
322
323    let norm = norm_sq.sqrt();
324    if norm == 0.0 {
325        return;
326    }
327
328    let inv_norm = 1.0 / norm;
329    let inv_norm_vec = vdupq_n_f32(inv_norm);
330
331    // Second pass: divide by norm with 4x unrolling
332    for i in 0..chunks {
333        let base = i * CHUNK_SIZE;
334
335        let v0 = vld1q_f32(vector.as_ptr().add(base));
336        vst1q_f32(vector.as_mut_ptr().add(base), vmulq_f32(v0, inv_norm_vec));
337
338        let v1 = vld1q_f32(vector.as_ptr().add(base + SIMD_WIDTH));
339        vst1q_f32(
340            vector.as_mut_ptr().add(base + SIMD_WIDTH),
341            vmulq_f32(v1, inv_norm_vec),
342        );
343
344        let v2 = vld1q_f32(vector.as_ptr().add(base + SIMD_WIDTH * 2));
345        vst1q_f32(
346            vector.as_mut_ptr().add(base + SIMD_WIDTH * 2),
347            vmulq_f32(v2, inv_norm_vec),
348        );
349
350        let v3 = vld1q_f32(vector.as_ptr().add(base + SIMD_WIDTH * 3));
351        vst1q_f32(
352            vector.as_mut_ptr().add(base + SIMD_WIDTH * 3),
353            vmulq_f32(v3, inv_norm_vec),
354        );
355    }
356
357    // Remainder for scaling
358    for val in vector.iter_mut().skip(chunks * CHUNK_SIZE) {
359        *val *= inv_norm;
360    }
361}