Skip to main content

oxifft/simd/
avx.rs

1//! AVX SIMD implementation for x86_64.
2//!
3//! Provides 256-bit SIMD operations using AVX intrinsics.
4//! - f64: 4 lanes (256-bit = 4 × 64-bit)
5//! - f32: 8 lanes (256-bit = 8 × 32-bit)
6
7use super::traits::{SimdComplex, SimdVector};
8use core::arch::x86_64::*;
9
10/// AVX f64 vector type (4 lanes).
11#[derive(Copy, Clone, Debug)]
12#[repr(transparent)]
13pub struct AvxF64(pub __m256d);
14
15/// AVX f32 vector type (8 lanes).
16#[derive(Copy, Clone, Debug)]
17#[repr(transparent)]
18pub struct AvxF32(pub __m256);
19
20// Safety: AVX vectors are POD types that can be safely sent between threads
21unsafe impl Send for AvxF64 {}
22unsafe impl Sync for AvxF64 {}
23unsafe impl Send for AvxF32 {}
24unsafe impl Sync for AvxF32 {}
25
26impl SimdVector for AvxF64 {
27    type Scalar = f64;
28    const LANES: usize = 4;
29
30    #[inline]
31    fn splat(value: f64) -> Self {
32        unsafe { Self(_mm256_set1_pd(value)) }
33    }
34
35    #[inline]
36    unsafe fn load_aligned(ptr: *const f64) -> Self {
37        unsafe { Self(_mm256_load_pd(ptr)) }
38    }
39
40    #[inline]
41    unsafe fn load_unaligned(ptr: *const f64) -> Self {
42        unsafe { Self(_mm256_loadu_pd(ptr)) }
43    }
44
45    #[inline]
46    unsafe fn store_aligned(self, ptr: *mut f64) {
47        unsafe { _mm256_store_pd(ptr, self.0) }
48    }
49
50    #[inline]
51    unsafe fn store_unaligned(self, ptr: *mut f64) {
52        unsafe { _mm256_storeu_pd(ptr, self.0) }
53    }
54
55    #[inline]
56    fn add(self, other: Self) -> Self {
57        unsafe { Self(_mm256_add_pd(self.0, other.0)) }
58    }
59
60    #[inline]
61    fn sub(self, other: Self) -> Self {
62        unsafe { Self(_mm256_sub_pd(self.0, other.0)) }
63    }
64
65    #[inline]
66    fn mul(self, other: Self) -> Self {
67        unsafe { Self(_mm256_mul_pd(self.0, other.0)) }
68    }
69
70    #[inline]
71    fn div(self, other: Self) -> Self {
72        unsafe { Self(_mm256_div_pd(self.0, other.0)) }
73    }
74}
75
76#[allow(dead_code)]
77impl AvxF64 {
78    /// Create a vector from four f64 values: [a, b, c, d]
79    #[inline]
80    pub fn new(a: f64, b: f64, c: f64, d: f64) -> Self {
81        unsafe { Self(_mm256_set_pd(d, c, b, a)) }
82    }
83
84    /// Extract element at index (0-3).
85    #[inline]
86    pub fn extract(self, idx: usize) -> f64 {
87        debug_assert!(idx < 4);
88        let mut arr = [0.0_f64; 4];
89        unsafe { self.store_unaligned(arr.as_mut_ptr()) };
90        arr[idx]
91    }
92
93    /// Negate all elements.
94    #[inline]
95    pub fn negate(self) -> Self {
96        unsafe {
97            let sign_mask = _mm256_set1_pd(-0.0);
98            Self(_mm256_xor_pd(self.0, sign_mask))
99        }
100    }
101
102    /// Permute elements within 128-bit lanes.
103    /// Each 128-bit lane is permuted independently.
104    #[inline]
105    pub fn shuffle_within_lanes<const MASK: i32>(self) -> Self {
106        unsafe { Self(_mm256_shuffle_pd(self.0, self.0, MASK)) }
107    }
108
109    /// Swap 128-bit lanes: [a, b, c, d] -> [c, d, a, b]
110    #[inline]
111    pub fn swap_lanes(self) -> Self {
112        unsafe { Self(_mm256_permute2f128_pd(self.0, self.0, 0x01)) }
113    }
114
115    /// Interleave low elements from two vectors.
116    #[inline]
117    pub fn unpack_lo(self, other: Self) -> Self {
118        unsafe { Self(_mm256_unpacklo_pd(self.0, other.0)) }
119    }
120
121    /// Interleave high elements from two vectors.
122    #[inline]
123    pub fn unpack_hi(self, other: Self) -> Self {
124        unsafe { Self(_mm256_unpackhi_pd(self.0, other.0)) }
125    }
126
127    /// Blend elements from two vectors based on mask.
128    #[inline]
129    pub fn blend<const MASK: i32>(self, other: Self) -> Self {
130        unsafe { Self(_mm256_blend_pd(self.0, other.0, MASK)) }
131    }
132
133    /// Get the low 128-bit lane as SSE vector.
134    #[inline]
135    pub fn low_128(self) -> super::sse2::Sse2F64 {
136        unsafe { super::sse2::Sse2F64(_mm256_castpd256_pd128(self.0)) }
137    }
138
139    /// Get the high 128-bit lane as SSE vector.
140    #[inline]
141    pub fn high_128(self) -> super::sse2::Sse2F64 {
142        unsafe { super::sse2::Sse2F64(_mm256_extractf128_pd(self.0, 1)) }
143    }
144}
145
146impl SimdComplex for AvxF64 {
147    /// Complex multiply for 2 interleaved complex numbers.
148    ///
149    /// Format: [re0, im0, re1, im1]
150    /// Computes: [(re0*re0' - im0*im0', re0*im0' + im0*re0'), (re1*re1' - im1*im1', re1*im1' + im1*re1')]
151    #[inline]
152    fn cmul(self, other: Self) -> Self {
153        unsafe {
154            // Duplicate real parts: [re0, re0, re1, re1]
155            let a_re = _mm256_unpacklo_pd(self.0, self.0);
156            // Duplicate imag parts: [im0, im0, im1, im1]
157            let a_im = _mm256_unpackhi_pd(self.0, self.0);
158
159            // Swap pairs in b: [im0, re0, im1, re1]
160            let b_swap = _mm256_shuffle_pd(other.0, other.0, 0b0101);
161
162            // prod1 = a_re * b = [re0*re0', re0*im0', re1*re1', re1*im1']
163            let prod1 = _mm256_mul_pd(a_re, other.0);
164            // prod2 = a_im * b_swap = [im0*im0', im0*re0', im1*im1', im1*re1']
165            let prod2 = _mm256_mul_pd(a_im, b_swap);
166
167            // Combine with addsub: [re*re - im*im, re*im + im*re, ...]
168            Self(_mm256_addsub_pd(prod1, prod2))
169        }
170    }
171
172    /// Complex conjugate multiply.
173    #[inline]
174    fn cmul_conj(self, other: Self) -> Self {
175        unsafe {
176            let a_re = _mm256_unpacklo_pd(self.0, self.0);
177            let a_im = _mm256_unpackhi_pd(self.0, self.0);
178            let b_swap = _mm256_shuffle_pd(other.0, other.0, 0b0101);
179
180            let prod1 = _mm256_mul_pd(a_re, other.0);
181            let prod2 = _mm256_mul_pd(a_im, b_swap);
182
183            // For conjugate: swap add/sub pattern
184            // [re*re + im*im, -re*im + im*re]
185            let sign = _mm256_set_pd(-0.0, 0.0, -0.0, 0.0);
186            let prod1_signed = _mm256_xor_pd(prod1, sign);
187            Self(_mm256_add_pd(prod1_signed, prod2))
188        }
189    }
190}
191
192impl SimdVector for AvxF32 {
193    type Scalar = f32;
194    const LANES: usize = 8;
195
196    #[inline]
197    fn splat(value: f32) -> Self {
198        unsafe { Self(_mm256_set1_ps(value)) }
199    }
200
201    #[inline]
202    unsafe fn load_aligned(ptr: *const f32) -> Self {
203        unsafe { Self(_mm256_load_ps(ptr)) }
204    }
205
206    #[inline]
207    unsafe fn load_unaligned(ptr: *const f32) -> Self {
208        unsafe { Self(_mm256_loadu_ps(ptr)) }
209    }
210
211    #[inline]
212    unsafe fn store_aligned(self, ptr: *mut f32) {
213        unsafe { _mm256_store_ps(ptr, self.0) }
214    }
215
216    #[inline]
217    unsafe fn store_unaligned(self, ptr: *mut f32) {
218        unsafe { _mm256_storeu_ps(ptr, self.0) }
219    }
220
221    #[inline]
222    fn add(self, other: Self) -> Self {
223        unsafe { Self(_mm256_add_ps(self.0, other.0)) }
224    }
225
226    #[inline]
227    fn sub(self, other: Self) -> Self {
228        unsafe { Self(_mm256_sub_ps(self.0, other.0)) }
229    }
230
231    #[inline]
232    fn mul(self, other: Self) -> Self {
233        unsafe { Self(_mm256_mul_ps(self.0, other.0)) }
234    }
235
236    #[inline]
237    fn div(self, other: Self) -> Self {
238        unsafe { Self(_mm256_div_ps(self.0, other.0)) }
239    }
240}
241
242#[allow(dead_code)]
243impl AvxF32 {
244    /// Create a vector from eight f32 values.
245    #[inline]
246    pub fn new(a: f32, b: f32, c: f32, d: f32, e: f32, f: f32, g: f32, h: f32) -> Self {
247        unsafe { Self(_mm256_set_ps(h, g, f, e, d, c, b, a)) }
248    }
249
250    /// Negate all elements.
251    #[inline]
252    pub fn negate(self) -> Self {
253        unsafe {
254            let sign_mask = _mm256_set1_ps(-0.0);
255            Self(_mm256_xor_ps(self.0, sign_mask))
256        }
257    }
258
259    /// Interleave low elements.
260    #[inline]
261    pub fn unpack_lo(self, other: Self) -> Self {
262        unsafe { Self(_mm256_unpacklo_ps(self.0, other.0)) }
263    }
264
265    /// Interleave high elements.
266    #[inline]
267    pub fn unpack_hi(self, other: Self) -> Self {
268        unsafe { Self(_mm256_unpackhi_ps(self.0, other.0)) }
269    }
270
271    /// Swap 128-bit lanes.
272    #[inline]
273    pub fn swap_lanes(self) -> Self {
274        unsafe { Self(_mm256_permute2f128_ps(self.0, self.0, 0x01)) }
275    }
276
277    /// Get the low 128-bit lane as SSE vector.
278    #[inline]
279    pub fn low_128(self) -> super::sse2::Sse2F32 {
280        unsafe { super::sse2::Sse2F32(_mm256_castps256_ps128(self.0)) }
281    }
282
283    /// Get the high 128-bit lane as SSE vector.
284    #[inline]
285    pub fn high_128(self) -> super::sse2::Sse2F32 {
286        unsafe { super::sse2::Sse2F32(_mm256_extractf128_ps(self.0, 1)) }
287    }
288}
289
290impl SimdComplex for AvxF32 {
291    /// Complex multiply for 4 interleaved complex numbers.
292    ///
293    /// Format: [re0, im0, re1, im1, re2, im2, re3, im3]
294    #[inline]
295    fn cmul(self, other: Self) -> Self {
296        unsafe {
297            // Duplicate real parts: [re0, re0, re1, re1, re2, re2, re3, re3]
298            let a_re = _mm256_shuffle_ps(self.0, self.0, 0b1010_0000);
299            // Duplicate imag parts: [im0, im0, im1, im1, im2, im2, im3, im3]
300            let a_im = _mm256_shuffle_ps(self.0, self.0, 0b1111_0101);
301
302            // Swap pairs in b: [im0, re0, im1, re1, im2, re2, im3, re3]
303            let b_swap = _mm256_shuffle_ps(other.0, other.0, 0b1011_0001);
304
305            let prod1 = _mm256_mul_ps(a_re, other.0);
306            let prod2 = _mm256_mul_ps(a_im, b_swap);
307
308            // addsub pattern: [re*re - im*im, re*im + im*re, ...]
309            Self(_mm256_addsub_ps(prod1, prod2))
310        }
311    }
312
313    /// Complex conjugate multiply.
314    #[inline]
315    fn cmul_conj(self, other: Self) -> Self {
316        unsafe {
317            let a_re = _mm256_shuffle_ps(self.0, self.0, 0b1010_0000);
318            let a_im = _mm256_shuffle_ps(self.0, self.0, 0b1111_0101);
319            let b_swap = _mm256_shuffle_ps(other.0, other.0, 0b1011_0001);
320
321            let prod1 = _mm256_mul_ps(a_re, other.0);
322            let prod2 = _mm256_mul_ps(a_im, b_swap);
323
324            // For conjugate: [re*re + im*im, -re*im + im*re]
325            let sign = _mm256_set_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
326            let prod1_signed = _mm256_xor_ps(prod1, sign);
327            Self(_mm256_add_ps(prod1_signed, prod2))
328        }
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    fn has_avx() -> bool {
337        is_x86_feature_detected!("avx")
338    }
339
340    #[test]
341    fn test_avx_f64_basic() {
342        if !has_avx() {
343            return;
344        }
345
346        let a = AvxF64::splat(2.0);
347        let b = AvxF64::splat(3.0);
348
349        let sum = a.add(b);
350        assert_eq!(sum.extract(0), 5.0);
351        assert_eq!(sum.extract(3), 5.0);
352
353        let diff = a.sub(b);
354        assert_eq!(diff.extract(0), -1.0);
355
356        let prod = a.mul(b);
357        assert_eq!(prod.extract(0), 6.0);
358    }
359
360    #[test]
361    fn test_avx_f64_new() {
362        if !has_avx() {
363            return;
364        }
365
366        let v = AvxF64::new(1.0, 2.0, 3.0, 4.0);
367        assert_eq!(v.extract(0), 1.0);
368        assert_eq!(v.extract(1), 2.0);
369        assert_eq!(v.extract(2), 3.0);
370        assert_eq!(v.extract(3), 4.0);
371    }
372
373    #[test]
374    fn test_avx_f64_cmul() {
375        if !has_avx() {
376            return;
377        }
378
379        // Two complex: (1+2i), (3+4i)
380        // (1+2i)*(5+6i) = (1*5-2*6) + (1*6+2*5)i = -7 + 16i
381        // (3+4i)*(7+8i) = (3*7-4*8) + (3*8+4*7)i = -11 + 52i
382        let a = AvxF64::new(1.0, 2.0, 3.0, 4.0);
383        let b = AvxF64::new(5.0, 6.0, 7.0, 8.0);
384        let c = a.cmul(b);
385        assert!((c.extract(0) - (-7.0)).abs() < 1e-10);
386        assert!((c.extract(1) - 16.0).abs() < 1e-10);
387        assert!((c.extract(2) - (-11.0)).abs() < 1e-10);
388        assert!((c.extract(3) - 52.0).abs() < 1e-10);
389    }
390
391    #[test]
392    fn test_avx_f64_load_store() {
393        if !has_avx() {
394            return;
395        }
396
397        let data = [1.0_f64, 2.0, 3.0, 4.0];
398        unsafe {
399            let v = AvxF64::load_unaligned(data.as_ptr());
400            assert_eq!(v.extract(0), 1.0);
401            assert_eq!(v.extract(3), 4.0);
402
403            let mut out = [0.0_f64; 4];
404            v.store_unaligned(out.as_mut_ptr());
405            assert_eq!(out, [1.0, 2.0, 3.0, 4.0]);
406        }
407    }
408
409    #[test]
410    fn test_avx_f32_basic() {
411        if !has_avx() {
412            return;
413        }
414
415        let a = AvxF32::splat(2.0);
416        let b = AvxF32::splat(3.0);
417
418        let sum = a.add(b);
419        let mut out = [0.0_f32; 8];
420        unsafe { sum.store_unaligned(out.as_mut_ptr()) };
421        assert_eq!(out, [5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0]);
422    }
423
424    #[test]
425    fn test_avx_f32_cmul() {
426        if !has_avx() {
427            return;
428        }
429
430        // Four complex numbers
431        let a = AvxF32::new(1.0, 2.0, 3.0, 4.0, 1.0, 0.0, 0.0, 1.0);
432        let b = AvxF32::new(5.0, 6.0, 7.0, 8.0, 1.0, 0.0, 0.0, 1.0);
433        let c = a.cmul(b);
434        let mut out = [0.0_f32; 8];
435        unsafe { c.store_unaligned(out.as_mut_ptr()) };
436        // (1+2i)*(5+6i) = -7 + 16i
437        assert!((out[0] - (-7.0)).abs() < 1e-5);
438        assert!((out[1] - 16.0).abs() < 1e-5);
439        // (3+4i)*(7+8i) = -11 + 52i
440        assert!((out[2] - (-11.0)).abs() < 1e-5);
441        assert!((out[3] - 52.0).abs() < 1e-5);
442        // (1+0i)*(1+0i) = 1+0i
443        assert!((out[4] - 1.0).abs() < 1e-5);
444        assert!((out[5] - 0.0).abs() < 1e-5);
445        // (0+1i)*(0+1i) = -1+0i
446        assert!((out[6] - (-1.0)).abs() < 1e-5);
447        assert!((out[7] - 0.0).abs() < 1e-5);
448    }
449}