1#[cfg(target_arch = "x86_64")]
7use std::arch::x86_64::*;
8
9use super::VectorBackend;
10
11mod ops;
12
13pub struct Avx2Backend;
15
16impl VectorBackend for Avx2Backend {
17 #[inline]
18 #[target_feature(enable = "avx2")]
19 unsafe fn add(a: &[f32], b: &[f32], result: &mut [f32]) {
21 unsafe { ops::arithmetic::add(a, b, result) }
22 }
23
24 #[inline]
25 #[target_feature(enable = "avx2")]
26 unsafe fn sub(a: &[f32], b: &[f32], result: &mut [f32]) {
28 unsafe { ops::arithmetic::sub(a, b, result) }
29 }
30
31 #[inline]
32 #[target_feature(enable = "avx2")]
33 unsafe fn mul(a: &[f32], b: &[f32], result: &mut [f32]) {
35 unsafe { ops::arithmetic::mul(a, b, result) }
36 }
37
38 #[inline]
39 #[target_feature(enable = "avx2")]
40 unsafe fn div(a: &[f32], b: &[f32], result: &mut [f32]) {
42 unsafe { ops::arithmetic::div(a, b, result) }
43 }
44
45 #[inline]
46 #[target_feature(enable = "avx2,fma")]
47 unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
49 unsafe { ops::reductions::dot(a, b) }
50 }
51
52 #[inline]
53 #[target_feature(enable = "avx2")]
54 unsafe fn sum(a: &[f32]) -> f32 {
56 unsafe { ops::reductions::sum(a) }
57 }
58
59 #[inline]
60 #[target_feature(enable = "avx2")]
61 unsafe fn max(a: &[f32]) -> f32 {
63 unsafe { ops::reductions::max(a) }
64 }
65
66 #[inline]
67 #[target_feature(enable = "avx2")]
68 unsafe fn min(a: &[f32]) -> f32 {
70 unsafe { ops::reductions::min(a) }
71 }
72
73 #[inline]
74 #[target_feature(enable = "avx2")]
75 unsafe fn argmax(a: &[f32]) -> usize {
77 unsafe { ops::reductions::argmax(a) }
78 }
79
80 #[inline]
81 #[target_feature(enable = "avx2")]
82 unsafe fn argmin(a: &[f32]) -> usize {
84 unsafe { ops::reductions::argmin(a) }
85 }
86
87 #[inline]
88 unsafe fn sum_kahan(a: &[f32]) -> f32 {
90 unsafe { ops::reductions::sum_kahan(a) }
91 }
92
93 #[inline]
94 #[target_feature(enable = "avx2,fma")]
95 unsafe fn norm_l2(a: &[f32]) -> f32 {
97 unsafe {
98 if a.is_empty() {
99 return 0.0;
100 }
101 Self::dot(a, a).sqrt()
102 }
103 }
104
105 #[inline]
106 #[target_feature(enable = "avx2")]
107 unsafe fn norm_l1(a: &[f32]) -> f32 {
109 unsafe {
110 if a.is_empty() {
111 return 0.0;
112 }
113 let len = a.len();
114 let mut i = 0;
115 let mut acc = _mm256_setzero_ps();
116 let sign_mask = _mm256_set1_ps(f32::from_bits(0x7FFF_FFFF));
117
118 while i + 8 <= len {
119 let va = _mm256_loadu_ps(a.as_ptr().add(i));
120 let abs_va = _mm256_and_ps(va, sign_mask);
121 acc = _mm256_add_ps(acc, abs_va);
122 i += 8;
123 }
124
125 let mut result = {
126 let sum_halves =
127 _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
128 let temp = _mm_add_ps(sum_halves, _mm_movehl_ps(sum_halves, sum_halves));
129 let temp = _mm_add_ss(temp, _mm_shuffle_ps(temp, temp, 1));
130 _mm_cvtss_f32(temp)
131 };
132
133 for &val in &a[i..] {
134 result += val.abs();
135 }
136 result
137 }
138 }
139
140 #[inline]
141 #[target_feature(enable = "avx2")]
142 unsafe fn norm_linf(a: &[f32]) -> f32 {
144 unsafe {
145 if a.is_empty() {
146 return 0.0;
147 }
148 let len = a.len();
149 let mut i = 0;
150 let mut max_vec = _mm256_setzero_ps();
151 let sign_mask = _mm256_set1_ps(f32::from_bits(0x7FFF_FFFF));
152
153 while i + 8 <= len {
154 let va = _mm256_loadu_ps(a.as_ptr().add(i));
155 let abs_va = _mm256_and_ps(va, sign_mask);
156 max_vec = _mm256_max_ps(max_vec, abs_va);
157 i += 8;
158 }
159
160 let mut result = {
161 let max_halves =
162 _mm_max_ps(_mm256_castps256_ps128(max_vec), _mm256_extractf128_ps(max_vec, 1));
163 let temp = _mm_max_ps(max_halves, _mm_movehl_ps(max_halves, max_halves));
164 let temp = _mm_max_ss(temp, _mm_shuffle_ps(temp, temp, 1));
165 _mm_cvtss_f32(temp)
166 };
167
168 for &val in &a[i..] {
169 let abs_val = val.abs();
170 if abs_val > result {
171 result = abs_val;
172 }
173 }
174 result
175 }
176 }
177
178 #[inline]
179 #[target_feature(enable = "avx2")]
180 unsafe fn scale(a: &[f32], scalar: f32, result: &mut [f32]) {
182 unsafe {
183 let len = a.len();
184 let mut i = 0;
185 let scalar_vec = _mm256_set1_ps(scalar);
186
187 while i + 8 <= len {
188 let va = _mm256_loadu_ps(a.as_ptr().add(i));
189 let vresult = _mm256_mul_ps(va, scalar_vec);
190 _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
191 i += 8;
192 }
193
194 while i < len {
195 result[i] = a[i] * scalar;
196 i += 1;
197 }
198 }
199 }
200
201 #[inline]
202 #[target_feature(enable = "avx2")]
203 unsafe fn abs(a: &[f32], result: &mut [f32]) {
205 unsafe {
206 let len = a.len();
207 let mut i = 0;
208 let sign_mask = _mm256_set1_ps(f32::from_bits(0x7FFF_FFFF));
209
210 while i + 8 <= len {
211 let va = _mm256_loadu_ps(a.as_ptr().add(i));
212 let vresult = _mm256_and_ps(va, sign_mask);
213 _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
214 i += 8;
215 }
216
217 for j in i..len {
218 result[j] = a[j].abs();
219 }
220 }
221 }
222
223 #[inline]
224 #[target_feature(enable = "avx2")]
225 unsafe fn clamp(a: &[f32], min_val: f32, max_val: f32, result: &mut [f32]) {
227 unsafe {
228 let len = a.len();
229 let mut i = 0;
230 let vmin = _mm256_set1_ps(min_val);
231 let vmax = _mm256_set1_ps(max_val);
232
233 while i + 8 <= len {
234 let va = _mm256_loadu_ps(a.as_ptr().add(i));
235 let vresult = _mm256_min_ps(_mm256_max_ps(va, vmin), vmax);
236 _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
237 i += 8;
238 }
239
240 for j in i..len {
241 result[j] = a[j].clamp(min_val, max_val);
242 }
243 }
244 }
245
246 #[inline]
247 #[target_feature(enable = "avx2,fma")]
248 unsafe fn lerp(a: &[f32], b: &[f32], t: f32, result: &mut [f32]) {
250 unsafe {
251 let len = a.len();
252 let mut i = 0;
253 let vt = _mm256_set1_ps(t);
254 let v1_minus_t = _mm256_set1_ps(1.0 - t);
255
256 while i + 8 <= len {
257 let va = _mm256_loadu_ps(a.as_ptr().add(i));
258 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
259 let vresult = _mm256_fmadd_ps(vb, vt, _mm256_mul_ps(va, v1_minus_t));
260 _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
261 i += 8;
262 }
263
264 for j in i..len {
265 result[j] = a[j] * (1.0 - t) + b[j] * t;
266 }
267 }
268 }
269
270 #[inline]
271 #[target_feature(enable = "avx2,fma")]
272 unsafe fn fma(a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]) {
274 unsafe {
275 let len = a.len();
276 let mut i = 0;
277
278 while i + 8 <= len {
279 let va = _mm256_loadu_ps(a.as_ptr().add(i));
280 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
281 let vc = _mm256_loadu_ps(c.as_ptr().add(i));
282 let vresult = _mm256_fmadd_ps(va, vb, vc);
283 _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
284 i += 8;
285 }
286
287 for j in i..len {
288 result[j] = a[j] * b[j] + c[j];
289 }
290 }
291 }
292
293 #[inline]
294 #[target_feature(enable = "avx2")]
295 unsafe fn relu(a: &[f32], result: &mut [f32]) {
297 unsafe {
298 let len = a.len();
299 let mut i = 0;
300 let vzero = _mm256_setzero_ps();
301
302 while i + 8 <= len {
303 let va = _mm256_loadu_ps(a.as_ptr().add(i));
304 let vresult = _mm256_max_ps(va, vzero);
305 _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
306 i += 8;
307 }
308
309 for j in i..len {
310 result[j] = a[j].max(0.0);
311 }
312 }
313 }
314
315 #[inline]
317 unsafe fn exp(a: &[f32], result: &mut [f32]) {
319 unsafe { super::scalar::ScalarBackend::exp(a, result) }
320 }
321
322 #[inline]
323 unsafe fn sigmoid(a: &[f32], result: &mut [f32]) {
325 unsafe { super::scalar::ScalarBackend::sigmoid(a, result) }
326 }
327
328 #[inline]
329 unsafe fn gelu(a: &[f32], result: &mut [f32]) {
331 unsafe { super::scalar::ScalarBackend::gelu(a, result) }
332 }
333
334 #[inline]
335 unsafe fn swish(a: &[f32], result: &mut [f32]) {
337 unsafe { super::scalar::ScalarBackend::swish(a, result) }
338 }
339
340 #[inline]
341 unsafe fn tanh(a: &[f32], result: &mut [f32]) {
343 unsafe { super::scalar::ScalarBackend::tanh(a, result) }
344 }
345
346 #[inline]
347 #[target_feature(enable = "avx2")]
348 unsafe fn sqrt(a: &[f32], result: &mut [f32]) {
350 unsafe {
351 let len = a.len();
352 let mut i = 0;
353
354 while i + 8 <= len {
355 let va = _mm256_loadu_ps(a.as_ptr().add(i));
356 let vresult = _mm256_sqrt_ps(va);
357 _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
358 i += 8;
359 }
360
361 for j in i..len {
362 result[j] = a[j].sqrt();
363 }
364 }
365 }
366
367 #[inline]
368 #[target_feature(enable = "avx2")]
369 unsafe fn recip(a: &[f32], result: &mut [f32]) {
371 unsafe {
372 let len = a.len();
373 let mut i = 0;
374 let vone = _mm256_set1_ps(1.0);
375
376 while i + 8 <= len {
377 let va = _mm256_loadu_ps(a.as_ptr().add(i));
378 let vresult = _mm256_div_ps(vone, va);
379 _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
380 i += 8;
381 }
382
383 for j in i..len {
384 result[j] = 1.0 / a[j];
385 }
386 }
387 }
388
389 #[inline]
391 unsafe fn ln(a: &[f32], result: &mut [f32]) {
393 unsafe { super::scalar::ScalarBackend::ln(a, result) }
394 }
395
396 #[inline]
397 unsafe fn log2(a: &[f32], result: &mut [f32]) {
399 unsafe { super::scalar::ScalarBackend::log2(a, result) }
400 }
401
402 #[inline]
403 unsafe fn log10(a: &[f32], result: &mut [f32]) {
405 unsafe { super::scalar::ScalarBackend::log10(a, result) }
406 }
407
408 #[inline]
410 unsafe fn sin(a: &[f32], result: &mut [f32]) {
412 unsafe { super::scalar::ScalarBackend::sin(a, result) }
413 }
414
415 #[inline]
416 unsafe fn cos(a: &[f32], result: &mut [f32]) {
418 unsafe { super::scalar::ScalarBackend::cos(a, result) }
419 }
420
421 #[inline]
422 unsafe fn tan(a: &[f32], result: &mut [f32]) {
424 unsafe { super::scalar::ScalarBackend::tan(a, result) }
425 }
426
427 #[inline]
428 #[target_feature(enable = "avx2")]
429 unsafe fn floor(a: &[f32], result: &mut [f32]) {
431 unsafe {
432 let len = a.len();
433 let mut i = 0;
434
435 while i + 8 <= len {
436 let va = _mm256_loadu_ps(a.as_ptr().add(i));
437 let vresult = _mm256_floor_ps(va);
438 _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
439 i += 8;
440 }
441
442 for j in i..len {
443 result[j] = a[j].floor();
444 }
445 }
446 }
447
448 #[inline]
449 #[target_feature(enable = "avx2")]
450 unsafe fn ceil(a: &[f32], result: &mut [f32]) {
452 unsafe {
453 let len = a.len();
454 let mut i = 0;
455
456 while i + 8 <= len {
457 let va = _mm256_loadu_ps(a.as_ptr().add(i));
458 let vresult = _mm256_ceil_ps(va);
459 _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
460 i += 8;
461 }
462
463 for j in i..len {
464 result[j] = a[j].ceil();
465 }
466 }
467 }
468
469 #[inline]
470 #[target_feature(enable = "avx2")]
471 unsafe fn round(a: &[f32], result: &mut [f32]) {
473 unsafe {
474 let len = a.len();
475 let mut i = 0;
476
477 let half = _mm256_set1_ps(0.5);
479 let sign_mask = _mm256_set1_ps(f32::from_bits(0x8000_0000));
480 let abs_mask = _mm256_set1_ps(f32::from_bits(0x7FFF_FFFF));
481
482 while i + 8 <= len {
483 let va = _mm256_loadu_ps(a.as_ptr().add(i));
484 let sign = _mm256_and_ps(va, sign_mask);
485 let abs_val = _mm256_and_ps(va, abs_mask);
486 let shifted = _mm256_add_ps(abs_val, half);
487 let rounded_abs = _mm256_floor_ps(shifted);
488 let vresult = _mm256_or_ps(rounded_abs, sign);
489 _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
490 i += 8;
491 }
492
493 for j in i..len {
494 result[j] = a[j].round();
495 }
496 }
497 }
498}