use crate::Ops;
#[cfg(tract_sve)]
use crate::frame::mmm::ImplementationQuality::ManuallyOptimized;
#[cfg(tract_sve)]
use crate::mmm::*;
#[cfg(tract_sve)]
use crate::pack::PackedFormat;
#[cfg(tract_sve)]
use tract_data::prelude::f16;
#[cfg(tract_sve)]
const CAN_FUSE: fn(&FusedSpec) -> bool = |f| {
!matches!(
f,
FusedSpec::LeakyRelu(_)
| FusedSpec::QScale(_, _, _)
| FusedSpec::RoundingShiftRight(_, _)
| FusedSpec::ShiftLeft(_)
)
};
#[cfg(tract_sve)]
const CAN_FUSE_I32: fn(&FusedSpec) -> bool = |f| !matches!(f, FusedSpec::LeakyRelu(_));
#[cfg(tract_sve)]
const SVE2: fn() -> bool = has_sve2;
#[cfg(tract_sve)]
const SVE2_FP16: fn() -> bool = || has_sve2() && crate::arm64::has_fp16();
#[cfg(tract_sve)]
mod sve_sys {
use crate::frame::mmm::FusedKerSpec;
use tract_data::prelude::f16;
unsafe extern "C" {
pub fn sve_mmm_f32_kernel(ops: *const FusedKerSpec<f32>) -> isize;
pub fn sve_mmv_f32_64x1_kernel(ops: *const FusedKerSpec<f32>) -> isize;
pub fn sve_mmm_i32_kernel(ops: *const FusedKerSpec<i32>) -> isize;
pub fn sve_mmm_i32_64x1_kernel(ops: *const FusedKerSpec<i32>) -> isize;
pub fn sve_mmm_f16_kernel(ops: *const FusedKerSpec<f16>) -> isize;
pub fn sve_mmv_f16_64x1_kernel(ops: *const FusedKerSpec<f16>) -> isize;
}
}
#[cfg(tract_sve)]
MMMRustKernel!(sve_sys::sve_mmm_f32_kernel => sve_mmm_f32_8x8<f32>(8, 8)
where(SVE2)
can_fuse(CAN_FUSE)
quality(ManuallyOptimized)
);
#[cfg(tract_sve)]
MMMRustKernel!(sve_sys::sve_mmv_f32_64x1_kernel => sve_mmv_f32_64x1<f32>(64, 1)
where(SVE2)
can_fuse(CAN_FUSE)
quality(ManuallyOptimized)
);
#[cfg(tract_sve)]
MMMRustKernel!(sve_sys::sve_mmm_i32_kernel => sve_mmm_i32_8x8<i32>(8, 8)
where(SVE2)
can_fuse(CAN_FUSE_I32)
packing[1] = i8i8 => |k| k.with_packing(
PackedFormat::new(DatumType::I8, 8, 16),
PackedFormat::new(DatumType::I8, 8, 16),
);
quality(ManuallyOptimized)
store(i8)
);
#[cfg(tract_sve)]
MMMRustKernel!(sve_sys::sve_mmm_i32_64x1_kernel => sve_mmm_i32_64x1<i32>(64, 1)
where(SVE2)
can_fuse(CAN_FUSE_I32)
packing[1] = i8i8 => |k| k.with_packing(
PackedFormat::new(DatumType::I8, 64, 16),
PackedFormat::new(DatumType::I8, 1, 1),
);
quality(ManuallyOptimized)
store(i8)
);
#[cfg(tract_sve)]
MMMRustKernel!(sve_sys::sve_mmm_f16_kernel => sve_mmm_f16_8x8<f16>(8, 8)
where(SVE2_FP16)
can_fuse(CAN_FUSE)
quality(ManuallyOptimized)
);
#[cfg(tract_sve)]
MMMRustKernel!(sve_sys::sve_mmv_f16_64x1_kernel => sve_mmv_f16_64x1<f16>(64, 1)
where(SVE2_FP16)
can_fuse(CAN_FUSE)
quality(ManuallyOptimized)
);
#[cfg(target_os = "linux")]
pub fn has_sve() -> bool {
if std::env::var_os("TRACT_SVE_DISABLE").is_some() {
return false;
}
const HWCAP_SVE: u64 = 1 << 22;
unsafe extern "C" {
fn getauxval(t: u64) -> u64;
}
const AT_HWCAP: u64 = 16;
unsafe { (getauxval(AT_HWCAP) & HWCAP_SVE) != 0 }
}
#[cfg(not(target_os = "linux"))]
pub fn has_sve() -> bool {
false
}
#[cfg(target_os = "linux")]
pub fn has_sve2() -> bool {
if std::env::var_os("TRACT_SVE_DISABLE").is_some() {
return false;
}
const HWCAP2_SVE2: u64 = 1 << 1;
unsafe extern "C" {
fn getauxval(t: u64) -> u64;
}
const AT_HWCAP2: u64 = 26;
unsafe { (getauxval(AT_HWCAP2) & HWCAP2_SVE2) != 0 }
}
#[cfg(not(target_os = "linux"))]
pub fn has_sve2() -> bool {
false
}
#[cfg(target_os = "linux")]
#[allow(dead_code)]
pub fn rdvl_bytes() -> u64 {
let vl: u64;
unsafe {
std::arch::asm!(
".inst 0x04bf5020", out("x0") vl,
options(nomem, nostack, preserves_flags),
);
}
vl
}
pub fn plug(ops: &mut Ops) {
let _ = ops;
if has_sve2() {
#[cfg(target_os = "linux")]
log::info!("SVE2 optimisation available (VL = {} bytes)", rdvl_bytes());
#[cfg(tract_sve)]
{
ops.mmm_f32 = Box::new(|_, _, _| sve_mmm_f32_8x8.mmm());
ops.mmv_f32 = Box::new(|_, _| sve_mmv_f32_64x1.mmm());
ops.qmmm_i32 = Box::new(|_, _, _| sve_mmm_i32_8x8.mmm());
ops.qmmv_i32 = Box::new(|_, _| sve_mmm_i32_64x1.mmm());
ops.mmm_impls.extend_from_slice(&[
sve_mmm_f32_8x8.mmm(),
sve_mmv_f32_64x1.mmm(),
sve_mmm_i32_8x8.mmm(),
sve_mmm_i32_64x1.mmm(),
]);
if crate::arm64::has_fp16() {
ops.mmm_f16 = Box::new(|_, _, _| sve_mmm_f16_8x8.mmm());
ops.mmv_f16 = Box::new(|_, _| sve_mmv_f16_64x1.mmm());
ops.mmm_impls.extend_from_slice(&[sve_mmm_f16_8x8.mmm(), sve_mmv_f16_64x1.mmm()]);
}
}
} else if has_sve() {
log::info!("SVE (v1) present; SVE2 kernels not enabled");
} else {
log::info!("No SVE optimisation");
}
}