Skip to main content

ronn_core/
simd.rs

1//! SIMD (Single Instruction Multiple Data) optimizations
2//!
3//! Provides vectorized operations for 2-8x performance improvements on supported CPUs.
4//! Automatically detects CPU features and falls back to scalar implementations when needed.
5
6use std::arch::x86_64::*;
7
8/// CPU feature detection for SIMD support
9#[derive(Debug, Clone, Copy)]
10pub struct SimdFeatures {
11    /// SSE2 support (x86_64 always has this)
12    pub sse2: bool,
13    /// AVX support (256-bit vectors)
14    pub avx: bool,
15    /// AVX2 support (256-bit integer operations)
16    pub avx2: bool,
17    /// AVX-512 support (512-bit vectors)
18    pub avx512f: bool,
19    /// FMA (Fused Multiply-Add) support
20    pub fma: bool,
21}
22
23impl SimdFeatures {
24    /// Detect available SIMD features at runtime
25    #[cfg(target_arch = "x86_64")]
26    pub fn detect() -> Self {
27        Self {
28            sse2: is_x86_feature_detected!("sse2"),
29            avx: is_x86_feature_detected!("avx"),
30            avx2: is_x86_feature_detected!("avx2"),
31            avx512f: is_x86_feature_detected!("avx512f"),
32            fma: is_x86_feature_detected!("fma"),
33        }
34    }
35
36    #[cfg(not(target_arch = "x86_64"))]
37    pub fn detect() -> Self {
38        Self {
39            sse2: false,
40            avx: false,
41            avx2: false,
42            avx512f: false,
43            fma: false,
44        }
45    }
46
47    /// Get the best available SIMD level
48    pub fn best_simd(&self) -> SimdLevel {
49        if self.avx512f {
50            SimdLevel::Avx512
51        } else if self.avx2 {
52            SimdLevel::Avx2
53        } else if self.avx {
54            SimdLevel::Avx
55        } else if self.sse2 {
56            SimdLevel::Sse2
57        } else {
58            SimdLevel::Scalar
59        }
60    }
61}
62
63/// SIMD instruction set level
64#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
65pub enum SimdLevel {
66    /// Scalar operations (no SIMD)
67    Scalar = 0,
68    /// SSE2 (128-bit)
69    Sse2 = 1,
70    /// AVX (256-bit)
71    Avx = 2,
72    /// AVX2 (256-bit with integer ops)
73    Avx2 = 3,
74    /// AVX-512 (512-bit)
75    Avx512 = 4,
76}
77
78impl SimdLevel {
79    /// Get the vector width in bytes for this SIMD level
80    pub fn vector_width(&self) -> usize {
81        match self {
82            SimdLevel::Scalar => 1,
83            SimdLevel::Sse2 => 16,
84            SimdLevel::Avx | SimdLevel::Avx2 => 32,
85            SimdLevel::Avx512 => 64,
86        }
87    }
88
89    /// Get number of f32 elements per vector
90    pub fn f32_lanes(&self) -> usize {
91        self.vector_width() / 4
92    }
93}
94
95/// Vectorized dot product (f32)
96///
97/// # Performance
98/// - Scalar: 1x baseline
99/// - AVX2: 4-8x faster
100/// - AVX-512: 8-16x faster
101///
102/// # Safety
103/// Requires aligned input arrays
104#[inline]
105pub fn dot_product_f32(a: &[f32], b: &[f32]) -> f32 {
106    assert_eq!(a.len(), b.len(), "Arrays must have equal length");
107
108    let features = SimdFeatures::detect();
109
110    #[cfg(target_arch = "x86_64")]
111    {
112        if features.avx2 && features.fma {
113            unsafe { dot_product_f32_avx2_fma(a, b) }
114        } else if features.avx {
115            unsafe { dot_product_f32_avx(a, b) }
116        } else {
117            dot_product_f32_scalar(a, b)
118        }
119    }
120
121    #[cfg(not(target_arch = "x86_64"))]
122    {
123        dot_product_f32_scalar(a, b)
124    }
125}
126
127/// Scalar fallback for dot product
128#[inline]
129fn dot_product_f32_scalar(a: &[f32], b: &[f32]) -> f32 {
130    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
131}
132
133/// AVX implementation of dot product
134#[cfg(target_arch = "x86_64")]
135#[target_feature(enable = "avx")]
136#[inline]
137unsafe fn dot_product_f32_avx(a: &[f32], b: &[f32]) -> f32 {
138    unsafe {
139        let len = a.len();
140        let mut sum = _mm256_setzero_ps();
141
142        // Process 8 elements at a time
143        let chunks = len / 8;
144        for i in 0..chunks {
145            let idx = i * 8;
146            let va = _mm256_loadu_ps(a.as_ptr().add(idx));
147            let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
148            sum = _mm256_add_ps(sum, _mm256_mul_ps(va, vb));
149        }
150
151        // Horizontal sum
152        let mut result = horizontal_sum_avx(sum);
153
154        // Handle remaining elements
155        for i in (chunks * 8)..len {
156            result += a[i] * b[i];
157        }
158
159        result
160    }
161}
162
163/// AVX2 + FMA implementation of dot product (fastest on modern CPUs)
164#[cfg(target_arch = "x86_64")]
165#[target_feature(enable = "avx2,fma")]
166#[inline]
167unsafe fn dot_product_f32_avx2_fma(a: &[f32], b: &[f32]) -> f32 {
168    unsafe {
169        let len = a.len();
170        let mut sum = _mm256_setzero_ps();
171
172        // Process 8 elements at a time with FMA
173        let chunks = len / 8;
174        for i in 0..chunks {
175            let idx = i * 8;
176            let va = _mm256_loadu_ps(a.as_ptr().add(idx));
177            let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
178            // Fused multiply-add: sum = sum + (va * vb)
179            sum = _mm256_fmadd_ps(va, vb, sum);
180        }
181
182        // Horizontal sum
183        let mut result = horizontal_sum_avx(sum);
184
185        // Handle remaining elements
186        for i in (chunks * 8)..len {
187            result += a[i] * b[i];
188        }
189
190        result
191    }
192}
193
194/// Horizontal sum of AVX vector
195#[cfg(target_arch = "x86_64")]
196#[target_feature(enable = "avx")]
197#[inline]
198unsafe fn horizontal_sum_avx(v: __m256) -> f32 {
199    unsafe {
200        // Split into high and low 128-bit lanes
201        let hi = _mm256_extractf128_ps(v, 1);
202        let lo = _mm256_castps256_ps128(v);
203        let sum128 = _mm_add_ps(hi, lo);
204
205        // Horizontal add within 128-bit
206        let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
207        let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
208
209        _mm_cvtss_f32(sum32)
210    }
211}
212
213/// Vectorized element-wise addition
214///
215/// # Performance
216/// - AVX2: 4-8x faster than scalar
217#[inline]
218pub fn add_f32(a: &[f32], b: &[f32], result: &mut [f32]) {
219    assert_eq!(a.len(), b.len());
220    assert_eq!(a.len(), result.len());
221
222    let features = SimdFeatures::detect();
223
224    #[cfg(target_arch = "x86_64")]
225    {
226        if features.avx2 {
227            unsafe { add_f32_avx2(a, b, result) }
228        } else {
229            add_f32_scalar(a, b, result)
230        }
231    }
232
233    #[cfg(not(target_arch = "x86_64"))]
234    {
235        add_f32_scalar(a, b, result)
236    }
237}
238
239/// Scalar fallback for addition
240#[inline]
241fn add_f32_scalar(a: &[f32], b: &[f32], result: &mut [f32]) {
242    for i in 0..a.len() {
243        result[i] = a[i] + b[i];
244    }
245}
246
247/// AVX2 implementation of element-wise addition
248#[cfg(target_arch = "x86_64")]
249#[target_feature(enable = "avx2")]
250#[inline]
251unsafe fn add_f32_avx2(a: &[f32], b: &[f32], result: &mut [f32]) {
252    unsafe {
253        let len = a.len();
254        let chunks = len / 8;
255
256        // Process 8 elements at a time
257        for i in 0..chunks {
258            let idx = i * 8;
259            let va = _mm256_loadu_ps(a.as_ptr().add(idx));
260            let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
261            let sum = _mm256_add_ps(va, vb);
262            _mm256_storeu_ps(result.as_mut_ptr().add(idx), sum);
263        }
264
265        // Handle remaining elements
266        for i in (chunks * 8)..len {
267            result[i] = a[i] + b[i];
268        }
269    }
270}
271
272/// Vectorized ReLU activation
273///
274/// # Performance
275/// - AVX2: 4-8x faster than scalar
276#[inline]
277pub fn relu_f32(input: &[f32], output: &mut [f32]) {
278    assert_eq!(input.len(), output.len());
279
280    let features = SimdFeatures::detect();
281
282    #[cfg(target_arch = "x86_64")]
283    {
284        if features.avx2 {
285            unsafe { relu_f32_avx2(input, output) }
286        } else {
287            relu_f32_scalar(input, output)
288        }
289    }
290
291    #[cfg(not(target_arch = "x86_64"))]
292    {
293        relu_f32_scalar(input, output)
294    }
295}
296
297/// Scalar ReLU
298#[inline]
299fn relu_f32_scalar(input: &[f32], output: &mut [f32]) {
300    for i in 0..input.len() {
301        output[i] = input[i].max(0.0);
302    }
303}
304
305/// AVX2 ReLU implementation
306#[cfg(target_arch = "x86_64")]
307#[target_feature(enable = "avx2")]
308#[inline]
309unsafe fn relu_f32_avx2(input: &[f32], output: &mut [f32]) {
310    unsafe {
311        let len = input.len();
312        let chunks = len / 8;
313        let zero = _mm256_setzero_ps();
314
315        // Process 8 elements at a time
316        for i in 0..chunks {
317            let idx = i * 8;
318            let v = _mm256_loadu_ps(input.as_ptr().add(idx));
319            let relu = _mm256_max_ps(v, zero);
320            _mm256_storeu_ps(output.as_mut_ptr().add(idx), relu);
321        }
322
323        // Handle remaining elements
324        for i in (chunks * 8)..len {
325            output[i] = input[i].max(0.0);
326        }
327    }
328}
329
330/// Get global SIMD features (cached detection)
331pub fn simd_features() -> SimdFeatures {
332    static FEATURES: std::sync::OnceLock<SimdFeatures> = std::sync::OnceLock::new();
333    *FEATURES.get_or_init(SimdFeatures::detect)
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn test_simd_detection() {
342        let features = SimdFeatures::detect();
343        let level = features.best_simd();
344        println!("Detected SIMD level: {:?}", level);
345        println!("Features: {:?}", features);
346
347        #[cfg(target_arch = "x86_64")]
348        {
349            assert!(features.sse2, "x86_64 always has SSE2");
350        }
351    }
352
353    #[test]
354    fn test_dot_product() {
355        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
356        let b = vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
357
358        let result = dot_product_f32(&a, &b);
359        let expected: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
360
361        assert!((result - expected).abs() < 1e-5);
362    }
363
364    #[test]
365    fn test_add_vectorized() {
366        let a = vec![1.0; 100];
367        let b = vec![2.0; 100];
368        let mut result = vec![0.0; 100];
369
370        add_f32(&a, &b, &mut result);
371
372        for &r in &result {
373            assert!((r - 3.0).abs() < 1e-5);
374        }
375    }
376
377    #[test]
378    fn test_relu() {
379        let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
380        let mut output = vec![0.0; 5];
381
382        relu_f32(&input, &mut output);
383
384        let expected = vec![0.0, 0.0, 0.0, 1.0, 2.0];
385        for (o, e) in output.iter().zip(&expected) {
386            assert!((o - e).abs() < 1e-5);
387        }
388    }
389
390    #[test]
391    fn test_large_dot_product() {
392        let size = 10_000;
393        let a: Vec<f32> = (0..size).map(|i| i as f32).collect();
394        let b: Vec<f32> = (0..size).map(|i| (size - i) as f32).collect();
395
396        let result = dot_product_f32(&a, &b);
397        let expected: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
398
399        // Allow for small floating point differences
400        let relative_error = ((result - expected) / expected).abs();
401        assert!(relative_error < 1e-4);
402    }
403
404    #[test]
405    fn test_simd_level_comparison() {
406        assert!(SimdLevel::Avx512 > SimdLevel::Avx2);
407        assert!(SimdLevel::Avx2 > SimdLevel::Avx);
408        assert!(SimdLevel::Avx > SimdLevel::Sse2);
409        assert!(SimdLevel::Sse2 > SimdLevel::Scalar);
410    }
411
412    #[test]
413    fn test_vector_widths() {
414        assert_eq!(SimdLevel::Scalar.vector_width(), 1);
415        assert_eq!(SimdLevel::Sse2.vector_width(), 16);
416        assert_eq!(SimdLevel::Avx.vector_width(), 32);
417        assert_eq!(SimdLevel::Avx2.vector_width(), 32);
418        assert_eq!(SimdLevel::Avx512.vector_width(), 64);
419
420        assert_eq!(SimdLevel::Avx2.f32_lanes(), 8);
421        assert_eq!(SimdLevel::Avx512.f32_lanes(), 16);
422    }
423}