use crate::Ops;
#[cfg(tract_sve)]
use crate::frame::mmm::ImplementationQuality::ManuallyOptimized;
#[cfg(tract_sve)]
use crate::mmm::*;
#[cfg(tract_sve)]
use crate::pack::PackedFormat;
#[cfg(tract_sve)]
use tract_data::prelude::f16;
#[cfg(tract_sve)]
const CAN_FUSE: fn(&FusedSpec) -> bool = |f| {
!matches!(
f,
FusedSpec::LeakyRelu(_)
| FusedSpec::QScale(_, _, _)
| FusedSpec::RoundingShiftRight(_, _)
| FusedSpec::ShiftLeft(_)
)
};
#[cfg(tract_sve)]
const CAN_FUSE_I32: fn(&FusedSpec) -> bool = |f| !matches!(f, FusedSpec::LeakyRelu(_));
#[cfg(tract_sve)]
const SVE2: fn() -> bool = has_sve2;
#[cfg(tract_sve)]
const SVE2_FP16: fn() -> bool = || has_sve2() && crate::arm64::has_fp16();
#[cfg(tract_sve)]
mod sve_sys {
use crate::frame::mmm::FusedKerSpec;
use tract_data::prelude::f16;
unsafe extern "C" {
pub fn sve_mmm_f32_kernel(ops: *const FusedKerSpec<f32>) -> isize;
pub fn sve_mmv_f32_64x1_kernel(ops: *const FusedKerSpec<f32>) -> isize;
pub fn sve_mmm_i32_kernel(ops: *const FusedKerSpec<i32>) -> isize;
pub fn sve_mmm_i32_64x1_kernel(ops: *const FusedKerSpec<i32>) -> isize;
pub fn sve_mmm_f16_kernel(ops: *const FusedKerSpec<f16>) -> isize;
pub fn sve_mmv_f16_64x1_kernel(ops: *const FusedKerSpec<f16>) -> isize;
pub fn sve_rms_norm_f32_kernel(buf: *mut f32, n: i64, eps: f32);
}
}
#[cfg(tract_sve)]
pub fn sve_rms_norm_f32(buf: &mut [f32], eps: f32) {
if buf.is_empty() {
return;
}
unsafe { sve_sys::sve_rms_norm_f32_kernel(buf.as_mut_ptr(), buf.len() as i64, eps) }
}
#[cfg(tract_sve)]
MMMRustKernel!(sve_sys::sve_mmm_f32_kernel => sve_mmm_f32_8x8<f32>(8, 8)
where(SVE2)
can_fuse(CAN_FUSE)
quality(ManuallyOptimized)
);
#[cfg(tract_sve)]
MMMRustKernel!(sve_sys::sve_mmv_f32_64x1_kernel => sve_mmv_f32_64x1<f32>(64, 1)
where(SVE2)
can_fuse(CAN_FUSE)
quality(ManuallyOptimized)
);
#[cfg(tract_sve)]
MMMRustKernel!(sve_sys::sve_mmm_i32_kernel => sve_mmm_i32_8x8<i32>(8, 8)
where(SVE2)
can_fuse(CAN_FUSE_I32)
packing[1] = i8i8 => |k| k.with_packing(
PackedFormat::new(DatumType::I8, 8, 16),
PackedFormat::new(DatumType::I8, 8, 16),
);
quality(ManuallyOptimized)
store(i8)
);
#[cfg(tract_sve)]
MMMRustKernel!(sve_sys::sve_mmm_i32_64x1_kernel => sve_mmm_i32_64x1<i32>(64, 1)
where(SVE2)
can_fuse(CAN_FUSE_I32)
packing[1] = i8i8 => |k| k.with_packing(
PackedFormat::new(DatumType::I8, 64, 16),
PackedFormat::new(DatumType::I8, 1, 1),
);
quality(ManuallyOptimized)
store(i8)
);
#[cfg(tract_sve)]
MMMRustKernel!(sve_sys::sve_mmm_f16_kernel => sve_mmm_f16_8x8<f16>(8, 8)
where(SVE2_FP16)
can_fuse(CAN_FUSE)
quality(ManuallyOptimized)
);
#[cfg(tract_sve)]
MMMRustKernel!(sve_sys::sve_mmv_f16_64x1_kernel => sve_mmv_f16_64x1<f16>(64, 1)
where(SVE2_FP16)
can_fuse(CAN_FUSE)
quality(ManuallyOptimized)
);
#[cfg(target_os = "linux")]
pub fn has_sve() -> bool {
if std::env::var_os("TRACT_SVE_DISABLE").is_some() {
return false;
}
const HWCAP_SVE: u64 = 1 << 22;
unsafe extern "C" {
fn getauxval(t: u64) -> u64;
}
const AT_HWCAP: u64 = 16;
unsafe { (getauxval(AT_HWCAP) & HWCAP_SVE) != 0 }
}
#[cfg(not(target_os = "linux"))]
pub fn has_sve() -> bool {
false
}
#[cfg(target_os = "linux")]
pub fn has_sve2() -> bool {
if std::env::var_os("TRACT_SVE_DISABLE").is_some() {
return false;
}
const HWCAP2_SVE2: u64 = 1 << 1;
unsafe extern "C" {
fn getauxval(t: u64) -> u64;
}
const AT_HWCAP2: u64 = 26;
unsafe { (getauxval(AT_HWCAP2) & HWCAP2_SVE2) != 0 }
}
#[cfg(not(target_os = "linux"))]
pub fn has_sve2() -> bool {
false
}
#[cfg(target_os = "linux")]
#[allow(dead_code)]
pub fn rdvl_bytes() -> u64 {
let vl: u64;
unsafe {
std::arch::asm!(
".inst 0x04bf5020", out("x0") vl,
options(nomem, nostack, preserves_flags),
);
}
vl
}
pub fn plug(ops: &mut Ops) {
let _ = ops;
if has_sve2() {
#[cfg(target_os = "linux")]
log::info!("SVE2 optimisation available (VL = {} bytes)", rdvl_bytes());
#[cfg(tract_sve)]
{
ops.mmm_f32 = Box::new(|_, _, _| sve_mmm_f32_8x8.mmm());
ops.mmv_f32 = Box::new(|_, _| sve_mmv_f32_64x1.mmm());
ops.qmmm_i32 = Box::new(|_, _, _| sve_mmm_i32_8x8.mmm());
ops.qmmv_i32 = Box::new(|_, _| sve_mmm_i32_64x1.mmm());
ops.rms_norm_f32 = Box::new(sve_rms_norm_f32);
ops.mmm_impls.extend_from_slice(&[
sve_mmm_f32_8x8.mmm(),
sve_mmv_f32_64x1.mmm(),
sve_mmm_i32_8x8.mmm(),
sve_mmm_i32_64x1.mmm(),
]);
if crate::arm64::has_fp16() {
ops.mmm_f16 = Box::new(|_, _, _| sve_mmm_f16_8x8.mmm());
ops.mmv_f16 = Box::new(|_, _| sve_mmv_f16_64x1.mmm());
ops.mmm_impls.extend_from_slice(&[sve_mmm_f16_8x8.mmm(), sve_mmv_f16_64x1.mmm()]);
}
}
} else if has_sve() {
log::info!("SVE (v1) present; SVE2 kernels not enabled");
} else {
log::info!("No SVE optimisation");
}
}
#[cfg(all(test, tract_sve))]
mod rms_norm_tests {
use super::*;
fn scalar_ref(buf: &mut [f32], eps: f32) {
let n = buf.len() as f32;
let s: f32 = buf.iter().map(|x| x * x).sum();
let inv_std = (s / n + eps).sqrt().recip();
for x in buf.iter_mut() {
*x *= inv_std;
}
}
fn close_enough(got: &[f32], want: &[f32], n: usize) {
let rel = 1e-5 + (n as f32).sqrt() * 1e-7;
let abs = 1e-5;
for (i, (g, w)) in got.iter().zip(want.iter()).enumerate() {
let tol = (rel * w.abs().max(1.0)).max(abs);
let diff = (g - w).abs();
assert!(diff <= tol, "idx {i}: got {g} want {w} diff {diff} tol {tol}");
}
}
fn check(n: usize, f: impl Fn(usize) -> f32) {
if !has_sve2() {
eprintln!("SVE2 not present, skipping (n={n})");
return;
}
let mut sve_buf: Vec<f32> = (0..n).map(&f).collect();
let mut ref_buf = sve_buf.clone();
sve_rms_norm_f32(&mut sve_buf, 1e-5);
scalar_ref(&mut ref_buf, 1e-5);
close_enough(&sve_buf, &ref_buf, n);
}
#[test]
fn empty_is_noop() {
let mut x: Vec<f32> = vec![];
sve_rms_norm_f32(&mut x, 1e-5);
assert!(x.is_empty());
}
#[test]
fn short_below_step() {
for n in [1usize, 3, 7, 8, 15, 16, 17, 31, 32, 33] {
check(n, |i| ((i as f32 * 0.13).sin() * 5.0) - 0.5);
}
}
#[test]
fn matches_reference_1024_with_tail() {
check(1024 + 7, |i| (i as f32 * 0.07).cos() * 3.0);
}
#[test]
fn matches_reference_4096() {
check(4096, |i| ((i as f32 * 0.001).sin() * 4.0) + ((i as f32 * 0.013).cos() * 0.5));
}
#[test]
fn all_zero() {
check(256, |_| 0.0);
}
#[test]
fn matches_neon_bit_close() {
if !has_sve2() {
eprintln!("SVE2 not present, skipping");
return;
}
for n in [16usize, 64, 1024, 1024 + 7, 4096, 8192] {
let x: Vec<f32> = (0..n).map(|i| (i as f32 * 0.11).sin() * 2.5).collect();
let mut sve_out = x.clone();
let mut neon_out = x.clone();
sve_rms_norm_f32(&mut sve_out, 1e-5);
crate::arm64::arm64simd_rms_norm_f32(&mut neon_out, 1e-5);
close_enough(&sve_out, &neon_out, n);
}
}
}