Skip to main content

trueno/blis/microkernels/
neon.rs

1//! ARM NEON Microkernel
2//!
3//! Contains the NEON SIMD microkernel for ARM64 (aarch64) targets.
4
5/// NEON microkernel (8x8 output tile)
6#[cfg(target_arch = "aarch64")]
7// SAFETY: Caller ensures NEON is available (always on aarch64) and pointers/dims are valid
8pub unsafe fn microkernel_8x8_neon(
9    k: usize,
10    a: *const f32,
11    b: *const f32,
12    c: *mut f32,
13    ldc: usize,
14) {
15    use std::arch::aarch64::*;
16
17    // Load C into registers (8 columns, split into 2x float32x4)
18    let mut c00 = vld1q_f32(c);
19    let mut c01 = vld1q_f32(c.add(4));
20    let mut c10 = vld1q_f32(c.add(ldc));
21    let mut c11 = vld1q_f32(c.add(ldc + 4));
22    let mut c20 = vld1q_f32(c.add(2 * ldc));
23    let mut c21 = vld1q_f32(c.add(2 * ldc + 4));
24    let mut c30 = vld1q_f32(c.add(3 * ldc));
25    let mut c31 = vld1q_f32(c.add(3 * ldc + 4));
26    let mut c40 = vld1q_f32(c.add(4 * ldc));
27    let mut c41 = vld1q_f32(c.add(4 * ldc + 4));
28    let mut c50 = vld1q_f32(c.add(5 * ldc));
29    let mut c51 = vld1q_f32(c.add(5 * ldc + 4));
30    let mut c60 = vld1q_f32(c.add(6 * ldc));
31    let mut c61 = vld1q_f32(c.add(6 * ldc + 4));
32    let mut c70 = vld1q_f32(c.add(7 * ldc));
33    let mut c71 = vld1q_f32(c.add(7 * ldc + 4));
34
35    for p in 0..k {
36        let a0 = vld1q_f32(a.add(p * 8));
37        let a1 = vld1q_f32(a.add(p * 8 + 4));
38
39        let b0 = vld1q_dup_f32(b.add(p * 8));
40        let b1 = vld1q_dup_f32(b.add(p * 8 + 1));
41        let b2 = vld1q_dup_f32(b.add(p * 8 + 2));
42        let b3 = vld1q_dup_f32(b.add(p * 8 + 3));
43        let b4 = vld1q_dup_f32(b.add(p * 8 + 4));
44        let b5 = vld1q_dup_f32(b.add(p * 8 + 5));
45        let b6 = vld1q_dup_f32(b.add(p * 8 + 6));
46        let b7 = vld1q_dup_f32(b.add(p * 8 + 7));
47
48        c00 = vfmaq_f32(c00, a0, b0);
49        c01 = vfmaq_f32(c01, a1, b0);
50        c10 = vfmaq_f32(c10, a0, b1);
51        c11 = vfmaq_f32(c11, a1, b1);
52        c20 = vfmaq_f32(c20, a0, b2);
53        c21 = vfmaq_f32(c21, a1, b2);
54        c30 = vfmaq_f32(c30, a0, b3);
55        c31 = vfmaq_f32(c31, a1, b3);
56        c40 = vfmaq_f32(c40, a0, b4);
57        c41 = vfmaq_f32(c41, a1, b4);
58        c50 = vfmaq_f32(c50, a0, b5);
59        c51 = vfmaq_f32(c51, a1, b5);
60        c60 = vfmaq_f32(c60, a0, b6);
61        c61 = vfmaq_f32(c61, a1, b6);
62        c70 = vfmaq_f32(c70, a0, b7);
63        c71 = vfmaq_f32(c71, a1, b7);
64    }
65
66    vst1q_f32(c, c00);
67    vst1q_f32(c.add(4), c01);
68    vst1q_f32(c.add(ldc), c10);
69    vst1q_f32(c.add(ldc + 4), c11);
70    vst1q_f32(c.add(2 * ldc), c20);
71    vst1q_f32(c.add(2 * ldc + 4), c21);
72    vst1q_f32(c.add(3 * ldc), c30);
73    vst1q_f32(c.add(3 * ldc + 4), c31);
74    vst1q_f32(c.add(4 * ldc), c40);
75    vst1q_f32(c.add(4 * ldc + 4), c41);
76    vst1q_f32(c.add(5 * ldc), c50);
77    vst1q_f32(c.add(5 * ldc + 4), c51);
78    vst1q_f32(c.add(6 * ldc), c60);
79    vst1q_f32(c.add(6 * ldc + 4), c61);
80    vst1q_f32(c.add(7 * ldc), c70);
81    vst1q_f32(c.add(7 * ldc + 4), c71);
82}