use super::traits::{SimdComplex, SimdVector};
#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[cfg(all(target_arch = "aarch64", feature = "sve"))]
pub fn sve_vector_length_bytes() -> usize {
detect_sve_length()
}
#[cfg(not(all(target_arch = "aarch64", feature = "sve")))]
pub fn sve_vector_length_bytes() -> usize {
0
}
#[cfg(all(target_arch = "aarch64", target_os = "linux"))]
fn detect_sve_length() -> usize {
0
}
#[cfg(all(target_arch = "aarch64", not(target_os = "linux")))]
fn detect_sve_length() -> usize {
0
}
pub fn sve_f64_lanes() -> usize {
let bytes = sve_vector_length_bytes();
if bytes == 0 {
0
} else {
bytes / 8 }
}
pub fn sve_f32_lanes() -> usize {
let bytes = sve_vector_length_bytes();
if bytes == 0 {
0
} else {
bytes / 4 }
}
#[cfg(all(target_arch = "aarch64", target_os = "linux", feature = "sve"))]
pub fn has_sve() -> bool {
#[allow(unsafe_code)]
unsafe {
const AT_HWCAP: libc::c_ulong = 16;
const HWCAP_SVE: u64 = 1 << 22;
let hwcap = libc::getauxval(AT_HWCAP);
(hwcap & HWCAP_SVE) != 0
}
}
#[cfg(not(all(target_arch = "aarch64", target_os = "linux", feature = "sve")))]
pub fn has_sve() -> bool {
false
}
#[derive(Copy, Clone, Debug)]
#[repr(align(32))]
pub struct Sve256F64 {
data: [f64; 4],
}
#[derive(Copy, Clone, Debug)]
#[repr(align(32))]
pub struct Sve256F32 {
data: [f32; 8],
}
unsafe impl Send for Sve256F64 {}
unsafe impl Sync for Sve256F64 {}
unsafe impl Send for Sve256F32 {}
unsafe impl Sync for Sve256F32 {}
impl SimdVector for Sve256F64 {
type Scalar = f64;
const LANES: usize = 4;
#[inline]
fn splat(value: f64) -> Self {
Self {
data: [value, value, value, value],
}
}
#[inline]
unsafe fn load_aligned(ptr: *const f64) -> Self {
unsafe {
Self {
data: [*ptr, *ptr.add(1), *ptr.add(2), *ptr.add(3)],
}
}
}
#[inline]
unsafe fn load_unaligned(ptr: *const f64) -> Self {
unsafe {
Self {
data: [*ptr, *ptr.add(1), *ptr.add(2), *ptr.add(3)],
}
}
}
#[inline]
unsafe fn store_aligned(self, ptr: *mut f64) {
unsafe {
*ptr = self.data[0];
*ptr.add(1) = self.data[1];
*ptr.add(2) = self.data[2];
*ptr.add(3) = self.data[3];
}
}
#[inline]
unsafe fn store_unaligned(self, ptr: *mut f64) {
unsafe {
*ptr = self.data[0];
*ptr.add(1) = self.data[1];
*ptr.add(2) = self.data[2];
*ptr.add(3) = self.data[3];
}
}
#[inline]
fn add(self, other: Self) -> Self {
Self {
data: [
self.data[0] + other.data[0],
self.data[1] + other.data[1],
self.data[2] + other.data[2],
self.data[3] + other.data[3],
],
}
}
#[inline]
fn sub(self, other: Self) -> Self {
Self {
data: [
self.data[0] - other.data[0],
self.data[1] - other.data[1],
self.data[2] - other.data[2],
self.data[3] - other.data[3],
],
}
}
#[inline]
fn mul(self, other: Self) -> Self {
Self {
data: [
self.data[0] * other.data[0],
self.data[1] * other.data[1],
self.data[2] * other.data[2],
self.data[3] * other.data[3],
],
}
}
#[inline]
fn div(self, other: Self) -> Self {
Self {
data: [
self.data[0] / other.data[0],
self.data[1] / other.data[1],
self.data[2] / other.data[2],
self.data[3] / other.data[3],
],
}
}
#[inline]
fn fmadd(self, a: Self, b: Self) -> Self {
Self {
data: [
self.data[0].mul_add(a.data[0], b.data[0]),
self.data[1].mul_add(a.data[1], b.data[1]),
self.data[2].mul_add(a.data[2], b.data[2]),
self.data[3].mul_add(a.data[3], b.data[3]),
],
}
}
#[inline]
fn fmsub(self, a: Self, b: Self) -> Self {
Self {
data: [
self.data[0].mul_add(a.data[0], -b.data[0]),
self.data[1].mul_add(a.data[1], -b.data[1]),
self.data[2].mul_add(a.data[2], -b.data[2]),
self.data[3].mul_add(a.data[3], -b.data[3]),
],
}
}
}
#[allow(dead_code)]
impl Sve256F64 {
#[inline]
pub fn new(a: f64, b: f64, c: f64, d: f64) -> Self {
Self { data: [a, b, c, d] }
}
#[inline]
pub fn extract(self, idx: usize) -> f64 {
self.data[idx]
}
#[inline]
pub fn negate(self) -> Self {
Self {
data: [-self.data[0], -self.data[1], -self.data[2], -self.data[3]],
}
}
#[inline]
pub fn low_pair(self) -> (f64, f64) {
(self.data[0], self.data[1])
}
#[inline]
pub fn high_pair(self) -> (f64, f64) {
(self.data[2], self.data[3])
}
#[inline]
pub fn zip_lo(self, other: Self) -> Self {
Self {
data: [self.data[0], other.data[0], self.data[1], other.data[1]],
}
}
#[inline]
pub fn zip_hi(self, other: Self) -> Self {
Self {
data: [self.data[2], other.data[2], self.data[3], other.data[3]],
}
}
#[inline]
pub fn add_predicated(self, other: Self, mask: [bool; 4]) -> Self {
Self {
data: [
if mask[0] {
self.data[0] + other.data[0]
} else {
self.data[0]
},
if mask[1] {
self.data[1] + other.data[1]
} else {
self.data[1]
},
if mask[2] {
self.data[2] + other.data[2]
} else {
self.data[2]
},
if mask[3] {
self.data[3] + other.data[3]
} else {
self.data[3]
},
],
}
}
#[inline]
pub unsafe fn load_predicated(ptr: *const f64, mask: [bool; 4]) -> Self {
unsafe {
Self {
data: [
if mask[0] { *ptr } else { 0.0 },
if mask[1] { *ptr.add(1) } else { 0.0 },
if mask[2] { *ptr.add(2) } else { 0.0 },
if mask[3] { *ptr.add(3) } else { 0.0 },
],
}
}
}
#[inline]
pub unsafe fn store_predicated(self, ptr: *mut f64, mask: [bool; 4]) {
unsafe {
if mask[0] {
*ptr = self.data[0];
}
if mask[1] {
*ptr.add(1) = self.data[1];
}
if mask[2] {
*ptr.add(2) = self.data[2];
}
if mask[3] {
*ptr.add(3) = self.data[3];
}
}
}
}
impl SimdComplex for Sve256F64 {
#[inline]
fn cmul(self, other: Self) -> Self {
let (a, b, c, d) = (self.data[0], self.data[1], self.data[2], self.data[3]);
let (e, f, g, h) = (other.data[0], other.data[1], other.data[2], other.data[3]);
Self {
data: [
a.mul_add(e, -(b * f)), a.mul_add(f, b * e), c.mul_add(g, -(d * h)), c.mul_add(h, d * g), ],
}
}
#[inline]
fn cmul_conj(self, other: Self) -> Self {
let (a, b, c, d) = (self.data[0], self.data[1], self.data[2], self.data[3]);
let (e, f, g, h) = (other.data[0], -other.data[1], other.data[2], -other.data[3]);
Self {
data: [
a.mul_add(e, -(b * f)),
a.mul_add(f, b * e),
c.mul_add(g, -(d * h)),
c.mul_add(h, d * g),
],
}
}
}
impl SimdVector for Sve256F32 {
type Scalar = f32;
const LANES: usize = 8;
#[inline]
fn splat(value: f32) -> Self {
Self { data: [value; 8] }
}
#[inline]
unsafe fn load_aligned(ptr: *const f32) -> Self {
let mut data = [0.0_f32; 8];
unsafe {
for i in 0..8 {
data[i] = *ptr.add(i);
}
}
Self { data }
}
#[inline]
unsafe fn load_unaligned(ptr: *const f32) -> Self {
let mut data = [0.0_f32; 8];
unsafe {
for i in 0..8 {
data[i] = *ptr.add(i);
}
}
Self { data }
}
#[inline]
unsafe fn store_aligned(self, ptr: *mut f32) {
unsafe {
for i in 0..8 {
*ptr.add(i) = self.data[i];
}
}
}
#[inline]
unsafe fn store_unaligned(self, ptr: *mut f32) {
unsafe {
for i in 0..8 {
*ptr.add(i) = self.data[i];
}
}
}
#[inline]
fn add(self, other: Self) -> Self {
let mut data = [0.0_f32; 8];
for i in 0..8 {
data[i] = self.data[i] + other.data[i];
}
Self { data }
}
#[inline]
fn sub(self, other: Self) -> Self {
let mut data = [0.0_f32; 8];
for i in 0..8 {
data[i] = self.data[i] - other.data[i];
}
Self { data }
}
#[inline]
fn mul(self, other: Self) -> Self {
let mut data = [0.0_f32; 8];
for i in 0..8 {
data[i] = self.data[i] * other.data[i];
}
Self { data }
}
#[inline]
fn div(self, other: Self) -> Self {
let mut data = [0.0_f32; 8];
for i in 0..8 {
data[i] = self.data[i] / other.data[i];
}
Self { data }
}
#[inline]
fn fmadd(self, a: Self, b: Self) -> Self {
let mut data = [0.0_f32; 8];
for i in 0..8 {
data[i] = self.data[i].mul_add(a.data[i], b.data[i]);
}
Self { data }
}
}
#[allow(dead_code)]
impl Sve256F32 {
#[inline]
pub fn new(a: f32, b: f32, c: f32, d: f32, e: f32, f: f32, g: f32, h: f32) -> Self {
Self {
data: [a, b, c, d, e, f, g, h],
}
}
#[inline]
pub fn extract(self, idx: usize) -> f32 {
self.data[idx]
}
#[inline]
pub fn negate(self) -> Self {
let mut data = [0.0_f32; 8];
for i in 0..8 {
data[i] = -self.data[i];
}
Self { data }
}
}
impl SimdComplex for Sve256F32 {
#[inline]
fn cmul(self, other: Self) -> Self {
let mut result = [0.0_f32; 8];
for i in 0..4 {
let re_idx = i * 2;
let im_idx = i * 2 + 1;
let a = self.data[re_idx];
let b = self.data[im_idx];
let e = other.data[re_idx];
let f = other.data[im_idx];
result[re_idx] = a.mul_add(e, -(b * f));
result[im_idx] = a.mul_add(f, b * e);
}
Self { data: result }
}
#[inline]
fn cmul_conj(self, other: Self) -> Self {
let mut result = [0.0_f32; 8];
for i in 0..4 {
let re_idx = i * 2;
let im_idx = i * 2 + 1;
let a = self.data[re_idx];
let b = self.data[im_idx];
let e = other.data[re_idx];
let f = -other.data[im_idx];
result[re_idx] = a.mul_add(e, -(b * f));
result[im_idx] = a.mul_add(f, b * e);
}
Self { data: result }
}
}
#[derive(Copy, Clone, Debug)]
pub struct SvePredicate<const N: usize> {
pub active: [bool; N],
}
impl<const N: usize> SvePredicate<N> {
#[inline]
pub fn all_true() -> Self {
Self { active: [true; N] }
}
#[inline]
pub fn all_false() -> Self {
Self { active: [false; N] }
}
#[inline]
pub fn while_lt(count: usize) -> Self {
let mut active = [false; N];
for item in active.iter_mut().take(count.min(N)) {
*item = true;
}
Self { active }
}
#[inline]
pub fn count_active(&self) -> usize {
self.active.iter().filter(|&&x| x).count()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sve256_f64_basic() {
let a = Sve256F64::splat(2.0);
let b = Sve256F64::splat(3.0);
let c = a.add(b);
assert_eq!(c.extract(0), 5.0);
assert_eq!(c.extract(1), 5.0);
assert_eq!(c.extract(2), 5.0);
assert_eq!(c.extract(3), 5.0);
}
#[test]
fn test_sve256_f64_new() {
let v = Sve256F64::new(1.0, 2.0, 3.0, 4.0);
assert_eq!(v.extract(0), 1.0);
assert_eq!(v.extract(1), 2.0);
assert_eq!(v.extract(2), 3.0);
assert_eq!(v.extract(3), 4.0);
}
#[test]
fn test_sve256_f64_fmadd() {
let a = Sve256F64::splat(2.0);
let b = Sve256F64::splat(3.0);
let c = Sve256F64::splat(4.0);
let result = a.fmadd(b, c);
for i in 0..4 {
assert_eq!(result.extract(i), 10.0);
}
}
#[test]
fn test_sve256_f64_load_store() {
let data = [1.0_f64, 2.0, 3.0, 4.0];
let v = unsafe { Sve256F64::load_unaligned(data.as_ptr()) };
let mut out = [0.0_f64; 4];
unsafe { v.store_unaligned(out.as_mut_ptr()) };
assert_eq!(data, out);
}
#[test]
fn test_sve256_f64_cmul() {
let a = Sve256F64::new(3.0, 4.0, 1.0, 0.0);
let b = Sve256F64::new(1.0, 2.0, 1.0, 0.0);
let c = a.cmul(b);
let tol = 1e-10;
assert!((c.extract(0) - (-5.0)).abs() < tol);
assert!((c.extract(1) - 10.0).abs() < tol);
assert!((c.extract(2) - 1.0).abs() < tol);
assert!((c.extract(3) - 0.0).abs() < tol);
}
#[test]
fn test_sve256_f64_predicated() {
let a = Sve256F64::new(1.0, 2.0, 3.0, 4.0);
let b = Sve256F64::new(10.0, 20.0, 30.0, 40.0);
let mask = [true, false, true, false];
let result = a.add_predicated(b, mask);
assert_eq!(result.extract(0), 11.0); assert_eq!(result.extract(1), 2.0); assert_eq!(result.extract(2), 33.0); assert_eq!(result.extract(3), 4.0); }
#[test]
fn test_sve256_f32_basic() {
let a = Sve256F32::splat(2.0);
let b = Sve256F32::splat(3.0);
let c = a.mul(b);
for i in 0..8 {
assert_eq!(c.extract(i), 6.0);
}
}
#[test]
fn test_sve256_f32_cmul() {
let a = Sve256F32::new(3.0, 4.0, 1.0, 0.0, 2.0, 0.0, 0.0, 1.0);
let b = Sve256F32::new(1.0, 2.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0);
let c = a.cmul(b);
let tol = 1e-5;
assert!((c.extract(0) - (-5.0)).abs() < tol);
assert!((c.extract(1) - 10.0).abs() < tol);
}
#[test]
fn test_sve_predicate() {
let pred: SvePredicate<4> = SvePredicate::all_true();
assert_eq!(pred.count_active(), 4);
let pred2: SvePredicate<4> = SvePredicate::while_lt(2);
assert_eq!(pred2.count_active(), 2);
assert!(pred2.active[0]);
assert!(pred2.active[1]);
assert!(!pred2.active[2]);
assert!(!pred2.active[3]);
}
}