use super::Cpu;
#[cfg(target_arch = "arm")]
use core::arch::arm::*;
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
pub struct CurrentCpu {}
impl CurrentCpu {
#[cfg(target_arch = "aarch64")]
unsafe fn reduce_one(x: float32x4_t) -> f32 {
vaddvq_f32(x)
}
#[cfg(target_arch = "arm")]
unsafe fn reduce_one(x: float32x4_t) -> f32 {
vgetq_lane_f32(x, 0) + vgetq_lane_f32(x, 1) + vgetq_lane_f32(x, 2) + vgetq_lane_f32(x, 3)
}
}
impl Cpu for CurrentCpu {
type Unit = float32x4_t;
type Array = [float32x4_t; Self::ARR];
const STEP: usize = 32;
const EPR: usize = 4;
unsafe fn zero() -> Self::Unit {
vdupq_n_f32(0.0)
}
unsafe fn from_f32(x: f32) -> Self::Unit {
vdupq_n_f32(x)
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); Self::ARR]
}
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
vld1q_f32(mem_addr)
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
vaddq_f32(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
vfmaq_f32(a, b, c)
}
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
vst1q_f32(mem_addr, a);
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
for i in 0..Self::ARR / 2 {
x[2 * i] = vaddq_f32(x[2 * i], x[2 * i + 1]);
}
for i in 0..Self::ARR / 4 {
x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]);
}
for i in 0..Self::ARR / 8 {
x[8 * i] = vaddq_f32(x[8 * i], x[8 * i + 4]);
}
*y = Self::reduce_one(x[0]);
}
}
mod fp16 {
use super::super::CpuF16;
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
#[cfg(target_arch = "arm")]
use core::arch::arm::*;
use half::f16;
use std::arch::asm;
#[cfg(not(target_feature = "fp16"))]
mod inner {
use super::*;
pub struct CurrentCpuF16 {}
impl CpuF16 for CurrentCpuF16 {
type Unit = float32x4_t;
type Array = [float32x4_t; Self::ARR];
const STEP: usize = 32;
const EPR: usize = 4;
unsafe fn zero() -> Self::Unit {
vdupq_n_f32(0.0)
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); Self::ARR]
}
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
let bits = vld1_u16(mem_addr as *const u16);
let result: Self::Unit;
asm!(
"fcvtl {dst:v}.4s, {src:v}.4h",
src = in(vreg) bits,
dst = out(vreg) result,
options(pure, nomem, nostack),
);
result
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
vaddq_f32(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
vfmaq_f32(a, b, c)
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
for i in 0..Self::ARR / 2 {
x[2 * i] = vaddq_f32(x[2 * i], x[2 * i + 1]);
}
for i in 0..Self::ARR / 4 {
x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]);
}
for i in 0..Self::ARR / 8 {
x[8 * i] = vaddq_f32(x[8 * i], x[8 * i + 4]);
}
*y = vaddvq_f32(x[0]);
}
unsafe fn from_f32(v: f32) -> Self::Unit {
vdupq_n_f32(v)
}
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
let bits: uint16x4_t;
asm!(
"fcvtn {dst:v}.4h, {src:v}.4s",
src = in(vreg) a,
dst = out(vreg) bits,
options(pure, nomem, nostack),
);
vst1_u16(mem_addr as *mut u16, bits);
}
}
}
#[cfg(target_feature = "fp16")]
mod inner {
use super::*;
pub struct CurrentCpuF16 {}
impl CurrentCpuF16 {
fn reduce_one(x: float16x8_t) -> f32 {
let result: f32;
unsafe {
asm!(
"fcvtl {bot:v}.4s, {v:v}.4h",
"fcvtl2 {top:v}.4s, {v:v}.8h",
"fadd {bot:v}.4s, {bot:v}.4s, {top:v}.4s",
"faddp {bot:v}.4s, {bot:v}.4s, {bot:v}.4s",
"faddp {bot:v}.4s, {bot:v}.4s, {bot:v}.4s",
"mov {res:s}, {bot:v}.s[0]",
v = in(vreg) x,
top = lateout(vreg) _,
bot = out(vreg) _,
res = out(vreg) result,
options(nostack, preserves_flags)
);
}
result
}
}
impl CpuF16 for CurrentCpuF16 {
type Unit = float16x8_t;
type Array = [float16x8_t; Self::ARR];
const STEP: usize = 32;
const EPR: usize = 8;
const FLUSH_INTERVAL: usize = 16;
unsafe fn zero() -> Self::Unit {
let result: Self::Unit;
asm!(
"movi {v:v}.8h, #0",
v = out(vreg) result,
options(nostack, pure, nomem)
);
result
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); Self::ARR]
}
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
std::ptr::read_unaligned(mem_addr.cast())
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
let result: Self::Unit;
asm!(
"fadd {v:v}.8h, {a:v}.8h, {b:v}.8h",
v = out(vreg) result,
a = in(vreg) a,
b = in(vreg) b,
options(pure, nomem, nostack)
);
result
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
let result: Self::Unit;
asm!(
"fmla {a:v}.8h, {b:v}.8h, {c:v}.8h",
a = inout(vreg) a => result,
b = in(vreg) b,
c = in(vreg) c,
options(pure, nomem, nostack)
);
result
}
unsafe fn vec_reduce(x: Self::Array, y: *mut f32) {
let mut sum = 0.0f32;
for item in x {
sum += Self::reduce_one(item);
}
*y = sum;
}
unsafe fn from_f32(v: f32) -> Self::Unit {
let result: Self::Unit;
asm!(
"fcvt {tmp:h}, {src:s}",
"fmov {w:w}, {tmp:h}",
"dup {dst:v}.8h, {w:w}",
src = in(vreg) v,
tmp = out(vreg) _,
w = out(reg) _,
dst = out(vreg) result,
options(nostack, pure, nomem),
);
result
}
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
asm!(
"st1 {{ {vec:v}.8h }}, [{ptr}]",
vec = in(vreg) a,
ptr = in(reg) mem_addr,
options(nostack, preserves_flags)
);
}
}
}
pub use inner::CurrentCpuF16;
}
pub use fp16::CurrentCpuF16;
mod bf16 {
use super::super::CpuBF16;
use core::arch::aarch64::*;
use half::bf16;
#[cfg(not(target_feature = "bf16"))]
mod inner {
use super::*;
pub struct CurrentCpuBF16 {}
impl CpuBF16 for CurrentCpuBF16 {
type Unit = float32x4_t;
type Array = [float32x4_t; Self::ARR];
const STEP: usize = 32;
const EPR: usize = 4;
unsafe fn zero() -> Self::Unit {
vdupq_n_f32(0.0)
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); Self::ARR]
}
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
let i = vld1_u16(mem_addr as *const u16);
let s = vshll_n_u16::<16>(i);
vreinterpretq_f32_u32(s)
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
vaddq_f32(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
vfmaq_f32(a, b, c)
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
for i in 0..Self::ARR / 2 {
x[2 * i] = vaddq_f32(x[2 * i], x[2 * i + 1]);
}
for i in 0..Self::ARR / 4 {
x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]);
}
for i in 0..Self::ARR / 8 {
x[8 * i] = vaddq_f32(x[8 * i], x[8 * i + 4]);
}
*y = vaddvq_f32(x[0]);
}
unsafe fn from_f32(v: f32) -> Self::Unit {
vdupq_n_f32(v)
}
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
let s = vreinterpretq_u32_f32(a);
let h = vshrn_n_u32::<16>(s);
vst1_u16(mem_addr as *mut u16, h);
}
}
}
#[cfg(target_feature = "bf16")]
mod inner {
use super::*;
use std::arch::asm;
pub struct CurrentCpuBF16 {}
impl CpuBF16 for CurrentCpuBF16 {
type Unit = uint16x8_t;
type Array = [uint16x8_t; Self::ARR];
const STEP: usize = 32;
const EPR: usize = 8;
unsafe fn zero() -> Self::Unit {
vreinterpretq_u16_f32(vdupq_n_f32(0.0))
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); Self::ARR]
}
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
vld1q_u16(mem_addr as *const u16)
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
let result: Self::Unit;
asm!(
"fadd {v:v}.4s, {a:v}.4s, {b:v}.4s",
v = out(vreg) result,
a = in(vreg) a,
b = in(vreg) b,
options(pure, nomem, nostack)
);
result
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
let result: Self::Unit;
asm!(
"bfdot {acc:v}.4s, {b:v}.8h, {c:v}.8h",
acc = inout(vreg) a => result,
b = in(vreg) b,
c = in(vreg) c,
options(pure, nomem, nostack)
);
result
}
unsafe fn vec_reduce(x: Self::Array, y: *mut f32) {
let mut xf: [float32x4_t; Self::ARR] = std::mem::transmute(x);
for i in 0..Self::ARR / 2 {
xf[2 * i] = vaddq_f32(xf[2 * i], xf[2 * i + 1]);
}
for i in 0..Self::ARR / 4 {
xf[4 * i] = vaddq_f32(xf[4 * i], xf[4 * i + 2]);
}
*y = vaddvq_f32(xf[0]);
}
unsafe fn from_f32(v: f32) -> Self::Unit {
vreinterpretq_u16_f32(vdupq_n_f32(v))
}
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
vst1q_u16(mem_addr as *mut u16, a);
}
}
}
pub use inner::CurrentCpuBF16;
}
pub use bf16::CurrentCpuBF16;