use super::traits::{SimdComplex, SimdVector};
use core::arch::x86_64::*;
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
pub struct Sse2F64(pub __m128d);
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
pub struct Sse2F32(pub __m128);
unsafe impl Send for Sse2F64 {}
unsafe impl Sync for Sse2F64 {}
unsafe impl Send for Sse2F32 {}
unsafe impl Sync for Sse2F32 {}
impl SimdVector for Sse2F64 {
type Scalar = f64;
const LANES: usize = 2;
#[inline]
fn splat(value: f64) -> Self {
unsafe { Self(_mm_set1_pd(value)) }
}
#[inline]
unsafe fn load_aligned(ptr: *const f64) -> Self {
unsafe { Self(_mm_load_pd(ptr)) }
}
#[inline]
unsafe fn load_unaligned(ptr: *const f64) -> Self {
unsafe { Self(_mm_loadu_pd(ptr)) }
}
#[inline]
unsafe fn store_aligned(self, ptr: *mut f64) {
unsafe { _mm_store_pd(ptr, self.0) }
}
#[inline]
unsafe fn store_unaligned(self, ptr: *mut f64) {
unsafe { _mm_storeu_pd(ptr, self.0) }
}
#[inline]
fn add(self, other: Self) -> Self {
unsafe { Self(_mm_add_pd(self.0, other.0)) }
}
#[inline]
fn sub(self, other: Self) -> Self {
unsafe { Self(_mm_sub_pd(self.0, other.0)) }
}
#[inline]
fn mul(self, other: Self) -> Self {
unsafe { Self(_mm_mul_pd(self.0, other.0)) }
}
#[inline]
fn div(self, other: Self) -> Self {
unsafe { Self(_mm_div_pd(self.0, other.0)) }
}
}
impl Sse2F64 {
#[inline]
pub fn new(a: f64, b: f64) -> Self {
unsafe { Self(_mm_set_pd(b, a)) }
}
#[inline]
pub fn low(self) -> f64 {
unsafe { _mm_cvtsd_f64(self.0) }
}
#[inline]
pub fn high(self) -> f64 {
unsafe { _mm_cvtsd_f64(_mm_unpackhi_pd(self.0, self.0)) }
}
#[inline]
pub fn swap(self) -> Self {
unsafe { Self(_mm_shuffle_pd(self.0, self.0, 0b01)) }
}
#[inline]
pub fn dup_low(self) -> Self {
unsafe { Self(_mm_unpacklo_pd(self.0, self.0)) }
}
#[inline]
pub fn dup_high(self) -> Self {
unsafe { Self(_mm_unpackhi_pd(self.0, self.0)) }
}
#[inline]
pub fn negate_high(self) -> Self {
unsafe {
let sign_mask = _mm_set_pd(-0.0, 0.0);
Self(_mm_xor_pd(self.0, sign_mask))
}
}
#[inline]
pub fn negate_low(self) -> Self {
unsafe {
let sign_mask = _mm_set_pd(0.0, -0.0);
Self(_mm_xor_pd(self.0, sign_mask))
}
}
#[inline]
pub fn negate(self) -> Self {
unsafe {
let sign_mask = _mm_set1_pd(-0.0);
Self(_mm_xor_pd(self.0, sign_mask))
}
}
#[inline]
pub fn hadd(self) -> Self {
unsafe {
let swapped = _mm_shuffle_pd(self.0, self.0, 0b01);
Self(_mm_add_pd(self.0, swapped))
}
}
#[inline]
pub fn hsub(self) -> Self {
unsafe {
let swapped = _mm_shuffle_pd(self.0, self.0, 0b01);
Self(_mm_sub_pd(self.0, swapped))
}
}
#[inline]
pub fn unpack_lo(self, other: Self) -> Self {
unsafe { Self(_mm_unpacklo_pd(self.0, other.0)) }
}
#[inline]
pub fn unpack_hi(self, other: Self) -> Self {
unsafe { Self(_mm_unpackhi_pd(self.0, other.0)) }
}
}
impl SimdComplex for Sse2F64 {
#[inline]
fn cmul(self, other: Self) -> Self {
unsafe {
let a_re_re = _mm_unpacklo_pd(self.0, self.0); let a_im_im = _mm_unpackhi_pd(self.0, self.0);
let b_im_re = _mm_shuffle_pd(other.0, other.0, 0b01);
let prod1 = _mm_mul_pd(a_re_re, other.0); let prod2 = _mm_mul_pd(a_im_im, b_im_re);
let sign = _mm_set_pd(0.0, -0.0); let prod2_signed = _mm_xor_pd(prod2, sign);
Self(_mm_add_pd(prod1, prod2_signed))
}
}
#[inline]
fn cmul_conj(self, other: Self) -> Self {
unsafe {
let a_re_re = _mm_unpacklo_pd(self.0, self.0); let a_im_im = _mm_unpackhi_pd(self.0, self.0);
let b_im_re = _mm_shuffle_pd(other.0, other.0, 0b01);
let prod1 = _mm_mul_pd(a_re_re, other.0); let prod2 = _mm_mul_pd(a_im_im, b_im_re);
let sign = _mm_set_pd(-0.0, 0.0); let prod1_signed = _mm_xor_pd(prod1, sign);
Self(_mm_add_pd(prod1_signed, prod2))
}
}
}
impl SimdVector for Sse2F32 {
type Scalar = f32;
const LANES: usize = 4;
#[inline]
fn splat(value: f32) -> Self {
unsafe { Self(_mm_set1_ps(value)) }
}
#[inline]
unsafe fn load_aligned(ptr: *const f32) -> Self {
unsafe { Self(_mm_load_ps(ptr)) }
}
#[inline]
unsafe fn load_unaligned(ptr: *const f32) -> Self {
unsafe { Self(_mm_loadu_ps(ptr)) }
}
#[inline]
unsafe fn store_aligned(self, ptr: *mut f32) {
unsafe { _mm_store_ps(ptr, self.0) }
}
#[inline]
unsafe fn store_unaligned(self, ptr: *mut f32) {
unsafe { _mm_storeu_ps(ptr, self.0) }
}
#[inline]
fn add(self, other: Self) -> Self {
unsafe { Self(_mm_add_ps(self.0, other.0)) }
}
#[inline]
fn sub(self, other: Self) -> Self {
unsafe { Self(_mm_sub_ps(self.0, other.0)) }
}
#[inline]
fn mul(self, other: Self) -> Self {
unsafe { Self(_mm_mul_ps(self.0, other.0)) }
}
#[inline]
fn div(self, other: Self) -> Self {
unsafe { Self(_mm_div_ps(self.0, other.0)) }
}
}
impl Sse2F32 {
#[inline]
pub fn new(a: f32, b: f32, c: f32, d: f32) -> Self {
unsafe { Self(_mm_set_ps(d, c, b, a)) }
}
#[inline]
pub fn negate(self) -> Self {
unsafe {
let sign_mask = _mm_set1_ps(-0.0);
Self(_mm_xor_ps(self.0, sign_mask))
}
}
#[inline]
pub fn shuffle<const MASK: i32>(self, other: Self) -> Self {
unsafe { Self(_mm_shuffle_ps(self.0, other.0, MASK)) }
}
#[inline]
pub fn move_hl(self, other: Self) -> Self {
unsafe { Self(_mm_movehl_ps(self.0, other.0)) }
}
#[inline]
pub fn move_lh(self, other: Self) -> Self {
unsafe { Self(_mm_movelh_ps(self.0, other.0)) }
}
#[inline]
pub fn unpack_lo(self, other: Self) -> Self {
unsafe { Self(_mm_unpacklo_ps(self.0, other.0)) }
}
#[inline]
pub fn unpack_hi(self, other: Self) -> Self {
unsafe { Self(_mm_unpackhi_ps(self.0, other.0)) }
}
}
impl SimdComplex for Sse2F32 {
#[inline]
fn cmul(self, other: Self) -> Self {
unsafe {
let a_re = _mm_shuffle_ps(self.0, self.0, 0b1010_0000);
let a_im = _mm_shuffle_ps(self.0, self.0, 0b1111_0101);
let b_swap = _mm_shuffle_ps(other.0, other.0, 0b1011_0001);
let prod1 = _mm_mul_ps(a_re, other.0);
let prod2 = _mm_mul_ps(a_im, b_swap);
let sign = _mm_set_ps(0.0, -0.0, 0.0, -0.0);
let prod2_signed = _mm_xor_ps(prod2, sign);
Self(_mm_add_ps(prod1, prod2_signed))
}
}
#[inline]
fn cmul_conj(self, other: Self) -> Self {
unsafe {
let a_re = _mm_shuffle_ps(self.0, self.0, 0b1010_0000);
let a_im = _mm_shuffle_ps(self.0, self.0, 0b1111_0101);
let b_swap = _mm_shuffle_ps(other.0, other.0, 0b1011_0001);
let prod1 = _mm_mul_ps(a_re, other.0);
let prod2 = _mm_mul_ps(a_im, b_swap);
let sign = _mm_set_ps(-0.0, 0.0, -0.0, 0.0);
let prod1_signed = _mm_xor_ps(prod1, sign);
Self(_mm_add_ps(prod1_signed, prod2))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sse2_f64_basic() {
let a = Sse2F64::splat(2.0);
let b = Sse2F64::splat(3.0);
let sum = a.add(b);
assert_eq!(sum.low(), 5.0);
assert_eq!(sum.high(), 5.0);
let diff = a.sub(b);
assert_eq!(diff.low(), -1.0);
assert_eq!(diff.high(), -1.0);
let prod = a.mul(b);
assert_eq!(prod.low(), 6.0);
assert_eq!(prod.high(), 6.0);
}
#[test]
fn test_sse2_f64_new() {
let v = Sse2F64::new(1.0, 2.0);
assert_eq!(v.low(), 1.0);
assert_eq!(v.high(), 2.0);
}
#[test]
fn test_sse2_f64_swap() {
let v = Sse2F64::new(1.0, 2.0);
let swapped = v.swap();
assert_eq!(swapped.low(), 2.0);
assert_eq!(swapped.high(), 1.0);
}
#[test]
fn test_sse2_f64_negate() {
let v = Sse2F64::new(1.0, 2.0);
let neg = v.negate();
assert_eq!(neg.low(), -1.0);
assert_eq!(neg.high(), -2.0);
let neg_low = v.negate_low();
assert_eq!(neg_low.low(), -1.0);
assert_eq!(neg_low.high(), 2.0);
let neg_high = v.negate_high();
assert_eq!(neg_high.low(), 1.0);
assert_eq!(neg_high.high(), -2.0);
}
#[test]
fn test_sse2_f64_cmul() {
let a = Sse2F64::new(1.0, 2.0);
let b = Sse2F64::new(3.0, 4.0);
let c = a.cmul(b);
assert!((c.low() - (-5.0)).abs() < 1e-10);
assert!((c.high() - 10.0).abs() < 1e-10);
}
#[test]
fn test_sse2_f64_cmul_conj() {
let a = Sse2F64::new(1.0, 2.0);
let b = Sse2F64::new(3.0, 4.0);
let c = a.cmul_conj(b);
assert!((c.low() - 11.0).abs() < 1e-10);
assert!((c.high() - 2.0).abs() < 1e-10);
}
#[test]
fn test_sse2_f64_load_store() {
let data = [1.0_f64, 2.0];
let v = unsafe { Sse2F64::load_unaligned(data.as_ptr()) };
assert_eq!(v.low(), 1.0);
assert_eq!(v.high(), 2.0);
let mut out = [0.0_f64; 2];
unsafe { v.store_unaligned(out.as_mut_ptr()) };
assert_eq!(out, [1.0, 2.0]);
}
#[test]
fn test_sse2_f32_basic() {
let a = Sse2F32::splat(2.0);
let b = Sse2F32::splat(3.0);
let sum = a.add(b);
let mut out = [0.0_f32; 4];
unsafe { sum.store_unaligned(out.as_mut_ptr()) };
assert_eq!(out, [5.0, 5.0, 5.0, 5.0]);
}
#[test]
fn test_sse2_f32_new() {
let v = Sse2F32::new(1.0, 2.0, 3.0, 4.0);
let mut out = [0.0_f32; 4];
unsafe { v.store_unaligned(out.as_mut_ptr()) };
assert_eq!(out, [1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_sse2_f32_cmul() {
let a = Sse2F32::new(1.0, 2.0, 3.0, 4.0);
let b = Sse2F32::new(5.0, 6.0, 7.0, 8.0);
let c = a.cmul(b);
let mut out = [0.0_f32; 4];
unsafe { c.store_unaligned(out.as_mut_ptr()) };
assert!((out[0] - (-7.0)).abs() < 1e-5);
assert!((out[1] - 16.0).abs() < 1e-5);
assert!((out[2] - (-11.0)).abs() < 1e-5);
assert!((out[3] - 52.0).abs() < 1e-5);
}
}