trueno/blis/microkernels/
neon.rs1#[cfg(target_arch = "aarch64")]
7pub unsafe fn microkernel_8x8_neon(
9 k: usize,
10 a: *const f32,
11 b: *const f32,
12 c: *mut f32,
13 ldc: usize,
14) {
15 use std::arch::aarch64::*;
16
17 let mut c00 = vld1q_f32(c);
19 let mut c01 = vld1q_f32(c.add(4));
20 let mut c10 = vld1q_f32(c.add(ldc));
21 let mut c11 = vld1q_f32(c.add(ldc + 4));
22 let mut c20 = vld1q_f32(c.add(2 * ldc));
23 let mut c21 = vld1q_f32(c.add(2 * ldc + 4));
24 let mut c30 = vld1q_f32(c.add(3 * ldc));
25 let mut c31 = vld1q_f32(c.add(3 * ldc + 4));
26 let mut c40 = vld1q_f32(c.add(4 * ldc));
27 let mut c41 = vld1q_f32(c.add(4 * ldc + 4));
28 let mut c50 = vld1q_f32(c.add(5 * ldc));
29 let mut c51 = vld1q_f32(c.add(5 * ldc + 4));
30 let mut c60 = vld1q_f32(c.add(6 * ldc));
31 let mut c61 = vld1q_f32(c.add(6 * ldc + 4));
32 let mut c70 = vld1q_f32(c.add(7 * ldc));
33 let mut c71 = vld1q_f32(c.add(7 * ldc + 4));
34
35 for p in 0..k {
36 let a0 = vld1q_f32(a.add(p * 8));
37 let a1 = vld1q_f32(a.add(p * 8 + 4));
38
39 let b0 = vld1q_dup_f32(b.add(p * 8));
40 let b1 = vld1q_dup_f32(b.add(p * 8 + 1));
41 let b2 = vld1q_dup_f32(b.add(p * 8 + 2));
42 let b3 = vld1q_dup_f32(b.add(p * 8 + 3));
43 let b4 = vld1q_dup_f32(b.add(p * 8 + 4));
44 let b5 = vld1q_dup_f32(b.add(p * 8 + 5));
45 let b6 = vld1q_dup_f32(b.add(p * 8 + 6));
46 let b7 = vld1q_dup_f32(b.add(p * 8 + 7));
47
48 c00 = vfmaq_f32(c00, a0, b0);
49 c01 = vfmaq_f32(c01, a1, b0);
50 c10 = vfmaq_f32(c10, a0, b1);
51 c11 = vfmaq_f32(c11, a1, b1);
52 c20 = vfmaq_f32(c20, a0, b2);
53 c21 = vfmaq_f32(c21, a1, b2);
54 c30 = vfmaq_f32(c30, a0, b3);
55 c31 = vfmaq_f32(c31, a1, b3);
56 c40 = vfmaq_f32(c40, a0, b4);
57 c41 = vfmaq_f32(c41, a1, b4);
58 c50 = vfmaq_f32(c50, a0, b5);
59 c51 = vfmaq_f32(c51, a1, b5);
60 c60 = vfmaq_f32(c60, a0, b6);
61 c61 = vfmaq_f32(c61, a1, b6);
62 c70 = vfmaq_f32(c70, a0, b7);
63 c71 = vfmaq_f32(c71, a1, b7);
64 }
65
66 vst1q_f32(c, c00);
67 vst1q_f32(c.add(4), c01);
68 vst1q_f32(c.add(ldc), c10);
69 vst1q_f32(c.add(ldc + 4), c11);
70 vst1q_f32(c.add(2 * ldc), c20);
71 vst1q_f32(c.add(2 * ldc + 4), c21);
72 vst1q_f32(c.add(3 * ldc), c30);
73 vst1q_f32(c.add(3 * ldc + 4), c31);
74 vst1q_f32(c.add(4 * ldc), c40);
75 vst1q_f32(c.add(4 * ldc + 4), c41);
76 vst1q_f32(c.add(5 * ldc), c50);
77 vst1q_f32(c.add(5 * ldc + 4), c51);
78 vst1q_f32(c.add(6 * ldc), c60);
79 vst1q_f32(c.add(6 * ldc + 4), c61);
80 vst1q_f32(c.add(7 * ldc), c70);
81 vst1q_f32(c.add(7 * ldc + 4), c71);
82}