avx_arrow/
simd.rs

1//! SIMD-accelerated operations for arrays
2//!
3//! Uses AVX2 instructions when available for 4x performance improvement
4
5#[cfg(target_arch = "x86_64")]
6use std::arch::x86_64::*;
7
8/// SIMD-accelerated sum for f64 arrays
9#[cfg(target_arch = "x86_64")]
10#[target_feature(enable = "avx2")]
11pub unsafe fn sum_f64_simd(data: &[f64]) -> f64 {
12    let mut sum = 0.0;
13    let len = data.len();
14    let chunks = len / 4;
15
16    if chunks > 0 {
17        let mut accumulator = _mm256_setzero_pd();
18
19        for i in 0..chunks {
20            let offset = i * 4;
21            let values = _mm256_loadu_pd(data.as_ptr().add(offset));
22            accumulator = _mm256_add_pd(accumulator, values);
23        }
24
25        // Horizontal sum
26        let mut temp = [0.0; 4];
27        _mm256_storeu_pd(temp.as_mut_ptr(), accumulator);
28        sum = temp.iter().sum();
29    }
30
31    // Handle remaining elements
32    sum + data.iter().skip(chunks * 4).sum::<f64>()
33}
34
35/// SIMD-accelerated addition for f64 arrays
36#[cfg(target_arch = "x86_64")]
37#[target_feature(enable = "avx2")]
38pub unsafe fn add_f64_simd(left: &[f64], right: &[f64], result: &mut [f64]) {
39    let len = left.len().min(right.len()).min(result.len());
40    let chunks = len / 4;
41
42    for i in 0..chunks {
43        let offset = i * 4;
44        let a = _mm256_loadu_pd(left.as_ptr().add(offset));
45        let b = _mm256_loadu_pd(right.as_ptr().add(offset));
46        let sum = _mm256_add_pd(a, b);
47        _mm256_storeu_pd(result.as_mut_ptr().add(offset), sum);
48    }
49
50    // Handle remaining elements
51    for i in (chunks * 4)..len {
52        result[i] = left[i] + right[i];
53    }
54}
55
56/// SIMD-accelerated multiplication for f64 arrays
57#[cfg(target_arch = "x86_64")]
58#[target_feature(enable = "avx2")]
59pub unsafe fn mul_f64_simd(left: &[f64], right: &[f64], result: &mut [f64]) {
60    let len = left.len().min(right.len()).min(result.len());
61    let chunks = len / 4;
62
63    for i in 0..chunks {
64        let offset = i * 4;
65        let a = _mm256_loadu_pd(left.as_ptr().add(offset));
66        let b = _mm256_loadu_pd(right.as_ptr().add(offset));
67        let product = _mm256_mul_pd(a, b);
68        _mm256_storeu_pd(result.as_mut_ptr().add(offset), product);
69    }
70
71    // Handle remaining elements
72    for i in (chunks * 4)..len {
73        result[i] = left[i] * right[i];
74    }
75}
76
77/// SIMD-accelerated comparison for f64 arrays
78#[cfg(target_arch = "x86_64")]
79#[target_feature(enable = "avx2")]
80pub unsafe fn gt_f64_simd(data: &[f64], scalar: f64, result: &mut [bool]) {
81    let len = data.len().min(result.len());
82    let chunks = len / 4;
83    let scalar_vec = _mm256_set1_pd(scalar);
84
85    for i in 0..chunks {
86        let offset = i * 4;
87        let values = _mm256_loadu_pd(data.as_ptr().add(offset));
88        let cmp = _mm256_cmp_pd(values, scalar_vec, _CMP_GT_OQ);
89
90        let mut mask = [0u64; 4];
91        _mm256_storeu_pd(mask.as_mut_ptr() as *mut f64, cmp);
92
93        for j in 0..4 {
94            result[offset + j] = mask[j] != 0;
95        }
96    }
97
98    // Handle remaining elements
99    for i in (chunks * 4)..len {
100        result[i] = data[i] > scalar;
101    }
102}
103
104/// SIMD-accelerated subtraction for f64 arrays
105#[cfg(target_arch = "x86_64")]
106#[target_feature(enable = "avx2")]
107pub unsafe fn sub_f64_simd(left: &[f64], right: &[f64], result: &mut [f64]) {
108    let len = left.len().min(right.len()).min(result.len());
109    let chunks = len / 4;
110
111    for i in 0..chunks {
112        let offset = i * 4;
113        let a = _mm256_loadu_pd(left.as_ptr().add(offset));
114        let b = _mm256_loadu_pd(right.as_ptr().add(offset));
115        let diff = _mm256_sub_pd(a, b);
116        _mm256_storeu_pd(result.as_mut_ptr().add(offset), diff);
117    }
118
119    // Handle remaining elements
120    for i in (chunks * 4)..len {
121        result[i] = left[i] - right[i];
122    }
123}
124
125/// SIMD-accelerated division for f64 arrays
126#[cfg(target_arch = "x86_64")]
127#[target_feature(enable = "avx2")]
128pub unsafe fn div_f64_simd(left: &[f64], right: &[f64], result: &mut [f64]) {
129    let len = left.len().min(right.len()).min(result.len());
130    let chunks = len / 4;
131
132    for i in 0..chunks {
133        let offset = i * 4;
134        let a = _mm256_loadu_pd(left.as_ptr().add(offset));
135        let b = _mm256_loadu_pd(right.as_ptr().add(offset));
136        let quotient = _mm256_div_pd(a, b);
137        _mm256_storeu_pd(result.as_mut_ptr().add(offset), quotient);
138    }
139
140    // Handle remaining elements
141    for i in (chunks * 4)..len {
142        result[i] = left[i] / right[i];
143    }
144}
145
146/// SIMD-accelerated square root for f64 arrays
147#[cfg(target_arch = "x86_64")]
148#[target_feature(enable = "avx2")]
149pub unsafe fn sqrt_f64_simd(data: &[f64], result: &mut [f64]) {
150    let len = data.len().min(result.len());
151    let chunks = len / 4;
152
153    for i in 0..chunks {
154        let offset = i * 4;
155        let values = _mm256_loadu_pd(data.as_ptr().add(offset));
156        let roots = _mm256_sqrt_pd(values);
157        _mm256_storeu_pd(result.as_mut_ptr().add(offset), roots);
158    }
159
160    // Handle remaining elements
161    for i in (chunks * 4)..len {
162        result[i] = data[i].sqrt();
163    }
164}
165
166/// SIMD-accelerated FMA (fused multiply-add): result = a * b + c
167#[cfg(target_arch = "x86_64")]
168#[target_feature(enable = "avx2,fma")]
169pub unsafe fn fma_f64_simd(a: &[f64], b: &[f64], c: &[f64], result: &mut [f64]) {
170    let len = a.len().min(b.len()).min(c.len()).min(result.len());
171    let chunks = len / 4;
172
173    for i in 0..chunks {
174        let offset = i * 4;
175        let va = _mm256_loadu_pd(a.as_ptr().add(offset));
176        let vb = _mm256_loadu_pd(b.as_ptr().add(offset));
177        let vc = _mm256_loadu_pd(c.as_ptr().add(offset));
178        let vr = _mm256_fmadd_pd(va, vb, vc);
179        _mm256_storeu_pd(result.as_mut_ptr().add(offset), vr);
180    }
181
182    // Handle remaining elements
183    for i in (chunks * 4)..len {
184        result[i] = a[i] * b[i] + c[i];
185    }
186}
187
188/// Fallback sum for non-x86_64 or when AVX2 is not available
189#[cfg(not(target_arch = "x86_64"))]
190pub fn sum_f64_simd(data: &[f64]) -> f64 {
191    data.iter().sum()
192}
193
194/// Fallback add for non-x86_64
195#[cfg(not(target_arch = "x86_64"))]
196pub fn add_f64_simd(left: &[f64], right: &[f64], result: &mut [f64]) {
197    let len = left.len().min(right.len()).min(result.len());
198    for i in 0..len {
199        result[i] = left[i] + right[i];
200    }
201}
202
203/// Fallback mul for non-x86_64
204#[cfg(not(target_arch = "x86_64"))]
205pub fn mul_f64_simd(left: &[f64], right: &[f64], result: &mut [f64]) {
206    let len = left.len().min(right.len()).min(result.len());
207    for i in 0..len {
208        result[i] = left[i] * right[i];
209    }
210}
211
212/// Fallback gt for non-x86_64
213#[cfg(not(target_arch = "x86_64"))]
214pub fn gt_f64_simd(data: &[f64], scalar: f64, result: &mut [bool]) {
215    let len = data.len().min(result.len());
216    for i in 0..len {
217        result[i] = data[i] > scalar;
218    }
219}
220
221/// Fallback sub for non-x86_64
222#[cfg(not(target_arch = "x86_64"))]
223pub fn sub_f64_simd(left: &[f64], right: &[f64], result: &mut [f64]) {
224    let len = left.len().min(right.len()).min(result.len());
225    for i in 0..len {
226        result[i] = left[i] - right[i];
227    }
228}
229
230/// Fallback div for non-x86_64
231#[cfg(not(target_arch = "x86_64"))]
232pub fn div_f64_simd(left: &[f64], right: &[f64], result: &mut [f64]) {
233    let len = left.len().min(right.len()).min(result.len());
234    for i in 0..len {
235        result[i] = left[i] / right[i];
236    }
237}
238
239/// Fallback sqrt for non-x86_64
240#[cfg(not(target_arch = "x86_64"))]
241pub fn sqrt_f64_simd(data: &[f64], result: &mut [f64]) {
242    let len = data.len().min(result.len());
243    for i in 0..len {
244        result[i] = data[i].sqrt();
245    }
246}
247
248/// Fallback FMA for non-x86_64
249#[cfg(not(target_arch = "x86_64"))]
250pub fn fma_f64_simd(a: &[f64], b: &[f64], c: &[f64], result: &mut [f64]) {
251    let len = a.len().min(b.len()).min(c.len()).min(result.len());
252    for i in 0..len {
253        result[i] = a[i] * b[i] + c[i];
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    #[test]
262    fn test_sum_f64_simd() {
263        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
264        let sum = unsafe { sum_f64_simd(&data) };
265        assert_eq!(sum, 36.0);
266    }
267
268    #[test]
269    fn test_add_f64_simd() {
270        let left = vec![1.0, 2.0, 3.0, 4.0, 5.0];
271        let right = vec![10.0, 20.0, 30.0, 40.0, 50.0];
272        let mut result = vec![0.0; 5];
273
274        unsafe { add_f64_simd(&left, &right, &mut result) };
275
276        assert_eq!(result, vec![11.0, 22.0, 33.0, 44.0, 55.0]);
277    }
278
279    #[test]
280    fn test_mul_f64_simd() {
281        let left = vec![2.0, 3.0, 4.0, 5.0];
282        let right = vec![10.0, 10.0, 10.0, 10.0];
283        let mut result = vec![0.0; 4];
284
285        unsafe { mul_f64_simd(&left, &right, &mut result) };
286
287        assert_eq!(result, vec![20.0, 30.0, 40.0, 50.0]);
288    }
289
290    #[test]
291    fn test_gt_f64_simd() {
292        let data = vec![1.0, 5.0, 10.0, 15.0, 20.0];
293        let mut result = vec![false; 5];
294
295        unsafe { gt_f64_simd(&data, 8.0, &mut result) };
296
297        assert_eq!(result, vec![false, false, true, true, true]);
298    }
299
300    #[test]
301    fn test_sub_f64_simd() {
302        let left = vec![10.0, 20.0, 30.0, 40.0, 50.0];
303        let right = vec![1.0, 2.0, 3.0, 4.0, 5.0];
304        let mut result = vec![0.0; 5];
305
306        unsafe { sub_f64_simd(&left, &right, &mut result) };
307
308        assert_eq!(result, vec![9.0, 18.0, 27.0, 36.0, 45.0]);
309    }
310
311    #[test]
312    fn test_div_f64_simd() {
313        let left = vec![100.0, 200.0, 300.0, 400.0];
314        let right = vec![10.0, 10.0, 10.0, 10.0];
315        let mut result = vec![0.0; 4];
316
317        unsafe { div_f64_simd(&left, &right, &mut result) };
318
319        assert_eq!(result, vec![10.0, 20.0, 30.0, 40.0]);
320    }
321
322    #[test]
323    fn test_sqrt_f64_simd() {
324        let data = vec![4.0, 9.0, 16.0, 25.0, 36.0];
325        let mut result = vec![0.0; 5];
326
327        unsafe { sqrt_f64_simd(&data, &mut result) };
328
329        assert_eq!(result, vec![2.0, 3.0, 4.0, 5.0, 6.0]);
330    }
331
332    #[test]
333    fn test_fma_f64_simd() {
334        let a = vec![2.0, 3.0, 4.0, 5.0];
335        let b = vec![10.0, 10.0, 10.0, 10.0];
336        let c = vec![1.0, 2.0, 3.0, 4.0];
337        let mut result = vec![0.0; 4];
338
339        unsafe { fma_f64_simd(&a, &b, &c, &mut result) };
340
341        assert_eq!(result, vec![21.0, 32.0, 43.0, 54.0]);
342    }
343
344    #[test]
345    fn test_simd_non_aligned() {
346        // Test with non-multiple of 4 length
347        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
348        let sum = unsafe { sum_f64_simd(&data) };
349        assert_eq!(sum, 28.0);
350    }
351}
352
353
354
355
356