Skip to main content

sklears_simd/
intrinsics.rs

1//! Intrinsic function wrappers and compiler optimization hints
2//!
3//! This module provides safe wrappers around SIMD intrinsics and compiler
4//! optimization hints to improve code generation and performance.
5
6/// Compiler hints for optimization
7pub mod hints {
8    /// Hint to the compiler that this branch is likely to be taken
9    #[inline(always)]
10    pub fn likely(b: bool) -> bool {
11        // Use manual implementation for stable Rust
12        if b {
13            #[cfg(feature = "no-std")]
14            {
15                core::hint::black_box(true)
16            }
17            #[cfg(not(feature = "no-std"))]
18            {
19                std::hint::black_box(true)
20            }
21        } else {
22            false
23        }
24    }
25
26    /// Hint to the compiler that this branch is unlikely to be taken
27    #[inline(always)]
28    pub fn unlikely(b: bool) -> bool {
29        // Use manual implementation for stable Rust
30        if !b {
31            #[cfg(feature = "no-std")]
32            {
33                core::hint::black_box(false)
34            }
35            #[cfg(not(feature = "no-std"))]
36            {
37                std::hint::black_box(false)
38            }
39        } else {
40            true
41        }
42    }
43
44    /// Hint to the compiler that this code path is unreachable.
45    ///
46    /// # Safety
47    ///
48    /// Calling this function when the code path is actually reachable is undefined behaviour.
49    #[inline(always)]
50    pub unsafe fn unreachable_unchecked() -> ! {
51        #[cfg(feature = "no-std")]
52        {
53            core::hint::unreachable_unchecked()
54        }
55        #[cfg(not(feature = "no-std"))]
56        {
57            std::hint::unreachable_unchecked()
58        }
59    }
60
61    /// Hint to prevent vectorization of a loop
62    #[inline(always)]
63    pub fn prevent_vectorization() {
64        // Insert a volatile operation to prevent vectorization
65        unsafe {
66            #[cfg(feature = "no-std")]
67            {
68                core::ptr::read_volatile(&0 as *const i32);
69            }
70            #[cfg(not(feature = "no-std"))]
71            {
72                std::ptr::read_volatile(&0 as *const i32);
73            }
74        }
75    }
76
77    /// Force vectorization of a loop (when possible)
78    #[inline(always)]
79    pub fn force_vectorization() {
80        // This is a hint to encourage vectorization
81        // The actual mechanism depends on the compiler
82    }
83}
84
85/// Memory alignment utilities
86pub mod alignment {
87    /// Check if a pointer is aligned to the specified boundary
88    #[inline(always)]
89    pub fn is_aligned<T>(ptr: *const T, alignment: usize) -> bool {
90        (ptr as usize).is_multiple_of(alignment)
91    }
92
93    /// Assume that a pointer is aligned (optimization hint).
94    ///
95    /// # Safety
96    ///
97    /// `ptr` must be aligned to `alignment` bytes; calling this function with a
98    /// misaligned pointer is undefined behaviour.
99    #[inline(always)]
100    pub unsafe fn assume_aligned<T>(ptr: *const T, alignment: usize) -> *const T {
101        if !is_aligned(ptr, alignment) {
102            // Safety: caller guarantees alignment; this branch is unreachable.
103            unsafe { core::hint::unreachable_unchecked() }
104        }
105        ptr
106    }
107
108    /// Assume that a mutable pointer is aligned (optimization hint).
109    ///
110    /// # Safety
111    ///
112    /// `ptr` must be aligned to `alignment` bytes; calling this function with a
113    /// misaligned pointer is undefined behaviour.
114    #[inline(always)]
115    pub unsafe fn assume_aligned_mut<T>(ptr: *mut T, alignment: usize) -> *mut T {
116        if !is_aligned(ptr, alignment) {
117            // Safety: caller guarantees alignment; this branch is unreachable.
118            unsafe { core::hint::unreachable_unchecked() }
119        }
120        ptr
121    }
122}
123
124/// SIMD intrinsic wrappers for safe usage
125#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
126pub mod x86 {
127    #[cfg(feature = "no-std")]
128    use core::arch::x86_64::*;
129    #[cfg(not(feature = "no-std"))]
130    use core::arch::x86_64::*;
131
132    /// Safe wrapper for SSE2 operations
133    pub mod sse2 {
134        use super::*;
135
136        /// Safe horizontal add for __m128
137        pub fn horizontal_add_f32(v: __m128) -> f32 {
138            unsafe {
139                let temp = _mm_hadd_ps(v, v);
140                let temp = _mm_hadd_ps(temp, temp);
141                _mm_cvtss_f32(temp)
142            }
143        }
144
145        /// Load aligned f32 vector from raw pointer.
146        ///
147        /// # Safety
148        ///
149        /// `ptr` must be non-null, aligned to a 16-byte boundary, and point to at least 4 valid
150        /// `f32` values. Passing a misaligned or dangling pointer is undefined behaviour.
151        pub unsafe fn load_aligned_f32(ptr: *const f32) -> __m128 {
152            debug_assert!(super::super::alignment::is_aligned(ptr, 16));
153            unsafe { _mm_load_ps(ptr) }
154        }
155
156        /// Store aligned f32 vector to raw pointer.
157        ///
158        /// # Safety
159        ///
160        /// `ptr` must be non-null, aligned to a 16-byte boundary, and point to at least 4 writable
161        /// `f32` slots. Passing a misaligned or dangling pointer is undefined behaviour.
162        pub unsafe fn store_aligned_f32(ptr: *mut f32, v: __m128) {
163            debug_assert!(super::super::alignment::is_aligned(ptr, 16));
164            unsafe { _mm_store_ps(ptr, v) }
165        }
166
167        /// Safe fused multiply-add.
168        ///
169        /// # Safety
170        ///
171        /// The caller must ensure that the `fma` target feature is enabled at runtime.
172        #[target_feature(enable = "fma")]
173        pub unsafe fn fma_f32(a: __m128, b: __m128, c: __m128) -> __m128 {
174            _mm_fmadd_ps(a, b, c)
175        }
176    }
177
178    /// Safe wrapper for AVX2 operations
179    pub mod avx2 {
180        use super::*;
181
182        /// Safe horizontal add for __m256
183        pub fn horizontal_add_f32(v: __m256) -> f32 {
184            unsafe {
185                let hi = _mm256_extractf128_ps(v, 1);
186                let lo = _mm256_castps256_ps128(v);
187                let sum128 = _mm_add_ps(hi, lo);
188                let temp = _mm_hadd_ps(sum128, sum128);
189                let temp = _mm_hadd_ps(temp, temp);
190                _mm_cvtss_f32(temp)
191            }
192        }
193
194        /// Load aligned f32 vector from raw pointer.
195        ///
196        /// # Safety
197        ///
198        /// `ptr` must be non-null, aligned to a 32-byte boundary, and point to at least 8 valid
199        /// `f32` values. Passing a misaligned or dangling pointer is undefined behaviour.
200        pub unsafe fn load_aligned_f32(ptr: *const f32) -> __m256 {
201            debug_assert!(super::super::alignment::is_aligned(ptr, 32));
202            unsafe { _mm256_load_ps(ptr) }
203        }
204
205        /// Store aligned f32 vector to raw pointer.
206        ///
207        /// # Safety
208        ///
209        /// `ptr` must be non-null, aligned to a 32-byte boundary, and point to at least 8 writable
210        /// `f32` slots. Passing a misaligned or dangling pointer is undefined behaviour.
211        pub unsafe fn store_aligned_f32(ptr: *mut f32, v: __m256) {
212            debug_assert!(super::super::alignment::is_aligned(ptr, 32));
213            unsafe { _mm256_store_ps(ptr, v) }
214        }
215
216        /// Safe fused multiply-add.
217        ///
218        /// # Safety
219        ///
220        /// The caller must ensure that the `fma` target feature is enabled at runtime.
221        #[target_feature(enable = "fma")]
222        pub unsafe fn fma_f32(a: __m256, b: __m256, c: __m256) -> __m256 {
223            _mm256_fmadd_ps(a, b, c)
224        }
225
226        /// Safe blend operation with compile-time mask
227        pub fn blend_f32<const MASK: i32>(a: __m256, b: __m256) -> __m256 {
228            unsafe { _mm256_blend_ps(a, b, MASK) }
229        }
230    }
231
232    /// Safe wrapper for AVX-512 operations (when available)
233    #[cfg(target_feature = "avx512f")]
234    pub mod avx512 {
235        use super::*;
236
237        /// Safe vector load with alignment check
238        pub fn load_aligned_f32(ptr: *const f32) -> __m512 {
239            debug_assert!(super::super::alignment::is_aligned(ptr, 64));
240            unsafe { _mm512_load_ps(ptr) }
241        }
242
243        /// Safe vector store with alignment check  
244        pub fn store_aligned_f32(ptr: *mut f32, v: __m512) {
245            debug_assert!(super::super::alignment::is_aligned(ptr, 64));
246            unsafe { _mm512_store_ps(ptr, v) }
247        }
248
249        /// Safe fused multiply-add
250        pub fn fma_f32(a: __m512, b: __m512, c: __m512) -> __m512 {
251            unsafe { _mm512_fmadd_ps(a, b, c) }
252        }
253
254        /// Safe horizontal reduction sum
255        pub fn reduce_add_f32(v: __m512) -> f32 {
256            unsafe { _mm512_reduce_add_ps(v) }
257        }
258    }
259}
260
261/// ARM NEON intrinsic wrappers
262#[cfg(target_arch = "aarch64")]
263pub mod neon {
264    #[cfg(feature = "no-std")]
265    use core::arch::aarch64::*;
266    #[cfg(not(feature = "no-std"))]
267    use core::arch::aarch64::*;
268
269    /// Safe horizontal add for float32x4_t
270    pub fn horizontal_add_f32(v: float32x4_t) -> f32 {
271        unsafe { vaddvq_f32(v) }
272    }
273
274    /// Safe vector load with alignment check.
275    ///
276    /// # Safety
277    ///
278    /// `ptr` must be valid, non-null, and point to at least 4 initialized `f32` values
279    /// aligned to a 16-byte boundary.
280    pub unsafe fn load_aligned_f32(ptr: *const f32) -> float32x4_t {
281        debug_assert!(super::alignment::is_aligned(ptr, 16));
282        unsafe { vld1q_f32(ptr) }
283    }
284
285    /// Safe vector store with alignment check.
286    ///
287    /// # Safety
288    ///
289    /// `ptr` must be valid, non-null, and point to writable storage for at least 4 `f32`
290    /// values aligned to a 16-byte boundary.
291    pub unsafe fn store_aligned_f32(ptr: *mut f32, v: float32x4_t) {
292        debug_assert!(super::alignment::is_aligned(ptr, 16));
293        unsafe { vst1q_f32(ptr, v) }
294    }
295
296    /// Safe fused multiply-add
297    pub fn fma_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
298        unsafe { vfmaq_f32(c, a, b) }
299    }
300}
301
302/// Branch prediction and loop optimization
303pub mod optimization {
304    /// Mark a loop for potential unrolling
305    #[inline(always)]
306    pub fn suggest_unroll<F>(iterations: usize, mut f: F)
307    where
308        F: FnMut(usize),
309    {
310        for i in 0..iterations {
311            f(i);
312        }
313    }
314
315    /// Prefetch hint for upcoming memory access
316    #[inline(always)]
317    pub fn prefetch_hint<T>(_ptr: *const T) {
318        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
319        unsafe {
320            #[cfg(feature = "no-std")]
321            {
322                core::arch::x86_64::_mm_prefetch(
323                    _ptr as *const i8,
324                    core::arch::x86_64::_MM_HINT_T0,
325                );
326            }
327            #[cfg(not(feature = "no-std"))]
328            {
329                core::arch::x86_64::_mm_prefetch(
330                    _ptr as *const i8,
331                    core::arch::x86_64::_MM_HINT_T0,
332                );
333            }
334        }
335    }
336
337    /// Cold function annotation (hint for code layout)
338    #[cold]
339    pub fn cold_path() {
340        // This function is marked as cold, compiler will optimize for size
341    }
342
343    /// Hot function annotation (hint for aggressive optimization)
344    #[inline(always)]
345    pub fn hot_path() {
346        // This function is hot, compiler will optimize for speed
347    }
348}
349
350/// Auto-vectorization helpers
351pub mod vectorization {
352    /// Helper to enable vectorization for simple operations
353    pub fn vectorize_simple_op<T, F>(src: &[T], dest: &mut [T], op: F)
354    where
355        T: Copy,
356        F: Fn(T) -> T,
357    {
358        assert_eq!(src.len(), dest.len());
359
360        // Hint to compiler for vectorization
361        #[allow(clippy::needless_range_loop)]
362        for i in 0..src.len() {
363            dest[i] = op(src[i]);
364        }
365    }
366
367    /// Helper for vectorized binary operations
368    pub fn vectorize_binary_op<T, F>(a: &[T], b: &[T], dest: &mut [T], op: F)
369    where
370        T: Copy,
371        F: Fn(T, T) -> T,
372    {
373        assert_eq!(a.len(), b.len());
374        assert_eq!(a.len(), dest.len());
375
376        // Hint to compiler for vectorization
377        #[allow(clippy::needless_range_loop)]
378        for i in 0..a.len() {
379            dest[i] = op(a[i], b[i]);
380        }
381    }
382
383    /// Vectorization hint with stride patterns
384    pub fn vectorize_strided<T, F>(src: &[T], dest: &mut [T], stride: usize, op: F)
385    where
386        T: Copy,
387        F: Fn(T) -> T,
388    {
389        let mut i = 0;
390        while i < src.len() {
391            dest[i] = op(src[i]);
392            i += stride;
393        }
394    }
395}
396
397/// Performance measurement utilities
398pub mod perf {
399    #[cfg(not(feature = "no-std"))]
400    use std::time::Instant;
401
402    #[cfg(feature = "no-std")]
403    use core::time::Duration;
404    #[cfg(not(feature = "no-std"))]
405    use std::time::Duration;
406
407    /// High-precision timing for micro-benchmarks
408    #[cfg(not(feature = "no-std"))]
409    pub fn time_operation<F, R>(op: F) -> (R, Duration)
410    where
411        F: FnOnce() -> R,
412    {
413        let start = Instant::now();
414        let result = op();
415        let elapsed = start.elapsed();
416        (result, elapsed)
417    }
418
419    /// Mock timing for no-std environments
420    #[cfg(feature = "no-std")]
421    pub fn time_operation<F, R>(op: F) -> (R, Duration)
422    where
423        F: FnOnce() -> R,
424    {
425        let result = op();
426        // Return mock duration for no-std compatibility
427        (result, Duration::from_nanos(0))
428    }
429
430    /// CPU cycle counter (x86 only)
431    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
432    pub fn rdtsc() -> u64 {
433        unsafe {
434            #[cfg(feature = "no-std")]
435            {
436                core::arch::x86_64::_rdtsc()
437            }
438            #[cfg(not(feature = "no-std"))]
439            {
440                core::arch::x86_64::_rdtsc()
441            }
442        }
443    }
444
445    /// Memory fence for timing measurements
446    pub fn memory_fence() {
447        #[cfg(feature = "no-std")]
448        {
449            core::sync::atomic::fence(core::sync::atomic::Ordering::SeqCst);
450        }
451        #[cfg(not(feature = "no-std"))]
452        {
453            std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst);
454        }
455    }
456}
457
458#[allow(non_snake_case)]
459#[cfg(all(test, not(feature = "no-std")))]
460mod tests {
461    use super::*;
462
463    #[cfg(feature = "no-std")]
464    use alloc::{vec, vec::Vec};
465
466    #[test]
467    fn test_alignment_check() {
468        let data = [1.0f32; 16];
469        let ptr = data.as_ptr();
470
471        // Most allocators align to at least 8 bytes
472        assert!(alignment::is_aligned(ptr, 4));
473    }
474
475    #[test]
476    fn test_vectorization_helpers() {
477        let src = vec![1.0f32, 2.0, 3.0, 4.0];
478        let mut dest = vec![0.0f32; 4];
479
480        vectorization::vectorize_simple_op(&src, &mut dest, |x| x * 2.0);
481
482        assert_eq!(dest, vec![2.0, 4.0, 6.0, 8.0]);
483    }
484
485    #[test]
486    fn test_binary_vectorization() {
487        let a = vec![1.0f32, 2.0, 3.0, 4.0];
488        let b = vec![1.0f32, 1.0, 1.0, 1.0];
489        let mut dest = vec![0.0f32; 4];
490
491        vectorization::vectorize_binary_op(&a, &b, &mut dest, |x, y| x + y);
492
493        assert_eq!(dest, vec![2.0, 3.0, 4.0, 5.0]);
494    }
495
496    #[test]
497    fn test_performance_timing() {
498        let (result, duration) = perf::time_operation(|| (0..1000).sum::<i32>());
499
500        assert_eq!(result, 499500);
501        assert!(duration.as_nanos() > 0);
502    }
503
504    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
505    #[test]
506    fn test_sse2_horizontal_add() {
507        unsafe {
508            #[cfg(feature = "no-std")]
509            let v = core::arch::x86_64::_mm_setr_ps(1.0, 2.0, 3.0, 4.0);
510            #[cfg(not(feature = "no-std"))]
511            let v = core::arch::x86_64::_mm_setr_ps(1.0, 2.0, 3.0, 4.0);
512
513            let sum = x86::sse2::horizontal_add_f32(v);
514            assert!((sum - 10.0).abs() < 1e-6);
515        }
516    }
517
518    #[test]
519    fn test_optimization_unroll() {
520        let mut sum = 0;
521        optimization::suggest_unroll(10, |i| {
522            sum += i;
523        });
524        assert_eq!(sum, 45); // 0+1+2+...+9 = 45
525    }
526}