ghostflow_core/
simd_ops.rs

1//! Advanced SIMD optimizations for tensor operations
2//!
3//! This module provides highly optimized SIMD implementations for common operations.
4
5#[cfg(target_arch = "x86_64")]
6use std::arch::x86_64::*;
7
8/// SIMD-optimized vector addition
9#[inline]
10pub fn simd_add_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
11    assert_eq!(a.len(), b.len());
12    assert_eq!(a.len(), out.len());
13    
14    #[cfg(target_arch = "x86_64")]
15    {
16        if is_x86_feature_detected!("avx2") {
17            unsafe { simd_add_f32_avx2(a, b, out) }
18        } else if is_x86_feature_detected!("sse4.1") {
19            unsafe { simd_add_f32_sse(a, b, out) }
20        } else {
21            scalar_add_f32(a, b, out)
22        }
23    }
24    
25    #[cfg(not(target_arch = "x86_64"))]
26    {
27        scalar_add_f32(a, b, out)
28    }
29}
30
31/// AVX2 implementation of vector addition
32#[cfg(target_arch = "x86_64")]
33#[target_feature(enable = "avx2")]
34unsafe fn simd_add_f32_avx2(a: &[f32], b: &[f32], out: &mut [f32]) {
35    let len = a.len();
36    let mut i = 0;
37    
38    // Process 8 elements at a time with AVX2
39    while i + 8 <= len {
40        let va = _mm256_loadu_ps(a.as_ptr().add(i));
41        let vb = _mm256_loadu_ps(b.as_ptr().add(i));
42        let vout = _mm256_add_ps(va, vb);
43        _mm256_storeu_ps(out.as_mut_ptr().add(i), vout);
44        i += 8;
45    }
46    
47    // Handle remaining elements
48    while i < len {
49        out[i] = a[i] + b[i];
50        i += 1;
51    }
52}
53
54/// SSE implementation of vector addition
55#[cfg(target_arch = "x86_64")]
56#[target_feature(enable = "sse4.1")]
57unsafe fn simd_add_f32_sse(a: &[f32], b: &[f32], out: &mut [f32]) {
58    let len = a.len();
59    let mut i = 0;
60    
61    // Process 4 elements at a time with SSE
62    while i + 4 <= len {
63        let va = _mm_loadu_ps(a.as_ptr().add(i));
64        let vb = _mm_loadu_ps(b.as_ptr().add(i));
65        let vout = _mm_add_ps(va, vb);
66        _mm_storeu_ps(out.as_mut_ptr().add(i), vout);
67        i += 4;
68    }
69    
70    // Handle remaining elements
71    while i < len {
72        out[i] = a[i] + b[i];
73        i += 1;
74    }
75}
76
77/// Scalar fallback for vector addition
78#[inline]
79fn scalar_add_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
80    for i in 0..a.len() {
81        out[i] = a[i] + b[i];
82    }
83}
84
85/// SIMD-optimized vector multiplication
86#[inline]
87pub fn simd_mul_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
88    assert_eq!(a.len(), b.len());
89    assert_eq!(a.len(), out.len());
90    
91    #[cfg(target_arch = "x86_64")]
92    {
93        if is_x86_feature_detected!("avx2") {
94            unsafe { simd_mul_f32_avx2(a, b, out) }
95        } else if is_x86_feature_detected!("sse4.1") {
96            unsafe { simd_mul_f32_sse(a, b, out) }
97        } else {
98            scalar_mul_f32(a, b, out)
99        }
100    }
101    
102    #[cfg(not(target_arch = "x86_64"))]
103    {
104        scalar_mul_f32(a, b, out)
105    }
106}
107
108/// AVX2 implementation of vector multiplication
109#[cfg(target_arch = "x86_64")]
110#[target_feature(enable = "avx2")]
111unsafe fn simd_mul_f32_avx2(a: &[f32], b: &[f32], out: &mut [f32]) {
112    let len = a.len();
113    let mut i = 0;
114    
115    while i + 8 <= len {
116        let va = _mm256_loadu_ps(a.as_ptr().add(i));
117        let vb = _mm256_loadu_ps(b.as_ptr().add(i));
118        let vout = _mm256_mul_ps(va, vb);
119        _mm256_storeu_ps(out.as_mut_ptr().add(i), vout);
120        i += 8;
121    }
122    
123    while i < len {
124        out[i] = a[i] * b[i];
125        i += 1;
126    }
127}
128
129/// SSE implementation of vector multiplication
130#[cfg(target_arch = "x86_64")]
131#[target_feature(enable = "sse4.1")]
132unsafe fn simd_mul_f32_sse(a: &[f32], b: &[f32], out: &mut [f32]) {
133    let len = a.len();
134    let mut i = 0;
135    
136    while i + 4 <= len {
137        let va = _mm_loadu_ps(a.as_ptr().add(i));
138        let vb = _mm_loadu_ps(b.as_ptr().add(i));
139        let vout = _mm_mul_ps(va, vb);
140        _mm_storeu_ps(out.as_mut_ptr().add(i), vout);
141        i += 4;
142    }
143    
144    while i < len {
145        out[i] = a[i] * b[i];
146        i += 1;
147    }
148}
149
150/// Scalar fallback for vector multiplication
151#[inline]
152fn scalar_mul_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
153    for i in 0..a.len() {
154        out[i] = a[i] * b[i];
155    }
156}
157
158/// SIMD-optimized dot product
159#[inline]
160pub fn simd_dot_f32(a: &[f32], b: &[f32]) -> f32 {
161    assert_eq!(a.len(), b.len());
162    
163    #[cfg(target_arch = "x86_64")]
164    {
165        if is_x86_feature_detected!("avx2") {
166            unsafe { simd_dot_f32_avx2(a, b) }
167        } else if is_x86_feature_detected!("sse4.1") {
168            unsafe { simd_dot_f32_sse(a, b) }
169        } else {
170            scalar_dot_f32(a, b)
171        }
172    }
173    
174    #[cfg(not(target_arch = "x86_64"))]
175    {
176        scalar_dot_f32(a, b)
177    }
178}
179
180/// AVX2 implementation of dot product
181#[cfg(target_arch = "x86_64")]
182#[target_feature(enable = "avx2")]
183unsafe fn simd_dot_f32_avx2(a: &[f32], b: &[f32]) -> f32 {
184    let len = a.len();
185    let mut i = 0;
186    let mut sum = _mm256_setzero_ps();
187    
188    while i + 8 <= len {
189        let va = _mm256_loadu_ps(a.as_ptr().add(i));
190        let vb = _mm256_loadu_ps(b.as_ptr().add(i));
191        let vprod = _mm256_mul_ps(va, vb);
192        sum = _mm256_add_ps(sum, vprod);
193        i += 8;
194    }
195    
196    // Horizontal sum
197    let mut result = 0.0f32;
198    let sum_array: [f32; 8] = std::mem::transmute(sum);
199    for &val in &sum_array {
200        result += val;
201    }
202    
203    // Handle remaining elements
204    while i < len {
205        result += a[i] * b[i];
206        i += 1;
207    }
208    
209    result
210}
211
212/// SSE implementation of dot product
213#[cfg(target_arch = "x86_64")]
214#[target_feature(enable = "sse4.1")]
215unsafe fn simd_dot_f32_sse(a: &[f32], b: &[f32]) -> f32 {
216    let len = a.len();
217    let mut i = 0;
218    let mut sum = _mm_setzero_ps();
219    
220    while i + 4 <= len {
221        let va = _mm_loadu_ps(a.as_ptr().add(i));
222        let vb = _mm_loadu_ps(b.as_ptr().add(i));
223        let vprod = _mm_mul_ps(va, vb);
224        sum = _mm_add_ps(sum, vprod);
225        i += 4;
226    }
227    
228    // Horizontal sum
229    let mut result = 0.0f32;
230    let sum_array: [f32; 4] = std::mem::transmute(sum);
231    for &val in &sum_array {
232        result += val;
233    }
234    
235    // Handle remaining elements
236    while i < len {
237        result += a[i] * b[i];
238        i += 1;
239    }
240    
241    result
242}
243
244/// Scalar fallback for dot product
245#[inline]
246fn scalar_dot_f32(a: &[f32], b: &[f32]) -> f32 {
247    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
248}
249
250/// SIMD-optimized ReLU activation
251#[inline]
252pub fn simd_relu_f32(input: &[f32], output: &mut [f32]) {
253    assert_eq!(input.len(), output.len());
254    
255    #[cfg(target_arch = "x86_64")]
256    {
257        if is_x86_feature_detected!("avx2") {
258            unsafe { simd_relu_f32_avx2(input, output) }
259        } else if is_x86_feature_detected!("sse4.1") {
260            unsafe { simd_relu_f32_sse(input, output) }
261        } else {
262            scalar_relu_f32(input, output)
263        }
264    }
265    
266    #[cfg(not(target_arch = "x86_64"))]
267    {
268        scalar_relu_f32(input, output)
269    }
270}
271
272/// AVX2 implementation of ReLU
273#[cfg(target_arch = "x86_64")]
274#[target_feature(enable = "avx2")]
275unsafe fn simd_relu_f32_avx2(input: &[f32], output: &mut [f32]) {
276    let len = input.len();
277    let mut i = 0;
278    let zero = _mm256_setzero_ps();
279    
280    while i + 8 <= len {
281        let v = _mm256_loadu_ps(input.as_ptr().add(i));
282        let vout = _mm256_max_ps(v, zero);
283        _mm256_storeu_ps(output.as_mut_ptr().add(i), vout);
284        i += 8;
285    }
286    
287    while i < len {
288        output[i] = input[i].max(0.0);
289        i += 1;
290    }
291}
292
293/// SSE implementation of ReLU
294#[cfg(target_arch = "x86_64")]
295#[target_feature(enable = "sse4.1")]
296unsafe fn simd_relu_f32_sse(input: &[f32], output: &mut [f32]) {
297    let len = input.len();
298    let mut i = 0;
299    let zero = _mm_setzero_ps();
300    
301    while i + 4 <= len {
302        let v = _mm_loadu_ps(input.as_ptr().add(i));
303        let vout = _mm_max_ps(v, zero);
304        _mm_storeu_ps(output.as_mut_ptr().add(i), vout);
305        i += 4;
306    }
307    
308    while i < len {
309        output[i] = input[i].max(0.0);
310        i += 1;
311    }
312}
313
314/// Scalar fallback for ReLU
315#[inline]
316fn scalar_relu_f32(input: &[f32], output: &mut [f32]) {
317    for i in 0..input.len() {
318        output[i] = input[i].max(0.0);
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn test_simd_add() {
328        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
329        let b = vec![8.0f32, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
330        let mut out = vec![0.0f32; 8];
331        
332        simd_add_f32(&a, &b, &mut out);
333        
334        for i in 0..8 {
335            assert_eq!(out[i], 9.0);
336        }
337    }
338
339    #[test]
340    fn test_simd_mul() {
341        let a = vec![1.0f32, 2.0, 3.0, 4.0];
342        let b = vec![2.0f32, 3.0, 4.0, 5.0];
343        let mut out = vec![0.0f32; 4];
344        
345        simd_mul_f32(&a, &b, &mut out);
346        
347        assert_eq!(out, vec![2.0, 6.0, 12.0, 20.0]);
348    }
349
350    #[test]
351    fn test_simd_dot() {
352        let a = vec![1.0f32, 2.0, 3.0, 4.0];
353        let b = vec![5.0f32, 6.0, 7.0, 8.0];
354        
355        let result = simd_dot_f32(&a, &b);
356        
357        assert_eq!(result, 70.0); // 1*5 + 2*6 + 3*7 + 4*8
358    }
359
360    #[test]
361    fn test_simd_relu() {
362        let input = vec![-2.0f32, -1.0, 0.0, 1.0, 2.0];
363        let mut output = vec![0.0f32; 5];
364        
365        simd_relu_f32(&input, &mut output);
366        
367        assert_eq!(output, vec![0.0, 0.0, 0.0, 1.0, 2.0]);
368    }
369}