1use crate::pool;
23
24#[cfg(target_arch = "aarch64")]
29#[inline(always)]
30#[allow(unsafe_op_in_unsafe_fn)]
31pub unsafe fn neon_exp4(x: std::arch::aarch64::float32x4_t) -> std::arch::aarch64::float32x4_t {
32 use std::arch::aarch64::*;
33 let x = vmaxq_f32(x, vdupq_n_f32(-87.3));
34 let x = vminq_f32(x, vdupq_n_f32(88.7));
35 let inv_ln2 = vdupq_n_f32(std::f32::consts::LOG2_E);
36 let ln2_hi = vdupq_n_f32(0.693_145_75);
37 let ln2_lo = vdupq_n_f32(1.428_606_8e-6);
38 let n = vrndnq_f32(vmulq_f32(x, inv_ln2));
39 let r = vfmsq_f32(vfmsq_f32(x, n, ln2_hi), n, ln2_lo);
40 let c1 = vdupq_n_f32(1.0);
41 let mut p = vdupq_n_f32(0.001_388_888_9);
42 p = vfmaq_f32(vdupq_n_f32(0.008_333_334), p, r);
43 p = vfmaq_f32(vdupq_n_f32(0.041_666_668), p, r);
44 p = vfmaq_f32(vdupq_n_f32(0.166_666_67), p, r);
45 p = vfmaq_f32(vdupq_n_f32(0.5), p, r);
46 p = vfmaq_f32(c1, p, r);
47 p = vfmaq_f32(c1, p, r);
48 let ni = vcvtq_s32_f32(n);
49 vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(p), vshlq_n_s32(ni, 23)))
50}
51
52#[cfg(all(
56 target_arch = "x86_64",
57 target_feature = "avx2",
58 target_feature = "fma"
59))]
60#[inline(always)]
61#[allow(unsafe_op_in_unsafe_fn)]
62pub unsafe fn avx2_exp8(x: std::arch::x86_64::__m256) -> std::arch::x86_64::__m256 {
63 use std::arch::x86_64::*;
64 let x = _mm256_max_ps(x, _mm256_set1_ps(-87.3));
65 let x = _mm256_min_ps(x, _mm256_set1_ps(88.7));
66 let inv_ln2 = _mm256_set1_ps(1.442695040888963);
67 let ln2_hi = _mm256_set1_ps(0.693145751953125);
68 let ln2_lo = _mm256_set1_ps(1.428606765330187e-6);
69 let n = _mm256_round_ps::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(_mm256_mul_ps(
71 x, inv_ln2,
72 ));
73 let r = _mm256_fnmadd_ps(n, ln2_lo, _mm256_fnmadd_ps(n, ln2_hi, x));
75 let c1 = _mm256_set1_ps(1.0);
76 let mut p = _mm256_set1_ps(0.001388888888888889);
77 p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.008333333333333333));
78 p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.041666666666666664));
79 p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.16666666666666666));
80 p = _mm256_fmadd_ps(p, r, _mm256_set1_ps(0.5));
81 p = _mm256_fmadd_ps(p, r, c1);
82 p = _mm256_fmadd_ps(p, r, c1);
83 let ni = _mm256_cvtps_epi32(n);
85 let shifted = _mm256_slli_epi32::<23>(ni);
86 _mm256_castsi256_ps(_mm256_add_epi32(_mm256_castps_si256(p), shifted))
87}
88
89#[cfg(target_arch = "aarch64")]
94pub fn bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
95 use std::arch::aarch64::*;
96 let chunks = n / 4;
97 unsafe {
98 let half = vdupq_n_f32(0.5);
99 let one = vdupq_n_f32(1.0);
100 let inv_sqrt2 = vdupq_n_f32(std::f32::consts::FRAC_1_SQRT_2);
101 let p = vdupq_n_f32(0.3275911);
102 let a1 = vdupq_n_f32(0.254_829_6);
103 let a2 = vdupq_n_f32(-0.284_496_72);
104 let a3 = vdupq_n_f32(1.421_413_8);
105 let a4 = vdupq_n_f32(-1.453_152_1);
106 let a5 = vdupq_n_f32(1.061_405_4);
107 let neg_one = vdupq_n_f32(-1.0);
108 let zero = vdupq_n_f32(0.0);
109
110 for row in 0..m {
111 let base = row * n;
112 for c in 0..chunks {
113 let off = base + c * 4;
114 let ptr = data.as_mut_ptr().add(off);
115 let x = vaddq_f32(vld1q_f32(ptr), vld1q_f32(bias.as_ptr().add(c * 4)));
116 let erf_arg = vmulq_f32(x, inv_sqrt2);
117 let xa = vabsq_f32(erf_arg);
118 let sign = vbslq_f32(vcgeq_f32(erf_arg, zero), one, neg_one);
119 let denom = vfmaq_f32(one, p, xa);
120 let t = vdivq_f32(one, denom);
121 let mut y = a5;
122 y = vfmaq_f32(a4, y, t);
123 y = vfmaq_f32(a3, y, t);
124 y = vfmaq_f32(a2, y, t);
125 y = vfmaq_f32(a1, y, t);
126 y = vmulq_f32(y, t);
127 let exp_val = neon_exp4(vnegq_f32(vmulq_f32(xa, xa)));
128 let erf_val = vmulq_f32(sign, vfmsq_f32(one, y, exp_val));
129 vst1q_f32(ptr, vmulq_f32(x, vmulq_f32(half, vaddq_f32(one, erf_val))));
130 }
131 for i in (chunks * 4)..n {
132 let x = data[base + i] + bias[i];
133 data[base + i] = scalar_gelu(x);
134 }
135 }
136 }
137}
138
139#[cfg(all(
140 target_arch = "x86_64",
141 target_feature = "avx2",
142 target_feature = "fma"
143))]
144pub fn bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
145 use std::arch::x86_64::*;
146 let chunks = n / 8;
147 unsafe {
148 let half = _mm256_set1_ps(0.5);
149 let one = _mm256_set1_ps(1.0);
150 let inv_sqrt2 = _mm256_set1_ps(std::f32::consts::FRAC_1_SQRT_2);
151 let p = _mm256_set1_ps(0.3275911);
152 let a1 = _mm256_set1_ps(0.254829592);
153 let a2 = _mm256_set1_ps(-0.284496736);
154 let a3 = _mm256_set1_ps(1.421413741);
155 let a4 = _mm256_set1_ps(-1.453152027);
156 let a5 = _mm256_set1_ps(1.061405429);
157 let neg_one = _mm256_set1_ps(-1.0);
158 let zero = _mm256_set1_ps(0.0);
159 let abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fff_ffff));
161
162 for row in 0..m {
163 let base = row * n;
164 for c in 0..chunks {
165 let off = base + c * 8;
166 let ptr = data.as_mut_ptr().add(off);
167 let x = _mm256_add_ps(
168 _mm256_loadu_ps(ptr),
169 _mm256_loadu_ps(bias.as_ptr().add(c * 8)),
170 );
171 let erf_arg = _mm256_mul_ps(x, inv_sqrt2);
172 let xa = _mm256_and_ps(erf_arg, abs_mask);
173 let ge0 = _mm256_cmp_ps::<_CMP_GE_OQ>(erf_arg, zero);
175 let sign = _mm256_blendv_ps(neg_one, one, ge0);
176 let denom = _mm256_fmadd_ps(p, xa, one);
177 let t = _mm256_div_ps(one, denom);
178 let mut y = a5;
179 y = _mm256_fmadd_ps(y, t, a4);
180 y = _mm256_fmadd_ps(y, t, a3);
181 y = _mm256_fmadd_ps(y, t, a2);
182 y = _mm256_fmadd_ps(y, t, a1);
183 y = _mm256_mul_ps(y, t);
184 let exp_val = avx2_exp8(_mm256_sub_ps(zero, _mm256_mul_ps(xa, xa)));
185 let erf_val = _mm256_mul_ps(sign, _mm256_fnmadd_ps(y, exp_val, one));
187 _mm256_storeu_ps(
188 ptr,
189 _mm256_mul_ps(x, _mm256_mul_ps(half, _mm256_add_ps(one, erf_val))),
190 );
191 }
192 for i in (chunks * 8)..n {
193 let x = data[base + i] + bias[i];
194 data[base + i] = scalar_gelu(x);
195 }
196 }
197 }
198}
199
200#[cfg(not(any(
201 target_arch = "aarch64",
202 all(
203 target_arch = "x86_64",
204 target_feature = "avx2",
205 target_feature = "fma"
206 )
207)))]
208pub fn bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
209 for row in 0..m {
210 let base = row * n;
211 for i in 0..n {
212 let x = data[base + i] + bias[i];
213 data[base + i] = scalar_gelu(x);
214 }
215 }
216}
217
218pub fn par_bias_gelu(data: &mut [f32], bias: &[f32], m: usize, n: usize) {
220 let cfg = crate::config::RuntimeConfig::global();
221 if m * n < cfg.par_threshold || m < cfg.min_rows_per_thread {
222 bias_gelu(data, bias, m, n);
223 return;
224 }
225 let data_ptr = data.as_mut_ptr() as usize;
226 let bias_ptr = bias.as_ptr() as usize;
227 pool::par_for(m, cfg.min_rows_per_thread, &|off, cnt| unsafe {
228 let d = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(off * n), cnt * n);
229 let b = std::slice::from_raw_parts(bias_ptr as *const f32, n);
230 bias_gelu(d, b, cnt, n);
231 });
232}
233
234#[cfg(target_arch = "aarch64")]
238pub fn silu_inplace(data: &mut [f32]) {
239 use std::arch::aarch64::*;
240 let chunks = data.len() / 4;
241 unsafe {
242 let one = vdupq_n_f32(1.0);
243 for c in 0..chunks {
244 let ptr = data.as_mut_ptr().add(c * 4);
245 let x = vld1q_f32(ptr);
246 let exp_neg = neon_exp4(vnegq_f32(x));
247 let sigmoid = vdivq_f32(one, vaddq_f32(one, exp_neg));
248 vst1q_f32(ptr, vmulq_f32(x, sigmoid));
249 }
250 }
251 for i in (chunks * 4)..data.len() {
252 let x = data[i];
253 data[i] = x / (1.0 + (-x).exp());
254 }
255}
256
257#[cfg(all(
258 target_arch = "x86_64",
259 target_feature = "avx2",
260 target_feature = "fma"
261))]
262pub fn silu_inplace(data: &mut [f32]) {
263 use std::arch::x86_64::*;
264 let chunks = data.len() / 8;
265 unsafe {
266 let one = _mm256_set1_ps(1.0);
267 let zero = _mm256_set1_ps(0.0);
268 for c in 0..chunks {
269 let off = c * 8;
270 let ptr = data.as_mut_ptr().add(off);
271 let x = _mm256_loadu_ps(ptr);
272 let neg_x = _mm256_sub_ps(zero, x);
274 let denom = _mm256_add_ps(one, avx2_exp8(neg_x));
275 _mm256_storeu_ps(ptr, _mm256_div_ps(x, denom));
276 }
277 for i in (chunks * 8)..data.len() {
278 let x = data[i];
279 data[i] = x / (1.0 + (-x).exp());
280 }
281 }
282}
283
284#[cfg(not(any(
285 target_arch = "aarch64",
286 all(
287 target_arch = "x86_64",
288 target_feature = "avx2",
289 target_feature = "fma"
290 )
291)))]
292pub fn silu_inplace(data: &mut [f32]) {
293 for v in data.iter_mut() {
294 let x = *v;
295 *v = x / (1.0 + (-x).exp());
296 }
297}
298
299#[cfg(target_arch = "aarch64")]
304pub fn layer_norm_row(
305 input: &[f32],
306 gamma: &[f32],
307 beta: &[f32],
308 output: &mut [f32],
309 h: usize,
310 eps: f32,
311) {
312 use std::arch::aarch64::*;
313 let inv_hf = 1.0 / h as f32;
314 let chunks = h / 4;
315 unsafe {
316 let mut vsum = vdupq_n_f32(0.0);
317 let mut vsumsq = vdupq_n_f32(0.0);
318 for c in 0..chunks {
319 let x = vld1q_f32(input.as_ptr().add(c * 4));
320 vsum = vaddq_f32(vsum, x);
321 vsumsq = vfmaq_f32(vsumsq, x, x);
322 }
323 let mut sum = vaddvq_f32(vsum);
324 let mut sumsq = vaddvq_f32(vsumsq);
325 for i in (chunks * 4)..h {
326 sum += input[i];
327 sumsq += input[i] * input[i];
328 }
329 let mean = sum * inv_hf;
330 let var = (sumsq * inv_hf - mean * mean).max(0.0);
331 let inv = 1.0 / (var + eps).sqrt();
332 let vmean = vdupq_n_f32(mean);
333 let vinv = vdupq_n_f32(inv);
334 for c in 0..chunks {
335 let off = c * 4;
336 let x = vld1q_f32(input.as_ptr().add(off));
337 let norm = vmulq_f32(vsubq_f32(x, vmean), vinv);
338 vst1q_f32(
339 output.as_mut_ptr().add(off),
340 vfmaq_f32(
341 vld1q_f32(beta.as_ptr().add(off)),
342 norm,
343 vld1q_f32(gamma.as_ptr().add(off)),
344 ),
345 );
346 }
347 for i in (chunks * 4)..h {
348 output[i] = (input[i] - mean) * inv * gamma[i] + beta[i];
349 }
350 }
351}
352
353#[cfg(all(
354 target_arch = "x86_64",
355 target_feature = "avx2",
356 target_feature = "fma"
357))]
358pub fn layer_norm_row(
359 input: &[f32],
360 gamma: &[f32],
361 beta: &[f32],
362 output: &mut [f32],
363 h: usize,
364 eps: f32,
365) {
366 use std::arch::x86_64::*;
367 let inv_hf = 1.0 / h as f32;
368 let chunks = h / 8;
369 unsafe {
370 let mut vsum = _mm256_setzero_ps();
371 let mut vsumsq = _mm256_setzero_ps();
372 for c in 0..chunks {
373 let x = _mm256_loadu_ps(input.as_ptr().add(c * 8));
374 vsum = _mm256_add_ps(vsum, x);
375 vsumsq = _mm256_fmadd_ps(x, x, vsumsq);
376 }
377 let hsum = {
379 let lo = _mm256_castps256_ps128(vsum);
380 let hi = _mm256_extractf128_ps::<1>(vsum);
381 let s4 = _mm_add_ps(lo, hi);
382 let s2 = _mm_add_ps(s4, _mm_movehl_ps(s4, s4));
383 let s1 = _mm_add_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
384 _mm_cvtss_f32(s1)
385 };
386 let hsumsq = {
387 let lo = _mm256_castps256_ps128(vsumsq);
388 let hi = _mm256_extractf128_ps::<1>(vsumsq);
389 let s4 = _mm_add_ps(lo, hi);
390 let s2 = _mm_add_ps(s4, _mm_movehl_ps(s4, s4));
391 let s1 = _mm_add_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
392 _mm_cvtss_f32(s1)
393 };
394 let mut sum = hsum;
395 let mut sumsq = hsumsq;
396 for i in (chunks * 8)..h {
397 sum += input[i];
398 sumsq += input[i] * input[i];
399 }
400 let mean = sum * inv_hf;
401 let var = (sumsq * inv_hf - mean * mean).max(0.0);
402 let inv = 1.0 / (var + eps).sqrt();
403 let vmean = _mm256_set1_ps(mean);
404 let vinv = _mm256_set1_ps(inv);
405 for c in 0..chunks {
406 let off = c * 8;
407 let x = _mm256_loadu_ps(input.as_ptr().add(off));
408 let norm = _mm256_mul_ps(_mm256_sub_ps(x, vmean), vinv);
409 let g = _mm256_loadu_ps(gamma.as_ptr().add(off));
410 let b = _mm256_loadu_ps(beta.as_ptr().add(off));
411 _mm256_storeu_ps(output.as_mut_ptr().add(off), _mm256_fmadd_ps(norm, g, b));
412 }
413 for i in (chunks * 8)..h {
414 output[i] = (input[i] - mean) * inv * gamma[i] + beta[i];
415 }
416 }
417}
418
419#[cfg(not(any(
420 target_arch = "aarch64",
421 all(
422 target_arch = "x86_64",
423 target_feature = "avx2",
424 target_feature = "fma"
425 )
426)))]
427pub fn layer_norm_row(
428 input: &[f32],
429 gamma: &[f32],
430 beta: &[f32],
431 output: &mut [f32],
432 h: usize,
433 eps: f32,
434) {
435 let inv_hf = 1.0 / h as f32;
436 let mut sum = 0f32;
437 let mut sumsq = 0f32;
438 for i in 0..h {
439 sum += input[i];
440 sumsq += input[i] * input[i];
441 }
442 let mean = sum * inv_hf;
443 let var = (sumsq * inv_hf - mean * mean).max(0.0);
444 let inv = 1.0 / (var + eps).sqrt();
445 for i in 0..h {
446 output[i] = (input[i] - mean) * inv * gamma[i] + beta[i];
447 }
448}
449
450pub fn batch_norm_inference(
455 x: &[f32],
456 gamma: &[f32],
457 beta: &[f32],
458 mean: &[f32],
459 var: &[f32],
460 out: &mut [f32],
461 channels: usize,
462 eps: f32,
463) {
464 let n = x.len() / channels.max(1);
465 for i in 0..n {
466 for c in 0..channels {
467 let idx = i * channels + c;
468 let inv = 1.0 / (var[c] + eps).sqrt();
469 let xhat = (x[idx] - mean[c]) * inv;
470 out[idx] = gamma[c] * xhat + beta[c];
471 }
472 }
473}
474
475pub fn batch_norm_inference_backward_input(
477 x: &[f32],
478 gamma: &[f32],
479 _mean: &[f32],
480 var: &[f32],
481 dy: &[f32],
482 dx: &mut [f32],
483 channels: usize,
484 eps: f32,
485) {
486 let n = x.len() / channels.max(1);
487 for i in 0..n {
488 for c in 0..channels {
489 let idx = i * channels + c;
490 let inv = 1.0 / (var[c] + eps).sqrt();
491 dx[idx] = dy[idx] * gamma[c] * inv;
492 }
493 }
494}
495
496pub fn batch_norm_inference_backward_gamma(
498 x: &[f32],
499 mean: &[f32],
500 var: &[f32],
501 dy: &[f32],
502 dgamma: &mut [f32],
503 channels: usize,
504 eps: f32,
505) {
506 dgamma.fill(0.0);
507 let n = x.len() / channels.max(1);
508 for i in 0..n {
509 for c in 0..channels {
510 let idx = i * channels + c;
511 let inv = 1.0 / (var[c] + eps).sqrt();
512 let xhat = (x[idx] - mean[c]) * inv;
513 dgamma[c] += dy[idx] * xhat;
514 }
515 }
516}
517
518pub fn batch_norm_inference_backward_beta(dy: &[f32], dbeta: &mut [f32], channels: usize) {
520 dbeta.fill(0.0);
521 let n = dy.len() / channels.max(1);
522 for i in 0..n {
523 for c in 0..channels {
524 dbeta[c] += dy[i * channels + c];
525 }
526 }
527}
528
529pub fn residual_bias_layer_norm(
532 a: &[f32],
533 b: &[f32],
534 bias: &[f32],
535 gamma: &[f32],
536 beta: &[f32],
537 output: &mut [f32],
538 n: usize,
539 h: usize,
540 eps: f32,
541) {
542 let mut tmp = vec![0f32; h];
544 for row in 0..n {
545 let base = row * h;
546 for i in 0..h {
547 tmp[i] = a[base + i] + b[base + i] + bias[i];
548 }
549 layer_norm_row(&tmp, gamma, beta, &mut output[base..base + h], h, eps);
550 }
551}
552
553pub fn residual_bias_rms_norm(
556 a: &[f32],
557 b: &[f32],
558 bias: &[f32],
559 gamma: &[f32],
560 beta: &[f32],
561 output: &mut [f32],
562 n: usize,
563 h: usize,
564 eps: f32,
565) {
566 let inv_h = 1.0 / h as f32;
567 for row in 0..n {
568 let base = row * h;
569 let mut sumsq = 0f32;
570 for i in 0..h {
571 let v = a[base + i] + b[base + i] + bias[i];
572 sumsq += v * v;
573 }
574 let inv_rms = (sumsq * inv_h + eps).sqrt().recip();
575 for i in 0..h {
576 let v = a[base + i] + b[base + i] + bias[i];
577 output[base + i] = v * inv_rms * gamma[i] + beta[i];
578 }
579 }
580}
581
582pub fn par_residual_bias_ln(
584 a: &[f32],
585 b: &[f32],
586 bias: &[f32],
587 gamma: &[f32],
588 beta: &[f32],
589 output: &mut [f32],
590 n: usize,
591 h: usize,
592 eps: f32,
593) {
594 let cfg = crate::config::RuntimeConfig::global();
595 if n * h < cfg.par_threshold || n < cfg.min_rows_per_thread {
596 residual_bias_layer_norm(a, b, bias, gamma, beta, output, n, h, eps);
597 return;
598 }
599 let a_ptr = a.as_ptr() as usize;
600 let b_ptr = b.as_ptr() as usize;
601 let o_ptr = output.as_mut_ptr() as usize;
602 let bias_ptr = bias.as_ptr() as usize;
603 let gamma_ptr = gamma.as_ptr() as usize;
604 let beta_ptr = beta.as_ptr() as usize;
605 pool::par_for(n, cfg.min_rows_per_thread, &|off, cnt| unsafe {
606 let a_s = std::slice::from_raw_parts((a_ptr as *const f32).add(off * h), cnt * h);
607 let b_s = std::slice::from_raw_parts((b_ptr as *const f32).add(off * h), cnt * h);
608 let o_s = std::slice::from_raw_parts_mut((o_ptr as *mut f32).add(off * h), cnt * h);
609 let bi = std::slice::from_raw_parts(bias_ptr as *const f32, h);
610 let g = std::slice::from_raw_parts(gamma_ptr as *const f32, h);
611 let be = std::slice::from_raw_parts(beta_ptr as *const f32, h);
612 residual_bias_layer_norm(a_s, b_s, bi, g, be, o_s, cnt, h, eps);
613 });
614}
615
616#[cfg(target_arch = "aarch64")]
620pub fn neon_softmax(data: &mut [f32], rows: usize, cols: usize) {
621 use std::arch::aarch64::*;
622 let chunks = cols / 4;
623 unsafe {
624 for row in 0..rows {
625 let base = row * cols;
626 let ptr = data.as_mut_ptr().add(base);
627
628 let mut vmax = vdupq_n_f32(f32::NEG_INFINITY);
630 for c in 0..chunks {
631 vmax = vmaxq_f32(vmax, vld1q_f32(ptr.add(c * 4)));
632 }
633 let mut max_val = vmaxvq_f32(vmax);
634 for i in (chunks * 4)..cols {
635 max_val = max_val.max(*ptr.add(i));
636 }
637
638 let vmx = vdupq_n_f32(max_val);
640 let mut vsum = vdupq_n_f32(0.0);
641 for c in 0..chunks {
642 let off = c * 4;
643 let e = neon_exp4(vsubq_f32(vld1q_f32(ptr.add(off)), vmx));
644 vst1q_f32(ptr.add(off), e);
645 vsum = vaddq_f32(vsum, e);
646 }
647 let mut sum = vaddvq_f32(vsum);
648 for i in (chunks * 4)..cols {
649 let e = (*ptr.add(i) - max_val).exp();
650 *ptr.add(i) = e;
651 sum += e;
652 }
653
654 let vinv = vdupq_n_f32(1.0 / sum);
656 for c in 0..chunks {
657 let off = c * 4;
658 vst1q_f32(ptr.add(off), vmulq_f32(vld1q_f32(ptr.add(off)), vinv));
659 }
660 let inv = 1.0 / sum;
661 for i in (chunks * 4)..cols {
662 *ptr.add(i) *= inv;
663 }
664 }
665 }
666}
667
668#[cfg(all(
669 target_arch = "x86_64",
670 target_feature = "avx2",
671 target_feature = "fma"
672))]
673pub fn neon_softmax(data: &mut [f32], rows: usize, cols: usize) {
674 use std::arch::x86_64::*;
675 let chunks = cols / 8;
676 unsafe {
677 for r in 0..rows {
678 let row = data.as_mut_ptr().add(r * cols);
679 let mut vmax = _mm256_set1_ps(f32::NEG_INFINITY);
681 for c in 0..chunks {
682 vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(row.add(c * 8)));
683 }
684 let mut max_v = {
685 let lo = _mm256_castps256_ps128(vmax);
686 let hi = _mm256_extractf128_ps::<1>(vmax);
687 let s4 = _mm_max_ps(lo, hi);
688 let s2 = _mm_max_ps(s4, _mm_movehl_ps(s4, s4));
689 let s1 = _mm_max_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
690 _mm_cvtss_f32(s1)
691 };
692 for i in (chunks * 8)..cols {
693 let v = *row.add(i);
694 if v > max_v {
695 max_v = v;
696 }
697 }
698 let vmax = _mm256_set1_ps(max_v);
700 let mut vsum = _mm256_setzero_ps();
701 for c in 0..chunks {
702 let off = c * 8;
703 let e = avx2_exp8(_mm256_sub_ps(_mm256_loadu_ps(row.add(off)), vmax));
704 _mm256_storeu_ps(row.add(off), e);
705 vsum = _mm256_add_ps(vsum, e);
706 }
707 let mut sum_v = {
708 let lo = _mm256_castps256_ps128(vsum);
709 let hi = _mm256_extractf128_ps::<1>(vsum);
710 let s4 = _mm_add_ps(lo, hi);
711 let s2 = _mm_add_ps(s4, _mm_movehl_ps(s4, s4));
712 let s1 = _mm_add_ss(s2, _mm_shuffle_ps::<0x55>(s2, s2));
713 _mm_cvtss_f32(s1)
714 };
715 for i in (chunks * 8)..cols {
716 let v = (*row.add(i) - max_v).exp();
717 *row.add(i) = v;
718 sum_v += v;
719 }
720 let vinv = _mm256_set1_ps(1.0 / sum_v);
722 for c in 0..chunks {
723 let off = c * 8;
724 _mm256_storeu_ps(
725 row.add(off),
726 _mm256_mul_ps(_mm256_loadu_ps(row.add(off)), vinv),
727 );
728 }
729 let inv_sum = 1.0 / sum_v;
730 for i in (chunks * 8)..cols {
731 *row.add(i) *= inv_sum;
732 }
733 }
734 }
735}
736
737#[cfg(not(any(
738 target_arch = "aarch64",
739 all(
740 target_arch = "x86_64",
741 target_feature = "avx2",
742 target_feature = "fma"
743 )
744)))]
745pub fn neon_softmax(data: &mut [f32], rows: usize, cols: usize) {
746 crate::naive::softmax(data, rows, cols);
747}
748
749#[cfg(target_arch = "aarch64")]
753pub fn gelu_inplace(data: &mut [f32]) {
754 use std::arch::aarch64::*;
755 let len = data.len();
756 let chunks = len / 4;
757 unsafe {
758 let half = vdupq_n_f32(0.5);
759 let one = vdupq_n_f32(1.0);
760 let inv_sqrt2 = vdupq_n_f32(std::f32::consts::FRAC_1_SQRT_2);
761 let p = vdupq_n_f32(0.3275911);
762 let a1 = vdupq_n_f32(0.254_829_6);
763 let a2 = vdupq_n_f32(-0.284_496_72);
764 let a3 = vdupq_n_f32(1.421_413_8);
765 let a4 = vdupq_n_f32(-1.453_152_1);
766 let a5 = vdupq_n_f32(1.061_405_4);
767 let neg_one = vdupq_n_f32(-1.0);
768 let zero = vdupq_n_f32(0.0);
769
770 for c in 0..chunks {
771 let ptr = data.as_mut_ptr().add(c * 4);
772 let x = vld1q_f32(ptr);
773 let erf_arg = vmulq_f32(x, inv_sqrt2);
774 let xa = vabsq_f32(erf_arg);
775 let sign = vbslq_f32(vcgeq_f32(erf_arg, zero), one, neg_one);
776 let denom = vfmaq_f32(one, p, xa);
777 let t = vdivq_f32(one, denom);
778 let mut y = a5;
779 y = vfmaq_f32(a4, y, t);
780 y = vfmaq_f32(a3, y, t);
781 y = vfmaq_f32(a2, y, t);
782 y = vfmaq_f32(a1, y, t);
783 y = vmulq_f32(y, t);
784 let exp_val = neon_exp4(vnegq_f32(vmulq_f32(xa, xa)));
785 let erf_val = vmulq_f32(sign, vfmsq_f32(one, y, exp_val));
786 vst1q_f32(ptr, vmulq_f32(x, vmulq_f32(half, vaddq_f32(one, erf_val))));
787 }
788 for i in (chunks * 4)..len {
789 data[i] = scalar_gelu(data[i]);
790 }
791 }
792}
793
794#[cfg(all(
795 target_arch = "x86_64",
796 target_feature = "avx2",
797 target_feature = "fma"
798))]
799pub fn gelu_inplace(data: &mut [f32]) {
800 use std::arch::x86_64::*;
801 let chunks = data.len() / 8;
802 unsafe {
803 let half = _mm256_set1_ps(0.5);
804 let one = _mm256_set1_ps(1.0);
805 let inv_sqrt2 = _mm256_set1_ps(std::f32::consts::FRAC_1_SQRT_2);
806 let p = _mm256_set1_ps(0.3275911);
807 let a1 = _mm256_set1_ps(0.254829592);
808 let a2 = _mm256_set1_ps(-0.284496736);
809 let a3 = _mm256_set1_ps(1.421413741);
810 let a4 = _mm256_set1_ps(-1.453152027);
811 let a5 = _mm256_set1_ps(1.061405429);
812 let neg_one = _mm256_set1_ps(-1.0);
813 let zero = _mm256_set1_ps(0.0);
814 let abs_mask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fff_ffff));
815 for c in 0..chunks {
816 let off = c * 8;
817 let ptr = data.as_mut_ptr().add(off);
818 let x = _mm256_loadu_ps(ptr);
819 let erf_arg = _mm256_mul_ps(x, inv_sqrt2);
820 let xa = _mm256_and_ps(erf_arg, abs_mask);
821 let ge0 = _mm256_cmp_ps::<_CMP_GE_OQ>(erf_arg, zero);
822 let sign = _mm256_blendv_ps(neg_one, one, ge0);
823 let denom = _mm256_fmadd_ps(p, xa, one);
824 let t = _mm256_div_ps(one, denom);
825 let mut y = a5;
826 y = _mm256_fmadd_ps(y, t, a4);
827 y = _mm256_fmadd_ps(y, t, a3);
828 y = _mm256_fmadd_ps(y, t, a2);
829 y = _mm256_fmadd_ps(y, t, a1);
830 y = _mm256_mul_ps(y, t);
831 let exp_val = avx2_exp8(_mm256_sub_ps(zero, _mm256_mul_ps(xa, xa)));
832 let erf_val = _mm256_mul_ps(sign, _mm256_fnmadd_ps(y, exp_val, one));
833 _mm256_storeu_ps(
834 ptr,
835 _mm256_mul_ps(x, _mm256_mul_ps(half, _mm256_add_ps(one, erf_val))),
836 );
837 }
838 for i in (chunks * 8)..data.len() {
839 data[i] = scalar_gelu(data[i]);
840 }
841 }
842}
843
844#[cfg(not(any(
845 target_arch = "aarch64",
846 all(
847 target_arch = "x86_64",
848 target_feature = "avx2",
849 target_feature = "fma"
850 )
851)))]
852pub fn gelu_inplace(data: &mut [f32]) {
853 for v in data.iter_mut() {
854 *v = scalar_gelu(*v);
855 }
856}
857
858const ACTIVATION_PAR_MIN: usize = 1 << 20;
869
870#[inline]
879pub fn scalar_gelu_approx(x: f32) -> f32 {
880 const C: f32 = 0.797_884_6; const A: f32 = 0.044_715;
882 0.5 * x * (1.0 + (C * (x + A * x * x * x)).tanh())
883}
884
885pub fn gelu_approx_inplace(data: &mut [f32]) {
886 for v in data.iter_mut() {
887 *v = scalar_gelu_approx(*v);
888 }
889}
890
891pub fn par_gelu_approx_inplace(data: &mut [f32]) {
892 let len = data.len();
893 if len < ACTIVATION_PAR_MIN {
894 gelu_approx_inplace(data);
895 return;
896 }
897 let cfg = crate::config::RuntimeConfig::global();
898 let chunk = 512;
899 let rows = len / chunk;
900 if rows < 2 {
901 gelu_approx_inplace(data);
902 return;
903 }
904 let data_ptr = data.as_mut_ptr() as usize;
905 pool::par_for(rows, cfg.min_rows_per_thread, &|off, cnt| unsafe {
906 let start = off * chunk;
907 let end = if off + cnt >= rows {
908 len
909 } else {
910 (off + cnt) * chunk
911 };
912 let s = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(start), end - start);
913 gelu_approx_inplace(s);
914 });
915 let done = rows * chunk;
916 if done < len {
917 gelu_approx_inplace(&mut data[done..]);
918 }
919}
920
921pub fn par_gelu_inplace(data: &mut [f32]) {
922 let len = data.len();
923 if len < ACTIVATION_PAR_MIN {
924 gelu_inplace(data);
925 return;
926 }
927 let cfg = crate::config::RuntimeConfig::global();
928 let chunk = 512;
929 let rows = len / chunk;
930 if rows < 2 {
931 gelu_inplace(data);
932 return;
933 }
934 let data_ptr = data.as_mut_ptr() as usize;
935 pool::par_for(rows, cfg.min_rows_per_thread, &|off, cnt| unsafe {
936 let start = off * chunk;
937 let end = if off + cnt >= rows {
938 len
939 } else {
940 (off + cnt) * chunk
941 };
942 let s = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(start), end - start);
943 gelu_inplace(s);
944 });
945 let done = rows * chunk;
946 if done < len {
947 gelu_inplace(&mut data[done..]);
948 }
949}
950
951pub fn par_silu_inplace(data: &mut [f32]) {
953 let len = data.len();
954 if len < ACTIVATION_PAR_MIN {
955 silu_inplace(data);
956 return;
957 }
958 let cfg = crate::config::RuntimeConfig::global();
959 let chunk = 512;
960 let rows = len / chunk;
961 if rows < 2 {
962 silu_inplace(data);
963 return;
964 }
965 let data_ptr = data.as_mut_ptr() as usize;
966 pool::par_for(rows, cfg.min_rows_per_thread, &|off, cnt| unsafe {
967 let start = off * chunk;
968 let end = if off + cnt >= rows {
969 len
970 } else {
971 (off + cnt) * chunk
972 };
973 let s = std::slice::from_raw_parts_mut((data_ptr as *mut f32).add(start), end - start);
974 silu_inplace(s);
975 });
976 let done = rows * chunk;
977 if done < len {
978 silu_inplace(&mut data[done..]);
979 }
980}
981
982#[cfg(target_arch = "aarch64")]
989pub fn neon_sgemm_small(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
990 use std::arch::aarch64::*;
991 let n4 = n / 4;
992 unsafe {
993 for j4 in 0..n4 {
994 let j = j4 * 4;
995 let mut acc = [vdupq_n_f32(0.0); 8];
997 for kk in 0..k {
998 let bv = vld1q_f32(b.as_ptr().add(kk * n + j));
999 for i in 0..m {
1000 let av = vdupq_n_f32(*a.as_ptr().add(i * k + kk));
1001 acc[i] = vfmaq_f32(acc[i], av, bv);
1002 }
1003 }
1004 for i in 0..m {
1005 vst1q_f32(c.as_mut_ptr().add(i * n + j), acc[i]);
1006 }
1007 }
1008 for j in (n4 * 4)..n {
1010 for i in 0..m {
1011 let mut sum = 0f32;
1012 for kk in 0..k {
1013 sum += a[i * k + kk] * b[kk * n + j];
1014 }
1015 c[i * n + j] = sum;
1016 }
1017 }
1018 }
1019}
1020
1021#[cfg(not(target_arch = "aarch64"))]
1022pub fn neon_sgemm_small(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
1023 crate::naive::matmul(a, b, c, m, k, n);
1024}
1025
1026#[cfg(target_arch = "aarch64")]
1028pub fn neon_sgemm_bias_small(
1029 a: &[f32],
1030 b: &[f32],
1031 bias: &[f32],
1032 c: &mut [f32],
1033 m: usize,
1034 k: usize,
1035 n: usize,
1036) {
1037 neon_sgemm_small(a, b, c, m, k, n);
1038 crate::blas::bias_add(c, bias, m, n);
1039}
1040
1041#[cfg(not(target_arch = "aarch64"))]
1042pub fn neon_sgemm_bias_small(
1043 a: &[f32],
1044 b: &[f32],
1045 bias: &[f32],
1046 c: &mut [f32],
1047 m: usize,
1048 k: usize,
1049 n: usize,
1050) {
1051 crate::naive::matmul(a, b, c, m, k, n);
1052 crate::naive::bias_add(c, bias, m, n);
1053}
1054
1055fn scalar_gelu(x: f32) -> f32 {
1058 x * 0.5 * (1.0 + scalar_erf(x * std::f32::consts::FRAC_1_SQRT_2))
1059}
1060
1061fn scalar_erf(x: f32) -> f32 {
1062 let sign = if x >= 0.0 { 1.0f32 } else { -1.0 };
1063 let xa = x.abs();
1064 let t = 1.0 / (1.0 + 0.3275911 * xa);
1065 let y = t
1066 * (0.254_829_6
1067 + t * (-0.284_496_72 + t * (1.421_413_8 + t * (-1.453_152_1 + t * 1.061_405_4))));
1068 sign * (1.0 - y * (-xa * xa).exp())
1069}
1070
1071pub fn layer_norm2d_nchw(
1074 input: &[f32],
1075 gamma: &[f32],
1076 beta: &[f32],
1077 output: &mut [f32],
1078 batch: usize,
1079 channels: usize,
1080 h: usize,
1081 w: usize,
1082 eps: f32,
1083) {
1084 let spatial = h * w;
1085 for b in 0..batch {
1086 for i in 0..spatial {
1087 let mut mean = 0.0f32;
1088 for c in 0..channels {
1089 mean += input[((b * channels + c) * spatial) + i];
1090 }
1091 mean /= channels as f32;
1092 let mut var = 0.0f32;
1093 for c in 0..channels {
1094 let d = input[((b * channels + c) * spatial) + i] - mean;
1095 var += d * d;
1096 }
1097 var /= channels as f32;
1098 let inv = 1.0 / (var + eps).sqrt();
1099 for c in 0..channels {
1100 let v = (input[((b * channels + c) * spatial) + i] - mean) * inv;
1101 output[((b * channels + c) * spatial) + i] = v * gamma[c] + beta[c];
1102 }
1103 }
1104 }
1105}
1106
1107pub fn conv_transpose2d_nchw(
1110 input: &[f32],
1111 weight: &[f32],
1112 output: &mut [f32],
1113 n: usize,
1114 c_in: usize,
1115 h: usize,
1116 w: usize,
1117 c_out: usize,
1118 h_out: usize,
1119 w_out: usize,
1120 kh: usize,
1121 kw: usize,
1122 sh: usize,
1123 sw: usize,
1124 ph: usize,
1125 pw: usize,
1126 dh: usize,
1127 dw: usize,
1128 groups: usize,
1129) {
1130 output.fill(0.0);
1131 let c_in_per_g = c_in / groups;
1132 let c_out_per_g = c_out / groups;
1133 for ni in 0..n {
1134 for ic in 0..c_in {
1135 let g = ic / c_in_per_g;
1136 let _ic_off = ic % c_in_per_g;
1137 for iy in 0..h {
1138 for ix in 0..w {
1139 let v = input[((ni * c_in + ic) * h + iy) * w + ix];
1140 if v == 0.0 {
1141 continue;
1142 }
1143 for ky in 0..kh {
1144 let oy = iy * sh + ky * dh;
1145 if oy < ph || oy >= h_out + ph {
1146 continue;
1147 }
1148 let oy = oy - ph;
1149 if oy >= h_out {
1150 continue;
1151 }
1152 for kx in 0..kw {
1153 let ox = ix * sw + kx * dw;
1154 if ox < pw || ox >= w_out + pw {
1155 continue;
1156 }
1157 let ox = ox - pw;
1158 if ox >= w_out {
1159 continue;
1160 }
1161 for oc_off in 0..c_out_per_g {
1162 let oc = g * c_out_per_g + oc_off;
1163 let w_idx = ((ic * c_out_per_g + oc_off) * kh + ky) * kw + kx;
1164 let wt = weight[w_idx];
1165 output[((ni * c_out + oc) * h_out + oy) * w_out + ox] += v * wt;
1166 }
1167 }
1168 }
1169 }
1170 }
1171 }
1172 }
1173}
1174
1175pub fn group_norm_nchw(
1177 input: &[f32],
1178 gamma: &[f32],
1179 beta: &[f32],
1180 output: &mut [f32],
1181 batch: usize,
1182 channels: usize,
1183 h: usize,
1184 w: usize,
1185 num_groups: usize,
1186 eps: f32,
1187) {
1188 let cpg = channels / num_groups;
1189 let spatial = h * w;
1190 let n = (cpg * spatial) as f32;
1191 for b in 0..batch {
1192 for g in 0..num_groups {
1193 let c0 = g * cpg;
1194 let mut mean = 0.0f32;
1195 for c in 0..cpg {
1196 let plane = &input
1197 [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1198 mean += plane.iter().sum::<f32>();
1199 }
1200 mean /= n;
1201 let mut var = 0.0f32;
1202 for c in 0..cpg {
1203 let plane = &input
1204 [((b * channels + c0 + c) * spatial)..((b * channels + c0 + c + 1) * spatial)];
1205 for &v in plane {
1206 let d = v - mean;
1207 var += d * d;
1208 }
1209 }
1210 var /= n;
1211 let inv = 1.0 / (var + eps).sqrt();
1212 for c in 0..cpg {
1213 let gi = c0 + c;
1214 let gamm = gamma[gi];
1215 let bet = beta[gi];
1216 let src =
1217 &input[((b * channels + gi) * spatial)..((b * channels + gi + 1) * spatial)];
1218 let dst = &mut output
1219 [((b * channels + gi) * spatial)..((b * channels + gi + 1) * spatial)];
1220 for (d, &s) in dst.iter_mut().zip(src) {
1221 *d = (s - mean) * inv * gamm + bet;
1222 }
1223 }
1224 }
1225 }
1226}
1227
1228pub fn resize_nearest_2x_nchw(
1230 input: &[f32],
1231 output: &mut [f32],
1232 channels: usize,
1233 h: usize,
1234 w: usize,
1235) {
1236 let h2 = h * 2;
1237 let w2 = w * 2;
1238 for c in 0..channels {
1239 let plane = &input[c * h * w..(c + 1) * h * w];
1240 let dst = &mut output[c * h2 * w2..(c + 1) * h2 * w2];
1241 for y in 0..h {
1242 for x in 0..w {
1243 let v = plane[y * w + x];
1244 for dy in 0..2 {
1245 for dx in 0..2 {
1246 dst[(y * 2 + dy) * w2 + (x * 2 + dx)] = v;
1247 }
1248 }
1249 }
1250 }
1251 }
1252}
1253
1254#[cfg(test)]
1255mod tests {
1256 use super::*;
1257
1258 #[test]
1259 fn gelu_correctness() {
1260 let x = 1.5f32;
1261 let g = scalar_gelu(x);
1262 assert!((g - 1.3990).abs() < 0.01, "gelu(1.5) = {g}");
1264 }
1265
1266 #[test]
1267 fn bias_gelu_works() {
1268 let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1269 let bias = vec![0.1, 0.2, 0.3, 0.4];
1270 bias_gelu(&mut data, &bias, 2, 4);
1271 for &v in &data {
1273 assert!(v > 0.0, "bias_gelu produced {v}");
1274 }
1275 }
1276
1277 #[test]
1278 fn batch_norm_inference_roundtrip() {
1279 let c = 4usize;
1280 let x: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1281 let gamma = vec![1.0; c];
1282 let beta = vec![0.0; c];
1283 let mean = vec![2.5, 2.5, 2.5, 2.5];
1284 let var = vec![1.0; c];
1285 let mut y = vec![0.0; 8];
1286 batch_norm_inference(&x, &gamma, &beta, &mean, &var, &mut y, c, 1e-5);
1287 let mut dx = vec![0.0; 8];
1288 let dy = vec![1.0; 8];
1289 let mut dgamma = vec![0.0; c];
1290 let mut dbeta = vec![0.0; c];
1291 batch_norm_inference_backward_input(&x, &gamma, &mean, &var, &dy, &mut dx, c, 1e-5);
1292 batch_norm_inference_backward_gamma(&x, &mean, &var, &dy, &mut dgamma, c, 1e-5);
1293 batch_norm_inference_backward_beta(&dy, &mut dbeta, c);
1294 assert!(y.iter().all(|v| v.is_finite()));
1295 assert!(dx.iter().all(|v| v.is_finite()));
1296 assert!(dgamma.iter().any(|&v| v.abs() > 1e-6));
1297 assert_eq!(dbeta, vec![2.0, 2.0, 2.0, 2.0]);
1298 }
1299
1300 #[test]
1301 fn layer_norm_unit_test() {
1302 let input = vec![1.0, 2.0, 3.0, 4.0];
1303 let gamma = vec![1.0; 4];
1304 let beta = vec![0.0; 4];
1305 let mut output = vec![0.0; 4];
1306 layer_norm_row(&input, &gamma, &beta, &mut output, 4, 1e-5);
1307 assert!((output[0] - -1.342).abs() < 0.01);
1309 assert!((output[3] - 1.342).abs() < 0.01);
1310 let sum: f32 = output.iter().sum();
1312 assert!(sum.abs() < 0.01, "LN sum should be ~0, got {sum}");
1313 }
1314
1315 #[test]
1316 fn par_bias_gelu_matches_sequential() {
1317 let n = 100;
1318 let m = 64;
1319 let mut data_par = vec![0.5f32; n * m];
1320 let mut data_seq = data_par.clone();
1321 let bias = vec![0.1f32; m];
1322
1323 bias_gelu(&mut data_seq, &bias, n, m);
1324 par_bias_gelu(&mut data_par, &bias, n, m);
1325
1326 let max_diff: f32 = data_par
1327 .iter()
1328 .zip(data_seq.iter())
1329 .map(|(a, b)| (a - b).abs())
1330 .fold(0f32, f32::max);
1331 assert!(max_diff < 1e-6, "par vs seq diff: {max_diff}");
1332 }
1333}