avila_math/tensor/
simd.rs

1//! # SIMD Optimizations Module
2//!
3//! Vectorized operations using AVX2/SSE instructions for high-performance tensor operations.
4//!
5//! ## Features
6//! - Vectorized dot products
7//! - SIMD matrix multiplication
8//! - Element-wise operations (add, mul, relu)
9//! - Reduction operations (sum, max, min)
10//!
11//! ## Platform Support
12//! - x86_64 with AVX2: Full SIMD acceleration
13//! - Fallback: Pure Rust implementation for other platforms
14
15#[cfg(target_arch = "x86_64")]
16use std::arch::x86_64::*;
17
18/// Dot product vetorizado usando AVX2
19///
20/// # Safety
21/// Usa instruções AVX2 quando disponível. Fallback para implementação escalar.
22///
23/// # Example
24/// ```
25/// use avila_math::tensor::simd::dot_product_simd;
26///
27/// let a = vec![1.0, 2.0, 3.0, 4.0];
28/// let b = vec![5.0, 6.0, 7.0, 8.0];
29/// let result = dot_product_simd(&a, &b);
30/// assert_eq!(result, 70.0); // 1*5 + 2*6 + 3*7 + 4*8
31/// ```
32pub fn dot_product_simd(a: &[f64], b: &[f64]) -> f64 {
33    assert_eq!(a.len(), b.len(), "Vectors must have same length");
34
35    #[cfg(target_arch = "x86_64")]
36    {
37        if is_x86_feature_detected!("avx2") {
38            unsafe { dot_product_avx2(a, b) }
39        } else {
40            dot_product_scalar(a, b)
41        }
42    }
43
44    #[cfg(not(target_arch = "x86_64"))]
45    {
46        dot_product_scalar(a, b)
47    }
48}
49
50/// Implementação AVX2 do dot product (4 f64 por vez)
51#[cfg(target_arch = "x86_64")]
52#[target_feature(enable = "avx2")]
53unsafe fn dot_product_avx2(a: &[f64], b: &[f64]) -> f64 {
54    let len = a.len();
55    let lanes = 4; // AVX2 processa 4 f64 por vez
56    let chunks = len / lanes;
57    let _remainder = len % lanes;
58
59    let mut sum_vec = _mm256_setzero_pd();
60
61    // Processa chunks de 4 elementos
62    for i in 0..chunks {
63        let idx = i * lanes;
64        let a_vec = _mm256_loadu_pd(a.as_ptr().add(idx));
65        let b_vec = _mm256_loadu_pd(b.as_ptr().add(idx));
66        let mul_vec = _mm256_mul_pd(a_vec, b_vec);
67        sum_vec = _mm256_add_pd(sum_vec, mul_vec);
68    }
69
70    // Reduz o vetor para escalar
71    let mut result = [0.0; 4];
72    _mm256_storeu_pd(result.as_mut_ptr(), sum_vec);
73    let mut sum = result.iter().sum::<f64>();
74
75    // Processa elementos restantes
76    let start = chunks * lanes;
77    for i in start..len {
78        sum += a[i] * b[i];
79    }
80
81    sum
82}
83
84/// Implementação escalar do dot product (fallback)
85fn dot_product_scalar(a: &[f64], b: &[f64]) -> f64 {
86    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
87}
88
89/// Multiplicação elemento-wise com SIMD
90///
91/// # Example
92/// ```
93/// use avila_math::tensor::simd::mul_elementwise_simd;
94///
95/// let a = vec![1.0, 2.0, 3.0, 4.0];
96/// let b = vec![2.0, 3.0, 4.0, 5.0];
97/// let result = mul_elementwise_simd(&a, &b);
98/// assert_eq!(result, vec![2.0, 6.0, 12.0, 20.0]);
99/// ```
100pub fn mul_elementwise_simd(a: &[f64], b: &[f64]) -> Vec<f64> {
101    assert_eq!(a.len(), b.len(), "Arrays must have same length");
102
103    #[cfg(target_arch = "x86_64")]
104    {
105        if is_x86_feature_detected!("avx2") {
106            unsafe { mul_elementwise_avx2(a, b) }
107        } else {
108            mul_elementwise_scalar(a, b)
109        }
110    }
111
112    #[cfg(not(target_arch = "x86_64"))]
113    {
114        mul_elementwise_scalar(a, b)
115    }
116}
117
118#[cfg(target_arch = "x86_64")]
119#[target_feature(enable = "avx2")]
120unsafe fn mul_elementwise_avx2(a: &[f64], b: &[f64]) -> Vec<f64> {
121    let len = a.len();
122    let lanes = 4;
123    let chunks = len / lanes;
124    let _remainder = len % lanes;
125
126    let mut result = vec![0.0; len];
127
128    for i in 0..chunks {
129        let idx = i * lanes;
130        let a_vec = _mm256_loadu_pd(a.as_ptr().add(idx));
131        let b_vec = _mm256_loadu_pd(b.as_ptr().add(idx));
132        let mul_vec = _mm256_mul_pd(a_vec, b_vec);
133        _mm256_storeu_pd(result.as_mut_ptr().add(idx), mul_vec);
134    }
135
136    let start = chunks * lanes;
137    for i in start..len {
138        result[i] = a[i] * b[i];
139    }
140
141    result
142}
143
144fn mul_elementwise_scalar(a: &[f64], b: &[f64]) -> Vec<f64> {
145    a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
146}
147
148/// Adição elemento-wise com SIMD
149pub fn add_elementwise_simd(a: &[f64], b: &[f64]) -> Vec<f64> {
150    assert_eq!(a.len(), b.len());
151
152    #[cfg(target_arch = "x86_64")]
153    {
154        if is_x86_feature_detected!("avx2") {
155            unsafe { add_elementwise_avx2(a, b) }
156        } else {
157            add_elementwise_scalar(a, b)
158        }
159    }
160
161    #[cfg(not(target_arch = "x86_64"))]
162    {
163        add_elementwise_scalar(a, b)
164    }
165}
166
167#[cfg(target_arch = "x86_64")]
168#[target_feature(enable = "avx2")]
169unsafe fn add_elementwise_avx2(a: &[f64], b: &[f64]) -> Vec<f64> {
170    let len = a.len();
171    let lanes = 4;
172    let chunks = len / lanes;
173
174    let mut result = vec![0.0; len];
175
176    for i in 0..chunks {
177        let idx = i * lanes;
178        let a_vec = _mm256_loadu_pd(a.as_ptr().add(idx));
179        let b_vec = _mm256_loadu_pd(b.as_ptr().add(idx));
180        let add_vec = _mm256_add_pd(a_vec, b_vec);
181        _mm256_storeu_pd(result.as_mut_ptr().add(idx), add_vec);
182    }
183
184    let start = chunks * lanes;
185    for i in start..len {
186        result[i] = a[i] + b[i];
187    }
188
189    result
190}
191
192fn add_elementwise_scalar(a: &[f64], b: &[f64]) -> Vec<f64> {
193    a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
194}
195
196/// ReLU com SIMD (max(0, x))
197///
198/// # Example
199/// ```
200/// use avila_math::tensor::simd::relu_simd;
201///
202/// let input = vec![-1.0, 2.0, -3.0, 4.0];
203/// let output = relu_simd(&input);
204/// assert_eq!(output, vec![0.0, 2.0, 0.0, 4.0]);
205/// ```
206pub fn relu_simd(input: &[f64]) -> Vec<f64> {
207    #[cfg(target_arch = "x86_64")]
208    {
209        if is_x86_feature_detected!("avx2") {
210            unsafe { relu_avx2(input) }
211        } else {
212            relu_scalar(input)
213        }
214    }
215
216    #[cfg(not(target_arch = "x86_64"))]
217    {
218        relu_scalar(input)
219    }
220}
221
222#[cfg(target_arch = "x86_64")]
223#[target_feature(enable = "avx2")]
224unsafe fn relu_avx2(input: &[f64]) -> Vec<f64> {
225    let len = input.len();
226    let lanes = 4;
227    let chunks = len / lanes;
228
229    let mut result = vec![0.0; len];
230    let zeros = _mm256_setzero_pd();
231
232    for i in 0..chunks {
233        let idx = i * lanes;
234        let x_vec = _mm256_loadu_pd(input.as_ptr().add(idx));
235        let relu_vec = _mm256_max_pd(x_vec, zeros);
236        _mm256_storeu_pd(result.as_mut_ptr().add(idx), relu_vec);
237    }
238
239    let start = chunks * lanes;
240    for i in start..len {
241        result[i] = input[i].max(0.0);
242    }
243
244    result
245}
246
247fn relu_scalar(input: &[f64]) -> Vec<f64> {
248    input.iter().map(|&x| x.max(0.0)).collect()
249}
250
251/// Soma de todos os elementos com SIMD
252pub fn sum_simd(input: &[f64]) -> f64 {
253    #[cfg(target_arch = "x86_64")]
254    {
255        if is_x86_feature_detected!("avx2") {
256            unsafe { sum_avx2(input) }
257        } else {
258            input.iter().sum()
259        }
260    }
261
262    #[cfg(not(target_arch = "x86_64"))]
263    {
264        input.iter().sum()
265    }
266}
267
268#[cfg(target_arch = "x86_64")]
269#[target_feature(enable = "avx2")]
270unsafe fn sum_avx2(input: &[f64]) -> f64 {
271    let len = input.len();
272    let lanes = 4;
273    let chunks = len / lanes;
274
275    let mut sum_vec = _mm256_setzero_pd();
276
277    for i in 0..chunks {
278        let idx = i * lanes;
279        let x_vec = _mm256_loadu_pd(input.as_ptr().add(idx));
280        sum_vec = _mm256_add_pd(sum_vec, x_vec);
281    }
282
283    let mut result = [0.0; 4];
284    _mm256_storeu_pd(result.as_mut_ptr(), sum_vec);
285    let mut sum = result.iter().sum::<f64>();
286
287    let start = chunks * lanes;
288    for val in input.iter().skip(start) {
289        sum += val;
290    }
291
292    sum
293}
294
295/// Multiplicação escalar com SIMD
296pub fn mul_scalar_simd(input: &[f64], scalar: f64) -> Vec<f64> {
297    #[cfg(target_arch = "x86_64")]
298    {
299        if is_x86_feature_detected!("avx2") {
300            unsafe { mul_scalar_avx2(input, scalar) }
301        } else {
302            mul_scalar_scalar(input, scalar)
303        }
304    }
305
306    #[cfg(not(target_arch = "x86_64"))]
307    {
308        mul_scalar_scalar(input, scalar)
309    }
310}
311
312#[cfg(target_arch = "x86_64")]
313#[target_feature(enable = "avx2")]
314unsafe fn mul_scalar_avx2(input: &[f64], scalar: f64) -> Vec<f64> {
315    let len = input.len();
316    let lanes = 4;
317    let chunks = len / lanes;
318
319    let mut result = vec![0.0; len];
320    let scalar_vec = _mm256_set1_pd(scalar);
321
322    for i in 0..chunks {
323        let idx = i * lanes;
324        let x_vec = _mm256_loadu_pd(input.as_ptr().add(idx));
325        let mul_vec = _mm256_mul_pd(x_vec, scalar_vec);
326        _mm256_storeu_pd(result.as_mut_ptr().add(idx), mul_vec);
327    }
328
329    let start = chunks * lanes;
330    for i in start..len {
331        result[i] = input[i] * scalar;
332    }
333
334    result
335}
336
337fn mul_scalar_scalar(input: &[f64], scalar: f64) -> Vec<f64> {
338    input.iter().map(|&x| x * scalar).collect()
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn test_dot_product_simd() {
347        let a = vec![1.0, 2.0, 3.0, 4.0];
348        let b = vec![5.0, 6.0, 7.0, 8.0];
349        let result = dot_product_simd(&a, &b);
350        assert_eq!(result, 70.0); // 1*5 + 2*6 + 3*7 + 4*8 = 70
351    }
352
353    #[test]
354    fn test_dot_product_large() {
355        let a: Vec<f64> = (0..100).map(|x| x as f64).collect();
356        let b: Vec<f64> = (0..100).map(|x| x as f64 * 2.0).collect();
357
358        let result_simd = dot_product_simd(&a, &b);
359        let result_scalar = dot_product_scalar(&a, &b);
360
361        assert!((result_simd - result_scalar).abs() < 1e-10);
362    }
363
364    #[test]
365    fn test_mul_elementwise() {
366        let a = vec![1.0, 2.0, 3.0, 4.0];
367        let b = vec![2.0, 3.0, 4.0, 5.0];
368        let result = mul_elementwise_simd(&a, &b);
369        assert_eq!(result, vec![2.0, 6.0, 12.0, 20.0]);
370    }
371
372    #[test]
373    fn test_add_elementwise() {
374        let a = vec![1.0, 2.0, 3.0, 4.0];
375        let b = vec![5.0, 6.0, 7.0, 8.0];
376        let result = add_elementwise_simd(&a, &b);
377        assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
378    }
379
380    #[test]
381    fn test_relu_simd() {
382        let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
383        let result = relu_simd(&input);
384        assert_eq!(result, vec![0.0, 0.0, 0.0, 1.0, 2.0]);
385    }
386
387    #[test]
388    fn test_sum_simd() {
389        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
390        let result = sum_simd(&input);
391        assert_eq!(result, 15.0);
392    }
393
394    #[test]
395    fn test_mul_scalar() {
396        let input = vec![1.0, 2.0, 3.0, 4.0];
397        let result = mul_scalar_simd(&input, 2.5);
398        assert_eq!(result, vec![2.5, 5.0, 7.5, 10.0]);
399    }
400
401    #[test]
402    fn test_simd_unaligned_length() {
403        // Testa com tamanho que não é múltiplo de 4
404        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
405        let b = vec![6.0, 7.0, 8.0, 9.0, 10.0];
406
407        let dot = dot_product_simd(&a, &b);
408        assert_eq!(dot, 130.0); // 1*6 + 2*7 + 3*8 + 4*9 + 5*10
409    }
410}