1use crate::pool;
23
24#[cfg(target_arch = "aarch64")]
29#[inline(always)]
30#[allow(unsafe_op_in_unsafe_fn)]
31pub unsafe fn neon_exp4(x: std::arch::aarch64::float32x4_t) -> std::arch::aarch64::float32x4_t {
32 use std::arch::aarch64::*;
33 let x = vmaxq_f32(x, vdupq_n_f32(-87.3));
34 let x = vminq_f32(x, vdupq_n_f32(88.7));
35 let inv_ln2 = vdupq_n_f32(std::f32::consts::LOG2_E);
36 let ln2_hi = vdupq_n_f32(0.693_145_75);
37 let ln2_lo = vdupq_n_f32(1.428_606_8e-6);
38 let n = vrndnq_f32(vmulq_f32(x, inv_ln2));
39 let r = vfmsq_f32(vfmsq_f32(x, n, ln2_hi), n, ln2_lo);
40 let c1 = vdupq_n_f32(1.0);
41 let mut p = vdupq_n_f32(0.001_388_888_9);
42 p = vfmaq_f32(vdupq_n_f32(0.008_333_334), p, r);
43 p = vfmaq_f32(vdupq_n_f32(0.041_666_668), p, r);
44 p = vfmaq_f32(vdupq_n_f32(0.166_666_67), p, r);
45 p = vfmaq_f32(vdupq_n_f32(0.5), p, r);
46 p = vfmaq_f32(c1, p, r);
47 p = vfmaq_f32(c1, p, r);
48 let ni = vcvtq_s32_f32(n);
49 vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(p), vshlq_n_s32(ni, 23)))
50}
51
52#[cfg(all(
56 target_arch = "x86_64",
57 target_feature = "avx2",
58 target_feature = "fma"
59))]
60#[inline(always)]
61#[allow(unsafe_op_in_unsafe_fn)]
62pub unsafe fn avx2_exp8(x: std::arch::x86_64::__m256) -> std::arch::x86_64::__m256 {
63 use std::arch::x86_64::*;
64 let x = _mm256_max_ps(x, _mm256_set1_ps(-87.3));
65 let x = _mm256_min_ps(x, _mm256_set1_ps(88.7));
66 let inv_ln2 = _mm256_set1_ps(1.442695040888963);
67 let ln2_hi = _mm256_set1_ps(0.693145751953125);
68 let ln2_lo = _mm256_set1_ps(1.428606765330187e-6);
69 let n = _mm256_round_ps::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(_mm256_mul_ps(
71 x, inv_ln2,
72 ));
73 let r = _mm256_fnmadd_ps(n, ln2_lo, _mm256_fnmadd_ps(n, ln2_hi, x));
75 let c1 = _mm256_set1_ps(1.0);
76 let mut p = _mm256_set1_ps(0.001388888888888889);
77 p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.008333333333333333));
78 p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.041666666666666664));
79 p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.16666666666666666));
80 p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.5));
81 p = _mm256_fmadd_ps(p, r, c1);
82 p = _mm256_fmadd_ps(p, r, c1);
83 let ni = _mm256_cvtps_epi32(n);
85 let shifted = _mm256_slli_epi32::<23>(ni);
86 _mm256_castsi256_ps(_mm256_add_epi32(_mm256_castps_si256(p), shifted))
87}
88
89#[cfg(target_arch = "aarch64")]
94pub fn bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
95 use std::arch::aarch64::*;
96 let chunks = n / 4;
97 unsafe {
98 let half = vdupq_n_f32(0.5);
99 let one = vdupq_n_f32(1.0);
100 let inv_sqrt2 = vdupq_n_f32(std::f32::consts::FRAC_1_SQRT_2);
101 let p = vdupq_n_f32(0.3275911);
102 let a1 = vdupq_n_f32(0.254_829_6);
103 let a2 = vdupq_n_f32(-0.284_496_72);
104 let a3 = vdupq_n_f32(1.421_413_8);
105 let a4 = vdupq_n_f32(-1.453_152_1);
106 let a5 = vdupq_n_f32(1.061_405_4);
107 let neg_one = vdupq_n_f32(-1.0);
108 let zero = vdupq_n_f32(0.0);
109
110 for row in 0..m {
111 let base = row * n;
112 for c in 0..chunks {
113 let off = base + c * 4;
114 let ptr = data.as_mut_ptr().add(off);
115 let x = vaddq_f32(vld1q_f32(ptr), vld1q_f32(bias.as_ptr().add(c * 4)));
116 let erf_arg = vmulq_f32(x, inv_sqrt2);
117 let xa = vabsq_f32(erf_arg);
118 let sign = vbslq_f32(vcgeq_f32(erf_arg, zero), one, neg_one);
119 let denom = vfmaq_f32(one, p, xa);
120 let t = vdivq_f32(one, denom);
121 let mut y = a5;
122 y = vfmaq_f32(a4, y, t);
123 y = vfmaq_f32(a3, y, t);
124 y = vfmaq_f32(a2, y, t);
125 y = vfmaq_f32(a1, y, t);
126 y = vmulq_f32(y, t);
127 let exp_val = neon_exp4(vnegq_f32(vmulq_f32(xa, xa)));
128 let erf_val = vmulq_f32(sign, vfmsq_f32(one, y, exp_val));
129 vst1q_f32(ptr, vmulq_f32(x, vmulq_f32(half, vaddq_f32(one, erf_val))));
130 }
131 for i in (chunks * 4)..n {
132 let x = data[base + i] + bias[i];
133 data[base + i] = scalar_gelu(x);
134 }
135 }
136 }
137}
138
139#[cfg(all(
140 target_arch = "x86_64",
141 target_feature = "avx2",
142 target_feature = "fma"
143))]
144pub fn bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
145 use std::arch::x86_64::*;
146 let chunks = n / 8;
147 unsafe {
148 let half = _mm256_set1_ps(0.5);
149 let one = _mm256_set1_ps(1.0);
150 let inv_sqrt2 = _mm256_set1_ps(std::f32::consts::FRAC_1_SQRT_2);
151 let p = _mm256_set1_ps(0.3275911);
152 let a1 = _mm256_set1_ps(0.254829592);
153 let a2 = _mm256_set1_ps(-0.284496736);
154 let a3 = _mm256_set1_ps(1.421413741);
155 let a4 = _mm256_set1_ps(-1.453152027);
156 let a5 = _mm256_set1_ps(1.061405429);
157 let neg_one = _mm256_set1_ps(-1.0);
158 let zero = _mm256_set1_ps(0.0);
159 let abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fff_ffff));
161
162 for row in 0..m {
163 let base = row * n;
164 for c in 0..chunks {
165 let off = base + c * 8;
166 let ptr = data.as_mut_ptr().add(off);
167 let x = _mm256_add_ps(
168 _mm256_loadu_ps(ptr),
169 _mm256_loadu_ps(bias.as_ptr().add(c * 8)),
170 );
171 let erf_arg = _mm256_mul_ps(x, inv_sqrt2);
172 let xa = _mm256_and_ps(erf_arg, abs_mask);
173 let ge0 = _mm256_cmp_ps::<_CMP_GE_OQ>(erf_arg, zero);
175 let sign = _mm256_blendv_ps(neg_one, one, ge0);
176 let denom = _mm256_fmadd_ps(p, xa, one);
177 let t = _mm256_div_ps(one, denom);
178 let mut y = a5;
179 y = _mm256_fmadd_ps(y, t, a4);
180 y = _mm256_fmadd_ps(y, t, a3);
181 y = _mm256_fmadd_ps(y, t, a2);
182 y = _mm256_fmadd_ps(y, t, a1);
183 y = _mm256_mul_ps(y, t);
184 let exp_val = avx2_exp8(_mm256_sub_ps(zero, _mm256_mul_ps(xa, xa)));
185 let erf_val = _mm256_mul_ps(sign, _mm256_fnmadd_ps(y, exp_val, one));
187 _mm256_storeu_ps(
188 ptr,
189 _mm256_mul_ps(x, _mm256_mul_ps(half, _mm256_add_ps(one, erf_val))),
190 );
191 }
192 for i in (chunks * 8)..n {
193 let x = data[base + i] + bias[i];
194 data[base + i] = scalar_gelu(x);
195 }
196 }
197 }
198}
199
200#[cfg(not(any(
201 target_arch = "aarch64",
202 all(
203 target_arch = "x86_64",
204 target_feature = "avx2",
205 target_feature = "fma"
206 )
207)))]
208pub fn bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
209 for row in 0..m {
210 let base = row * n;
211 for i in 0..n {
212 let x = data[base + i] + bias[i];
213 data[base + i] = scalar_gelu(x);
214 }
215 }
216}
217
218pub fn par_bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
220 let cfg = crate::config::RuntimeConfig::global();
221 if m * n < cfg.par_threshold || m < cfg.min_rows_per_thread {
222 bias_gelu(data, bias, m, n);
223 return;
224 }
225 let data_ptr = data.as_mut_ptr() as usize;
226 let bias_ptr = bias.as_ptr() as usize;
227 pool::par_for(m, cfg.min_rows_per_thread, &|off, cnt| unsafe {
228 let d = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(off * n), cnt * n);
229 let b = std::slice::from_raw_parts(bias_ptr as *const f32, n);
230 bias_gelu(d, b, cnt, n);
231 });
232}
233
234#[cfg(target_arch = "aarch64")]
238pub fn silu_inplace(data: &mut [f32]) {
239 use std::arch::aarch64::*;
240 let chunks = data.len() / 4;
241 unsafe {
242 let one = vdupq_n_f32(1.0);
243 for c in 0..chunks {
244 let ptr = data.as_mut_ptr().add(c * 4);
245 let x = vld1q_f32(ptr);
246 let exp_neg = neon_exp4(vnegq_f32(x));
247 let sigmoid = vdivq_f32(one, vaddq_f32(one, exp_neg));
248 vst1q_f32(ptr, vmulq_f32(x, sigmoid));
249 }
250 }
251 for i in (chunks * 4)..data.len() {
252 let x = data[i];
253 data[i] = x / (1.0 + (-x).exp());
254 }
255}
256
257#[cfg(all(
258 target_arch = "x86_64",
259 target_feature = "avx2",
260 target_feature = "fma"
261))]
262pub fn silu_inplace(data: &mut [f32]) {
263 use std::arch::x86_64::*;
264 let chunks = data.len() / 8;
265 unsafe {
266 let one = _mm256_set1_ps(1.0);
267 let zero = _mm256_set1_ps(0.0);
268 for c in 0..chunks {
269 let off = c * 8;
270 let ptr = data.as_mut_ptr().add(off);
271 let x = _mm256_loadu_ps(ptr);
272 let neg_x = _mm256_sub_ps(zero, x);
274 let denom = _mm256_add_ps(one, avx2_exp8(neg_x));
275 _mm256_storeu_ps(ptr, _mm256_div_ps(x, denom));
276 }
277 for i in (chunks * 8)..data.len() {
278 let x = data[i];
279 data[i] = x / (1.0 + (-x).exp());
280 }
281 }
282}
283
284#[cfg(not(any(
285 target_arch = "aarch64",
286 all(
287 target_arch = "x86_64",
288 target_feature = "avx2",
289 target_feature = "fma"
290 )
291)))]
292pub fn silu_inplace(data: &mut [f32]) {
293 for v in data.iter_mut() {
294 let x = *v;
295 *v = x / (1.0 + (-x).exp());
296 }
297}
298
299#[cfg(target_arch = "aarch64")]
304pub fn layer_norm_row(
305 input: &[f32],
306 gamma: &[f32],
307 beta: &[f32],
308 output: &mut [f32],
309 h: usize,
310 eps: f32,
311) {
312 use std::arch::aarch64::*;
313 let inv_hf = 1.0 / h as f32;
314 let chunks = h / 4;
315 unsafe {
316 let mut vsum = vdupq_n_f32(0.0);
317 let mut vsumsq = vdupq_n_f32(0.0);
318 for c in 0..chunks {
319 let x = vld1q_f32(input.as_ptr().add(c * 4));
320 vsum = vaddq_f32(vsum, x);
321 vsumsq = vfmaq_f32(vsumsq, x, x);
322 }
323 let mut sum = vaddvq_f32(vsum);
324 let mut sumsq = vaddvq_f32(vsumsq);
325 for i in (chunks * 4)..h {
326 sum += input[i];
327 sumsq += input[i] * input[i];
328 }
329 let mean = sum * inv_hf;
330 let var = sumsq * inv_hf - mean * mean;
331 let inv = 1.0 / (var + eps).sqrt();
332 let vmean = vdupq_n_f32(mean);
333 let vinv = vdupq_n_f32(inv);
334 for c in 0..chunks {
335 let off = c * 4;
336 let x = vld1q_f32(input.as_ptr().add(off));
337 let norm = vmulq_f32(vsubq_f32(x, vmean), vinv);
338 vst1q_f32(
339 output.as_mut_ptr().add(off),
340 vfmaq_f32(
341 vld1q_f32(beta.as_ptr().add(off)),
342 norm,
343 vld1q_f32(gamma.as_ptr().add(off)),
344 ),
345 );
346 }
347 for i in (chunks * 4)..h {
348 output[i] = (input[i] - mean) * inv * gamma[i] + beta[i];
349 }
350 }
351}
352
353#[cfg(all(
354 target_arch = "x86_64",
355 target_feature = "avx2",
356 target_feature = "fma"
357))]
358pub fn layer_norm_row(
359 input: &[f32],
360 gamma: &[f32],
361 beta: &[f32],
362 output: &mut [f32],
363 h: usize,
364 eps: f32,
365) {
366 use std::arch::x86_64::*;
367 let inv_hf = 1.0 / h as f32;
368 let chunks = h / 8;
369 unsafe {
370 let mut vsum = _mm256_setzero_ps();
371 let mut vsumsq = _mm256_setzero_ps();
372 for c in 0..chunks {
373 let x = _mm256_loadu_ps(input.as_ptr().add(c * 8));
374 vsum = _mm256_add_ps(vsum, x);
375 vsumsq = _mm256_fmadd_ps(x, x, vsumsq);
376 }
377 let hsum = {
379 let lo = _mm256_castps256_ps128(vsum);
380 let hi = _mm256_extractf128_ps::<1>(vsum);
381 let s4 = _mm_add_ps(lo, hi);
382 let s2 = _mm_add_ps(s4, _mm_movehl_ps(s4, s4));
383 let s1 = _mm_add_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
384 _mm_cvtss_f32(s1)
385 };
386 let hsumsq = {
387 let lo = _mm256_castps256_ps128(vsumsq);
388 let hi = _mm256_extractf128_ps::<1>(vsumsq);
389 let s4 = _mm_add_ps(lo, hi);
390 let s2 = _mm_add_ps(s4, _mm_movehl_ps(s4, s4));
391 let s1 = _mm_add_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
392 _mm_cvtss_f32(s1)
393 };
394 let mut sum = hsum;
395 let mut sumsq = hsumsq;
396 for i in (chunks * 8)..h {
397 sum += input[i];
398 sumsq += input[i] * input[i];
399 }
400 let mean = sum * inv_hf;
401 let var = sumsq * inv_hf - mean * mean;
402 let inv = 1.0 / (var + eps).sqrt();
403 let vmean = _mm256_set1_ps(mean);
404 let vinv = _mm256_set1_ps(inv);
405 for c in 0..chunks {
406 let off = c * 8;
407 let x = _mm256_loadu_ps(input.as_ptr().add(off));
408 let norm = _mm256_mul_ps(_mm256_sub_ps(x, vmean), vinv);
409 let g = _mm256_loadu_ps(gamma.as_ptr().add(off));
410 let b = _mm256_loadu_ps(beta.as_ptr().add(off));
411 _mm256_storeu_ps(output.as_mut_ptr().add(off), _mm256_fmadd_ps(norm, g, b));
412 }
413 for i in (chunks * 8)..h {
414 output[i] = (input[i] - mean) * inv * gamma[i] + beta[i];
415 }
416 }
417}
418
419#[cfg(not(any(
420 target_arch = "aarch64",
421 all(
422 target_arch = "x86_64",
423 target_feature = "avx2",
424 target_feature = "fma"
425 )
426)))]
427pub fn layer_norm_row(
428 input: &[f32],
429 gamma: &[f32],
430 beta: &[f32],
431 output: &mut [f32],
432 h: usize,
433 eps: f32,
434) {
435 let inv_hf = 1.0 / h as f32;
436 let mut sum = 0f32;
437 let mut sumsq = 0f32;
438 for i in 0..h {
439 sum += input[i];
440 sumsq += input[i] * input[i];
441 }
442 let mean = sum * inv_hf;
443 let var = sumsq * inv_hf - mean * mean;
444 let inv = 1.0 / (var + eps).sqrt();
445 for i in 0..h {
446 output[i] = (input[i] - mean) * inv * gamma[i] + beta[i];
447 }
448}
449
450pub fn residual_bias_layer_norm(
453 a: &[f32],
454 b: &[f32],
455 bias: &[f32],
456 gamma: &[f32],
457 beta: &[f32],
458 output: &mut [f32],
459 n: usize,
460 h: usize,
461 eps: f32,
462) {
463 let mut tmp = vec![0f32; h];
465 for row in 0..n {
466 let base = row * h;
467 for i in 0..h {
468 tmp[i] = a[base + i] + b[base + i] + bias[i];
469 }
470 layer_norm_row(&tmp, gamma, beta, &mut output[base..base + h], h, eps);
471 }
472}
473
474pub fn residual_bias_rms_norm(
477 a: &[f32],
478 b: &[f32],
479 bias: &[f32],
480 gamma: &[f32],
481 beta: &[f32],
482 output: &mut [f32],
483 n: usize,
484 h: usize,
485 eps: f32,
486) {
487 let inv_h = 1.0 / h as f32;
488 for row in 0..n {
489 let base = row * h;
490 let mut sumsq = 0f32;
491 for i in 0..h {
492 let v = a[base + i] + b[base + i] + bias[i];
493 sumsq += v * v;
494 }
495 let inv_rms = (sumsq * inv_h + eps).sqrt().recip();
496 for i in 0..h {
497 let v = a[base + i] + b[base + i] + bias[i];
498 output[base + i] = v * inv_rms * gamma[i] + beta[i];
499 }
500 }
501}
502
503pub fn par_residual_bias_ln(
505 a: &[f32],
506 b: &[f32],
507 bias: &[f32],
508 gamma: &[f32],
509 beta: &[f32],
510 output: &mut [f32],
511 n: usize,
512 h: usize,
513 eps: f32,
514) {
515 let cfg = crate::config::RuntimeConfig::global();
516 if n * h < cfg.par_threshold || n < cfg.min_rows_per_thread {
517 residual_bias_layer_norm(a, b, bias, gamma, beta, output, n, h, eps);
518 return;
519 }
520 let a_ptr = a.as_ptr() as usize;
521 let b_ptr = b.as_ptr() as usize;
522 let o_ptr = output.as_mut_ptr() as usize;
523 let bias_ptr = bias.as_ptr() as usize;
524 let gamma_ptr = gamma.as_ptr() as usize;
525 let beta_ptr = beta.as_ptr() as usize;
526 pool::par_for(n, cfg.min_rows_per_thread, &|off, cnt| unsafe {
527 let a_s = std::slice::from_raw_parts((a_ptr as *const f32).add(off * h), cnt * h);
528 let b_s = std::slice::from_raw_parts((b_ptr as *const f32).add(off * h), cnt * h);
529 let o_s = std::slice::from_raw_parts_mut((o_ptr as *mut f32).add(off * h), cnt * h);
530 let bi = std::slice::from_raw_parts(bias_ptr as *const f32, h);
531 let g = std::slice::from_raw_parts(gamma_ptr as *const f32, h);
532 let be = std::slice::from_raw_parts(beta_ptr as *const f32, h);
533 residual_bias_layer_norm(a_s, b_s, bi, g, be, o_s, cnt, h, eps);
534 });
535}
536
537#[cfg(target_arch = "aarch64")]
541pub fn neon_softmax(data: &mut [f32], rows: usize, cols: usize) {
542 use std::arch::aarch64::*;
543 let chunks = cols / 4;
544 unsafe {
545 for row in 0..rows {
546 let base = row * cols;
547 let ptr = data.as_mut_ptr().add(base);
548
549 let mut vmax = vdupq_n_f32(f32::NEG_INFINITY);
551 for c in 0..chunks {
552 vmax = vmaxq_f32(vmax, vld1q_f32(ptr.add(c * 4)));
553 }
554 let mut max_val = vmaxvq_f32(vmax);
555 for i in (chunks * 4)..cols {
556 max_val = max_val.max(*ptr.add(i));
557 }
558
559 let vmx = vdupq_n_f32(max_val);
561 let mut vsum = vdupq_n_f32(0.0);
562 for c in 0..chunks {
563 let off = c * 4;
564 let e = neon_exp4(vsubq_f32(vld1q_f32(ptr.add(off)), vmx));
565 vst1q_f32(ptr.add(off), e);
566 vsum = vaddq_f32(vsum, e);
567 }
568 let mut sum = vaddvq_f32(vsum);
569 for i in (chunks * 4)..cols {
570 let e = (*ptr.add(i) - max_val).exp();
571 *ptr.add(i) = e;
572 sum += e;
573 }
574
575 let vinv = vdupq_n_f32(1.0 / sum);
577 for c in 0..chunks {
578 let off = c * 4;
579 vst1q_f32(ptr.add(off), vmulq_f32(vld1q_f32(ptr.add(off)), vinv));
580 }
581 let inv = 1.0 / sum;
582 for i in (chunks * 4)..cols {
583 *ptr.add(i) *= inv;
584 }
585 }
586 }
587}
588
589#[cfg(all(
590 target_arch = "x86_64",
591 target_feature = "avx2",
592 target_feature = "fma"
593))]
594pub fn neon_softmax(data: &mut [f32], rows: usize, cols: usize) {
595 use std::arch::x86_64::*;
596 let chunks = cols / 8;
597 unsafe {
598 for r in 0..rows {
599 let row = data.as_mut_ptr().add(r * cols);
600 let mut vmax = _mm256_set1_ps(f32::NEG_INFINITY);
602 for c in 0..chunks {
603 vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(row.add(c * 8)));
604 }
605 let mut max_v = {
606 let lo = _mm256_castps256_ps128(vmax);
607 let hi = _mm256_extractf128_ps::<1>(vmax);
608 let s4 = _mm_max_ps(lo, hi);
609 let s2 = _mm_max_ps(s4, _mm_movehl_ps(s4, s4));
610 let s1 = _mm_max_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
611 _mm_cvtss_f32(s1)
612 };
613 for i in (chunks * 8)..cols {
614 let v = *row.add(i);
615 if v > max_v {
616 max_v = v;
617 }
618 }
619 let vmax = _mm256_set1_ps(max_v);
621 let mut vsum = _mm256_setzero_ps();
622 for c in 0..chunks {
623 let off = c * 8;
624 let e = avx2_exp8(_mm256_sub_ps(_mm256_loadu_ps(row.add(off)), vmax));
625 _mm256_storeu_ps(row.add(off), e);
626 vsum = _mm256_add_ps(vsum, e);
627 }
628 let mut sum_v = {
629 let lo = _mm256_castps256_ps128(vsum);
630 let hi = _mm256_extractf128_ps::<1>(vsum);
631 let s4 = _mm_add_ps(lo, hi);
632 let s2 = _mm_add_ps(s4, _mm_movehl_ps(s4, s4));
633 let s1 = _mm_add_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
634 _mm_cvtss_f32(s1)
635 };
636 for i in (chunks * 8)..cols {
637 let v = (*row.add(i) - max_v).exp();
638 *row.add(i) = v;
639 sum_v += v;
640 }
641 let vinv = _mm256_set1_ps(1.0 / sum_v);
643 for c in 0..chunks {
644 let off = c * 8;
645 _mm256_storeu_ps(
646 row.add(off),
647 _mm256_mul_ps(_mm256_loadu_ps(row.add(off)), vinv),
648 );
649 }
650 let inv_sum = 1.0 / sum_v;
651 for i in (chunks * 8)..cols {
652 *row.add(i) *= inv_sum;
653 }
654 }
655 }
656}
657
658#[cfg(not(any(
659 target_arch = "aarch64",
660 all(
661 target_arch = "x86_64",
662 target_feature = "avx2",
663 target_feature = "fma"
664 )
665)))]
666pub fn neon_softmax(data: &mut [f32], rows: usize, cols: usize) {
667 crate::naive::softmax(data, rows, cols);
668}
669
670#[cfg(target_arch = "aarch64")]
674pub fn gelu_inplace(data: &mut [f32]) {
675 use std::arch::aarch64::*;
676 let len = data.len();
677 let chunks = len / 4;
678 unsafe {
679 let half = vdupq_n_f32(0.5);
680 let one = vdupq_n_f32(1.0);
681 let inv_sqrt2 = vdupq_n_f32(std::f32::consts::FRAC_1_SQRT_2);
682 let p = vdupq_n_f32(0.3275911);
683 let a1 = vdupq_n_f32(0.254_829_6);
684 let a2 = vdupq_n_f32(-0.284_496_72);
685 let a3 = vdupq_n_f32(1.421_413_8);
686 let a4 = vdupq_n_f32(-1.453_152_1);
687 let a5 = vdupq_n_f32(1.061_405_4);
688 let neg_one = vdupq_n_f32(-1.0);
689 let zero = vdupq_n_f32(0.0);
690
691 for c in 0..chunks {
692 let ptr = data.as_mut_ptr().add(c * 4);
693 let x = vld1q_f32(ptr);
694 let erf_arg = vmulq_f32(x, inv_sqrt2);
695 let xa = vabsq_f32(erf_arg);
696 let sign = vbslq_f32(vcgeq_f32(erf_arg, zero), one, neg_one);
697 let denom = vfmaq_f32(one, p, xa);
698 let t = vdivq_f32(one, denom);
699 let mut y = a5;
700 y = vfmaq_f32(a4, y, t);
701 y = vfmaq_f32(a3, y, t);
702 y = vfmaq_f32(a2, y, t);
703 y = vfmaq_f32(a1, y, t);
704 y = vmulq_f32(y, t);
705 let exp_val = neon_exp4(vnegq_f32(vmulq_f32(xa, xa)));
706 let erf_val = vmulq_f32(sign, vfmsq_f32(one, y, exp_val));
707 vst1q_f32(ptr, vmulq_f32(x, vmulq_f32(half, vaddq_f32(one, erf_val))));
708 }
709 for i in (chunks * 4)..len {
710 data[i] = scalar_gelu(data[i]);
711 }
712 }
713}
714
715#[cfg(all(
716 target_arch = "x86_64",
717 target_feature = "avx2",
718 target_feature = "fma"
719))]
720pub fn gelu_inplace(data: &mut [f32]) {
721 use std::arch::x86_64::*;
722 let chunks = data.len() / 8;
723 unsafe {
724 let half = _mm256_set1_ps(0.5);
725 let one = _mm256_set1_ps(1.0);
726 let inv_sqrt2 = _mm256_set1_ps(std::f32::consts::FRAC_1_SQRT_2);
727 let p = _mm256_set1_ps(0.3275911);
728 let a1 = _mm256_set1_ps(0.254829592);
729 let a2 = _mm256_set1_ps(-0.284496736);
730 let a3 = _mm256_set1_ps(1.421413741);
731 let a4 = _mm256_set1_ps(-1.453152027);
732 let a5 = _mm256_set1_ps(1.061405429);
733 let neg_one = _mm256_set1_ps(-1.0);
734 let zero = _mm256_set1_ps(0.0);
735 let abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fff_ffff));
736 for c in 0..chunks {
737 let off = c * 8;
738 let ptr = data.as_mut_ptr().add(off);
739 let x = _mm256_loadu_ps(ptr);
740 let erf_arg = _mm256_mul_ps(x, inv_sqrt2);
741 let xa = _mm256_and_ps(erf_arg, abs_mask);
742 let ge0 = _mm256_cmp_ps::<_CMP_GE_OQ>(erf_arg, zero);
743 let sign = _mm256_blendv_ps(neg_one, one, ge0);
744 let denom = _mm256_fmadd_ps(p, xa, one);
745 let t = _mm256_div_ps(one, denom);
746 let mut y = a5;
747 y = _mm256_fmadd_ps(y, t, a4);
748 y = _mm256_fmadd_ps(y, t, a3);
749 y = _mm256_fmadd_ps(y, t, a2);
750 y = _mm256_fmadd_ps(y, t, a1);
751 y = _mm256_mul_ps(y, t);
752 let exp_val = avx2_exp8(_mm256_sub_ps(zero, _mm256_mul_ps(xa, xa)));
753 let erf_val = _mm256_mul_ps(sign, _mm256_fnmadd_ps(y, exp_val, one));
754 _mm256_storeu_ps(
755 ptr,
756 _mm256_mul_ps(x, _mm256_mul_ps(half, _mm256_add_ps(one, erf_val))),
757 );
758 }
759 for i in (chunks * 8)..data.len() {
760 data[i] = scalar_gelu(data[i]);
761 }
762 }
763}
764
765#[cfg(not(any(
766 target_arch = "aarch64",
767 all(
768 target_arch = "x86_64",
769 target_feature = "avx2",
770 target_feature = "fma"
771 )
772)))]
773pub fn gelu_inplace(data: &mut [f32]) {
774 for v in data.iter_mut() {
775 *v = scalar_gelu(*v);
776 }
777}
778
779const ACTIVATION_PAR_MIN: usize = 1 << 20;
790
791#[inline]
800pub fn scalar_gelu_approx(x: f32) -> f32 {
801 const C: f32 = 0.797_884_6; const A: f32 = 0.044_715;
803 0.5 * x * (1.0 + (C * (x + A * x * x * x)).tanh())
804}
805
806pub fn gelu_approx_inplace(data: &mut [f32]) {
807 for v in data.iter_mut() {
808 *v = scalar_gelu_approx(*v);
809 }
810}
811
812pub fn par_gelu_approx_inplace(data: &mut [f32]) {
813 let len = data.len();
814 if len < ACTIVATION_PAR_MIN {
815 gelu_approx_inplace(data);
816 return;
817 }
818 let cfg = crate::config::RuntimeConfig::global();
819 let chunk = 512;
820 let rows = len / chunk;
821 if rows < 2 {
822 gelu_approx_inplace(data);
823 return;
824 }
825 let data_ptr = data.as_mut_ptr() as usize;
826 pool::par_for(rows, cfg.min_rows_per_thread, &|off, cnt| unsafe {
827 let start = off * chunk;
828 let end = if off + cnt >= rows {
829 len
830 } else {
831 (off + cnt) * chunk
832 };
833 let s = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(start), end - start);
834 gelu_approx_inplace(s);
835 });
836 let done = rows * chunk;
837 if done < len {
838 gelu_approx_inplace(&mut data[done..]);
839 }
840}
841
842pub fn par_gelu_inplace(data: &mut [f32]) {
843 let len = data.len();
844 if len < ACTIVATION_PAR_MIN {
845 gelu_inplace(data);
846 return;
847 }
848 let cfg = crate::config::RuntimeConfig::global();
849 let chunk = 512;
850 let rows = len / chunk;
851 if rows < 2 {
852 gelu_inplace(data);
853 return;
854 }
855 let data_ptr = data.as_mut_ptr() as usize;
856 pool::par_for(rows, cfg.min_rows_per_thread, &|off, cnt| unsafe {
857 let start = off * chunk;
858 let end = if off + cnt >= rows {
859 len
860 } else {
861 (off + cnt) * chunk
862 };
863 let s = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(start), end - start);
864 gelu_inplace(s);
865 });
866 let done = rows * chunk;
867 if done < len {
868 gelu_inplace(&mut data[done..]);
869 }
870}
871
872pub fn par_silu_inplace(data: &mut [f32]) {
874 let len = data.len();
875 if len < ACTIVATION_PAR_MIN {
876 silu_inplace(data);
877 return;
878 }
879 let cfg = crate::config::RuntimeConfig::global();
880 let chunk = 512;
881 let rows = len / chunk;
882 if rows < 2 {
883 silu_inplace(data);
884 return;
885 }
886 let data_ptr = data.as_mut_ptr() as usize;
887 pool::par_for(rows, cfg.min_rows_per_thread, &|off, cnt| unsafe {
888 let start = off * chunk;
889 let end = if off + cnt >= rows {
890 len
891 } else {
892 (off + cnt) * chunk
893 };
894 let s = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(start), end - start);
895 silu_inplace(s);
896 });
897 let done = rows * chunk;
898 if done < len {
899 silu_inplace(&mut data[done..]);
900 }
901}
902
903#[cfg(target_arch = "aarch64")]
910pub fn neon_sgemm_small(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
911 use std::arch::aarch64::*;
912 let n4 = n / 4;
913 unsafe {
914 for j4 in 0..n4 {
915 let j = j4 * 4;
916 let mut acc = [vdupq_n_f32(0.0); 8];
918 for kk in 0..k {
919 let bv = vld1q_f32(b.as_ptr().add(kk * n + j));
920 for i in 0..m {
921 let av = vdupq_n_f32(*a.as_ptr().add(i * k + kk));
922 acc[i] = vfmaq_f32(acc[i], av, bv);
923 }
924 }
925 for i in 0..m {
926 vst1q_f32(c.as_mut_ptr().add(i * n + j), acc[i]);
927 }
928 }
929 for j in (n4 * 4)..n {
931 for i in 0..m {
932 let mut sum = 0f32;
933 for kk in 0..k {
934 sum += a[i * k + kk] * b[kk * n + j];
935 }
936 c[i * n + j] = sum;
937 }
938 }
939 }
940}
941
942#[cfg(not(target_arch = "aarch64"))]
943pub fn neon_sgemm_small(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
944 crate::naive::matmul(a, b, c, m, k, n);
945}
946
947#[cfg(target_arch = "aarch64")]
949pub fn neon_sgemm_bias_small(
950 a: &[f32],
951 b: &[f32],
952 bias: &[f32],
953 c: &mut [f32],
954 m: usize,
955 k: usize,
956 n: usize,
957) {
958 neon_sgemm_small(a, b, c, m, k, n);
959 crate::blas::bias_add(c, bias, m, n);
960}
961
962#[cfg(not(target_arch = "aarch64"))]
963pub fn neon_sgemm_bias_small(
964 a: &[f32],
965 b: &[f32],
966 bias: &[f32],
967 c: &mut [f32],
968 m: usize,
969 k: usize,
970 n: usize,
971) {
972 crate::naive::matmul(a, b, c, m, k, n);
973 crate::naive::bias_add(c, bias, m, n);
974}
975
976fn scalar_gelu(x: f32) -> f32 {
979 x * 0.5 * (1.0 + scalar_erf(x * std::f32::consts::FRAC_1_SQRT_2))
980}
981
982fn scalar_erf(x: f32) -> f32 {
983 let sign = if x >= 0.0 { 1.0f32 } else { -1.0 };
984 let xa = x.abs();
985 let t = 1.0 / (1.0 + 0.3275911 * xa);
986 let y = t
987 * (0.254_829_6
988 + t * (-0.284_496_72 + t * (1.421_413_8 + t * (-1.453_152_1 + t * 1.061_405_4))));
989 sign * (1.0 - y * (-xa * xa).exp())
990}
991
992pub fn layer_norm2d_nchw(
995 input: &[f32],
996 gamma: &[f32],
997 beta: &[f32],
998 output: &mut [f32],
999 batch: usize,
1000 channels: usize,
1001 h: usize,
1002 w: usize,
1003 eps: f32,
1004) {
1005 let spatial = h * w;
1006 for b in 0..batch {
1007 for i in 0..spatial {
1008 let mut mean = 0.0f32;
1009 for c in 0..channels {
1010 mean += input[((b * channels + c) * spatial) + i];
1011 }
1012 mean /= channels as f32;
1013 let mut var = 0.0f32;
1014 for c in 0..channels {
1015 let d = input[((b * channels + c) * spatial) + i] - mean;
1016 var += d * d;
1017 }
1018 var /= channels as f32;
1019 let inv = 1.0 / (var + eps).sqrt();
1020 for c in 0..channels {
1021 let v = (input[((b * channels + c) * spatial) + i] - mean) * inv;
1022 output[((b * channels + c) * spatial) + i] = v * gamma[c] + beta[c];
1023 }
1024 }
1025 }
1026}
1027
1028pub fn conv_transpose2d_nchw(
1031 input: &[f32],
1032 weight: &[f32],
1033 output: &mut [f32],
1034 n: usize,
1035 c_in: usize,
1036 h: usize,
1037 w: usize,
1038 c_out: usize,
1039 h_out: usize,
1040 w_out: usize,
1041 kh: usize,
1042 kw: usize,
1043 sh: usize,
1044 sw: usize,
1045 ph: usize,
1046 pw: usize,
1047 dh: usize,
1048 dw: usize,
1049 groups: usize,
1050) {
1051 output.fill(0.0);
1052 let c_in_per_g = c_in / groups;
1053 let c_out_per_g = c_out / groups;
1054 for ni in 0..n {
1055 for ic in 0..c_in {
1056 let g = ic / c_in_per_g;
1057 let _ic_off = ic % c_in_per_g;
1058 for iy in 0..h {
1059 for ix in 0..w {
1060 let v = input[((ni * c_in + ic) * h + iy) * w + ix];
1061 if v == 0.0 {
1062 continue;
1063 }
1064 for ky in 0..kh {
1065 let oy = iy * sh + ky * dh;
1066 if oy < ph || oy >= h_out + ph {
1067 continue;
1068 }
1069 let oy = oy - ph;
1070 if oy >= h_out {
1071 continue;
1072 }
1073 for kx in 0..kw {
1074 let ox = ix * sw + kx * dw;
1075 if ox < pw || ox >= w_out + pw {
1076 continue;
1077 }
1078 let ox = ox - pw;
1079 if ox >= w_out {
1080 continue;
1081 }
1082 for oc_off in 0..c_out_per_g {
1083 let oc = g * c_out_per_g + oc_off;
1084 let w_idx = ((ic * c_out_per_g + oc_off) * kh + ky) * kw + kx;
1085 let wt = weight[w_idx];
1086 output[((ni * c_out + oc) * h_out + oy) * w_out + ox] += v * wt;
1087 }
1088 }
1089 }
1090 }
1091 }
1092 }
1093 }
1094}
1095
1096pub fn group_norm_nchw(
1098 input: &[f32],
1099 gamma: &[f32],
1100 beta: &[f32],
1101 output: &mut [f32],
1102 batch: usize,
1103 channels: usize,
1104 h: usize,
1105 w: usize,
1106 num_groups: usize,
1107 eps: f32,
1108) {
1109 let cpg = channels / num_groups;
1110 let spatial = h * w;
1111 let n = (cpg * spatial) as f32;
1112 for b in 0..batch {
1113 for g in 0..num_groups {
1114 let c0 = g * cpg;
1115 let mut mean = 0.0f32;
1116 for c in 0..cpg {
1117 let plane = &input
1118 [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1119 mean += plane.iter().sum::<f32>();
1120 }
1121 mean /= n;
1122 let mut var = 0.0f32;
1123 for c in 0..cpg {
1124 let plane = &input
1125 [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1126 for &v in plane {
1127 let d = v - mean;
1128 var += d * d;
1129 }
1130 }
1131 var /= n;
1132 let inv = 1.0 / (var + eps).sqrt();
1133 for c in 0..cpg {
1134 let gi = c0 + c;
1135 let gamm = gamma[gi];
1136 let bet = beta[gi];
1137 let src =
1138 &input[((b * channels + gi) * spatial)..((b * channels + gi + 1) * spatial)];
1139 let dst = &mut output
1140 [((b * channels + gi) * spatial)..((b * channels + gi + 1) * spatial)];
1141 for (d, &s) in dst.iter_mut().zip(src) {
1142 *d = (s - mean) * inv * gamm + bet;
1143 }
1144 }
1145 }
1146 }
1147}
1148
1149pub fn resize_nearest_2x_nchw(
1151 input: &[f32],
1152 output: &mut [f32],
1153 channels: usize,
1154 h: usize,
1155 w: usize,
1156) {
1157 let h2 = h * 2;
1158 let w2 = w * 2;
1159 for c in 0..channels {
1160 let plane = &input[c * h * w..(c + 1) * h * w];
1161 let dst = &mut output[c * h2 * w2..(c + 1) * h2 * w2];
1162 for y in 0..h {
1163 for x in 0..w {
1164 let v = plane[y * w + x];
1165 for dy in 0..2 {
1166 for dx in 0..2 {
1167 dst[(y * 2 + dy) * w2 + (x * 2 + dx)] = v;
1168 }
1169 }
1170 }
1171 }
1172 }
1173}
1174
1175#[cfg(test)]
1176mod tests {
1177 use super::*;
1178
1179 #[test]
1180 fn gelu_correctness() {
1181 let x = 1.5f32;
1182 let g = scalar_gelu(x);
1183 assert!((g - 1.3990).abs() < 0.01, "gelu(1.5) = {g}");
1185 }
1186
1187 #[test]
1188 fn bias_gelu_works() {
1189 let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1190 let bias = vec![0.1, 0.2, 0.3, 0.4];
1191 bias_gelu(&mut data, &bias, 2, 4);
1192 for &v in &data {
1194 assert!(v > 0.0, "bias_gelu produced {v}");
1195 }
1196 }
1197
1198 #[test]
1199 fn layer_norm_unit_test() {
1200 let input = vec![1.0, 2.0, 3.0, 4.0];
1201 let gamma = vec![1.0; 4];
1202 let beta = vec![0.0; 4];
1203 let mut output = vec![0.0; 4];
1204 layer_norm_row(&input, &gamma, &beta, &mut output, 4, 1e-5);
1205 assert!((output[0] - -1.342).abs() < 0.01);
1207 assert!((output[3] - 1.342).abs() < 0.01);
1208 let sum: f32 = output.iter().sum();
1210 assert!(sum.abs() < 0.01, "LN sum should be ~0, got {sum}");
1211 }
1212
1213 #[test]
1214 fn par_bias_gelu_matches_sequential() {
1215 let n = 100;
1216 let m = 64;
1217 let mut data_par = vec![0.5f32; n * m];
1218 let mut data_seq = data_par.clone();
1219 let bias = vec![0.1f32; m];
1220
1221 bias_gelu(&mut data_seq, &bias, n, m);
1222 par_bias_gelu(&mut data_par, &bias, n, m);
1223
1224 let max_diff: f32 = data_par
1225 .iter()
1226 .zip(data_seq.iter())
1227 .map(|(a, b)| (a - b).abs())
1228 .fold(0f32, f32::max);
1229 assert!(max_diff < 1e-6, "par vs seq diff: {max_diff}");
1230 }
1231}