1#[cfg(target_arch = "x86_64")]
16use std::arch::x86_64::*;
17
18pub fn dot_product_simd(a: &[f64], b: &[f64]) -> f64 {
33 assert_eq!(a.len(), b.len(), "Vectors must have same length");
34
35 #[cfg(target_arch = "x86_64")]
36 {
37 if is_x86_feature_detected!("avx2") {
38 unsafe { dot_product_avx2(a, b) }
39 } else {
40 dot_product_scalar(a, b)
41 }
42 }
43
44 #[cfg(not(target_arch = "x86_64"))]
45 {
46 dot_product_scalar(a, b)
47 }
48}
49
50#[cfg(target_arch = "x86_64")]
52#[target_feature(enable = "avx2")]
53unsafe fn dot_product_avx2(a: &[f64], b: &[f64]) -> f64 {
54 let len = a.len();
55 let lanes = 4; let chunks = len / lanes;
57 let _remainder = len % lanes;
58
59 let mut sum_vec = _mm256_setzero_pd();
60
61 for i in 0..chunks {
63 let idx = i * lanes;
64 let a_vec = _mm256_loadu_pd(a.as_ptr().add(idx));
65 let b_vec = _mm256_loadu_pd(b.as_ptr().add(idx));
66 let mul_vec = _mm256_mul_pd(a_vec, b_vec);
67 sum_vec = _mm256_add_pd(sum_vec, mul_vec);
68 }
69
70 let mut result = [0.0; 4];
72 _mm256_storeu_pd(result.as_mut_ptr(), sum_vec);
73 let mut sum = result.iter().sum::<f64>();
74
75 let start = chunks * lanes;
77 for i in start..len {
78 sum += a[i] * b[i];
79 }
80
81 sum
82}
83
84fn dot_product_scalar(a: &[f64], b: &[f64]) -> f64 {
86 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
87}
88
89pub fn mul_elementwise_simd(a: &[f64], b: &[f64]) -> Vec<f64> {
101 assert_eq!(a.len(), b.len(), "Arrays must have same length");
102
103 #[cfg(target_arch = "x86_64")]
104 {
105 if is_x86_feature_detected!("avx2") {
106 unsafe { mul_elementwise_avx2(a, b) }
107 } else {
108 mul_elementwise_scalar(a, b)
109 }
110 }
111
112 #[cfg(not(target_arch = "x86_64"))]
113 {
114 mul_elementwise_scalar(a, b)
115 }
116}
117
118#[cfg(target_arch = "x86_64")]
119#[target_feature(enable = "avx2")]
120unsafe fn mul_elementwise_avx2(a: &[f64], b: &[f64]) -> Vec<f64> {
121 let len = a.len();
122 let lanes = 4;
123 let chunks = len / lanes;
124 let _remainder = len % lanes;
125
126 let mut result = vec![0.0; len];
127
128 for i in 0..chunks {
129 let idx = i * lanes;
130 let a_vec = _mm256_loadu_pd(a.as_ptr().add(idx));
131 let b_vec = _mm256_loadu_pd(b.as_ptr().add(idx));
132 let mul_vec = _mm256_mul_pd(a_vec, b_vec);
133 _mm256_storeu_pd(result.as_mut_ptr().add(idx), mul_vec);
134 }
135
136 let start = chunks * lanes;
137 for i in start..len {
138 result[i] = a[i] * b[i];
139 }
140
141 result
142}
143
144fn mul_elementwise_scalar(a: &[f64], b: &[f64]) -> Vec<f64> {
145 a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
146}
147
148pub fn add_elementwise_simd(a: &[f64], b: &[f64]) -> Vec<f64> {
150 assert_eq!(a.len(), b.len());
151
152 #[cfg(target_arch = "x86_64")]
153 {
154 if is_x86_feature_detected!("avx2") {
155 unsafe { add_elementwise_avx2(a, b) }
156 } else {
157 add_elementwise_scalar(a, b)
158 }
159 }
160
161 #[cfg(not(target_arch = "x86_64"))]
162 {
163 add_elementwise_scalar(a, b)
164 }
165}
166
167#[cfg(target_arch = "x86_64")]
168#[target_feature(enable = "avx2")]
169unsafe fn add_elementwise_avx2(a: &[f64], b: &[f64]) -> Vec<f64> {
170 let len = a.len();
171 let lanes = 4;
172 let chunks = len / lanes;
173
174 let mut result = vec![0.0; len];
175
176 for i in 0..chunks {
177 let idx = i * lanes;
178 let a_vec = _mm256_loadu_pd(a.as_ptr().add(idx));
179 let b_vec = _mm256_loadu_pd(b.as_ptr().add(idx));
180 let add_vec = _mm256_add_pd(a_vec, b_vec);
181 _mm256_storeu_pd(result.as_mut_ptr().add(idx), add_vec);
182 }
183
184 let start = chunks * lanes;
185 for i in start..len {
186 result[i] = a[i] + b[i];
187 }
188
189 result
190}
191
192fn add_elementwise_scalar(a: &[f64], b: &[f64]) -> Vec<f64> {
193 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
194}
195
196pub fn relu_simd(input: &[f64]) -> Vec<f64> {
207 #[cfg(target_arch = "x86_64")]
208 {
209 if is_x86_feature_detected!("avx2") {
210 unsafe { relu_avx2(input) }
211 } else {
212 relu_scalar(input)
213 }
214 }
215
216 #[cfg(not(target_arch = "x86_64"))]
217 {
218 relu_scalar(input)
219 }
220}
221
222#[cfg(target_arch = "x86_64")]
223#[target_feature(enable = "avx2")]
224unsafe fn relu_avx2(input: &[f64]) -> Vec<f64> {
225 let len = input.len();
226 let lanes = 4;
227 let chunks = len / lanes;
228
229 let mut result = vec![0.0; len];
230 let zeros = _mm256_setzero_pd();
231
232 for i in 0..chunks {
233 let idx = i * lanes;
234 let x_vec = _mm256_loadu_pd(input.as_ptr().add(idx));
235 let relu_vec = _mm256_max_pd(x_vec, zeros);
236 _mm256_storeu_pd(result.as_mut_ptr().add(idx), relu_vec);
237 }
238
239 let start = chunks * lanes;
240 for i in start..len {
241 result[i] = input[i].max(0.0);
242 }
243
244 result
245}
246
247fn relu_scalar(input: &[f64]) -> Vec<f64> {
248 input.iter().map(|&x| x.max(0.0)).collect()
249}
250
251pub fn sum_simd(input: &[f64]) -> f64 {
253 #[cfg(target_arch = "x86_64")]
254 {
255 if is_x86_feature_detected!("avx2") {
256 unsafe { sum_avx2(input) }
257 } else {
258 input.iter().sum()
259 }
260 }
261
262 #[cfg(not(target_arch = "x86_64"))]
263 {
264 input.iter().sum()
265 }
266}
267
268#[cfg(target_arch = "x86_64")]
269#[target_feature(enable = "avx2")]
270unsafe fn sum_avx2(input: &[f64]) -> f64 {
271 let len = input.len();
272 let lanes = 4;
273 let chunks = len / lanes;
274
275 let mut sum_vec = _mm256_setzero_pd();
276
277 for i in 0..chunks {
278 let idx = i * lanes;
279 let x_vec = _mm256_loadu_pd(input.as_ptr().add(idx));
280 sum_vec = _mm256_add_pd(sum_vec, x_vec);
281 }
282
283 let mut result = [0.0; 4];
284 _mm256_storeu_pd(result.as_mut_ptr(), sum_vec);
285 let mut sum = result.iter().sum::<f64>();
286
287 let start = chunks * lanes;
288 for val in input.iter().skip(start) {
289 sum += val;
290 }
291
292 sum
293}
294
295pub fn mul_scalar_simd(input: &[f64], scalar: f64) -> Vec<f64> {
297 #[cfg(target_arch = "x86_64")]
298 {
299 if is_x86_feature_detected!("avx2") {
300 unsafe { mul_scalar_avx2(input, scalar) }
301 } else {
302 mul_scalar_scalar(input, scalar)
303 }
304 }
305
306 #[cfg(not(target_arch = "x86_64"))]
307 {
308 mul_scalar_scalar(input, scalar)
309 }
310}
311
312#[cfg(target_arch = "x86_64")]
313#[target_feature(enable = "avx2")]
314unsafe fn mul_scalar_avx2(input: &[f64], scalar: f64) -> Vec<f64> {
315 let len = input.len();
316 let lanes = 4;
317 let chunks = len / lanes;
318
319 let mut result = vec![0.0; len];
320 let scalar_vec = _mm256_set1_pd(scalar);
321
322 for i in 0..chunks {
323 let idx = i * lanes;
324 let x_vec = _mm256_loadu_pd(input.as_ptr().add(idx));
325 let mul_vec = _mm256_mul_pd(x_vec, scalar_vec);
326 _mm256_storeu_pd(result.as_mut_ptr().add(idx), mul_vec);
327 }
328
329 let start = chunks * lanes;
330 for i in start..len {
331 result[i] = input[i] * scalar;
332 }
333
334 result
335}
336
337fn mul_scalar_scalar(input: &[f64], scalar: f64) -> Vec<f64> {
338 input.iter().map(|&x| x * scalar).collect()
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn test_dot_product_simd() {
347 let a = vec![1.0, 2.0, 3.0, 4.0];
348 let b = vec![5.0, 6.0, 7.0, 8.0];
349 let result = dot_product_simd(&a, &b);
350 assert_eq!(result, 70.0); }
352
353 #[test]
354 fn test_dot_product_large() {
355 let a: Vec<f64> = (0..100).map(|x| x as f64).collect();
356 let b: Vec<f64> = (0..100).map(|x| x as f64 * 2.0).collect();
357
358 let result_simd = dot_product_simd(&a, &b);
359 let result_scalar = dot_product_scalar(&a, &b);
360
361 assert!((result_simd - result_scalar).abs() < 1e-10);
362 }
363
364 #[test]
365 fn test_mul_elementwise() {
366 let a = vec![1.0, 2.0, 3.0, 4.0];
367 let b = vec![2.0, 3.0, 4.0, 5.0];
368 let result = mul_elementwise_simd(&a, &b);
369 assert_eq!(result, vec![2.0, 6.0, 12.0, 20.0]);
370 }
371
372 #[test]
373 fn test_add_elementwise() {
374 let a = vec![1.0, 2.0, 3.0, 4.0];
375 let b = vec![5.0, 6.0, 7.0, 8.0];
376 let result = add_elementwise_simd(&a, &b);
377 assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
378 }
379
380 #[test]
381 fn test_relu_simd() {
382 let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
383 let result = relu_simd(&input);
384 assert_eq!(result, vec![0.0, 0.0, 0.0, 1.0, 2.0]);
385 }
386
387 #[test]
388 fn test_sum_simd() {
389 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
390 let result = sum_simd(&input);
391 assert_eq!(result, 15.0);
392 }
393
394 #[test]
395 fn test_mul_scalar() {
396 let input = vec![1.0, 2.0, 3.0, 4.0];
397 let result = mul_scalar_simd(&input, 2.5);
398 assert_eq!(result, vec![2.5, 5.0, 7.5, 10.0]);
399 }
400
401 #[test]
402 fn test_simd_unaligned_length() {
403 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
405 let b = vec![6.0, 7.0, 8.0, 9.0, 10.0];
406
407 let dot = dot_product_simd(&a, &b);
408 assert_eq!(dot, 130.0); }
410}