1#[cfg(target_arch = "aarch64")]
2use std::arch::aarch64::{
3 float32x4_t, vaddq_f32, vdivq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vmaxq_f32, vminq_f32,
4 vmulq_f32, vnegq_f32, vst1q_f32, vsubq_f32,
5};
6#[cfg(target_arch = "x86")]
7use std::arch::x86::{
8 __m128, __m256, _mm_add_ps, _mm_castsi128_ps, _mm_cvtepi32_ps, _mm_cvtps_epi32, _mm_loadu_ps,
9 _mm_max_ps, _mm_min_ps, _mm_mul_ps, _mm_set1_epi32, _mm_set1_ps, _mm_setzero_ps, _mm_storeu_ps,
10 _mm_sub_ps, _mm256_add_ps, _mm256_castsi256_ps, _mm256_cvtepi32_ps, _mm256_cvtps_epi32,
11 _mm256_loadu_ps, _mm256_max_ps, _mm256_min_ps, _mm256_mul_ps, _mm256_set1_epi32,
12 _mm256_set1_ps, _mm256_setzero_ps, _mm256_storeu_ps, _mm256_sub_ps,
13};
14#[cfg(target_arch = "x86_64")]
15use std::arch::x86_64::{
16 __m128, __m256, _mm_add_ps, _mm_castsi128_ps, _mm_cvtepi32_ps, _mm_cvtps_epi32, _mm_loadu_ps,
17 _mm_max_ps, _mm_min_ps, _mm_mul_ps, _mm_set1_epi32, _mm_set1_ps, _mm_setzero_ps, _mm_storeu_ps,
18 _mm_sub_ps, _mm256_add_ps, _mm256_castsi256_ps, _mm256_cvtepi32_ps, _mm256_cvtps_epi32,
19 _mm256_loadu_ps, _mm256_max_ps, _mm256_min_ps, _mm256_mul_ps, _mm256_set1_epi32,
20 _mm256_set1_ps, _mm256_setzero_ps, _mm256_storeu_ps, _mm256_sub_ps,
21};
22
23use super::config::BinaryKind;
24
25#[cfg(all(feature = "mkl", any(target_arch = "x86", target_arch = "x86_64")))]
26#[allow(unsafe_code, dead_code)]
27unsafe extern "C" {
28 fn vsAdd(n: i32, a: *const f32, b: *const f32, y: *mut f32);
29 fn vsSub(n: i32, a: *const f32, b: *const f32, y: *mut f32);
30 fn vsMul(n: i32, a: *const f32, b: *const f32, y: *mut f32);
31 fn vsDiv(n: i32, a: *const f32, b: *const f32, y: *mut f32);
32 fn vsExp(n: i32, a: *const f32, y: *mut f32);
33 fn vsSqrt(n: i32, a: *const f32, y: *mut f32);
34 fn vsLn(n: i32, a: *const f32, y: *mut f32);
35}
36
37#[cfg(all(feature = "armpl", target_arch = "aarch64", not(target_os = "macos")))]
38#[allow(unsafe_code, dead_code)]
39unsafe extern "C" {
40 fn armpl_svexp_f32(n: i32, x: *const f32, y: *mut f32);
41 fn armpl_svadd_f32(n: i32, a: *const f32, b: *const f32, y: *mut f32);
42 fn armpl_svsub_f32(n: i32, a: *const f32, b: *const f32, y: *mut f32);
43 fn armpl_svmul_f32(n: i32, a: *const f32, b: *const f32, y: *mut f32);
44 fn armpl_svlog_f32(n: i32, x: *const f32, y: *mut f32);
45 fn armpl_svsqrt_f32(n: i32, x: *const f32, y: *mut f32);
46}
47
48#[cfg(target_os = "macos")]
49#[allow(unsafe_code, dead_code)]
50unsafe extern "C" {
51 fn vvexpf(result: *mut f32, input: *const f32, count: *const i32);
52 fn vDSP_vadd(
53 __A: *const f32,
54 __IA: i32,
55 __B: *const f32,
56 __IB: i32,
57 __C: *mut f32,
58 __IC: i32,
59 __N: u32,
60 );
61 fn vDSP_vsub(
62 __B: *const f32,
63 __IB: i32,
64 __A: *const f32,
65 __IA: i32,
66 __C: *mut f32,
67 __IC: i32,
68 __N: u32,
69 );
70 fn vDSP_vmul(
71 __A: *const f32,
72 __IA: i32,
73 __B: *const f32,
74 __IB: i32,
75 __C: *mut f32,
76 __IC: i32,
77 __N: u32,
78 );
79}
80
81#[allow(unsafe_code)]
86#[inline]
87pub fn relu_slice_dispatch(values: &mut [f32]) {
88 if cfg!(miri) {
89 unsafe {
91 relu_slice_scalar(values);
92 }
93 return;
94 }
95
96 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
97 {
98 if std::is_x86_feature_detected!("avx") {
99 unsafe {
101 relu_slice_avx(values);
102 }
103 return;
104 }
105 if std::is_x86_feature_detected!("sse") {
106 unsafe {
108 relu_slice_sse(values);
109 }
110 return;
111 }
112 }
113
114 #[cfg(target_arch = "aarch64")]
115 {
116 if std::arch::is_aarch64_feature_detected!("neon") {
117 unsafe {
119 relu_slice_neon(values);
120 }
121 return;
122 }
123 }
124
125 unsafe {
127 relu_slice_scalar(values);
128 }
129}
130
131#[allow(unsafe_code)]
136#[inline]
137pub fn relu_to_slice_dispatch(input: &[f32], output: &mut [f32]) {
138 debug_assert_eq!(input.len(), output.len());
139
140 if cfg!(miri) {
141 unsafe {
143 relu_to_slice_scalar(input, output);
144 }
145 return;
146 }
147
148 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
149 {
150 if std::is_x86_feature_detected!("avx") {
151 unsafe {
153 relu_to_slice_avx(input, output);
154 }
155 return;
156 }
157 if std::is_x86_feature_detected!("sse") {
158 unsafe {
160 relu_to_slice_sse(input, output);
161 }
162 return;
163 }
164 }
165
166 #[cfg(target_arch = "aarch64")]
167 {
168 if std::arch::is_aarch64_feature_detected!("neon") {
169 unsafe {
171 relu_to_slice_neon(input, output);
172 }
173 return;
174 }
175 }
176
177 unsafe {
179 relu_to_slice_scalar(input, output);
180 }
181}
182
183#[inline]
184#[allow(dead_code)]
185pub(crate) fn sigmoid_slice(values: &mut [f32]) {
186 for value in values {
187 *value = sigmoid_scalar(*value);
188 }
189}
190
191#[inline]
192pub(crate) fn sigmoid_scalar(value: f32) -> f32 {
193 if value >= 0.0 {
194 let z = (-value).exp();
195 1.0 / (1.0 + z)
196 } else {
197 let z = value.exp();
198 z / (1.0 + z)
199 }
200}
201
202#[allow(unsafe_code, unreachable_code)]
211#[inline]
212pub fn exp_slice_dispatch(input: &[f32], output: &mut [f32]) {
213 debug_assert_eq!(input.len(), output.len());
214
215 if cfg!(miri) {
216 exp_slice_scalar(input, output);
217 return;
218 }
219
220 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
222 {
223 let count = input.len() as i32;
224 unsafe {
227 vvexpf(output.as_mut_ptr(), input.as_ptr(), &count);
228 }
229 return;
230 }
231
232 #[cfg(all(feature = "mkl", any(target_arch = "x86", target_arch = "x86_64")))]
234 {
235 let count = input.len() as i32;
236 unsafe { vsExp(count, input.as_ptr(), output.as_mut_ptr()) };
238 return;
239 }
240
241 #[cfg(all(feature = "armpl", target_arch = "aarch64", not(target_os = "macos")))]
243 {
244 let count = input.len() as i32;
245 unsafe { armpl_svexp_f32(count, input.as_ptr(), output.as_mut_ptr()) };
247 return;
248 }
249
250 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
251 {
252 if std::is_x86_feature_detected!("avx") {
253 unsafe {
255 exp_slice_avx(input, output);
256 }
257 return;
258 }
259 if std::is_x86_feature_detected!("sse") {
260 unsafe {
262 exp_slice_sse(input, output);
263 }
264 return;
265 }
266 }
267
268 #[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
269 {
270 if std::arch::is_aarch64_feature_detected!("neon") {
271 unsafe {
273 exp_slice_neon(input, output);
274 }
275 return;
276 }
277 }
278
279 exp_slice_scalar(input, output);
280}
281
282#[allow(unsafe_code)]
287#[inline]
288pub fn sub_exp_slice_dispatch(input: &[f32], offset: f32, output: &mut [f32]) {
289 debug_assert_eq!(input.len(), output.len());
290
291 if cfg!(miri) {
292 sub_exp_slice_scalar(input, offset, output);
293 return;
294 }
295
296 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
297 {
298 if std::is_x86_feature_detected!("avx") {
299 unsafe {
301 sub_exp_slice_avx(input, offset, output);
302 }
303 return;
304 }
305 if std::is_x86_feature_detected!("sse") {
306 unsafe {
308 sub_exp_slice_sse(input, offset, output);
309 }
310 return;
311 }
312 }
313
314 #[cfg(target_arch = "aarch64")]
315 {
316 if std::arch::is_aarch64_feature_detected!("neon") {
317 unsafe {
319 sub_exp_slice_neon(input, offset, output);
320 }
321 return;
322 }
323 }
324
325 sub_exp_slice_scalar(input, offset, output);
326}
327
328#[allow(unsafe_code, clippy::needless_return)]
330#[inline]
331pub fn sigmoid_slice_dispatch(input: &[f32], output: &mut [f32]) {
332 debug_assert_eq!(input.len(), output.len());
333
334 if cfg!(miri) {
335 sigmoid_slice_dispatch_scalar(input, output);
336 return;
337 }
338
339 {
341 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
342 {
343 if std::is_x86_feature_detected!("avx") {
344 unsafe {
346 sigmoid_slice_avx(input, output);
347 }
348 return;
349 }
350 if std::is_x86_feature_detected!("sse") {
351 unsafe {
353 sigmoid_slice_sse(input, output);
354 }
355 return;
356 }
357 }
358
359 #[cfg(target_arch = "aarch64")]
360 {
361 if std::arch::is_aarch64_feature_detected!("neon") {
362 unsafe {
363 sigmoid_slice_neon(input, output);
364 }
365 return;
366 }
367 }
368
369 sigmoid_slice_dispatch_scalar(input, output);
370 }
371}
372
373#[cfg(target_arch = "aarch64")]
374#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
375#[target_feature(enable = "neon")]
376#[inline]
377unsafe fn fast_exp_sigmoid_neon(x: float32x4_t) -> float32x4_t {
380 use std::arch::aarch64::{
381 vaddq_s32, vcvtnq_s32_f32, vcvtq_f32_s32, vdupq_n_s32, vreinterpretq_f32_s32, vshlq_n_s32,
382 vsubq_f32,
383 };
384 let x = vmaxq_f32(vdupq_n_f32(-88.0), vminq_f32(vdupq_n_f32(88.0), x));
385 let n_f = vmulq_f32(x, vdupq_n_f32(std::f32::consts::LOG2_E));
386 let n_i = vcvtnq_s32_f32(n_f);
387 let r = vsubq_f32(
388 x,
389 vmulq_f32(vcvtq_f32_s32(n_i), vdupq_n_f32(std::f32::consts::LN_2)),
390 );
391 let pow2n = vreinterpretq_f32_s32(vshlq_n_s32::<23>(vaddq_s32(n_i, vdupq_n_s32(127))));
392 let p = vfmaq_f32(vdupq_n_f32(0.5), r, vdupq_n_f32(1.0 / 6.0));
393 let p = vfmaq_f32(vdupq_n_f32(1.0), r, p);
394 vmulq_f32(vfmaq_f32(vdupq_n_f32(1.0), r, p), pow2n)
395}
396
397#[cfg(target_arch = "aarch64")]
402#[allow(unsafe_code)]
403unsafe fn sigmoid_slice_neon(input: &[f32], output: &mut [f32]) {
404 let len = input.len();
405 let mut inp = input.as_ptr();
406 let mut out = output.as_mut_ptr();
407 let mut remaining = len;
408
409 if remaining >= 4 {
411 unsafe {
412 let c_neg88: f32 = -88.0;
414 let c_pos88: f32 = 88.0;
415 let c_schr_c: f32 = 12102203.0; let c_schr_b: i32 = 127 << 23; let c_sixth: f32 = 1.0 / 6.0;
421 let c_half: f32 = 0.5;
422 let c_one: f32 = 1.0;
423 let c_127: i32 = 127;
424
425 std::arch::asm!(
427 "ld1r {{v16.4s}}, [{p_neg88}]",
428 "ld1r {{v17.4s}}, [{p_pos88}]",
429 "ld1r {{v18.4s}}, [{p_schr_c}]", "dup v19.4s, {p_schr_b:w}", "ld1r {{v20.4s}}, [{p_sixth}]",
432 "ld1r {{v21.4s}}, [{p_half}]",
433 "ld1r {{v22.4s}}, [{p_one}]",
434 "dup v23.4s, {p_127:w}",
435 p_neg88 = in(reg) &c_neg88,
436 p_pos88 = in(reg) &c_pos88,
437 p_schr_c = in(reg) &c_schr_c,
438 p_schr_b = in(reg) c_schr_b,
439 p_sixth = in(reg) &c_sixth,
440 p_half = in(reg) &c_half,
441 p_one = in(reg) &c_one,
442 p_127 = in(reg) c_127,
443 out("v16") _, out("v17") _, out("v18") _, out("v19") _,
444 out("v20") _, out("v21") _, out("v22") _, out("v23") _,
445 );
446
447 while remaining >= 16 {
451 std::arch::asm!(
452 "ldp q0, q1, [{inp}]",
453 "ldp q2, q3, [{inp}, #32]",
454 "add {inp}, {inp}, #64",
455 "fneg v0.4s, v0.4s",
456 "fneg v1.4s, v1.4s",
457 "fneg v2.4s, v2.4s",
458 "fneg v3.4s, v3.4s",
459 "fmax v0.4s, v0.4s, v16.4s",
460 "fmax v1.4s, v1.4s, v16.4s",
461 "fmax v2.4s, v2.4s, v16.4s",
462 "fmax v3.4s, v3.4s, v16.4s",
463 "fmin v0.4s, v0.4s, v17.4s",
464 "fmin v1.4s, v1.4s, v17.4s",
465 "fmin v2.4s, v2.4s, v17.4s",
466 "fmin v3.4s, v3.4s, v17.4s",
467 "fmul v0.4s, v0.4s, v18.4s",
469 "fmul v1.4s, v1.4s, v18.4s",
470 "fmul v2.4s, v2.4s, v18.4s",
471 "fmul v3.4s, v3.4s, v18.4s",
472 "fcvtzs v0.4s, v0.4s",
473 "fcvtzs v1.4s, v1.4s",
474 "fcvtzs v2.4s, v2.4s",
475 "fcvtzs v3.4s, v3.4s",
476 "add v0.4s, v0.4s, v19.4s",
478 "add v1.4s, v1.4s, v19.4s",
479 "add v2.4s, v2.4s, v19.4s",
480 "add v3.4s, v3.4s, v19.4s",
481 "fadd v0.4s, v22.4s, v0.4s",
484 "fadd v1.4s, v22.4s, v1.4s",
485 "fadd v2.4s, v22.4s, v2.4s",
486 "fadd v3.4s, v22.4s, v3.4s",
487 "fdiv v0.4s, v22.4s, v0.4s",
488 "fdiv v1.4s, v22.4s, v1.4s",
489 "fdiv v2.4s, v22.4s, v2.4s",
490 "fdiv v3.4s, v22.4s, v3.4s",
491 "stp q0, q1, [{out}]",
492 "stp q2, q3, [{out}, #32]",
493 "add {out}, {out}, #64",
494 inp = inout(reg) inp,
495 out = inout(reg) out,
496 out("v0") _, out("v1") _, out("v2") _, out("v3") _,
497 );
498 remaining -= 16;
499 }
500 while remaining >= 4 {
502 std::arch::asm!(
503 "ld1 {{v0.4s}}, [{inp}], #16",
504 "fneg v0.4s, v0.4s",
505 "fmax v0.4s, v0.4s, v16.4s",
506 "fmin v0.4s, v0.4s, v17.4s",
507 "fmul v0.4s, v0.4s, v18.4s",
508 "fcvtzs v0.4s, v0.4s",
509 "add v0.4s, v0.4s, v19.4s",
510 "fadd v0.4s, v22.4s, v0.4s",
511 "fdiv v0.4s, v22.4s, v0.4s",
512 "st1 {{v0.4s}}, [{out}], #16",
513 inp = inout(reg) inp,
514 out = inout(reg) out,
515 out("v0") _,
516 );
517 remaining -= 4;
518 }
519 while remaining >= 4 {
521 std::arch::asm!(
522 "ld1 {{v0.4s}}, [{inp}], #16",
523 "fneg v0.4s, v0.4s",
524 "fmax v0.4s, v0.4s, v16.4s",
525 "fmin v0.4s, v0.4s, v17.4s",
526 "fmul v0.4s, v0.4s, v18.4s",
527 "fcvtzs v0.4s, v0.4s",
528 "add v0.4s, v0.4s, v19.4s",
529 "fadd v0.4s, v22.4s, v0.4s",
530 "fdiv v0.4s, v22.4s, v0.4s",
531 "st1 {{v0.4s}}, [{out}], #16",
532 inp = inout(reg) inp,
533 out = inout(reg) out,
534 out("v0") _,
535 );
536 remaining -= 4;
537 }
538 }
539 }
540
541 for i in 0..remaining {
543 unsafe {
544 let x = *inp.add(i);
545 *out.add(i) = 1.0 / (1.0 + (-x).exp());
546 }
547 }
548}
549
550#[allow(unsafe_code)]
556#[inline]
557pub fn tanh_slice_dispatch(input: &[f32], output: &mut [f32]) {
558 debug_assert_eq!(input.len(), output.len());
559
560 if cfg!(miri) {
561 tanh_slice_dispatch_scalar(input, output);
562 return;
563 }
564
565 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
566 {
567 if std::is_x86_feature_detected!("avx") {
568 unsafe {
570 tanh_slice_avx(input, output);
571 }
572 return;
573 }
574 if std::is_x86_feature_detected!("sse") {
575 unsafe {
577 tanh_slice_sse(input, output);
578 }
579 return;
580 }
581 }
582
583 #[cfg(target_arch = "aarch64")]
584 {
585 if std::arch::is_aarch64_feature_detected!("neon") {
586 unsafe {
588 tanh_slice_neon(input, output);
589 }
590 return;
591 }
592 }
593
594 tanh_slice_dispatch_scalar(input, output);
595}
596
597#[allow(unsafe_code)]
601#[inline]
602pub fn silu_slice_dispatch(input: &[f32], output: &mut [f32]) {
603 debug_assert_eq!(input.len(), output.len());
604
605 if cfg!(miri) {
606 silu_slice_dispatch_scalar(input, output);
607 return;
608 }
609
610 #[cfg(target_arch = "aarch64")]
611 {
612 if std::arch::is_aarch64_feature_detected!("neon") {
613 unsafe {
614 silu_slice_neon(input, output);
615 }
616 return;
617 }
618 }
619
620 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
621 {
622 if std::is_x86_feature_detected!("avx") {
623 unsafe { silu_slice_avx(input, output) };
624 return;
625 }
626 if std::is_x86_feature_detected!("sse") {
627 unsafe { silu_slice_sse(input, output) };
628 return;
629 }
630 }
631
632 silu_slice_dispatch_scalar(input, output);
633}
634
635#[allow(unsafe_code, dead_code)]
641#[inline]
642pub fn max_reduce_dispatch(data: &[f32]) -> f32 {
643 if data.is_empty() {
644 return f32::NEG_INFINITY;
645 }
646
647 if cfg!(miri) {
648 return max_reduce_scalar(data);
649 }
650
651 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
652 {
653 if std::is_x86_feature_detected!("avx") {
654 return unsafe { max_reduce_avx(data) };
656 }
657 if std::is_x86_feature_detected!("sse") {
658 return unsafe { max_reduce_sse(data) };
660 }
661 }
662
663 #[cfg(target_arch = "aarch64")]
664 {
665 if std::arch::is_aarch64_feature_detected!("neon") {
666 return unsafe { max_reduce_neon(data) };
668 }
669 }
670
671 max_reduce_scalar(data)
672}
673
674#[allow(unsafe_code, dead_code)]
676#[inline]
677pub fn add_reduce_dispatch(data: &[f32]) -> f32 {
678 if data.is_empty() {
679 return 0.0;
680 }
681
682 if cfg!(miri) {
683 return add_reduce_scalar(data);
684 }
685
686 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
687 {
688 if std::is_x86_feature_detected!("avx") {
689 return unsafe { add_reduce_avx(data) };
691 }
692 if std::is_x86_feature_detected!("sse") {
693 return unsafe { add_reduce_sse(data) };
695 }
696 }
697
698 #[cfg(target_arch = "aarch64")]
699 {
700 if std::arch::is_aarch64_feature_detected!("neon") {
701 return unsafe { add_reduce_neon(data) };
703 }
704 }
705
706 add_reduce_scalar(data)
707}
708
709#[allow(unsafe_code, dead_code)]
715#[inline]
716pub fn mul_scalar_inplace_dispatch(data: &mut [f32], scalar: f32) {
717 if cfg!(miri) || data.is_empty() {
718 for v in data.iter_mut() {
719 *v *= scalar;
720 }
721 return;
722 }
723
724 #[cfg(target_arch = "aarch64")]
725 {
726 if std::arch::is_aarch64_feature_detected!("neon") {
727 unsafe { mul_scalar_inplace_neon(data, scalar) };
729 return;
730 }
731 }
732
733 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
734 {
735 if std::is_x86_feature_detected!("avx") {
736 unsafe { mul_scalar_inplace_avx(data, scalar) };
738 return;
739 }
740 if std::is_x86_feature_detected!("sse") {
741 unsafe { mul_scalar_inplace_sse(data, scalar) };
743 return;
744 }
745 }
746
747 for v in data.iter_mut() {
748 *v *= scalar;
749 }
750}
751
752#[cfg(target_arch = "aarch64")]
753#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
754#[target_feature(enable = "neon")]
755unsafe fn mul_scalar_inplace_neon(data: &mut [f32], scalar: f32) {
756 let len = data.len();
757 let ptr = data.as_mut_ptr();
758 let vs = vdupq_n_f32(scalar);
759 let mut i = 0usize;
760 while i + 4 <= len {
761 let v = vld1q_f32(ptr.add(i));
762 vst1q_f32(ptr.add(i), vmulq_f32(v, vs));
763 i += 4;
764 }
765 while i < len {
766 *ptr.add(i) *= scalar;
767 i += 1;
768 }
769}
770
771#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
772#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
773#[target_feature(enable = "avx")]
774unsafe fn mul_scalar_inplace_avx(data: &mut [f32], scalar: f32) {
775 #[cfg(target_arch = "x86")]
776 use std::arch::x86::*;
777 #[cfg(target_arch = "x86_64")]
778 use std::arch::x86_64::*;
779 let len = data.len();
780 let ptr = data.as_mut_ptr();
781 let vs = _mm256_set1_ps(scalar);
782 let mut i = 0usize;
783 while i + 8 <= len {
784 let v = _mm256_loadu_ps(ptr.add(i));
785 _mm256_storeu_ps(ptr.add(i), _mm256_mul_ps(v, vs));
786 i += 8;
787 }
788 let vs4 = _mm_set1_ps(scalar);
790 while i + 4 <= len {
791 let v = _mm_loadu_ps(ptr.add(i));
792 _mm_storeu_ps(ptr.add(i), _mm_mul_ps(v, vs4));
793 i += 4;
794 }
795 while i < len {
796 *ptr.add(i) *= scalar;
797 i += 1;
798 }
799}
800
801#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
802#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
803#[target_feature(enable = "sse")]
804unsafe fn mul_scalar_inplace_sse(data: &mut [f32], scalar: f32) {
805 #[cfg(target_arch = "x86")]
806 use std::arch::x86::*;
807 #[cfg(target_arch = "x86_64")]
808 use std::arch::x86_64::*;
809 let len = data.len();
810 let ptr = data.as_mut_ptr();
811 let vs = _mm_set1_ps(scalar);
812 let mut i = 0usize;
813 while i + 4 <= len {
814 let v = _mm_loadu_ps(ptr.add(i));
815 _mm_storeu_ps(ptr.add(i), _mm_mul_ps(v, vs));
816 i += 4;
817 }
818 while i < len {
819 *ptr.add(i) *= scalar;
820 i += 1;
821 }
822}
823
824#[allow(unsafe_code, dead_code)]
830#[inline]
831pub fn fma_slice_dispatch(a: &[f32], b: &[f32], acc: &mut [f32]) {
832 debug_assert_eq!(a.len(), b.len());
833 debug_assert_eq!(a.len(), acc.len());
834
835 if cfg!(miri) {
836 unsafe {
838 fma_slice_scalar(a, b, acc);
839 }
840 return;
841 }
842
843 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
844 {
845 if std::is_x86_feature_detected!("avx") {
846 unsafe {
848 fma_slice_avx(a, b, acc);
849 }
850 return;
851 }
852 if std::is_x86_feature_detected!("sse") {
853 unsafe {
855 fma_slice_sse(a, b, acc);
856 }
857 return;
858 }
859 }
860
861 #[cfg(target_arch = "aarch64")]
862 {
863 if std::arch::is_aarch64_feature_detected!("neon") {
864 unsafe {
866 fma_slice_neon(a, b, acc);
867 }
868 return;
869 }
870 }
871
872 unsafe {
874 fma_slice_scalar(a, b, acc);
875 }
876}
877
878#[allow(unsafe_code, unreachable_code)]
883#[inline]
884pub fn binary_same_shape_dispatch(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
885 debug_assert_eq!(lhs.len(), rhs.len());
886 debug_assert_eq!(lhs.len(), out.len());
887
888 if cfg!(miri) {
889 unsafe {
891 binary_same_shape_scalar(lhs, rhs, out, kind);
892 }
893 return;
894 }
895
896 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
898 {
899 let n = lhs.len() as u32;
900 unsafe {
902 match kind {
903 BinaryKind::Add => {
904 vDSP_vadd(lhs.as_ptr(), 1, rhs.as_ptr(), 1, out.as_mut_ptr(), 1, n)
905 }
906 BinaryKind::Sub => {
908 vDSP_vsub(rhs.as_ptr(), 1, lhs.as_ptr(), 1, out.as_mut_ptr(), 1, n)
909 }
910 BinaryKind::Mul => {
911 vDSP_vmul(lhs.as_ptr(), 1, rhs.as_ptr(), 1, out.as_mut_ptr(), 1, n)
912 }
913 }
914 }
915 return;
916 }
917
918 #[cfg(all(feature = "mkl", any(target_arch = "x86", target_arch = "x86_64")))]
920 {
921 let n = lhs.len() as i32;
922 unsafe {
924 match kind {
925 BinaryKind::Add => vsAdd(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
926 BinaryKind::Sub => vsSub(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
927 BinaryKind::Mul => vsMul(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
928 }
929 }
930 return;
931 }
932
933 #[cfg(all(feature = "armpl", target_arch = "aarch64", not(target_os = "macos")))]
935 {
936 let n = lhs.len() as i32;
937 unsafe {
939 match kind {
940 BinaryKind::Add => armpl_svadd_f32(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
941 BinaryKind::Sub => armpl_svsub_f32(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
942 BinaryKind::Mul => armpl_svmul_f32(n, lhs.as_ptr(), rhs.as_ptr(), out.as_mut_ptr()),
943 }
944 }
945 return;
946 }
947
948 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
949 {
950 if std::is_x86_feature_detected!("avx") {
951 unsafe {
953 binary_same_shape_avx(lhs, rhs, out, kind);
954 }
955 return;
956 }
957 if std::is_x86_feature_detected!("sse") {
958 unsafe {
960 binary_same_shape_sse(lhs, rhs, out, kind);
961 }
962 return;
963 }
964 }
965
966 #[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
967 {
968 if std::arch::is_aarch64_feature_detected!("neon") {
969 unsafe {
971 binary_same_shape_neon(lhs, rhs, out, kind);
972 }
973 return;
974 }
975 }
976
977 unsafe {
979 binary_same_shape_scalar(lhs, rhs, out, kind);
980 }
981}
982
983#[allow(unsafe_code)]
988#[allow(unsafe_op_in_unsafe_fn)]
989unsafe fn relu_slice_scalar(values: &mut [f32]) {
990 let len = values.len();
991 let ptr = values.as_mut_ptr();
992 let mut index = 0usize;
993
994 while index + 8 <= len {
995 let v0 = *ptr.add(index);
996 let v1 = *ptr.add(index + 1);
997 let v2 = *ptr.add(index + 2);
998 let v3 = *ptr.add(index + 3);
999 let v4 = *ptr.add(index + 4);
1000 let v5 = *ptr.add(index + 5);
1001 let v6 = *ptr.add(index + 6);
1002 let v7 = *ptr.add(index + 7);
1003 *ptr.add(index) = v0.max(0.0);
1004 *ptr.add(index + 1) = v1.max(0.0);
1005 *ptr.add(index + 2) = v2.max(0.0);
1006 *ptr.add(index + 3) = v3.max(0.0);
1007 *ptr.add(index + 4) = v4.max(0.0);
1008 *ptr.add(index + 5) = v5.max(0.0);
1009 *ptr.add(index + 6) = v6.max(0.0);
1010 *ptr.add(index + 7) = v7.max(0.0);
1011 index += 8;
1012 }
1013
1014 while index < len {
1015 *ptr.add(index) = (*ptr.add(index)).max(0.0);
1016 index += 1;
1017 }
1018}
1019
1020#[allow(unsafe_code)]
1021#[allow(unsafe_op_in_unsafe_fn)]
1022unsafe fn relu_to_slice_scalar(input: &[f32], output: &mut [f32]) {
1023 let len = input.len();
1024 let in_ptr = input.as_ptr();
1025 let out_ptr = output.as_mut_ptr();
1026 let mut index = 0usize;
1027
1028 while index + 8 <= len {
1029 *out_ptr.add(index) = (*in_ptr.add(index)).max(0.0);
1030 *out_ptr.add(index + 1) = (*in_ptr.add(index + 1)).max(0.0);
1031 *out_ptr.add(index + 2) = (*in_ptr.add(index + 2)).max(0.0);
1032 *out_ptr.add(index + 3) = (*in_ptr.add(index + 3)).max(0.0);
1033 *out_ptr.add(index + 4) = (*in_ptr.add(index + 4)).max(0.0);
1034 *out_ptr.add(index + 5) = (*in_ptr.add(index + 5)).max(0.0);
1035 *out_ptr.add(index + 6) = (*in_ptr.add(index + 6)).max(0.0);
1036 *out_ptr.add(index + 7) = (*in_ptr.add(index + 7)).max(0.0);
1037 index += 8;
1038 }
1039
1040 while index < len {
1041 *out_ptr.add(index) = (*in_ptr.add(index)).max(0.0);
1042 index += 1;
1043 }
1044}
1045
1046#[allow(unsafe_code)]
1047#[allow(unsafe_op_in_unsafe_fn)]
1048unsafe fn binary_same_shape_scalar(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
1049 let len = lhs.len();
1050 let left_ptr = lhs.as_ptr();
1051 let right_ptr = rhs.as_ptr();
1052 let out_ptr = out.as_mut_ptr();
1053 let mut index = 0usize;
1054
1055 match kind {
1056 BinaryKind::Add => {
1057 while index + 8 <= len {
1058 *out_ptr.add(index) = *left_ptr.add(index) + *right_ptr.add(index);
1059 *out_ptr.add(index + 1) = *left_ptr.add(index + 1) + *right_ptr.add(index + 1);
1060 *out_ptr.add(index + 2) = *left_ptr.add(index + 2) + *right_ptr.add(index + 2);
1061 *out_ptr.add(index + 3) = *left_ptr.add(index + 3) + *right_ptr.add(index + 3);
1062 *out_ptr.add(index + 4) = *left_ptr.add(index + 4) + *right_ptr.add(index + 4);
1063 *out_ptr.add(index + 5) = *left_ptr.add(index + 5) + *right_ptr.add(index + 5);
1064 *out_ptr.add(index + 6) = *left_ptr.add(index + 6) + *right_ptr.add(index + 6);
1065 *out_ptr.add(index + 7) = *left_ptr.add(index + 7) + *right_ptr.add(index + 7);
1066 index += 8;
1067 }
1068 while index < len {
1069 *out_ptr.add(index) = *left_ptr.add(index) + *right_ptr.add(index);
1070 index += 1;
1071 }
1072 }
1073 BinaryKind::Sub => {
1074 while index + 8 <= len {
1075 *out_ptr.add(index) = *left_ptr.add(index) - *right_ptr.add(index);
1076 *out_ptr.add(index + 1) = *left_ptr.add(index + 1) - *right_ptr.add(index + 1);
1077 *out_ptr.add(index + 2) = *left_ptr.add(index + 2) - *right_ptr.add(index + 2);
1078 *out_ptr.add(index + 3) = *left_ptr.add(index + 3) - *right_ptr.add(index + 3);
1079 *out_ptr.add(index + 4) = *left_ptr.add(index + 4) - *right_ptr.add(index + 4);
1080 *out_ptr.add(index + 5) = *left_ptr.add(index + 5) - *right_ptr.add(index + 5);
1081 *out_ptr.add(index + 6) = *left_ptr.add(index + 6) - *right_ptr.add(index + 6);
1082 *out_ptr.add(index + 7) = *left_ptr.add(index + 7) - *right_ptr.add(index + 7);
1083 index += 8;
1084 }
1085 while index < len {
1086 *out_ptr.add(index) = *left_ptr.add(index) - *right_ptr.add(index);
1087 index += 1;
1088 }
1089 }
1090 BinaryKind::Mul => {
1091 while index + 8 <= len {
1092 *out_ptr.add(index) = *left_ptr.add(index) * *right_ptr.add(index);
1093 *out_ptr.add(index + 1) = *left_ptr.add(index + 1) * *right_ptr.add(index + 1);
1094 *out_ptr.add(index + 2) = *left_ptr.add(index + 2) * *right_ptr.add(index + 2);
1095 *out_ptr.add(index + 3) = *left_ptr.add(index + 3) * *right_ptr.add(index + 3);
1096 *out_ptr.add(index + 4) = *left_ptr.add(index + 4) * *right_ptr.add(index + 4);
1097 *out_ptr.add(index + 5) = *left_ptr.add(index + 5) * *right_ptr.add(index + 5);
1098 *out_ptr.add(index + 6) = *left_ptr.add(index + 6) * *right_ptr.add(index + 6);
1099 *out_ptr.add(index + 7) = *left_ptr.add(index + 7) * *right_ptr.add(index + 7);
1100 index += 8;
1101 }
1102 while index < len {
1103 *out_ptr.add(index) = *left_ptr.add(index) * *right_ptr.add(index);
1104 index += 1;
1105 }
1106 }
1107 }
1108}
1109
1110fn exp_slice_scalar(input: &[f32], output: &mut [f32]) {
1111 for (o, &v) in output.iter_mut().zip(input.iter()) {
1112 *o = v.exp();
1113 }
1114}
1115
1116fn sub_exp_slice_scalar(input: &[f32], offset: f32, output: &mut [f32]) {
1117 for (o, &v) in output.iter_mut().zip(input.iter()) {
1118 *o = (v - offset).exp();
1119 }
1120}
1121
1122fn sigmoid_slice_dispatch_scalar(input: &[f32], output: &mut [f32]) {
1123 for (o, &v) in output.iter_mut().zip(input.iter()) {
1124 *o = sigmoid_scalar(v);
1125 }
1126}
1127
1128fn tanh_slice_dispatch_scalar(input: &[f32], output: &mut [f32]) {
1129 for (o, &v) in output.iter_mut().zip(input.iter()) {
1130 *o = v.tanh();
1131 }
1132}
1133
1134fn silu_slice_dispatch_scalar(input: &[f32], output: &mut [f32]) {
1135 for (o, &v) in output.iter_mut().zip(input.iter()) {
1136 let s = 1.0 / (1.0 + (-v).exp());
1137 *o = v * s;
1138 }
1139}
1140
1141#[allow(dead_code)]
1142fn max_reduce_scalar(data: &[f32]) -> f32 {
1143 let mut acc = f32::NEG_INFINITY;
1144 for &v in data {
1145 acc = acc.max(v);
1146 }
1147 acc
1148}
1149
1150#[allow(dead_code)]
1151fn add_reduce_scalar(data: &[f32]) -> f32 {
1152 let mut acc = 0.0f32;
1153 for &v in data {
1154 acc += v;
1155 }
1156 acc
1157}
1158
1159#[allow(unsafe_code, dead_code)]
1160#[allow(unsafe_op_in_unsafe_fn)]
1161unsafe fn fma_slice_scalar(a: &[f32], b: &[f32], acc: &mut [f32]) {
1162 let len = a.len();
1163 let a_ptr = a.as_ptr();
1164 let b_ptr = b.as_ptr();
1165 let acc_ptr = acc.as_mut_ptr();
1166 let mut index = 0usize;
1167
1168 while index + 4 <= len {
1169 *acc_ptr.add(index) += *a_ptr.add(index) * *b_ptr.add(index);
1170 *acc_ptr.add(index + 1) += *a_ptr.add(index + 1) * *b_ptr.add(index + 1);
1171 *acc_ptr.add(index + 2) += *a_ptr.add(index + 2) * *b_ptr.add(index + 2);
1172 *acc_ptr.add(index + 3) += *a_ptr.add(index + 3) * *b_ptr.add(index + 3);
1173 index += 4;
1174 }
1175 while index < len {
1176 *acc_ptr.add(index) += *a_ptr.add(index) * *b_ptr.add(index);
1177 index += 1;
1178 }
1179}
1180
1181#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1192#[allow(unsafe_code)]
1193#[allow(unsafe_op_in_unsafe_fn)]
1194#[target_feature(enable = "sse")]
1195#[inline]
1196unsafe fn fast_exp_bittrick_sse(x: __m128) -> __m128 {
1197 #[cfg(target_arch = "x86")]
1199 use std::arch::x86::{_mm_add_epi32, _mm_cvtps_epi32, _mm_set1_epi32};
1200 #[cfg(target_arch = "x86_64")]
1201 use std::arch::x86_64::{_mm_add_epi32, _mm_cvtps_epi32, _mm_set1_epi32};
1202 let scale = _mm_set1_ps(12102203.0); let offset = _mm_set1_epi32(1065353216); let clamp_lo = _mm_set1_ps(-87.0); let clamp_hi = _mm_set1_ps(88.0); let x_clamped = _mm_max_ps(_mm_min_ps(x, clamp_hi), clamp_lo);
1208 let val = _mm_cvtps_epi32(_mm_mul_ps(x_clamped, scale));
1209 _mm_castsi128_ps(_mm_add_epi32(val, offset))
1210}
1211
1212#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1216#[allow(unsafe_code)]
1217#[allow(unsafe_op_in_unsafe_fn)]
1218#[target_feature(enable = "sse")]
1219unsafe fn fast_exp_sse(x: __m128) -> __m128 {
1220 let ln2_inv = _mm_set1_ps(std::f32::consts::LOG2_E);
1221 let ln2_hi = _mm_set1_ps(0.693_359_4); let ln2_lo = _mm_set1_ps(-2.121_944_4e-4); let c0 = _mm_set1_ps(1.0);
1226 let c1 = _mm_set1_ps(1.0);
1227 let c2 = _mm_set1_ps(0.5);
1228 let c3 = _mm_set1_ps(1.0 / 6.0);
1229 let c4 = _mm_set1_ps(1.0 / 24.0);
1230 let c5 = _mm_set1_ps(1.0 / 120.0);
1231 let c6 = _mm_set1_ps(1.0 / 720.0);
1232
1233 let x = _mm_max_ps(_mm_set1_ps(-88.0), _mm_min_ps(_mm_set1_ps(88.0), x));
1235
1236 let n_f = _mm_mul_ps(x, ln2_inv);
1238 let n_i = _mm_cvtps_epi32(n_f);
1240 let n_f = _mm_cvtepi32_ps(n_i);
1241
1242 let r = _mm_sub_ps(
1244 _mm_sub_ps(x, _mm_mul_ps(n_f, ln2_hi)),
1245 _mm_mul_ps(n_f, ln2_lo),
1246 );
1247
1248 let mut poly = _mm_add_ps(c5, _mm_mul_ps(r, c6));
1250 poly = _mm_add_ps(c4, _mm_mul_ps(r, poly));
1251 poly = _mm_add_ps(c3, _mm_mul_ps(r, poly));
1252 poly = _mm_add_ps(c2, _mm_mul_ps(r, poly));
1253 poly = _mm_add_ps(c1, _mm_mul_ps(r, poly));
1254 poly = _mm_add_ps(c0, _mm_mul_ps(r, poly));
1255
1256 let pow2n = {
1259 #[cfg(target_arch = "x86")]
1260 use std::arch::x86::{_mm_add_epi32, _mm_slli_epi32};
1261 #[cfg(target_arch = "x86_64")]
1262 use std::arch::x86_64::{_mm_add_epi32, _mm_slli_epi32};
1263 let bias = _mm_set1_epi32(127);
1264 _mm_castsi128_ps(_mm_slli_epi32(_mm_add_epi32(n_i, bias), 23))
1265 };
1266
1267 _mm_mul_ps(poly, pow2n)
1268}
1269
1270#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1277#[allow(unsafe_code)]
1278#[allow(unsafe_op_in_unsafe_fn)]
1279#[target_feature(enable = "avx")]
1280#[inline]
1281unsafe fn fast_exp_bittrick_avx(x: __m256) -> __m256 {
1282 #[cfg(target_arch = "x86")]
1283 use std::arch::x86::{_mm256_add_epi32, _mm256_cvtps_epi32, _mm256_set1_epi32};
1284 #[cfg(target_arch = "x86_64")]
1285 use std::arch::x86_64::{_mm256_add_epi32, _mm256_cvtps_epi32, _mm256_set1_epi32};
1286 let scale = _mm256_set1_ps(12102203.0); let offset = _mm256_set1_epi32(1065353216); let clamp_lo = _mm256_set1_ps(-87.0); let clamp_hi = _mm256_set1_ps(88.0); let x_clamped = _mm256_max_ps(_mm256_min_ps(x, clamp_hi), clamp_lo);
1291 let val = _mm256_cvtps_epi32(_mm256_mul_ps(x_clamped, scale));
1292 _mm256_castsi256_ps(_mm256_add_epi32(val, offset))
1293}
1294
1295#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1299#[allow(unsafe_code)]
1300#[allow(unsafe_op_in_unsafe_fn)]
1301#[target_feature(enable = "avx")]
1302unsafe fn fast_exp_avx(x: __m256) -> __m256 {
1303 let ln2_inv = _mm256_set1_ps(std::f32::consts::LOG2_E);
1304 let ln2_hi = _mm256_set1_ps(0.693_359_4);
1305 let ln2_lo = _mm256_set1_ps(-2.121_944_4e-4);
1306
1307 let c0 = _mm256_set1_ps(1.0);
1308 let c1 = _mm256_set1_ps(1.0);
1309 let c2 = _mm256_set1_ps(0.5);
1310 let c3 = _mm256_set1_ps(1.0 / 6.0);
1311 let c4 = _mm256_set1_ps(1.0 / 24.0);
1312 let c5 = _mm256_set1_ps(1.0 / 120.0);
1313 let c6 = _mm256_set1_ps(1.0 / 720.0);
1314
1315 let x = _mm256_max_ps(
1316 _mm256_set1_ps(-88.0),
1317 _mm256_min_ps(_mm256_set1_ps(88.0), x),
1318 );
1319
1320 let n_f = _mm256_mul_ps(x, ln2_inv);
1321 let n_i = _mm256_cvtps_epi32(n_f);
1322 let n_f = _mm256_cvtepi32_ps(n_i);
1323
1324 let r = _mm256_sub_ps(
1325 _mm256_sub_ps(x, _mm256_mul_ps(n_f, ln2_hi)),
1326 _mm256_mul_ps(n_f, ln2_lo),
1327 );
1328
1329 let mut poly = _mm256_add_ps(c5, _mm256_mul_ps(r, c6));
1330 poly = _mm256_add_ps(c4, _mm256_mul_ps(r, poly));
1331 poly = _mm256_add_ps(c3, _mm256_mul_ps(r, poly));
1332 poly = _mm256_add_ps(c2, _mm256_mul_ps(r, poly));
1333 poly = _mm256_add_ps(c1, _mm256_mul_ps(r, poly));
1334 poly = _mm256_add_ps(c0, _mm256_mul_ps(r, poly));
1335
1336 let bias = _mm256_set1_epi32(127);
1337 let pow2n = {
1338 #[cfg(target_arch = "x86")]
1339 use std::arch::x86::{_mm256_add_epi32, _mm256_slli_epi32};
1340 #[cfg(target_arch = "x86_64")]
1341 use std::arch::x86_64::{_mm256_add_epi32, _mm256_slli_epi32};
1342 _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_add_epi32(n_i, bias), 23))
1343 };
1344
1345 _mm256_mul_ps(poly, pow2n)
1346}
1347
1348#[cfg(target_arch = "aarch64")]
1353#[allow(unsafe_code)]
1354#[allow(unsafe_op_in_unsafe_fn)]
1355#[target_feature(enable = "neon")]
1356unsafe fn fast_exp_neon(x: float32x4_t) -> float32x4_t {
1357 use std::arch::aarch64::{
1358 vaddq_s32, vcvtnq_s32_f32, vcvtq_f32_s32, vreinterpretq_f32_s32, vshlq_n_s32,
1359 };
1360
1361 let ln2_inv = vdupq_n_f32(std::f32::consts::LOG2_E);
1362 let ln2_hi = vdupq_n_f32(0.693_359_4);
1363 let ln2_lo = vdupq_n_f32(-2.121_944_4e-4);
1364
1365 let c0 = vdupq_n_f32(1.0);
1366 let c1 = vdupq_n_f32(1.0);
1367 let c2 = vdupq_n_f32(0.5);
1368 let c3 = vdupq_n_f32(1.0 / 6.0);
1369 let c4 = vdupq_n_f32(1.0 / 24.0);
1370 let c5 = vdupq_n_f32(1.0 / 120.0);
1371 let c6 = vdupq_n_f32(1.0 / 720.0);
1372
1373 let x = vmaxq_f32(vdupq_n_f32(-88.0), vminq_f32(vdupq_n_f32(88.0), x));
1374
1375 let n_f = vmulq_f32(x, ln2_inv);
1376 let n_i = vcvtnq_s32_f32(n_f);
1377 let n_f = vcvtq_f32_s32(n_i);
1378
1379 let r = vsubq_f32(vsubq_f32(x, vmulq_f32(n_f, ln2_hi)), vmulq_f32(n_f, ln2_lo));
1380
1381 let mut poly = vfmaq_f32(c5, r, c6);
1382 poly = vfmaq_f32(c4, r, poly);
1383 poly = vfmaq_f32(c3, r, poly);
1384 poly = vfmaq_f32(c2, r, poly);
1385 poly = vfmaq_f32(c1, r, poly);
1386 poly = vfmaq_f32(c0, r, poly);
1387
1388 use std::arch::aarch64::vdupq_n_s32;
1389 let bias = vdupq_n_s32(127);
1390 let pow2n = vreinterpretq_f32_s32(vshlq_n_s32::<23>(vaddq_s32(n_i, bias)));
1391
1392 vmulq_f32(poly, pow2n)
1393}
1394
1395#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1400#[allow(unsafe_code)]
1401#[allow(unsafe_op_in_unsafe_fn)]
1402#[target_feature(enable = "sse")]
1403unsafe fn exp_slice_sse(input: &[f32], output: &mut [f32]) {
1404 let len = input.len();
1405 let in_ptr = input.as_ptr();
1406 let out_ptr = output.as_mut_ptr();
1407 let mut index = 0usize;
1408
1409 while index + 4 <= len {
1410 let v = _mm_loadu_ps(in_ptr.add(index));
1411 let r = fast_exp_sse(v);
1412 _mm_storeu_ps(out_ptr.add(index), r);
1413 index += 4;
1414 }
1415
1416 while index < len {
1417 *out_ptr.add(index) = (*in_ptr.add(index)).exp();
1418 index += 1;
1419 }
1420}
1421
1422#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1423#[allow(unsafe_code)]
1424#[allow(unsafe_op_in_unsafe_fn)]
1425#[target_feature(enable = "avx")]
1426unsafe fn exp_slice_avx(input: &[f32], output: &mut [f32]) {
1427 let len = input.len();
1428 let in_ptr = input.as_ptr();
1429 let out_ptr = output.as_mut_ptr();
1430 let mut index = 0usize;
1431
1432 while index + 16 <= len {
1434 #[cfg(target_arch = "x86")]
1436 {
1437 use std::arch::x86::_mm_prefetch;
1438 _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1439 }
1440 #[cfg(target_arch = "x86_64")]
1441 {
1442 use std::arch::x86_64::_mm_prefetch;
1443 _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1444 }
1445 let v0 = _mm256_loadu_ps(in_ptr.add(index));
1446 let v1 = _mm256_loadu_ps(in_ptr.add(index + 8));
1447 let r0 = fast_exp_avx(v0);
1448 let r1 = fast_exp_avx(v1);
1449 _mm256_storeu_ps(out_ptr.add(index), r0);
1450 _mm256_storeu_ps(out_ptr.add(index + 8), r1);
1451 index += 16;
1452 }
1453
1454 while index + 8 <= len {
1456 let v = _mm256_loadu_ps(in_ptr.add(index));
1457 let r = fast_exp_avx(v);
1458 _mm256_storeu_ps(out_ptr.add(index), r);
1459 index += 8;
1460 }
1461
1462 if index < len {
1463 exp_slice_sse(&input[index..], &mut output[index..]);
1464 }
1465}
1466
1467#[cfg(target_arch = "aarch64")]
1468#[allow(unsafe_code, dead_code)]
1469#[allow(unsafe_op_in_unsafe_fn)]
1470#[target_feature(enable = "neon")]
1471unsafe fn exp_slice_neon(input: &[f32], output: &mut [f32]) {
1472 let len = input.len();
1473 let in_ptr = input.as_ptr();
1474 let out_ptr = output.as_mut_ptr();
1475 let mut index = 0usize;
1476
1477 while index + 4 <= len {
1478 let v = vld1q_f32(in_ptr.add(index));
1479 let r = fast_exp_neon(v);
1480 vst1q_f32(out_ptr.add(index), r);
1481 index += 4;
1482 }
1483
1484 while index < len {
1485 *out_ptr.add(index) = (*in_ptr.add(index)).exp();
1486 index += 1;
1487 }
1488}
1489
1490#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1495#[allow(unsafe_code)]
1496#[allow(unsafe_op_in_unsafe_fn)]
1497#[target_feature(enable = "sse")]
1498unsafe fn sub_exp_slice_sse(input: &[f32], offset: f32, output: &mut [f32]) {
1499 let len = input.len();
1500 let in_ptr = input.as_ptr();
1501 let out_ptr = output.as_mut_ptr();
1502 let off = _mm_set1_ps(offset);
1503 let mut index = 0usize;
1504
1505 while index + 4 <= len {
1506 let v = _mm_loadu_ps(in_ptr.add(index));
1507 let shifted = _mm_sub_ps(v, off);
1508 let r = fast_exp_sse(shifted);
1509 _mm_storeu_ps(out_ptr.add(index), r);
1510 index += 4;
1511 }
1512
1513 while index < len {
1514 *out_ptr.add(index) = (*in_ptr.add(index) - offset).exp();
1515 index += 1;
1516 }
1517}
1518
1519#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1520#[allow(unsafe_code)]
1521#[allow(unsafe_op_in_unsafe_fn)]
1522#[target_feature(enable = "avx")]
1523unsafe fn sub_exp_slice_avx(input: &[f32], offset: f32, output: &mut [f32]) {
1524 let len = input.len();
1525 let in_ptr = input.as_ptr();
1526 let out_ptr = output.as_mut_ptr();
1527 let off = _mm256_set1_ps(offset);
1528 let mut index = 0usize;
1529
1530 while index + 16 <= len {
1532 #[cfg(target_arch = "x86")]
1533 {
1534 use std::arch::x86::_mm_prefetch;
1535 _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1536 }
1537 #[cfg(target_arch = "x86_64")]
1538 {
1539 use std::arch::x86_64::_mm_prefetch;
1540 _mm_prefetch::<3>(in_ptr.add(index + 16) as *const i8);
1541 }
1542 let v0 = _mm256_loadu_ps(in_ptr.add(index));
1543 let v1 = _mm256_loadu_ps(in_ptr.add(index + 8));
1544 let shifted0 = _mm256_sub_ps(v0, off);
1545 let shifted1 = _mm256_sub_ps(v1, off);
1546 let r0 = fast_exp_avx(shifted0);
1547 let r1 = fast_exp_avx(shifted1);
1548 _mm256_storeu_ps(out_ptr.add(index), r0);
1549 _mm256_storeu_ps(out_ptr.add(index + 8), r1);
1550 index += 16;
1551 }
1552
1553 while index + 8 <= len {
1555 let v = _mm256_loadu_ps(in_ptr.add(index));
1556 let shifted = _mm256_sub_ps(v, off);
1557 let r = fast_exp_avx(shifted);
1558 _mm256_storeu_ps(out_ptr.add(index), r);
1559 index += 8;
1560 }
1561
1562 if index < len {
1563 sub_exp_slice_sse(&input[index..], offset, &mut output[index..]);
1564 }
1565}
1566
1567#[cfg(target_arch = "aarch64")]
1568#[allow(unsafe_code)]
1569#[allow(unsafe_op_in_unsafe_fn)]
1570#[target_feature(enable = "neon")]
1571unsafe fn sub_exp_slice_neon(input: &[f32], offset: f32, output: &mut [f32]) {
1572 let len = input.len();
1573 let in_ptr = input.as_ptr();
1574 let out_ptr = output.as_mut_ptr();
1575 let off = vdupq_n_f32(offset);
1576 let mut index = 0usize;
1577
1578 while index + 4 <= len {
1579 let v = vld1q_f32(in_ptr.add(index));
1580 let shifted = vsubq_f32(v, off);
1581 let r = fast_exp_neon(shifted);
1582 vst1q_f32(out_ptr.add(index), r);
1583 index += 4;
1584 }
1585
1586 while index < len {
1587 *out_ptr.add(index) = (*in_ptr.add(index) - offset).exp();
1588 index += 1;
1589 }
1590}
1591
1592#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1597#[allow(unsafe_code)]
1598#[allow(unsafe_op_in_unsafe_fn)]
1599#[target_feature(enable = "sse")]
1600unsafe fn sigmoid_slice_sse(input: &[f32], output: &mut [f32]) {
1601 #[cfg(target_arch = "x86")]
1602 use std::arch::x86::_mm_div_ps;
1603 #[cfg(target_arch = "x86_64")]
1604 use std::arch::x86_64::_mm_div_ps;
1605
1606 let len = input.len();
1607 let in_ptr = input.as_ptr();
1608 let out_ptr = output.as_mut_ptr();
1609 let one = _mm_set1_ps(1.0);
1610 let zero = _mm_setzero_ps();
1611 let mut index = 0usize;
1612
1613 while index + 16 <= len {
1615 let x0 = _mm_loadu_ps(in_ptr.add(index));
1616 let x1 = _mm_loadu_ps(in_ptr.add(index + 4));
1617 let x2 = _mm_loadu_ps(in_ptr.add(index + 8));
1618 let x3 = _mm_loadu_ps(in_ptr.add(index + 12));
1619
1620 let e0 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x0));
1622 let e1 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x1));
1623 let e2 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x2));
1624 let e3 = fast_exp_bittrick_sse(_mm_sub_ps(zero, x3));
1625
1626 let r0 = _mm_div_ps(one, _mm_add_ps(one, e0));
1627 let r1 = _mm_div_ps(one, _mm_add_ps(one, e1));
1628 let r2 = _mm_div_ps(one, _mm_add_ps(one, e2));
1629 let r3 = _mm_div_ps(one, _mm_add_ps(one, e3));
1630
1631 _mm_storeu_ps(out_ptr.add(index), r0);
1632 _mm_storeu_ps(out_ptr.add(index + 4), r1);
1633 _mm_storeu_ps(out_ptr.add(index + 8), r2);
1634 _mm_storeu_ps(out_ptr.add(index + 12), r3);
1635
1636 index += 16;
1637 }
1638
1639 while index + 4 <= len {
1641 let x = _mm_loadu_ps(in_ptr.add(index));
1642 let neg_x = _mm_sub_ps(zero, x);
1643 let exp_neg_x = fast_exp_bittrick_sse(neg_x);
1644 let denom = _mm_add_ps(one, exp_neg_x);
1645 let result = _mm_div_ps(one, denom);
1646 _mm_storeu_ps(out_ptr.add(index), result);
1647 index += 4;
1648 }
1649
1650 while index < len {
1651 *out_ptr.add(index) = sigmoid_scalar(*in_ptr.add(index));
1652 index += 1;
1653 }
1654}
1655
1656#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1657#[allow(unsafe_code)]
1658#[allow(unsafe_op_in_unsafe_fn)]
1659#[target_feature(enable = "avx")]
1660unsafe fn sigmoid_slice_avx(input: &[f32], output: &mut [f32]) {
1661 #[cfg(target_arch = "x86")]
1662 use std::arch::x86::_mm256_div_ps;
1663 #[cfg(target_arch = "x86_64")]
1664 use std::arch::x86_64::_mm256_div_ps;
1665
1666 let len = input.len();
1667 let in_ptr = input.as_ptr();
1668 let out_ptr = output.as_mut_ptr();
1669 let one = _mm256_set1_ps(1.0);
1670 let zero = _mm256_setzero_ps();
1671 let mut index = 0usize;
1672
1673 while index + 32 <= len {
1675 let x0 = _mm256_loadu_ps(in_ptr.add(index));
1676 let x1 = _mm256_loadu_ps(in_ptr.add(index + 8));
1677 let x2 = _mm256_loadu_ps(in_ptr.add(index + 16));
1678 let x3 = _mm256_loadu_ps(in_ptr.add(index + 24));
1679
1680 let e0 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x0));
1682 let e1 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x1));
1683 let e2 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x2));
1684 let e3 = fast_exp_bittrick_avx(_mm256_sub_ps(zero, x3));
1685
1686 let r0 = _mm256_div_ps(one, _mm256_add_ps(one, e0));
1687 let r1 = _mm256_div_ps(one, _mm256_add_ps(one, e1));
1688 let r2 = _mm256_div_ps(one, _mm256_add_ps(one, e2));
1689 let r3 = _mm256_div_ps(one, _mm256_add_ps(one, e3));
1690
1691 _mm256_storeu_ps(out_ptr.add(index), r0);
1692 _mm256_storeu_ps(out_ptr.add(index + 8), r1);
1693 _mm256_storeu_ps(out_ptr.add(index + 16), r2);
1694 _mm256_storeu_ps(out_ptr.add(index + 24), r3);
1695
1696 index += 32;
1697 }
1698
1699 while index + 8 <= len {
1701 let x = _mm256_loadu_ps(in_ptr.add(index));
1702 let neg_x = _mm256_sub_ps(zero, x);
1703 let exp_neg_x = fast_exp_bittrick_avx(neg_x);
1704 let denom = _mm256_add_ps(one, exp_neg_x);
1705 let result = _mm256_div_ps(one, denom);
1706 _mm256_storeu_ps(out_ptr.add(index), result);
1707 index += 8;
1708 }
1709
1710 if index < len {
1711 sigmoid_slice_sse(&input[index..], &mut output[index..]);
1712 }
1713}
1714
1715#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1722#[allow(unsafe_code)]
1723#[allow(unsafe_op_in_unsafe_fn)]
1724#[target_feature(enable = "sse")]
1725unsafe fn tanh_slice_sse(input: &[f32], output: &mut [f32]) {
1726 #[cfg(target_arch = "x86")]
1727 use std::arch::x86::_mm_div_ps;
1728 #[cfg(target_arch = "x86_64")]
1729 use std::arch::x86_64::_mm_div_ps;
1730 let len = input.len();
1731 let in_ptr = input.as_ptr();
1732 let out_ptr = output.as_mut_ptr();
1733 let two = _mm_set1_ps(2.0);
1734 let one = _mm_set1_ps(1.0);
1735 let zero = _mm_setzero_ps();
1736 let mut index = 0usize;
1737
1738 while index + 4 <= len {
1739 let x = _mm_loadu_ps(in_ptr.add(index));
1740 let two_x = _mm_mul_ps(two, x);
1741 let neg_two_x = _mm_sub_ps(zero, two_x);
1743 let exp_neg = fast_exp_sse(neg_two_x);
1745 let sig = _mm_div_ps(one, _mm_add_ps(one, exp_neg));
1746 let result = _mm_sub_ps(_mm_mul_ps(two, sig), one);
1748 _mm_storeu_ps(out_ptr.add(index), result);
1749 index += 4;
1750 }
1751
1752 while index < len {
1753 *out_ptr.add(index) = (*in_ptr.add(index)).tanh();
1754 index += 1;
1755 }
1756}
1757
1758#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1759#[allow(unsafe_code)]
1760#[allow(unsafe_op_in_unsafe_fn)]
1761#[target_feature(enable = "avx")]
1762unsafe fn tanh_slice_avx(input: &[f32], output: &mut [f32]) {
1763 #[cfg(target_arch = "x86")]
1764 use std::arch::x86::_mm256_div_ps;
1765 #[cfg(target_arch = "x86_64")]
1766 use std::arch::x86_64::_mm256_div_ps;
1767 let len = input.len();
1768 let in_ptr = input.as_ptr();
1769 let out_ptr = output.as_mut_ptr();
1770 let two = _mm256_set1_ps(2.0);
1771 let one = _mm256_set1_ps(1.0);
1772 let zero = _mm256_setzero_ps();
1773 let mut index = 0usize;
1774
1775 while index + 8 <= len {
1776 let x = _mm256_loadu_ps(in_ptr.add(index));
1777 let two_x = _mm256_mul_ps(two, x);
1778 let neg_two_x = _mm256_sub_ps(zero, two_x);
1779 let exp_neg = fast_exp_avx(neg_two_x);
1781 let sig = _mm256_div_ps(one, _mm256_add_ps(one, exp_neg));
1782 let result = _mm256_sub_ps(_mm256_mul_ps(two, sig), one);
1783 _mm256_storeu_ps(out_ptr.add(index), result);
1784 index += 8;
1785 }
1786
1787 if index < len {
1788 tanh_slice_sse(&input[index..], &mut output[index..]);
1789 }
1790}
1791
1792#[cfg(target_arch = "aarch64")]
1793#[allow(unsafe_code, dead_code)]
1794#[allow(unsafe_op_in_unsafe_fn)]
1795#[target_feature(enable = "neon")]
1796unsafe fn tanh_slice_neon(input: &[f32], output: &mut [f32]) {
1797 let len = input.len();
1798 let in_ptr = input.as_ptr();
1799 let out_ptr = output.as_mut_ptr();
1800 let two = vdupq_n_f32(2.0);
1801 let one = vdupq_n_f32(1.0);
1802 let mut index = 0usize;
1803
1804 while index + 32 <= len {
1806 let x0 = vld1q_f32(in_ptr.add(index));
1807 let x1 = vld1q_f32(in_ptr.add(index + 4));
1808 let x2 = vld1q_f32(in_ptr.add(index + 8));
1809 let x3 = vld1q_f32(in_ptr.add(index + 12));
1810 let x4 = vld1q_f32(in_ptr.add(index + 16));
1811 let x5 = vld1q_f32(in_ptr.add(index + 20));
1812 let x6 = vld1q_f32(in_ptr.add(index + 24));
1813 let x7 = vld1q_f32(in_ptr.add(index + 28));
1814
1815 let e0 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x0)));
1817 let e1 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x1)));
1818 let e2 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x2)));
1819 let e3 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x3)));
1820 let e4 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x4)));
1821 let e5 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x5)));
1822 let e6 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x6)));
1823 let e7 = fast_exp_sigmoid_neon(vnegq_f32(vmulq_f32(two, x7)));
1824
1825 vst1q_f32(
1827 out_ptr.add(index),
1828 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e0)), one),
1829 );
1830 vst1q_f32(
1831 out_ptr.add(index + 4),
1832 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e1)), one),
1833 );
1834 vst1q_f32(
1835 out_ptr.add(index + 8),
1836 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e2)), one),
1837 );
1838 vst1q_f32(
1839 out_ptr.add(index + 12),
1840 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e3)), one),
1841 );
1842 vst1q_f32(
1843 out_ptr.add(index + 16),
1844 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e4)), one),
1845 );
1846 vst1q_f32(
1847 out_ptr.add(index + 20),
1848 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e5)), one),
1849 );
1850 vst1q_f32(
1851 out_ptr.add(index + 24),
1852 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e6)), one),
1853 );
1854 vst1q_f32(
1855 out_ptr.add(index + 28),
1856 vsubq_f32(vdivq_f32(two, vaddq_f32(one, e7)), one),
1857 );
1858 index += 32;
1859 }
1860
1861 while index + 4 <= len {
1862 let x = vld1q_f32(in_ptr.add(index));
1863 let two_x = vmulq_f32(two, x);
1864 let neg_two_x = vnegq_f32(two_x);
1865 let exp_neg = fast_exp_sigmoid_neon(neg_two_x);
1866 let denom = vaddq_f32(one, exp_neg);
1867 let result = vsubq_f32(vdivq_f32(two, denom), one);
1868 vst1q_f32(out_ptr.add(index), result);
1869 index += 4;
1870 }
1871
1872 while index < len {
1873 *out_ptr.add(index) = (*in_ptr.add(index)).tanh();
1874 index += 1;
1875 }
1876}
1877
1878#[cfg(target_arch = "aarch64")]
1879#[allow(unsafe_code, dead_code)]
1880#[allow(unsafe_op_in_unsafe_fn)]
1881#[target_feature(enable = "neon")]
1882unsafe fn silu_slice_neon(input: &[f32], output: &mut [f32]) {
1885 let len = input.len();
1886 let in_ptr = input.as_ptr();
1887 let out_ptr = output.as_mut_ptr();
1888 let one = vdupq_n_f32(1.0);
1889 let mut index = 0usize;
1890
1891 while index + 32 <= len {
1893 let x0 = vld1q_f32(in_ptr.add(index));
1894 let x1 = vld1q_f32(in_ptr.add(index + 4));
1895 let x2 = vld1q_f32(in_ptr.add(index + 8));
1896 let x3 = vld1q_f32(in_ptr.add(index + 12));
1897 let x4 = vld1q_f32(in_ptr.add(index + 16));
1898 let x5 = vld1q_f32(in_ptr.add(index + 20));
1899 let x6 = vld1q_f32(in_ptr.add(index + 24));
1900 let x7 = vld1q_f32(in_ptr.add(index + 28));
1901
1902 let e0 = fast_exp_sigmoid_neon(vnegq_f32(x0));
1904 let e1 = fast_exp_sigmoid_neon(vnegq_f32(x1));
1905 let e2 = fast_exp_sigmoid_neon(vnegq_f32(x2));
1906 let e3 = fast_exp_sigmoid_neon(vnegq_f32(x3));
1907 let e4 = fast_exp_sigmoid_neon(vnegq_f32(x4));
1908 let e5 = fast_exp_sigmoid_neon(vnegq_f32(x5));
1909 let e6 = fast_exp_sigmoid_neon(vnegq_f32(x6));
1910 let e7 = fast_exp_sigmoid_neon(vnegq_f32(x7));
1911
1912 vst1q_f32(
1914 out_ptr.add(index),
1915 vmulq_f32(x0, vdivq_f32(one, vaddq_f32(one, e0))),
1916 );
1917 vst1q_f32(
1918 out_ptr.add(index + 4),
1919 vmulq_f32(x1, vdivq_f32(one, vaddq_f32(one, e1))),
1920 );
1921 vst1q_f32(
1922 out_ptr.add(index + 8),
1923 vmulq_f32(x2, vdivq_f32(one, vaddq_f32(one, e2))),
1924 );
1925 vst1q_f32(
1926 out_ptr.add(index + 12),
1927 vmulq_f32(x3, vdivq_f32(one, vaddq_f32(one, e3))),
1928 );
1929 vst1q_f32(
1930 out_ptr.add(index + 16),
1931 vmulq_f32(x4, vdivq_f32(one, vaddq_f32(one, e4))),
1932 );
1933 vst1q_f32(
1934 out_ptr.add(index + 20),
1935 vmulq_f32(x5, vdivq_f32(one, vaddq_f32(one, e5))),
1936 );
1937 vst1q_f32(
1938 out_ptr.add(index + 24),
1939 vmulq_f32(x6, vdivq_f32(one, vaddq_f32(one, e6))),
1940 );
1941 vst1q_f32(
1942 out_ptr.add(index + 28),
1943 vmulq_f32(x7, vdivq_f32(one, vaddq_f32(one, e7))),
1944 );
1945 index += 32;
1946 }
1947
1948 while index + 4 <= len {
1949 let x = vld1q_f32(in_ptr.add(index));
1950 let e = fast_exp_sigmoid_neon(vnegq_f32(x));
1951 let sig = vdivq_f32(one, vaddq_f32(one, e));
1952 vst1q_f32(out_ptr.add(index), vmulq_f32(x, sig));
1953 index += 4;
1954 }
1955
1956 while index < len {
1957 let x = *in_ptr.add(index);
1958 let s = 1.0 / (1.0 + (-x).exp());
1959 *out_ptr.add(index) = x * s;
1960 index += 1;
1961 }
1962}
1963
1964#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1966#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1967#[target_feature(enable = "sse")]
1968unsafe fn silu_slice_sse(input: &[f32], output: &mut [f32]) {
1969 #[cfg(target_arch = "x86")]
1970 use std::arch::x86::_mm_div_ps;
1971 #[cfg(target_arch = "x86_64")]
1972 use std::arch::x86_64::_mm_div_ps;
1973
1974 let len = input.len();
1975 let in_ptr = input.as_ptr();
1976 let out_ptr = output.as_mut_ptr();
1977 let one = _mm_set1_ps(1.0);
1978 let zero = _mm_setzero_ps();
1979 let mut index = 0usize;
1980
1981 while index + 16 <= len {
1982 let x0 = _mm_loadu_ps(in_ptr.add(index));
1983 let x1 = _mm_loadu_ps(in_ptr.add(index + 4));
1984 let x2 = _mm_loadu_ps(in_ptr.add(index + 8));
1985 let x3 = _mm_loadu_ps(in_ptr.add(index + 12));
1986
1987 let e0 = fast_exp_sse(_mm_sub_ps(zero, x0));
1989 let e1 = fast_exp_sse(_mm_sub_ps(zero, x1));
1990 let e2 = fast_exp_sse(_mm_sub_ps(zero, x2));
1991 let e3 = fast_exp_sse(_mm_sub_ps(zero, x3));
1992
1993 _mm_storeu_ps(
1995 out_ptr.add(index),
1996 _mm_mul_ps(x0, _mm_div_ps(one, _mm_add_ps(one, e0))),
1997 );
1998 _mm_storeu_ps(
1999 out_ptr.add(index + 4),
2000 _mm_mul_ps(x1, _mm_div_ps(one, _mm_add_ps(one, e1))),
2001 );
2002 _mm_storeu_ps(
2003 out_ptr.add(index + 8),
2004 _mm_mul_ps(x2, _mm_div_ps(one, _mm_add_ps(one, e2))),
2005 );
2006 _mm_storeu_ps(
2007 out_ptr.add(index + 12),
2008 _mm_mul_ps(x3, _mm_div_ps(one, _mm_add_ps(one, e3))),
2009 );
2010
2011 index += 16;
2012 }
2013
2014 while index + 4 <= len {
2015 let x = _mm_loadu_ps(in_ptr.add(index));
2016 let e = fast_exp_sse(_mm_sub_ps(zero, x));
2017 let sig = _mm_div_ps(one, _mm_add_ps(one, e));
2018 _mm_storeu_ps(out_ptr.add(index), _mm_mul_ps(x, sig));
2019 index += 4;
2020 }
2021
2022 while index < len {
2023 let v = *in_ptr.add(index);
2024 let s = 1.0 / (1.0 + (-v).exp());
2025 *out_ptr.add(index) = v * s;
2026 index += 1;
2027 }
2028}
2029
2030#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2032#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
2033#[target_feature(enable = "avx")]
2034unsafe fn silu_slice_avx(input: &[f32], output: &mut [f32]) {
2035 #[cfg(target_arch = "x86")]
2036 use std::arch::x86::_mm256_div_ps;
2037 #[cfg(target_arch = "x86_64")]
2038 use std::arch::x86_64::_mm256_div_ps;
2039
2040 let len = input.len();
2041 let in_ptr = input.as_ptr();
2042 let out_ptr = output.as_mut_ptr();
2043 let one = _mm256_set1_ps(1.0);
2044 let zero = _mm256_setzero_ps();
2045 let mut index = 0usize;
2046
2047 while index + 32 <= len {
2048 let x0 = _mm256_loadu_ps(in_ptr.add(index));
2049 let x1 = _mm256_loadu_ps(in_ptr.add(index + 8));
2050 let x2 = _mm256_loadu_ps(in_ptr.add(index + 16));
2051 let x3 = _mm256_loadu_ps(in_ptr.add(index + 24));
2052
2053 let e0 = fast_exp_avx(_mm256_sub_ps(zero, x0));
2055 let e1 = fast_exp_avx(_mm256_sub_ps(zero, x1));
2056 let e2 = fast_exp_avx(_mm256_sub_ps(zero, x2));
2057 let e3 = fast_exp_avx(_mm256_sub_ps(zero, x3));
2058
2059 _mm256_storeu_ps(
2061 out_ptr.add(index),
2062 _mm256_mul_ps(x0, _mm256_div_ps(one, _mm256_add_ps(one, e0))),
2063 );
2064 _mm256_storeu_ps(
2065 out_ptr.add(index + 8),
2066 _mm256_mul_ps(x1, _mm256_div_ps(one, _mm256_add_ps(one, e1))),
2067 );
2068 _mm256_storeu_ps(
2069 out_ptr.add(index + 16),
2070 _mm256_mul_ps(x2, _mm256_div_ps(one, _mm256_add_ps(one, e2))),
2071 );
2072 _mm256_storeu_ps(
2073 out_ptr.add(index + 24),
2074 _mm256_mul_ps(x3, _mm256_div_ps(one, _mm256_add_ps(one, e3))),
2075 );
2076
2077 index += 32;
2078 }
2079
2080 while index + 8 <= len {
2081 let x = _mm256_loadu_ps(in_ptr.add(index));
2082 let e = fast_exp_avx(_mm256_sub_ps(zero, x));
2083 let sig = _mm256_div_ps(one, _mm256_add_ps(one, e));
2084 _mm256_storeu_ps(out_ptr.add(index), _mm256_mul_ps(x, sig));
2085 index += 8;
2086 }
2087
2088 if index < len {
2089 silu_slice_sse(&input[index..], &mut output[index..]);
2090 }
2091}
2092
2093#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2098#[allow(unsafe_code)]
2099#[allow(unsafe_op_in_unsafe_fn)]
2100#[target_feature(enable = "sse")]
2101unsafe fn max_reduce_sse(data: &[f32]) -> f32 {
2102 let len = data.len();
2103 let ptr = data.as_ptr();
2104 let mut index = 0usize;
2105 let mut acc = _mm_set1_ps(f32::NEG_INFINITY);
2106
2107 while index + 4 <= len {
2108 let v = _mm_loadu_ps(ptr.add(index));
2109 acc = _mm_max_ps(acc, v);
2110 index += 4;
2111 }
2112
2113 let mut buf = [0.0f32; 4];
2115 _mm_storeu_ps(buf.as_mut_ptr(), acc);
2116 let mut result = buf[0].max(buf[1]).max(buf[2]).max(buf[3]);
2117
2118 while index < len {
2119 result = result.max(*ptr.add(index));
2120 index += 1;
2121 }
2122 result
2123}
2124
2125#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2126#[allow(unsafe_code)]
2127#[allow(unsafe_op_in_unsafe_fn)]
2128#[target_feature(enable = "avx")]
2129unsafe fn max_reduce_avx(data: &[f32]) -> f32 {
2130 let len = data.len();
2131 let ptr = data.as_ptr();
2132 let mut index = 0usize;
2133 let mut acc = _mm256_set1_ps(f32::NEG_INFINITY);
2134
2135 while index + 8 <= len {
2136 let v = _mm256_loadu_ps(ptr.add(index));
2137 acc = _mm256_max_ps(acc, v);
2138 index += 8;
2139 }
2140
2141 let mut buf = [0.0f32; 8];
2143 _mm256_storeu_ps(buf.as_mut_ptr(), acc);
2144 let mut result = buf[0];
2145 for i in 1..8 {
2146 result = result.max(buf[i]);
2147 }
2148
2149 while index < len {
2150 result = result.max(*ptr.add(index));
2151 index += 1;
2152 }
2153 result
2154}
2155
2156#[cfg(target_arch = "aarch64")]
2157#[allow(unsafe_code, dead_code)]
2158#[allow(unsafe_op_in_unsafe_fn)]
2159#[target_feature(enable = "neon")]
2160unsafe fn max_reduce_neon(data: &[f32]) -> f32 {
2161 use std::arch::aarch64::vmaxvq_f32;
2162
2163 let len = data.len();
2164 let ptr = data.as_ptr();
2165 let mut index = 0usize;
2166 let mut acc = vdupq_n_f32(f32::NEG_INFINITY);
2167
2168 while index + 4 <= len {
2169 let v = vld1q_f32(ptr.add(index));
2170 acc = vmaxq_f32(acc, v);
2171 index += 4;
2172 }
2173
2174 let mut result = vmaxvq_f32(acc);
2175 while index < len {
2176 result = result.max(*ptr.add(index));
2177 index += 1;
2178 }
2179 result
2180}
2181
2182#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2187#[allow(unsafe_code)]
2188#[allow(unsafe_op_in_unsafe_fn)]
2189#[target_feature(enable = "sse")]
2190unsafe fn add_reduce_sse(data: &[f32]) -> f32 {
2191 let len = data.len();
2192 let ptr = data.as_ptr();
2193 let mut index = 0usize;
2194 let mut acc = _mm_setzero_ps();
2195
2196 while index + 4 <= len {
2197 let v = _mm_loadu_ps(ptr.add(index));
2198 acc = _mm_add_ps(acc, v);
2199 index += 4;
2200 }
2201
2202 let mut buf = [0.0f32; 4];
2204 _mm_storeu_ps(buf.as_mut_ptr(), acc);
2205 let mut result = buf[0] + buf[1] + buf[2] + buf[3];
2206
2207 while index < len {
2208 result += *ptr.add(index);
2209 index += 1;
2210 }
2211 result
2212}
2213
2214#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2215#[allow(unsafe_code)]
2216#[allow(unsafe_op_in_unsafe_fn)]
2217#[target_feature(enable = "avx")]
2218unsafe fn add_reduce_avx(data: &[f32]) -> f32 {
2219 let len = data.len();
2220 let ptr = data.as_ptr();
2221 let mut index = 0usize;
2222 let mut acc = _mm256_setzero_ps();
2223
2224 while index + 8 <= len {
2225 let v = _mm256_loadu_ps(ptr.add(index));
2226 acc = _mm256_add_ps(acc, v);
2227 index += 8;
2228 }
2229
2230 let mut buf = [0.0f32; 8];
2231 _mm256_storeu_ps(buf.as_mut_ptr(), acc);
2232 let mut result = buf[0] + buf[1] + buf[2] + buf[3] + buf[4] + buf[5] + buf[6] + buf[7];
2233
2234 while index < len {
2235 result += *ptr.add(index);
2236 index += 1;
2237 }
2238 result
2239}
2240
2241#[cfg(target_arch = "aarch64")]
2242#[allow(unsafe_code, dead_code)]
2243#[allow(unsafe_op_in_unsafe_fn)]
2244#[target_feature(enable = "neon")]
2245unsafe fn add_reduce_neon(data: &[f32]) -> f32 {
2246 use std::arch::aarch64::vaddvq_f32;
2247
2248 let len = data.len();
2249 let ptr = data.as_ptr();
2250 let mut index = 0usize;
2251 let mut acc = vdupq_n_f32(0.0);
2252
2253 while index + 4 <= len {
2254 let v = vld1q_f32(ptr.add(index));
2255 acc = vaddq_f32(acc, v);
2256 index += 4;
2257 }
2258
2259 let mut result = vaddvq_f32(acc);
2260 while index < len {
2261 result += *ptr.add(index);
2262 index += 1;
2263 }
2264 result
2265}
2266
2267#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2272#[allow(unsafe_code)]
2273#[allow(unsafe_op_in_unsafe_fn)]
2274#[target_feature(enable = "sse")]
2275unsafe fn fma_slice_sse(a: &[f32], b: &[f32], acc: &mut [f32]) {
2276 let len = a.len();
2277 let a_ptr = a.as_ptr();
2278 let b_ptr = b.as_ptr();
2279 let acc_ptr = acc.as_mut_ptr();
2280 let mut index = 0usize;
2281
2282 while index + 4 <= len {
2283 let av = _mm_loadu_ps(a_ptr.add(index));
2284 let bv = _mm_loadu_ps(b_ptr.add(index));
2285 let cv = _mm_loadu_ps(acc_ptr.add(index));
2286 let result = _mm_add_ps(cv, _mm_mul_ps(av, bv));
2287 _mm_storeu_ps(acc_ptr.add(index), result);
2288 index += 4;
2289 }
2290
2291 if index < len {
2292 fma_slice_scalar(&a[index..], &b[index..], &mut acc[index..]);
2293 }
2294}
2295
2296#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2297#[allow(unsafe_code)]
2298#[allow(unsafe_op_in_unsafe_fn)]
2299#[target_feature(enable = "avx")]
2300unsafe fn fma_slice_avx(a: &[f32], b: &[f32], acc: &mut [f32]) {
2301 let len = a.len();
2302 let a_ptr = a.as_ptr();
2303 let b_ptr = b.as_ptr();
2304 let acc_ptr = acc.as_mut_ptr();
2305 let mut index = 0usize;
2306
2307 while index + 8 <= len {
2308 let av = _mm256_loadu_ps(a_ptr.add(index));
2309 let bv = _mm256_loadu_ps(b_ptr.add(index));
2310 let cv = _mm256_loadu_ps(acc_ptr.add(index));
2311 let result = _mm256_add_ps(cv, _mm256_mul_ps(av, bv));
2312 _mm256_storeu_ps(acc_ptr.add(index), result);
2313 index += 8;
2314 }
2315
2316 if index < len {
2317 fma_slice_sse(&a[index..], &b[index..], &mut acc[index..]);
2318 }
2319}
2320
2321#[cfg(target_arch = "aarch64")]
2322#[allow(unsafe_code, dead_code)]
2323#[allow(unsafe_op_in_unsafe_fn)]
2324#[target_feature(enable = "neon")]
2325unsafe fn fma_slice_neon(a: &[f32], b: &[f32], acc: &mut [f32]) {
2326 let len = a.len();
2327 let a_ptr = a.as_ptr();
2328 let b_ptr = b.as_ptr();
2329 let acc_ptr = acc.as_mut_ptr();
2330 let mut index = 0usize;
2331
2332 while index + 4 <= len {
2333 let av = vld1q_f32(a_ptr.add(index));
2334 let bv = vld1q_f32(b_ptr.add(index));
2335 let cv = vld1q_f32(acc_ptr.add(index));
2336 let result = vfmaq_f32(cv, av, bv);
2337 vst1q_f32(acc_ptr.add(index), result);
2338 index += 4;
2339 }
2340
2341 if index < len {
2342 fma_slice_scalar(&a[index..], &b[index..], &mut acc[index..]);
2343 }
2344}
2345
2346#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2351#[allow(unsafe_code)]
2352#[allow(unsafe_op_in_unsafe_fn)]
2353#[target_feature(enable = "sse")]
2354unsafe fn relu_slice_sse(values: &mut [f32]) {
2355 let len = values.len();
2356 let ptr = values.as_mut_ptr();
2357 let zero = _mm_setzero_ps();
2358 let mut index = 0usize;
2359
2360 while index + 4 <= len {
2361 let input = _mm_loadu_ps(ptr.add(index));
2362 let out = _mm_max_ps(input, zero);
2363 _mm_storeu_ps(ptr.add(index), out);
2364 index += 4;
2365 }
2366
2367 if index < len {
2368 relu_slice_scalar(&mut values[index..]);
2369 }
2370}
2371
2372#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2373#[allow(unsafe_code)]
2374#[allow(unsafe_op_in_unsafe_fn)]
2375#[target_feature(enable = "avx")]
2376unsafe fn relu_slice_avx(values: &mut [f32]) {
2377 let len = values.len();
2378 let ptr = values.as_mut_ptr();
2379 let zero = _mm256_setzero_ps();
2380 let mut index = 0usize;
2381
2382 while index + 32 <= len {
2384 let v0 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index)), zero);
2385 let v1 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index + 8)), zero);
2386 let v2 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index + 16)), zero);
2387 let v3 = _mm256_max_ps(_mm256_loadu_ps(ptr.add(index + 24)), zero);
2388 _mm256_storeu_ps(ptr.add(index), v0);
2389 _mm256_storeu_ps(ptr.add(index + 8), v1);
2390 _mm256_storeu_ps(ptr.add(index + 16), v2);
2391 _mm256_storeu_ps(ptr.add(index + 24), v3);
2392 index += 32;
2393 }
2394
2395 while index + 8 <= len {
2396 _mm256_storeu_ps(
2397 ptr.add(index),
2398 _mm256_max_ps(_mm256_loadu_ps(ptr.add(index)), zero),
2399 );
2400 index += 8;
2401 }
2402
2403 if index < len {
2404 relu_slice_sse(&mut values[index..]);
2405 }
2406}
2407
2408#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2409#[allow(unsafe_code)]
2410#[allow(unsafe_op_in_unsafe_fn)]
2411#[target_feature(enable = "sse")]
2412unsafe fn binary_same_shape_sse(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
2413 let len = lhs.len();
2414 let left_ptr = lhs.as_ptr();
2415 let right_ptr = rhs.as_ptr();
2416 let out_ptr = out.as_mut_ptr();
2417 let mut index = 0usize;
2418
2419 while index + 4 <= len {
2420 let left = _mm_loadu_ps(left_ptr.add(index));
2421 let right = _mm_loadu_ps(right_ptr.add(index));
2422 let result = match kind {
2423 BinaryKind::Add => _mm_add_ps(left, right),
2424 BinaryKind::Sub => _mm_sub_ps(left, right),
2425 BinaryKind::Mul => _mm_mul_ps(left, right),
2426 };
2427 _mm_storeu_ps(out_ptr.add(index), result);
2428 index += 4;
2429 }
2430
2431 if index < len {
2432 binary_same_shape_scalar(&lhs[index..], &rhs[index..], &mut out[index..], kind);
2433 }
2434}
2435
2436#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2437#[allow(unsafe_code)]
2438#[allow(unsafe_op_in_unsafe_fn)]
2439#[target_feature(enable = "avx")]
2440unsafe fn binary_same_shape_avx(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
2441 let len = lhs.len();
2442 let left_ptr = lhs.as_ptr();
2443 let right_ptr = rhs.as_ptr();
2444 let out_ptr = out.as_mut_ptr();
2445 let mut index = 0usize;
2446
2447 match kind {
2450 BinaryKind::Add => {
2451 while index + 32 <= len {
2452 #[cfg(target_arch = "x86")]
2453 {
2454 use std::arch::x86::_mm_prefetch;
2455 _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2456 _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2457 }
2458 #[cfg(target_arch = "x86_64")]
2459 {
2460 use std::arch::x86_64::_mm_prefetch;
2461 _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2462 _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2463 }
2464 let a0 = _mm256_loadu_ps(left_ptr.add(index));
2465 let b0 = _mm256_loadu_ps(right_ptr.add(index));
2466 let a1 = _mm256_loadu_ps(left_ptr.add(index + 8));
2467 let b1 = _mm256_loadu_ps(right_ptr.add(index + 8));
2468 _mm256_storeu_ps(out_ptr.add(index), _mm256_add_ps(a0, b0));
2469 _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_add_ps(a1, b1));
2470 let a2 = _mm256_loadu_ps(left_ptr.add(index + 16));
2471 let b2 = _mm256_loadu_ps(right_ptr.add(index + 16));
2472 let a3 = _mm256_loadu_ps(left_ptr.add(index + 24));
2473 let b3 = _mm256_loadu_ps(right_ptr.add(index + 24));
2474 _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_add_ps(a2, b2));
2475 _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_add_ps(a3, b3));
2476 index += 32;
2477 }
2478 }
2479 BinaryKind::Sub => {
2480 while index + 32 <= len {
2481 #[cfg(target_arch = "x86")]
2482 {
2483 use std::arch::x86::_mm_prefetch;
2484 _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2485 _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2486 }
2487 #[cfg(target_arch = "x86_64")]
2488 {
2489 use std::arch::x86_64::_mm_prefetch;
2490 _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2491 _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2492 }
2493 let a0 = _mm256_loadu_ps(left_ptr.add(index));
2494 let b0 = _mm256_loadu_ps(right_ptr.add(index));
2495 let a1 = _mm256_loadu_ps(left_ptr.add(index + 8));
2496 let b1 = _mm256_loadu_ps(right_ptr.add(index + 8));
2497 _mm256_storeu_ps(out_ptr.add(index), _mm256_sub_ps(a0, b0));
2498 _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_sub_ps(a1, b1));
2499 let a2 = _mm256_loadu_ps(left_ptr.add(index + 16));
2500 let b2 = _mm256_loadu_ps(right_ptr.add(index + 16));
2501 let a3 = _mm256_loadu_ps(left_ptr.add(index + 24));
2502 let b3 = _mm256_loadu_ps(right_ptr.add(index + 24));
2503 _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_sub_ps(a2, b2));
2504 _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_sub_ps(a3, b3));
2505 index += 32;
2506 }
2507 }
2508 BinaryKind::Mul => {
2509 while index + 32 <= len {
2510 #[cfg(target_arch = "x86")]
2511 {
2512 use std::arch::x86::_mm_prefetch;
2513 _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2514 _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2515 }
2516 #[cfg(target_arch = "x86_64")]
2517 {
2518 use std::arch::x86_64::_mm_prefetch;
2519 _mm_prefetch::<3>(left_ptr.add(index + 32) as *const i8);
2520 _mm_prefetch::<3>(right_ptr.add(index + 32) as *const i8);
2521 }
2522 let a0 = _mm256_loadu_ps(left_ptr.add(index));
2523 let b0 = _mm256_loadu_ps(right_ptr.add(index));
2524 let a1 = _mm256_loadu_ps(left_ptr.add(index + 8));
2525 let b1 = _mm256_loadu_ps(right_ptr.add(index + 8));
2526 _mm256_storeu_ps(out_ptr.add(index), _mm256_mul_ps(a0, b0));
2527 _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_mul_ps(a1, b1));
2528 let a2 = _mm256_loadu_ps(left_ptr.add(index + 16));
2529 let b2 = _mm256_loadu_ps(right_ptr.add(index + 16));
2530 let a3 = _mm256_loadu_ps(left_ptr.add(index + 24));
2531 let b3 = _mm256_loadu_ps(right_ptr.add(index + 24));
2532 _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_mul_ps(a2, b2));
2533 _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_mul_ps(a3, b3));
2534 index += 32;
2535 }
2536 }
2537 }
2538
2539 while index + 8 <= len {
2541 let left = _mm256_loadu_ps(left_ptr.add(index));
2542 let right = _mm256_loadu_ps(right_ptr.add(index));
2543 let result = match kind {
2544 BinaryKind::Add => _mm256_add_ps(left, right),
2545 BinaryKind::Sub => _mm256_sub_ps(left, right),
2546 BinaryKind::Mul => _mm256_mul_ps(left, right),
2547 };
2548 _mm256_storeu_ps(out_ptr.add(index), result);
2549 index += 8;
2550 }
2551
2552 if index < len {
2553 binary_same_shape_sse(&lhs[index..], &rhs[index..], &mut out[index..], kind);
2554 }
2555}
2556
2557#[cfg(target_arch = "aarch64")]
2558#[allow(unsafe_code)]
2559#[allow(unsafe_op_in_unsafe_fn)]
2560#[target_feature(enable = "neon")]
2561unsafe fn relu_slice_neon(values: &mut [f32]) {
2562 let len = values.len();
2563 let ptr = values.as_mut_ptr();
2564 let zero = vdupq_n_f32(0.0);
2565 let mut index = 0usize;
2566
2567 while index + 32 <= len {
2569 let v0 = vmaxq_f32(vld1q_f32(ptr.add(index)), zero);
2570 let v1 = vmaxq_f32(vld1q_f32(ptr.add(index + 4)), zero);
2571 let v2 = vmaxq_f32(vld1q_f32(ptr.add(index + 8)), zero);
2572 let v3 = vmaxq_f32(vld1q_f32(ptr.add(index + 12)), zero);
2573 let v4 = vmaxq_f32(vld1q_f32(ptr.add(index + 16)), zero);
2574 let v5 = vmaxq_f32(vld1q_f32(ptr.add(index + 20)), zero);
2575 let v6 = vmaxq_f32(vld1q_f32(ptr.add(index + 24)), zero);
2576 let v7 = vmaxq_f32(vld1q_f32(ptr.add(index + 28)), zero);
2577 vst1q_f32(ptr.add(index), v0);
2578 vst1q_f32(ptr.add(index + 4), v1);
2579 vst1q_f32(ptr.add(index + 8), v2);
2580 vst1q_f32(ptr.add(index + 12), v3);
2581 vst1q_f32(ptr.add(index + 16), v4);
2582 vst1q_f32(ptr.add(index + 20), v5);
2583 vst1q_f32(ptr.add(index + 24), v6);
2584 vst1q_f32(ptr.add(index + 28), v7);
2585 index += 32;
2586 }
2587
2588 while index + 4 <= len {
2589 vst1q_f32(ptr.add(index), vmaxq_f32(vld1q_f32(ptr.add(index)), zero));
2590 index += 4;
2591 }
2592
2593 if index < len {
2594 relu_slice_scalar(&mut values[index..]);
2595 }
2596}
2597
2598#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2603#[allow(unsafe_code)]
2604#[allow(unsafe_op_in_unsafe_fn)]
2605#[target_feature(enable = "sse")]
2606unsafe fn relu_to_slice_sse(input: &[f32], output: &mut [f32]) {
2607 let len = input.len();
2608 let in_ptr = input.as_ptr();
2609 let out_ptr = output.as_mut_ptr();
2610 let zero = _mm_setzero_ps();
2611 let mut index = 0usize;
2612
2613 while index + 4 <= len {
2614 let v = _mm_loadu_ps(in_ptr.add(index));
2615 let r = _mm_max_ps(v, zero);
2616 _mm_storeu_ps(out_ptr.add(index), r);
2617 index += 4;
2618 }
2619
2620 if index < len {
2621 relu_to_slice_scalar(&input[index..], &mut output[index..]);
2622 }
2623}
2624
2625#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2626#[allow(unsafe_code)]
2627#[allow(unsafe_op_in_unsafe_fn)]
2628#[target_feature(enable = "avx")]
2629unsafe fn relu_to_slice_avx(input: &[f32], output: &mut [f32]) {
2630 let len = input.len();
2631 let in_ptr = input.as_ptr();
2632 let out_ptr = output.as_mut_ptr();
2633 let zero = _mm256_setzero_ps();
2634 let mut index = 0usize;
2635
2636 while index + 32 <= len {
2638 let a0 = _mm256_loadu_ps(in_ptr.add(index));
2639 let a1 = _mm256_loadu_ps(in_ptr.add(index + 8));
2640 let a2 = _mm256_loadu_ps(in_ptr.add(index + 16));
2641 let a3 = _mm256_loadu_ps(in_ptr.add(index + 24));
2642 _mm256_storeu_ps(out_ptr.add(index), _mm256_max_ps(a0, zero));
2643 _mm256_storeu_ps(out_ptr.add(index + 8), _mm256_max_ps(a1, zero));
2644 _mm256_storeu_ps(out_ptr.add(index + 16), _mm256_max_ps(a2, zero));
2645 _mm256_storeu_ps(out_ptr.add(index + 24), _mm256_max_ps(a3, zero));
2646 index += 32;
2647 }
2648
2649 while index + 8 <= len {
2650 _mm256_storeu_ps(
2651 out_ptr.add(index),
2652 _mm256_max_ps(_mm256_loadu_ps(in_ptr.add(index)), zero),
2653 );
2654 index += 8;
2655 }
2656
2657 if index < len {
2658 relu_to_slice_sse(&input[index..], &mut output[index..]);
2659 }
2660}
2661
2662#[cfg(target_arch = "aarch64")]
2663#[allow(unsafe_code)]
2664#[allow(unsafe_op_in_unsafe_fn)]
2665#[target_feature(enable = "neon")]
2666unsafe fn relu_to_slice_neon(input: &[f32], output: &mut [f32]) {
2667 let len = input.len();
2668 let in_ptr = input.as_ptr();
2669 let out_ptr = output.as_mut_ptr();
2670 let zero = vdupq_n_f32(0.0);
2671 let mut index = 0usize;
2672
2673 while index + 32 <= len {
2675 let a0 = vld1q_f32(in_ptr.add(index));
2676 let a1 = vld1q_f32(in_ptr.add(index + 4));
2677 let a2 = vld1q_f32(in_ptr.add(index + 8));
2678 let a3 = vld1q_f32(in_ptr.add(index + 12));
2679 vst1q_f32(out_ptr.add(index), vmaxq_f32(a0, zero));
2680 vst1q_f32(out_ptr.add(index + 4), vmaxq_f32(a1, zero));
2681 let a4 = vld1q_f32(in_ptr.add(index + 16));
2682 let a5 = vld1q_f32(in_ptr.add(index + 20));
2683 vst1q_f32(out_ptr.add(index + 8), vmaxq_f32(a2, zero));
2684 vst1q_f32(out_ptr.add(index + 12), vmaxq_f32(a3, zero));
2685 let a6 = vld1q_f32(in_ptr.add(index + 24));
2686 let a7 = vld1q_f32(in_ptr.add(index + 28));
2687 vst1q_f32(out_ptr.add(index + 16), vmaxq_f32(a4, zero));
2688 vst1q_f32(out_ptr.add(index + 20), vmaxq_f32(a5, zero));
2689 vst1q_f32(out_ptr.add(index + 24), vmaxq_f32(a6, zero));
2690 vst1q_f32(out_ptr.add(index + 28), vmaxq_f32(a7, zero));
2691 index += 32;
2692 }
2693
2694 while index + 4 <= len {
2695 vst1q_f32(
2696 out_ptr.add(index),
2697 vmaxq_f32(vld1q_f32(in_ptr.add(index)), zero),
2698 );
2699 index += 4;
2700 }
2701
2702 if index < len {
2703 relu_to_slice_scalar(&input[index..], &mut output[index..]);
2704 }
2705}
2706
2707#[cfg(all(target_arch = "aarch64", not(target_os = "macos")))]
2708#[allow(unsafe_code)]
2709#[allow(unsafe_op_in_unsafe_fn)]
2710#[target_feature(enable = "neon")]
2711unsafe fn binary_same_shape_neon(lhs: &[f32], rhs: &[f32], out: &mut [f32], kind: BinaryKind) {
2712 let len = lhs.len();
2713 let left_ptr = lhs.as_ptr();
2714 let right_ptr = rhs.as_ptr();
2715 let out_ptr = out.as_mut_ptr();
2716 let mut index = 0usize;
2717
2718 while index + 4 <= len {
2719 let left = vld1q_f32(left_ptr.add(index));
2720 let right = vld1q_f32(right_ptr.add(index));
2721 let result = match kind {
2722 BinaryKind::Add => vaddq_f32(left, right),
2723 BinaryKind::Sub => vsubq_f32(left, right),
2724 BinaryKind::Mul => vmulq_f32(left, right),
2725 };
2726 vst1q_f32(out_ptr.add(index), result);
2727 index += 4;
2728 }
2729
2730 if index < len {
2731 binary_same_shape_scalar(&lhs[index..], &rhs[index..], &mut out[index..], kind);
2732 }
2733}
2734
2735#[inline]
2754#[allow(unsafe_code)]
2755#[allow(unsafe_op_in_unsafe_fn)]
2756pub unsafe fn matmul_row_dispatch(
2757 left_row: *const f32,
2758 right: *const f32,
2759 out_row: *mut f32,
2760 k: usize,
2761 n: usize,
2762) {
2763 if cfg!(miri) {
2764 matmul_row_scalar(left_row, right, out_row, k, n);
2765 return;
2766 }
2767
2768 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2769 {
2770 if std::is_x86_feature_detected!("avx") {
2771 matmul_row_avx(left_row, right, out_row, k, n);
2772 return;
2773 }
2774 if std::is_x86_feature_detected!("sse") {
2775 matmul_row_sse(left_row, right, out_row, k, n);
2776 return;
2777 }
2778 }
2779
2780 #[cfg(target_arch = "aarch64")]
2781 {
2782 if std::arch::is_aarch64_feature_detected!("neon") {
2783 matmul_row_neon(left_row, right, out_row, k, n);
2784 return;
2785 }
2786 }
2787
2788 matmul_row_scalar(left_row, right, out_row, k, n);
2789}
2790
2791#[allow(unsafe_code)]
2793#[allow(unsafe_op_in_unsafe_fn)]
2794unsafe fn matmul_row_scalar(
2795 left_row: *const f32,
2796 right: *const f32,
2797 out_row: *mut f32,
2798 k: usize,
2799 n: usize,
2800) {
2801 for p in 0..k {
2802 let a_val = *left_row.add(p);
2803 let b_row = right.add(p * n);
2804
2805 let mut col = 0usize;
2806 while col + 4 <= n {
2807 *out_row.add(col) += a_val * *b_row.add(col);
2808 *out_row.add(col + 1) += a_val * *b_row.add(col + 1);
2809 *out_row.add(col + 2) += a_val * *b_row.add(col + 2);
2810 *out_row.add(col + 3) += a_val * *b_row.add(col + 3);
2811 col += 4;
2812 }
2813 while col < n {
2814 *out_row.add(col) += a_val * *b_row.add(col);
2815 col += 1;
2816 }
2817 }
2818}
2819
2820#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2821#[allow(unsafe_code)]
2822#[allow(unsafe_op_in_unsafe_fn)]
2823#[target_feature(enable = "sse")]
2824unsafe fn matmul_row_sse(
2825 left_row: *const f32,
2826 right: *const f32,
2827 out_row: *mut f32,
2828 k: usize,
2829 n: usize,
2830) {
2831 for p in 0..k {
2832 let a_val = _mm_set1_ps(*left_row.add(p));
2833 let b_row = right.add(p * n);
2834
2835 let mut col = 0usize;
2836 while col + 4 <= n {
2837 let b_vec = _mm_loadu_ps(b_row.add(col));
2838 let out_vec = _mm_loadu_ps(out_row.add(col));
2839 let result = _mm_add_ps(out_vec, _mm_mul_ps(a_val, b_vec));
2840 _mm_storeu_ps(out_row.add(col), result);
2841 col += 4;
2842 }
2843 while col < n {
2844 *out_row.add(col) += *left_row.add(p) * *b_row.add(col);
2845 col += 1;
2846 }
2847 }
2848}
2849
2850#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2851#[allow(unsafe_code)]
2852#[allow(unsafe_op_in_unsafe_fn)]
2853#[target_feature(enable = "avx")]
2854unsafe fn matmul_row_avx(
2855 left_row: *const f32,
2856 right: *const f32,
2857 out_row: *mut f32,
2858 k: usize,
2859 n: usize,
2860) {
2861 for p in 0..k {
2862 let a_val_avx = _mm256_set1_ps(*left_row.add(p));
2863 let a_val_sse = _mm_set1_ps(*left_row.add(p));
2864 let b_row = right.add(p * n);
2865
2866 let mut col = 0usize;
2867 while col + 8 <= n {
2868 let b_vec = _mm256_loadu_ps(b_row.add(col));
2869 let out_vec = _mm256_loadu_ps(out_row.add(col));
2870 let result = _mm256_add_ps(out_vec, _mm256_mul_ps(a_val_avx, b_vec));
2871 _mm256_storeu_ps(out_row.add(col), result);
2872 col += 8;
2873 }
2874 while col + 4 <= n {
2876 let b_vec = _mm_loadu_ps(b_row.add(col));
2877 let out_vec = _mm_loadu_ps(out_row.add(col));
2878 let result = _mm_add_ps(out_vec, _mm_mul_ps(a_val_sse, b_vec));
2879 _mm_storeu_ps(out_row.add(col), result);
2880 col += 4;
2881 }
2882 while col < n {
2883 *out_row.add(col) += *left_row.add(p) * *b_row.add(col);
2884 col += 1;
2885 }
2886 }
2887}
2888
2889#[cfg(target_arch = "aarch64")]
2890#[allow(unsafe_code)]
2891#[allow(unsafe_op_in_unsafe_fn)]
2892#[target_feature(enable = "neon")]
2893unsafe fn matmul_row_neon(
2894 left_row: *const f32,
2895 right: *const f32,
2896 out_row: *mut f32,
2897 k: usize,
2898 n: usize,
2899) {
2900 for p in 0..k {
2901 let a_val: float32x4_t = vdupq_n_f32(*left_row.add(p));
2902 let b_row = right.add(p * n);
2903
2904 let mut col = 0usize;
2905 while col + 4 <= n {
2906 let b_vec = vld1q_f32(b_row.add(col));
2907 let out_vec = vld1q_f32(out_row.add(col));
2908 let result = vfmaq_f32(out_vec, a_val, b_vec);
2909 vst1q_f32(out_row.add(col), result);
2910 col += 4;
2911 }
2912 while col < n {
2913 *out_row.add(col) += *left_row.add(p) * *b_row.add(col);
2914 col += 1;
2915 }
2916 }
2917}
2918
2919#[allow(unsafe_code)]
2928#[inline]
2929pub fn softmax_row_fused_dispatch(input: &[f32], output: &mut [f32]) {
2930 debug_assert_eq!(input.len(), output.len());
2931
2932 if cfg!(miri) || input.is_empty() {
2933 softmax_row_fused_scalar(input, output);
2934 return;
2935 }
2936
2937 #[cfg(target_arch = "aarch64")]
2938 {
2939 if std::arch::is_aarch64_feature_detected!("neon") {
2940 unsafe {
2942 softmax_row_fused_neon(input, output);
2943 }
2944 return;
2945 }
2946 }
2947
2948 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2949 {
2950 if std::is_x86_feature_detected!("avx") {
2951 unsafe {
2953 softmax_row_fused_avx(input, output);
2954 }
2955 return;
2956 }
2957 if std::is_x86_feature_detected!("sse") {
2958 unsafe {
2960 softmax_row_fused_sse(input, output);
2961 }
2962 return;
2963 }
2964 }
2965
2966 softmax_row_fused_scalar(input, output);
2967}
2968
2969fn softmax_row_fused_scalar(input: &[f32], output: &mut [f32]) {
2970 if input.is_empty() {
2971 return;
2972 }
2973
2974 let mut max_val = f32::NEG_INFINITY;
2976 for &v in input {
2977 max_val = max_val.max(v);
2978 }
2979
2980 let mut sum_exp = 0.0f32;
2982 for (o, &v) in output.iter_mut().zip(input.iter()) {
2983 let e = (v - max_val).exp();
2984 *o = e;
2985 sum_exp += e;
2986 }
2987
2988 let inv = 1.0 / sum_exp;
2990 for o in output.iter_mut() {
2991 *o *= inv;
2992 }
2993}
2994
2995#[cfg(target_arch = "aarch64")]
2996#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
2997#[target_feature(enable = "neon")]
2998unsafe fn softmax_row_fused_neon(input: &[f32], output: &mut [f32]) {
2999 use std::arch::aarch64::{vaddvq_f32, vmaxvq_f32};
3000
3001 let len = input.len();
3002 let in_ptr = input.as_ptr();
3003 let out_ptr = output.as_mut_ptr();
3004
3005 let mut acc_max = vdupq_n_f32(f32::NEG_INFINITY);
3007 let mut i = 0usize;
3008 while i + 16 <= len {
3009 let v0 = vld1q_f32(in_ptr.add(i));
3010 let v1 = vld1q_f32(in_ptr.add(i + 4));
3011 let v2 = vld1q_f32(in_ptr.add(i + 8));
3012 let v3 = vld1q_f32(in_ptr.add(i + 12));
3013 acc_max = vmaxq_f32(acc_max, vmaxq_f32(vmaxq_f32(v0, v1), vmaxq_f32(v2, v3)));
3014 i += 16;
3015 }
3016 while i + 4 <= len {
3017 let v = vld1q_f32(in_ptr.add(i));
3018 acc_max = vmaxq_f32(acc_max, v);
3019 i += 4;
3020 }
3021 let mut max_val = vmaxvq_f32(acc_max);
3022 while i < len {
3023 max_val = max_val.max(*in_ptr.add(i));
3024 i += 1;
3025 }
3026
3027 let off = vdupq_n_f32(max_val);
3029 let mut acc_sum = vdupq_n_f32(0.0);
3030 i = 0;
3031 while i + 16 <= len {
3032 let v0 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3033 let v1 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 4)), off));
3034 let v2 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 8)), off));
3035 let v3 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 12)), off));
3036 vst1q_f32(out_ptr.add(i), v0);
3037 vst1q_f32(out_ptr.add(i + 4), v1);
3038 vst1q_f32(out_ptr.add(i + 8), v2);
3039 vst1q_f32(out_ptr.add(i + 12), v3);
3040 acc_sum = vaddq_f32(acc_sum, vaddq_f32(vaddq_f32(v0, v1), vaddq_f32(v2, v3)));
3041 i += 16;
3042 }
3043 while i + 4 <= len {
3044 let v = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3045 vst1q_f32(out_ptr.add(i), v);
3046 acc_sum = vaddq_f32(acc_sum, v);
3047 i += 4;
3048 }
3049 let mut sum_exp = vaddvq_f32(acc_sum);
3050 while i < len {
3051 let e = (*in_ptr.add(i) - max_val).exp();
3052 *out_ptr.add(i) = e;
3053 sum_exp += e;
3054 i += 1;
3055 }
3056
3057 let inv = vdupq_n_f32(1.0 / sum_exp);
3059 i = 0;
3060 while i + 16 <= len {
3061 vst1q_f32(out_ptr.add(i), vmulq_f32(vld1q_f32(out_ptr.add(i)), inv));
3062 vst1q_f32(
3063 out_ptr.add(i + 4),
3064 vmulq_f32(vld1q_f32(out_ptr.add(i + 4)), inv),
3065 );
3066 vst1q_f32(
3067 out_ptr.add(i + 8),
3068 vmulq_f32(vld1q_f32(out_ptr.add(i + 8)), inv),
3069 );
3070 vst1q_f32(
3071 out_ptr.add(i + 12),
3072 vmulq_f32(vld1q_f32(out_ptr.add(i + 12)), inv),
3073 );
3074 i += 16;
3075 }
3076 while i + 4 <= len {
3077 vst1q_f32(out_ptr.add(i), vmulq_f32(vld1q_f32(out_ptr.add(i)), inv));
3078 i += 4;
3079 }
3080 let inv_s = 1.0 / sum_exp;
3081 while i < len {
3082 *out_ptr.add(i) *= inv_s;
3083 i += 1;
3084 }
3085}
3086
3087#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3089#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3090#[target_feature(enable = "sse")]
3091unsafe fn softmax_row_fused_sse(input: &[f32], output: &mut [f32]) {
3092 let len = input.len();
3093 let in_ptr = input.as_ptr();
3094 let out_ptr = output.as_mut_ptr();
3095
3096 let mut acc_max = _mm_set1_ps(f32::NEG_INFINITY);
3098 let mut i = 0usize;
3099 while i + 4 <= len {
3100 acc_max = _mm_max_ps(acc_max, _mm_loadu_ps(in_ptr.add(i)));
3101 i += 4;
3102 }
3103 let mut buf = [0.0f32; 4];
3104 _mm_storeu_ps(buf.as_mut_ptr(), acc_max);
3105 let mut max_val = buf[0].max(buf[1]).max(buf[2].max(buf[3]));
3106 while i < len {
3107 max_val = max_val.max(*in_ptr.add(i));
3108 i += 1;
3109 }
3110
3111 let off = _mm_set1_ps(max_val);
3113 let mut acc_sum = _mm_setzero_ps();
3114 i = 0;
3115 while i + 4 <= len {
3116 let v = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off));
3117 _mm_storeu_ps(out_ptr.add(i), v);
3118 acc_sum = _mm_add_ps(acc_sum, v);
3119 i += 4;
3120 }
3121 _mm_storeu_ps(buf.as_mut_ptr(), acc_sum);
3122 let mut sum_exp = buf[0] + buf[1] + buf[2] + buf[3];
3123 while i < len {
3124 let e = (*in_ptr.add(i) - max_val).exp();
3125 *out_ptr.add(i) = e;
3126 sum_exp += e;
3127 i += 1;
3128 }
3129
3130 let inv = _mm_set1_ps(1.0 / sum_exp);
3132 i = 0;
3133 while i + 4 <= len {
3134 _mm_storeu_ps(
3135 out_ptr.add(i),
3136 _mm_mul_ps(_mm_loadu_ps(out_ptr.add(i)), inv),
3137 );
3138 i += 4;
3139 }
3140 let inv_s = 1.0 / sum_exp;
3141 while i < len {
3142 *out_ptr.add(i) *= inv_s;
3143 i += 1;
3144 }
3145}
3146
3147#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3149#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3150#[target_feature(enable = "avx")]
3151unsafe fn softmax_row_fused_avx(input: &[f32], output: &mut [f32]) {
3152 let len = input.len();
3153 let in_ptr = input.as_ptr();
3154 let out_ptr = output.as_mut_ptr();
3155
3156 let mut acc_max = _mm256_set1_ps(f32::NEG_INFINITY);
3158 let mut i = 0usize;
3159 while i + 8 <= len {
3160 acc_max = _mm256_max_ps(acc_max, _mm256_loadu_ps(in_ptr.add(i)));
3161 i += 8;
3162 }
3163 let mut buf8 = [0.0f32; 8];
3164 _mm256_storeu_ps(buf8.as_mut_ptr(), acc_max);
3165 let mut max_val = buf8[0];
3166 for &v in &buf8[1..] {
3167 max_val = max_val.max(v);
3168 }
3169 while i < len {
3170 max_val = max_val.max(*in_ptr.add(i));
3171 i += 1;
3172 }
3173
3174 let off = _mm256_set1_ps(max_val);
3176 let mut acc_sum = _mm256_setzero_ps();
3177 i = 0;
3178 while i + 8 <= len {
3179 let v = fast_exp_avx(_mm256_sub_ps(_mm256_loadu_ps(in_ptr.add(i)), off));
3180 _mm256_storeu_ps(out_ptr.add(i), v);
3181 acc_sum = _mm256_add_ps(acc_sum, v);
3182 i += 8;
3183 }
3184 _mm256_storeu_ps(buf8.as_mut_ptr(), acc_sum);
3185 let mut sum_exp: f32 = buf8.iter().sum();
3186 let off4 = _mm_set1_ps(max_val);
3188 while i + 4 <= len {
3189 let v = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off4));
3190 _mm_storeu_ps(out_ptr.add(i), v);
3191 let mut b4 = [0.0f32; 4];
3192 _mm_storeu_ps(b4.as_mut_ptr(), v);
3193 sum_exp += b4[0] + b4[1] + b4[2] + b4[3];
3194 i += 4;
3195 }
3196 while i < len {
3197 let e = (*in_ptr.add(i) - max_val).exp();
3198 *out_ptr.add(i) = e;
3199 sum_exp += e;
3200 i += 1;
3201 }
3202
3203 let inv8 = _mm256_set1_ps(1.0 / sum_exp);
3205 i = 0;
3206 while i + 8 <= len {
3207 _mm256_storeu_ps(
3208 out_ptr.add(i),
3209 _mm256_mul_ps(_mm256_loadu_ps(out_ptr.add(i)), inv8),
3210 );
3211 i += 8;
3212 }
3213 let inv4 = _mm_set1_ps(1.0 / sum_exp);
3214 while i + 4 <= len {
3215 _mm_storeu_ps(
3216 out_ptr.add(i),
3217 _mm_mul_ps(_mm_loadu_ps(out_ptr.add(i)), inv4),
3218 );
3219 i += 4;
3220 }
3221 let inv_s = 1.0 / sum_exp;
3222 while i < len {
3223 *out_ptr.add(i) *= inv_s;
3224 i += 1;
3225 }
3226}
3227
3228#[allow(unsafe_code)]
3233#[inline]
3234pub fn log_softmax_row_fused_dispatch(input: &[f32], output: &mut [f32]) {
3235 debug_assert_eq!(input.len(), output.len());
3236
3237 if cfg!(miri) || input.is_empty() {
3238 log_softmax_row_fused_scalar(input, output);
3239 return;
3240 }
3241
3242 #[cfg(target_arch = "aarch64")]
3243 {
3244 if std::arch::is_aarch64_feature_detected!("neon") {
3245 unsafe {
3247 log_softmax_row_fused_neon(input, output);
3248 }
3249 return;
3250 }
3251 }
3252
3253 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3254 {
3255 if std::is_x86_feature_detected!("avx") {
3256 unsafe {
3258 log_softmax_row_fused_avx(input, output);
3259 }
3260 return;
3261 }
3262 if std::is_x86_feature_detected!("sse") {
3263 unsafe {
3265 log_softmax_row_fused_sse(input, output);
3266 }
3267 return;
3268 }
3269 }
3270
3271 log_softmax_row_fused_scalar(input, output);
3272}
3273
3274fn log_softmax_row_fused_scalar(input: &[f32], output: &mut [f32]) {
3275 if input.is_empty() {
3276 return;
3277 }
3278
3279 let mut max_val = f32::NEG_INFINITY;
3281 for &v in input {
3282 max_val = max_val.max(v);
3283 }
3284
3285 let mut sum_exp = 0.0f32;
3287 for &v in input {
3288 sum_exp += (v - max_val).exp();
3289 }
3290
3291 let log_denom = max_val + sum_exp.ln();
3293 for (o, &v) in output.iter_mut().zip(input.iter()) {
3294 *o = v - log_denom;
3295 }
3296}
3297
3298#[cfg(target_arch = "aarch64")]
3299#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3300#[target_feature(enable = "neon")]
3301unsafe fn log_softmax_row_fused_neon(input: &[f32], output: &mut [f32]) {
3302 use std::arch::aarch64::{vaddvq_f32, vmaxvq_f32};
3303
3304 let len = input.len();
3305 let in_ptr = input.as_ptr();
3306 let out_ptr = output.as_mut_ptr();
3307
3308 let mut acc_max = vdupq_n_f32(f32::NEG_INFINITY);
3310 let mut i = 0usize;
3311 while i + 16 <= len {
3312 let v0 = vld1q_f32(in_ptr.add(i));
3313 let v1 = vld1q_f32(in_ptr.add(i + 4));
3314 let v2 = vld1q_f32(in_ptr.add(i + 8));
3315 let v3 = vld1q_f32(in_ptr.add(i + 12));
3316 acc_max = vmaxq_f32(acc_max, vmaxq_f32(vmaxq_f32(v0, v1), vmaxq_f32(v2, v3)));
3317 i += 16;
3318 }
3319 while i + 4 <= len {
3320 acc_max = vmaxq_f32(acc_max, vld1q_f32(in_ptr.add(i)));
3321 i += 4;
3322 }
3323 let mut max_val = vmaxvq_f32(acc_max);
3324 while i < len {
3325 max_val = max_val.max(*in_ptr.add(i));
3326 i += 1;
3327 }
3328
3329 let off = vdupq_n_f32(max_val);
3331 let mut acc_sum = vdupq_n_f32(0.0);
3332 i = 0;
3333 while i + 16 <= len {
3334 let e0 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3335 let e1 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 4)), off));
3336 let e2 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 8)), off));
3337 let e3 = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i + 12)), off));
3338 acc_sum = vaddq_f32(acc_sum, vaddq_f32(vaddq_f32(e0, e1), vaddq_f32(e2, e3)));
3339 i += 16;
3340 }
3341 while i + 4 <= len {
3342 let e = fast_exp_neon(vsubq_f32(vld1q_f32(in_ptr.add(i)), off));
3343 acc_sum = vaddq_f32(acc_sum, e);
3344 i += 4;
3345 }
3346 let mut sum_exp = vaddvq_f32(acc_sum);
3347 while i < len {
3348 sum_exp += (*in_ptr.add(i) - max_val).exp();
3349 i += 1;
3350 }
3351
3352 let log_denom = vdupq_n_f32(max_val + sum_exp.ln());
3354 i = 0;
3355 while i + 16 <= len {
3356 vst1q_f32(
3357 out_ptr.add(i),
3358 vsubq_f32(vld1q_f32(in_ptr.add(i)), log_denom),
3359 );
3360 vst1q_f32(
3361 out_ptr.add(i + 4),
3362 vsubq_f32(vld1q_f32(in_ptr.add(i + 4)), log_denom),
3363 );
3364 vst1q_f32(
3365 out_ptr.add(i + 8),
3366 vsubq_f32(vld1q_f32(in_ptr.add(i + 8)), log_denom),
3367 );
3368 vst1q_f32(
3369 out_ptr.add(i + 12),
3370 vsubq_f32(vld1q_f32(in_ptr.add(i + 12)), log_denom),
3371 );
3372 i += 16;
3373 }
3374 while i + 4 <= len {
3375 vst1q_f32(
3376 out_ptr.add(i),
3377 vsubq_f32(vld1q_f32(in_ptr.add(i)), log_denom),
3378 );
3379 i += 4;
3380 }
3381 let log_denom_s = max_val + sum_exp.ln();
3382 while i < len {
3383 *out_ptr.add(i) = *in_ptr.add(i) - log_denom_s;
3384 i += 1;
3385 }
3386}
3387
3388#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3390#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3391#[target_feature(enable = "sse")]
3392unsafe fn log_softmax_row_fused_sse(input: &[f32], output: &mut [f32]) {
3393 let len = input.len();
3394 let in_ptr = input.as_ptr();
3395 let out_ptr = output.as_mut_ptr();
3396
3397 let mut acc_max = _mm_set1_ps(f32::NEG_INFINITY);
3399 let mut i = 0usize;
3400 while i + 4 <= len {
3401 acc_max = _mm_max_ps(acc_max, _mm_loadu_ps(in_ptr.add(i)));
3402 i += 4;
3403 }
3404 let mut buf = [0.0f32; 4];
3405 _mm_storeu_ps(buf.as_mut_ptr(), acc_max);
3406 let mut max_val = buf[0].max(buf[1]).max(buf[2].max(buf[3]));
3407 while i < len {
3408 max_val = max_val.max(*in_ptr.add(i));
3409 i += 1;
3410 }
3411
3412 let off = _mm_set1_ps(max_val);
3414 let mut acc_sum = _mm_setzero_ps();
3415 i = 0;
3416 while i + 4 <= len {
3417 let e = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off));
3418 acc_sum = _mm_add_ps(acc_sum, e);
3419 i += 4;
3420 }
3421 _mm_storeu_ps(buf.as_mut_ptr(), acc_sum);
3422 let mut sum_exp = buf[0] + buf[1] + buf[2] + buf[3];
3423 while i < len {
3424 sum_exp += (*in_ptr.add(i) - max_val).exp();
3425 i += 1;
3426 }
3427
3428 let log_denom = _mm_set1_ps(max_val + sum_exp.ln());
3430 i = 0;
3431 while i + 4 <= len {
3432 _mm_storeu_ps(
3433 out_ptr.add(i),
3434 _mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), log_denom),
3435 );
3436 i += 4;
3437 }
3438 let log_denom_s = max_val + sum_exp.ln();
3439 while i < len {
3440 *out_ptr.add(i) = *in_ptr.add(i) - log_denom_s;
3441 i += 1;
3442 }
3443}
3444
3445#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3447#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
3448#[target_feature(enable = "avx")]
3449unsafe fn log_softmax_row_fused_avx(input: &[f32], output: &mut [f32]) {
3450 let len = input.len();
3451 let in_ptr = input.as_ptr();
3452 let out_ptr = output.as_mut_ptr();
3453
3454 let mut acc_max = _mm256_set1_ps(f32::NEG_INFINITY);
3456 let mut i = 0usize;
3457 while i + 8 <= len {
3458 acc_max = _mm256_max_ps(acc_max, _mm256_loadu_ps(in_ptr.add(i)));
3459 i += 8;
3460 }
3461 let mut buf8 = [0.0f32; 8];
3462 _mm256_storeu_ps(buf8.as_mut_ptr(), acc_max);
3463 let mut max_val = buf8[0];
3464 for &v in &buf8[1..] {
3465 max_val = max_val.max(v);
3466 }
3467 while i < len {
3468 max_val = max_val.max(*in_ptr.add(i));
3469 i += 1;
3470 }
3471
3472 let off = _mm256_set1_ps(max_val);
3474 let mut acc_sum = _mm256_setzero_ps();
3475 i = 0;
3476 while i + 8 <= len {
3477 let e = fast_exp_avx(_mm256_sub_ps(_mm256_loadu_ps(in_ptr.add(i)), off));
3478 acc_sum = _mm256_add_ps(acc_sum, e);
3479 i += 8;
3480 }
3481 _mm256_storeu_ps(buf8.as_mut_ptr(), acc_sum);
3482 let mut sum_exp: f32 = buf8.iter().sum();
3483 let off4 = _mm_set1_ps(max_val);
3485 while i + 4 <= len {
3486 let e = fast_exp_sse(_mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), off4));
3487 let mut b4 = [0.0f32; 4];
3488 _mm_storeu_ps(b4.as_mut_ptr(), e);
3489 sum_exp += b4[0] + b4[1] + b4[2] + b4[3];
3490 i += 4;
3491 }
3492 while i < len {
3493 sum_exp += (*in_ptr.add(i) - max_val).exp();
3494 i += 1;
3495 }
3496
3497 let log_denom_val = max_val + sum_exp.ln();
3499 let log_denom8 = _mm256_set1_ps(log_denom_val);
3500 i = 0;
3501 while i + 8 <= len {
3502 _mm256_storeu_ps(
3503 out_ptr.add(i),
3504 _mm256_sub_ps(_mm256_loadu_ps(in_ptr.add(i)), log_denom8),
3505 );
3506 i += 8;
3507 }
3508 let log_denom4 = _mm_set1_ps(log_denom_val);
3509 while i + 4 <= len {
3510 _mm_storeu_ps(
3511 out_ptr.add(i),
3512 _mm_sub_ps(_mm_loadu_ps(in_ptr.add(i)), log_denom4),
3513 );
3514 i += 4;
3515 }
3516 while i < len {
3517 *out_ptr.add(i) = *in_ptr.add(i) - log_denom_val;
3518 i += 1;
3519 }
3520}
3521
3522#[cfg(test)]
3527mod tests {
3528 use super::*;
3529
3530 fn assert_close(a: &[f32], b: &[f32], tol: f32) {
3531 assert_eq!(a.len(), b.len(), "length mismatch");
3532 for (i, (&x, &y)) in a.iter().zip(b.iter()).enumerate() {
3533 let d = (x - y).abs();
3534 assert!(d <= tol, "index {i}: {x} vs {y}, diff={d}, tolerance={tol}");
3535 }
3536 }
3537
3538 #[test]
3539 fn exp_matches_scalar() {
3540 let input: Vec<f32> = (-20..=20).map(|i| i as f32 * 0.5).collect();
3541 let mut simd_out = vec![0.0f32; input.len()];
3542 let mut scalar_out = vec![0.0f32; input.len()];
3543
3544 exp_slice_dispatch(&input, &mut simd_out);
3545 exp_slice_scalar(&input, &mut scalar_out);
3546
3547 for (i, (&s, &r)) in simd_out.iter().zip(scalar_out.iter()).enumerate() {
3549 let rel = if r.abs() > 1e-10 {
3550 (s - r).abs() / r.abs()
3551 } else {
3552 (s - r).abs()
3553 };
3554 assert!(
3555 rel < 1e-5,
3556 "exp mismatch at index {i}: simd={s}, scalar={r}, rel_err={rel}"
3557 );
3558 }
3559 }
3560
3561 #[test]
3562 fn sigmoid_dispatch_matches_scalar() {
3563 let input: Vec<f32> = (-30..=30).map(|i| i as f32 * 0.3).collect();
3564 let mut simd_out = vec![0.0f32; input.len()];
3565 let mut scalar_out = vec![0.0f32; input.len()];
3566
3567 sigmoid_slice_dispatch(&input, &mut simd_out);
3568 sigmoid_slice_dispatch_scalar(&input, &mut scalar_out);
3569
3570 assert_close(&simd_out, &scalar_out, 0.035);
3573 }
3574
3575 #[test]
3576 fn tanh_dispatch_matches_scalar() {
3577 let input: Vec<f32> = (-30..=30).map(|i| i as f32 * 0.3).collect();
3578 let mut simd_out = vec![0.0f32; input.len()];
3579 let mut scalar_out = vec![0.0f32; input.len()];
3580
3581 tanh_slice_dispatch(&input, &mut simd_out);
3582 tanh_slice_dispatch_scalar(&input, &mut scalar_out);
3583
3584 assert_close(&simd_out, &scalar_out, 2e-3);
3586 }
3587
3588 #[test]
3589 fn max_reduce_matches_scalar() {
3590 let data: Vec<f32> = (0..37).map(|i| (i as f32 * 0.7 - 12.0).sin()).collect();
3591 let simd_result = max_reduce_dispatch(&data);
3592 let scalar_result = max_reduce_scalar(&data);
3593 assert!((simd_result - scalar_result).abs() < 1e-6);
3594 }
3595
3596 #[test]
3597 fn max_reduce_empty() {
3598 assert_eq!(max_reduce_dispatch(&[]), f32::NEG_INFINITY);
3599 }
3600
3601 #[test]
3602 fn add_reduce_matches_scalar() {
3603 let data: Vec<f32> = (0..37).map(|i| i as f32 * 0.1).collect();
3604 let simd_result = add_reduce_dispatch(&data);
3605 let scalar_result = add_reduce_scalar(&data);
3606 assert!(
3607 (simd_result - scalar_result).abs() < 1e-3,
3608 "simd={simd_result}, scalar={scalar_result}"
3609 );
3610 }
3611
3612 #[test]
3613 fn add_reduce_empty() {
3614 assert_eq!(add_reduce_dispatch(&[]), 0.0);
3615 }
3616
3617 #[test]
3618 #[allow(unsafe_code)]
3619 fn fma_matches_scalar() {
3620 let a: Vec<f32> = (0..33).map(|i| i as f32 * 0.3).collect();
3621 let b: Vec<f32> = (0..33).map(|i| (i as f32 * 0.7).sin()).collect();
3622 let mut simd_acc = vec![1.0f32; 33];
3623 let mut scalar_acc = vec![1.0f32; 33];
3624
3625 fma_slice_dispatch(&a, &b, &mut simd_acc);
3626 unsafe { fma_slice_scalar(&a, &b, &mut scalar_acc) };
3627
3628 assert_close(&simd_acc, &scalar_acc, 1e-5);
3629 }
3630
3631 #[test]
3632 fn sigmoid_dispatch_boundary_values() {
3633 let input = vec![-100.0, -10.0, 0.0, 10.0, 100.0];
3635 let mut output = vec![0.0f32; 5];
3636 sigmoid_slice_dispatch(&input, &mut output);
3637
3638 assert!(
3640 output[0] < 0.01,
3641 "sigmoid(-100) should be near 0: {}",
3642 output[0]
3643 );
3644 assert!(
3645 (output[2] - 0.5).abs() < 0.01,
3646 "sigmoid(0) should be near 0.5: {}",
3647 output[2]
3648 );
3649 assert!(
3650 output[4] > 0.99,
3651 "sigmoid(100) should be near 1: {}",
3652 output[4]
3653 );
3654 }
3655
3656 #[test]
3657 fn tanh_dispatch_boundary_values() {
3658 let input = vec![-100.0, -1.0, 0.0, 1.0, 100.0];
3659 let mut output = vec![0.0f32; 5];
3660 tanh_slice_dispatch(&input, &mut output);
3661
3662 assert!(
3663 output[0] < -0.99,
3664 "tanh(-100) should be near -1: {}",
3665 output[0]
3666 );
3667 assert!(
3668 (output[2]).abs() < 0.01,
3669 "tanh(0) should be near 0: {}",
3670 output[2]
3671 );
3672 assert!(
3673 output[4] > 0.99,
3674 "tanh(100) should be near 1: {}",
3675 output[4]
3676 );
3677 }
3678}