generic_simd/arch/x86/
complex.rs

1#[cfg(target_arch = "x86")]
2use core::arch::x86::*;
3#[cfg(target_arch = "x86_64")]
4use core::arch::x86_64::*;
5
6use crate::{
7    arch::{generic, x86::*, Token},
8    scalar::Scalar,
9    shim::{Shim2, Shim4, Shim8, ShimToken},
10    vector::{width, Native, Vector},
11};
12use num_complex::Complex;
13
14impl Native<Sse> for Complex<f32> {
15    type Width = width::W2;
16}
17
18impl Native<Sse> for Complex<f64> {
19    type Width = width::W1;
20}
21
22impl Native<Avx> for Complex<f32> {
23    type Width = width::W4;
24}
25
26impl Native<Avx> for Complex<f64> {
27    type Width = width::W2;
28}
29
30/// An SSE vector of `Complex<f32>`s.
31///
32/// Requires feature `"complex"`.
33#[derive(Clone, Copy, Debug)]
34#[repr(transparent)]
35#[allow(non_camel_case_types)]
36pub struct cf32x2(__m128);
37
38/// An SSE vector of `Complex<f64>`s.
39///
40/// Requires feature `"complex"`.
41#[derive(Clone, Copy, Debug)]
42#[repr(transparent)]
43#[allow(non_camel_case_types)]
44pub struct cf64x1(__m128d);
45
46/// An AVX vector of `Complex<f32>`s.
47///
48/// Requires feature `"complex"`.
49#[derive(Clone, Copy, Debug)]
50#[repr(transparent)]
51#[allow(non_camel_case_types)]
52pub struct cf32x4(__m256);
53
54/// An AVX vector of `Complex<f64>`s.
55///
56/// Requires feature `"complex"`.
57#[derive(Clone, Copy, Debug)]
58#[repr(transparent)]
59#[allow(non_camel_case_types)]
60pub struct cf64x2(__m256d);
61
62impl Scalar<Sse, width::W1> for Complex<f32> {
63    type Vector = ShimToken<generic::cf32x1, Self, Sse>;
64}
65
66impl Scalar<Sse, width::W2> for Complex<f32> {
67    type Vector = cf32x2;
68}
69
70impl Scalar<Sse, width::W4> for Complex<f32> {
71    type Vector = Shim2<cf32x2, Complex<f32>>;
72}
73
74impl Scalar<Sse, width::W8> for Complex<f32> {
75    type Vector = Shim4<cf32x2, Complex<f32>>;
76}
77
78impl Scalar<Sse, width::W1> for Complex<f64> {
79    type Vector = cf64x1;
80}
81
82impl Scalar<Sse, width::W2> for Complex<f64> {
83    type Vector = Shim2<cf64x1, Complex<f64>>;
84}
85
86impl Scalar<Sse, width::W4> for Complex<f64> {
87    type Vector = Shim4<cf64x1, Complex<f64>>;
88}
89
90impl Scalar<Sse, width::W8> for Complex<f64> {
91    type Vector = Shim8<cf64x1, Self>;
92}
93
94impl Scalar<Avx, width::W1> for Complex<f32> {
95    type Vector = ShimToken<generic::cf32x1, Self, Avx>;
96}
97
98impl Scalar<Avx, width::W2> for Complex<f32> {
99    type Vector = ShimToken<cf32x2, Self, Avx>;
100}
101
102impl Scalar<Avx, width::W4> for Complex<f32> {
103    type Vector = cf32x4;
104}
105
106impl Scalar<Avx, width::W8> for Complex<f32> {
107    type Vector = Shim2<cf32x4, Complex<f32>>;
108}
109
110impl Scalar<Avx, width::W1> for Complex<f64> {
111    type Vector = ShimToken<cf64x1, Self, Avx>;
112}
113
114impl Scalar<Avx, width::W2> for Complex<f64> {
115    type Vector = cf64x2;
116}
117
118impl Scalar<Avx, width::W4> for Complex<f64> {
119    type Vector = Shim2<cf64x2, Complex<f64>>;
120}
121
122impl Scalar<Avx, width::W8> for Complex<f64> {
123    type Vector = Shim4<cf64x2, Complex<f64>>;
124}
125
126arithmetic_ops! {
127    feature: Sse::new_unchecked(),
128    for cf32x2:
129        add -> (_mm_add_ps),
130        sub -> (_mm_sub_ps),
131        mul -> (mul_cf32x2),
132        div -> (div_cf32x2)
133}
134
135arithmetic_ops! {
136    feature: Sse::new_unchecked(),
137    for cf64x1:
138        add -> (_mm_add_pd),
139        sub -> (_mm_sub_pd),
140        mul -> (mul_cf64x1),
141        div -> (div_cf64x1)
142}
143
144arithmetic_ops! {
145    feature: Avx::new_unchecked(),
146    for cf32x4:
147        add -> (_mm256_add_ps),
148        sub -> (_mm256_sub_ps),
149        mul -> (mul_cf32x4),
150        div -> (div_cf32x4)
151}
152
153arithmetic_ops! {
154    feature: Avx::new_unchecked(),
155    for cf64x2:
156        add -> (_mm256_add_pd),
157        sub -> (_mm256_sub_pd),
158        mul -> (mul_cf64x2),
159        div -> (div_cf64x2)
160}
161
162#[target_feature(enable = "sse3")]
163#[inline]
164unsafe fn mul_cf32x2(a: __m128, b: __m128) -> __m128 {
165    let re = _mm_moveldup_ps(a);
166    let im = _mm_movehdup_ps(a);
167    let sh = _mm_shuffle_ps(b, b, 0xb1);
168    _mm_addsub_ps(_mm_mul_ps(re, b), _mm_mul_ps(im, sh))
169}
170
171#[target_feature(enable = "sse3")]
172#[inline]
173unsafe fn mul_cf64x1(a: __m128d, b: __m128d) -> __m128d {
174    let re = _mm_shuffle_pd(a, a, 0x00);
175    let im = _mm_shuffle_pd(a, a, 0x03);
176    let sh = _mm_shuffle_pd(b, b, 0x01);
177    _mm_addsub_pd(_mm_mul_pd(re, b), _mm_mul_pd(im, sh))
178}
179
180// [(a.re * b.re + a.im * b.im) / (b.re * b.re + b.im * b.im)] + i [(a.im * b.re - a.re * b.im) / (b.re * b.re + b.im * b.im)]
181#[target_feature(enable = "sse3")]
182#[inline]
183unsafe fn div_cf32x2(a: __m128, b: __m128) -> __m128 {
184    let b_re = _mm_moveldup_ps(b);
185    let b_im = _mm_movehdup_ps(b);
186    let a_flip = _mm_shuffle_ps(a, a, 0xb1);
187    let norm_sqr = _mm_add_ps(_mm_mul_ps(b_re, b_re), _mm_mul_ps(b_im, b_im));
188    _mm_div_ps(
189        _mm_addsub_ps(
190            _mm_mul_ps(a, b_re),
191            _mm_xor_ps(_mm_mul_ps(a_flip, b_im), _mm_set1_ps(-0.)),
192        ),
193        norm_sqr,
194    )
195}
196
197#[target_feature(enable = "sse3")]
198#[inline]
199unsafe fn div_cf64x1(a: __m128d, b: __m128d) -> __m128d {
200    let b_re = _mm_shuffle_pd(b, b, 0x00);
201    let b_im = _mm_shuffle_pd(b, b, 0x03);
202    let a_flip = _mm_shuffle_pd(a, a, 0x01);
203    let norm_sqr = _mm_add_pd(_mm_mul_pd(b_re, b_re), _mm_mul_pd(b_im, b_im));
204    _mm_div_pd(
205        _mm_addsub_pd(
206            _mm_mul_pd(a, b_re),
207            _mm_xor_pd(_mm_mul_pd(a_flip, b_im), _mm_set1_pd(-0.)),
208        ),
209        norm_sqr,
210    )
211}
212
213#[target_feature(enable = "avx")]
214#[inline]
215unsafe fn mul_cf32x4(a: __m256, b: __m256) -> __m256 {
216    let re = _mm256_moveldup_ps(a);
217    let im = _mm256_movehdup_ps(a);
218    let sh = _mm256_shuffle_ps(b, b, 0xb1);
219    _mm256_addsub_ps(_mm256_mul_ps(re, b), _mm256_mul_ps(im, sh))
220}
221
222#[target_feature(enable = "avx")]
223#[inline]
224unsafe fn mul_cf64x2(a: __m256d, b: __m256d) -> __m256d {
225    let re = _mm256_unpacklo_pd(a, a);
226    let im = _mm256_unpackhi_pd(a, a);
227    let sh = _mm256_shuffle_pd(b, b, 0x5);
228    _mm256_addsub_pd(_mm256_mul_pd(re, b), _mm256_mul_pd(im, sh))
229}
230
231// [(a.re * b.re + a.im * b.im) / (b.re * b.re + b.im * b.im)] + i [(a.im * b.re - a.re * b.im) / (b.re * b.re + b.im * b.im)]
232#[target_feature(enable = "avx")]
233#[inline]
234unsafe fn div_cf32x4(a: __m256, b: __m256) -> __m256 {
235    let b_re = _mm256_moveldup_ps(b);
236    let b_im = _mm256_movehdup_ps(b);
237    let a_flip = _mm256_shuffle_ps(a, a, 0xb1);
238    let norm_sqr = _mm256_add_ps(_mm256_mul_ps(b_re, b_re), _mm256_mul_ps(b_im, b_im));
239    _mm256_div_ps(
240        _mm256_addsub_ps(
241            _mm256_mul_ps(a, b_re),
242            _mm256_xor_ps(_mm256_mul_ps(a_flip, b_im), _mm256_set1_ps(-0.)),
243        ),
244        norm_sqr,
245    )
246}
247
248#[target_feature(enable = "avx")]
249#[inline]
250unsafe fn div_cf64x2(a: __m256d, b: __m256d) -> __m256d {
251    let b_re = _mm256_unpacklo_pd(b, b);
252    let b_im = _mm256_unpackhi_pd(b, b);
253    let a_flip = _mm256_shuffle_pd(a, a, 0x5);
254    let norm_sqr = _mm256_add_pd(_mm256_mul_pd(b_re, b_re), _mm256_mul_pd(b_im, b_im));
255    _mm256_div_pd(
256        _mm256_addsub_pd(
257            _mm256_mul_pd(a, b_re),
258            _mm256_xor_pd(_mm256_mul_pd(a_flip, b_im), _mm256_set1_pd(-0.)),
259        ),
260        norm_sqr,
261    )
262}
263
264impl core::ops::Neg for cf32x2 {
265    type Output = Self;
266
267    #[inline]
268    fn neg(self) -> Self {
269        Self(unsafe { _mm_xor_ps(self.0, _mm_set1_ps(-0.)) })
270    }
271}
272
273impl core::ops::Neg for cf64x1 {
274    type Output = Self;
275
276    #[inline]
277    fn neg(self) -> Self {
278        Self(unsafe { _mm_xor_pd(self.0, _mm_set1_pd(-0.)) })
279    }
280}
281
282impl core::ops::Neg for cf32x4 {
283    type Output = Self;
284
285    #[inline]
286    fn neg(self) -> Self {
287        Self(unsafe { _mm256_xor_ps(self.0, _mm256_set1_ps(-0.)) })
288    }
289}
290
291impl core::ops::Neg for cf64x2 {
292    type Output = Self;
293
294    #[inline]
295    fn neg(self) -> Self {
296        Self(unsafe { _mm256_xor_pd(self.0, _mm256_set1_pd(-0.)) })
297    }
298}
299
300as_slice! { cf32x2 }
301as_slice! { cf32x4 }
302as_slice! { cf64x1 }
303as_slice! { cf64x2 }
304
305unsafe impl Vector for cf32x2 {
306    type Scalar = Complex<f32>;
307
308    type Token = Sse;
309
310    type Width = crate::vector::width::W2;
311
312    type Underlying = __m128;
313
314    #[inline]
315    fn zeroed(_: Self::Token) -> Self {
316        Self(unsafe { _mm_setzero_ps() })
317    }
318
319    #[inline]
320    fn splat(_: Self::Token, from: Self::Scalar) -> Self {
321        Self(unsafe { _mm_set_ps(from.im, from.re, from.im, from.re) })
322    }
323}
324
325unsafe impl Vector for cf64x1 {
326    type Scalar = Complex<f64>;
327
328    type Token = Sse;
329
330    type Width = crate::vector::width::W1;
331
332    type Underlying = __m128d;
333
334    #[inline]
335    fn zeroed(_: Self::Token) -> Self {
336        Self(unsafe { _mm_setzero_pd() })
337    }
338
339    #[inline]
340    fn splat(_: Self::Token, from: Self::Scalar) -> Self {
341        Self(unsafe { _mm_set_pd(from.im, from.re) })
342    }
343}
344
345unsafe impl Vector for cf32x4 {
346    type Scalar = Complex<f32>;
347
348    type Token = Avx;
349
350    type Width = crate::vector::width::W4;
351
352    type Underlying = __m256;
353
354    #[inline]
355    fn zeroed(_: Self::Token) -> Self {
356        Self(unsafe { _mm256_setzero_ps() })
357    }
358
359    #[inline]
360    fn splat(_: Self::Token, from: Self::Scalar) -> Self {
361        unsafe {
362            Self(_mm256_setr_ps(
363                from.re, from.im, from.re, from.im, from.re, from.im, from.re, from.im,
364            ))
365        }
366    }
367}
368
369unsafe impl Vector for cf64x2 {
370    type Scalar = Complex<f64>;
371
372    type Token = Avx;
373
374    type Width = crate::vector::width::W2;
375
376    type Underlying = __m256d;
377
378    #[inline]
379    fn zeroed(_: Self::Token) -> Self {
380        Self(unsafe { _mm256_setzero_pd() })
381    }
382
383    #[inline]
384    fn splat(_: Self::Token, from: Self::Scalar) -> Self {
385        Self(unsafe { _mm256_setr_pd(from.re, from.im, from.re, from.im) })
386    }
387}
388
389impl crate::vector::Complex for cf32x2 {
390    type RealScalar = f32;
391
392    #[inline]
393    fn conj(self) -> Self {
394        Self(unsafe { _mm_xor_ps(self.0, _mm_set_ps(-0., 0., -0., 0.)) })
395    }
396
397    #[inline]
398    fn mul_i(self) -> Self {
399        Self(unsafe { _mm_addsub_ps(_mm_setzero_ps(), _mm_shuffle_ps(self.0, self.0, 0xb1)) })
400    }
401
402    #[inline]
403    fn mul_neg_i(self) -> Self {
404        unsafe {
405            let neg = _mm_addsub_ps(_mm_setzero_ps(), self.0);
406            Self(_mm_shuffle_ps(neg, neg, 0xb1))
407        }
408    }
409}
410
411impl crate::vector::Complex for cf64x1 {
412    type RealScalar = f64;
413
414    #[inline]
415    fn conj(self) -> Self {
416        Self(unsafe { _mm_xor_pd(self.0, _mm_set_pd(-0., 0.)) })
417    }
418
419    #[inline]
420    fn mul_i(self) -> Self {
421        Self(unsafe { _mm_addsub_pd(_mm_setzero_pd(), _mm_shuffle_pd(self.0, self.0, 0x1)) })
422    }
423
424    #[inline]
425    fn mul_neg_i(self) -> Self {
426        unsafe {
427            let neg = _mm_addsub_pd(_mm_setzero_pd(), self.0);
428            Self(_mm_shuffle_pd(neg, neg, 0x1))
429        }
430    }
431}
432
433impl crate::vector::Complex for cf32x4 {
434    type RealScalar = f32;
435
436    #[inline]
437    fn conj(self) -> Self {
438        Self(unsafe { _mm256_xor_ps(self.0, _mm256_set_ps(-0., 0., -0., 0., -0., 0., -0., 0.)) })
439    }
440
441    #[inline]
442    fn mul_i(self) -> Self {
443        Self(unsafe {
444            _mm256_addsub_ps(_mm256_setzero_ps(), _mm256_shuffle_ps(self.0, self.0, 0xb1))
445        })
446    }
447
448    #[inline]
449    fn mul_neg_i(self) -> Self {
450        unsafe {
451            let neg = _mm256_addsub_ps(_mm256_setzero_ps(), self.0);
452            Self(_mm256_shuffle_ps(neg, neg, 0xb1))
453        }
454    }
455}
456
457impl crate::vector::Complex for cf64x2 {
458    type RealScalar = f64;
459
460    #[inline]
461    fn conj(self) -> Self {
462        Self(unsafe { _mm256_xor_pd(self.0, _mm256_set_pd(-0., 0., -0., 0.)) })
463    }
464
465    #[inline]
466    fn mul_i(self) -> Self {
467        Self(unsafe {
468            _mm256_addsub_pd(_mm256_setzero_pd(), _mm256_shuffle_pd(self.0, self.0, 0x5))
469        })
470    }
471
472    #[inline]
473    fn mul_neg_i(self) -> Self {
474        unsafe {
475            let neg = _mm256_addsub_pd(_mm256_setzero_pd(), self.0);
476            Self(_mm256_shuffle_pd(neg, neg, 0x5))
477        }
478    }
479}