#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SveCapabilities {
pub has_sve: bool,
pub has_sve2: bool,
pub vector_len_bytes: usize,
pub has_dotprod: bool,
}
impl Default for SveCapabilities {
fn default() -> Self {
detect_sve_capabilities()
}
}
pub fn detect_sve_capabilities() -> SveCapabilities {
#[cfg(target_arch = "aarch64")]
{
let has_sve = std::arch::is_aarch64_feature_detected!("sve");
let has_sve2 = std::arch::is_aarch64_feature_detected!("sve2");
let has_dotprod = std::arch::is_aarch64_feature_detected!("dotprod");
let vector_len_bytes = if has_sve {
unsafe { read_vector_length_bytes() }
} else {
0
};
SveCapabilities {
has_sve,
has_sve2,
vector_len_bytes,
has_dotprod,
}
}
#[cfg(not(target_arch = "aarch64"))]
{
SveCapabilities {
has_sve: false,
has_sve2: false,
vector_len_bytes: 0,
has_dotprod: false,
}
}
}
#[cfg(target_arch = "aarch64")]
unsafe fn read_vector_length_bytes() -> usize {
#[cfg(target_feature = "sve")]
{
let vl: usize;
core::arch::asm!(
"rdvl {vl}, #1",
vl = out(reg) vl,
options(nostack, pure, nomem)
);
vl
}
#[cfg(not(target_feature = "sve"))]
{
16_usize }
}
#[inline]
pub fn has_sve() -> bool {
#[cfg(target_arch = "aarch64")]
{
std::arch::is_aarch64_feature_detected!("sve")
}
#[cfg(not(target_arch = "aarch64"))]
{
false
}
}
#[inline]
pub fn has_sve2() -> bool {
#[cfg(target_arch = "aarch64")]
{
std::arch::is_aarch64_feature_detected!("sve2")
}
#[cfg(not(target_arch = "aarch64"))]
{
false
}
}
pub fn sve_add_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
unsafe { neon_add_f32_path(a, b, out) }
return;
}
}
scalar_add_f32(a, b, out)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn neon_add_f32_path(a: &[f32], b: &[f32], out: &mut [f32]) {
let len = a.len().min(b.len()).min(out.len());
let mut i = 0;
while i + 4 <= len {
let va = vld1q_f32(a.as_ptr().add(i));
let vb = vld1q_f32(b.as_ptr().add(i));
vst1q_f32(out.as_mut_ptr().add(i), vaddq_f32(va, vb));
i += 4;
}
while i < len {
out[i] = a[i] + b[i];
i += 1;
}
}
fn scalar_add_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
let len = a.len().min(b.len()).min(out.len());
for i in 0..len {
out[i] = a[i] + b[i];
}
}
pub fn sve_dot_f32(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { neon_dot_f32_path(a, b) };
}
}
scalar_dot_f32(a, b)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn neon_dot_f32_path(a: &[f32], b: &[f32]) -> f32 {
let len = a.len().min(b.len());
let mut i = 0;
let mut acc = vdupq_n_f32(0.0_f32);
while i + 4 <= len {
let va = vld1q_f32(a.as_ptr().add(i));
let vb = vld1q_f32(b.as_ptr().add(i));
acc = vfmaq_f32(acc, va, vb);
i += 4;
}
let mut result = vaddvq_f32(acc);
while i < len {
result += a[i] * b[i];
i += 1;
}
result
}
fn scalar_dot_f32(a: &[f32], b: &[f32]) -> f32 {
let len = a.len().min(b.len());
let mut sum = 0.0_f32;
for i in 0..len {
sum += a[i] * b[i];
}
sum
}
pub fn sve_sum_f32(a: &[f32]) -> f32 {
if a.is_empty() {
return 0.0;
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { neon_sum_f32_path(a) };
}
}
scalar_sum_f32(a)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn neon_sum_f32_path(a: &[f32]) -> f32 {
let len = a.len();
let mut i = 0;
let mut acc = vdupq_n_f32(0.0_f32);
while i + 4 <= len {
acc = vaddq_f32(acc, vld1q_f32(a.as_ptr().add(i)));
i += 4;
}
let mut result = vaddvq_f32(acc);
while i < len {
result += a[i];
i += 1;
}
result
}
fn scalar_sum_f32(a: &[f32]) -> f32 {
let mut sum = 0.0_f32;
for &x in a {
sum += x;
}
sum
}
pub fn sve_scale_f32(a: &[f32], scale: f32, out: &mut [f32]) {
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
unsafe { neon_scale_f32_path(a, scale, out) }
return;
}
}
scalar_scale_f32(a, scale, out)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn neon_scale_f32_path(a: &[f32], scale: f32, out: &mut [f32]) {
let len = a.len().min(out.len());
let mut i = 0;
let vscale = vdupq_n_f32(scale);
while i + 4 <= len {
let va = vld1q_f32(a.as_ptr().add(i));
vst1q_f32(out.as_mut_ptr().add(i), vmulq_f32(va, vscale));
i += 4;
}
while i < len {
out[i] = a[i] * scale;
i += 1;
}
}
fn scalar_scale_f32(a: &[f32], scale: f32, out: &mut [f32]) {
let len = a.len().min(out.len());
for i in 0..len {
out[i] = a[i] * scale;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sve_detection_does_not_panic() {
let caps = detect_sve_capabilities();
if caps.has_sve {
assert!(
caps.vector_len_bytes >= 16,
"SVE vector length must be at least 16 bytes"
);
assert_eq!(
caps.vector_len_bytes % 16,
0,
"SVE vector length must be a multiple of 16 bytes"
);
} else {
assert_eq!(caps.vector_len_bytes, 0);
}
assert_eq!(has_sve(), caps.has_sve);
assert_eq!(has_sve2(), caps.has_sve2);
}
#[test]
fn test_has_sve2_implies_has_sve() {
if has_sve2() {
assert!(has_sve(), "SVE2 implies SVE");
}
}
#[test]
fn test_sve_add_f32_basic() {
let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
let b = vec![10.0_f32, 20.0, 30.0, 40.0, 50.0];
let mut out = vec![0.0_f32; 5];
sve_add_f32(&a, &b, &mut out);
let expected = [11.0, 22.0, 33.0, 44.0, 55.0];
for (i, &exp) in expected.iter().enumerate() {
assert!(
(out[i] - exp).abs() < 1e-6,
"out[{i}]={} expected {exp}",
out[i]
);
}
}
#[test]
fn test_sve_add_f32_large() {
let n = 100;
let a = vec![1.0_f32; n];
let b = vec![2.0_f32; n];
let mut out = vec![0.0_f32; n];
sve_add_f32(&a, &b, &mut out);
for v in &out {
assert!((*v - 3.0).abs() < 1e-6);
}
}
#[test]
fn test_sve_dot_f32_basic() {
let a = vec![1.0_f32, 2.0, 3.0, 4.0];
let b = vec![1.0_f32; 4];
let result = sve_dot_f32(&a, &b);
assert!((result - 10.0).abs() < 1e-5, "expected 10.0 got {result}");
}
#[test]
fn test_sve_dot_f32_matches_scalar() {
let a: Vec<f32> = (0..21).map(|i| i as f32).collect();
let b: Vec<f32> = (0..21).map(|i| (i as f32) * 0.5).collect();
let scalar = scalar_dot_f32(&a, &b);
let simd = sve_dot_f32(&a, &b);
assert!(
(scalar - simd).abs() <= scalar.abs() * 1e-5 + 1e-5,
"scalar={scalar} simd={simd}"
);
}
#[test]
fn test_sve_sum_f32_basic() {
let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
let result = sve_sum_f32(&a);
assert!((result - 28.0).abs() < 1e-5, "expected 28.0 got {result}");
}
#[test]
fn test_sve_sum_f32_empty() {
assert_eq!(sve_sum_f32(&[]), 0.0);
}
#[test]
fn test_sve_scale_f32_basic() {
let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
let mut out = vec![0.0_f32; 5];
sve_scale_f32(&a, 3.0, &mut out);
let expected = [3.0, 6.0, 9.0, 12.0, 15.0];
for (i, &exp) in expected.iter().enumerate() {
assert!(
(out[i] - exp).abs() < 1e-5,
"out[{i}]={} expected {exp}",
out[i]
);
}
}
#[test]
fn test_sve_scale_f32_matches_scalar() {
let a: Vec<f32> = (0..17).map(|i| i as f32 - 8.0).collect();
let scale = 1.23456_f32;
let mut ref_out = vec![0.0_f32; 17];
let mut simd_out = vec![0.0_f32; 17];
scalar_scale_f32(&a, scale, &mut ref_out);
sve_scale_f32(&a, scale, &mut simd_out);
for i in 0..17 {
assert!((ref_out[i] - simd_out[i]).abs() < 1e-5, "mismatch at i={i}");
}
}
}