1use super::traits::{SimdComplex, SimdVector};
8use core::arch::x86_64::*;
9
10#[derive(Copy, Clone, Debug)]
12#[repr(transparent)]
13pub struct AvxF64(pub __m256d);
14
15#[derive(Copy, Clone, Debug)]
17#[repr(transparent)]
18pub struct AvxF32(pub __m256);
19
20unsafe impl Send for AvxF64 {}
22unsafe impl Sync for AvxF64 {}
23unsafe impl Send for AvxF32 {}
24unsafe impl Sync for AvxF32 {}
25
26impl SimdVector for AvxF64 {
27 type Scalar = f64;
28 const LANES: usize = 4;
29
30 #[inline]
31 fn splat(value: f64) -> Self {
32 unsafe { Self(_mm256_set1_pd(value)) }
33 }
34
35 #[inline]
36 unsafe fn load_aligned(ptr: *const f64) -> Self {
37 unsafe { Self(_mm256_load_pd(ptr)) }
38 }
39
40 #[inline]
41 unsafe fn load_unaligned(ptr: *const f64) -> Self {
42 unsafe { Self(_mm256_loadu_pd(ptr)) }
43 }
44
45 #[inline]
46 unsafe fn store_aligned(self, ptr: *mut f64) {
47 unsafe { _mm256_store_pd(ptr, self.0) }
48 }
49
50 #[inline]
51 unsafe fn store_unaligned(self, ptr: *mut f64) {
52 unsafe { _mm256_storeu_pd(ptr, self.0) }
53 }
54
55 #[inline]
56 fn add(self, other: Self) -> Self {
57 unsafe { Self(_mm256_add_pd(self.0, other.0)) }
58 }
59
60 #[inline]
61 fn sub(self, other: Self) -> Self {
62 unsafe { Self(_mm256_sub_pd(self.0, other.0)) }
63 }
64
65 #[inline]
66 fn mul(self, other: Self) -> Self {
67 unsafe { Self(_mm256_mul_pd(self.0, other.0)) }
68 }
69
70 #[inline]
71 fn div(self, other: Self) -> Self {
72 unsafe { Self(_mm256_div_pd(self.0, other.0)) }
73 }
74}
75
76#[allow(dead_code)]
77impl AvxF64 {
78 #[inline]
80 pub fn new(a: f64, b: f64, c: f64, d: f64) -> Self {
81 unsafe { Self(_mm256_set_pd(d, c, b, a)) }
82 }
83
84 #[inline]
86 pub fn extract(self, idx: usize) -> f64 {
87 debug_assert!(idx < 4);
88 let mut arr = [0.0_f64; 4];
89 unsafe { self.store_unaligned(arr.as_mut_ptr()) };
90 arr[idx]
91 }
92
93 #[inline]
95 pub fn negate(self) -> Self {
96 unsafe {
97 let sign_mask = _mm256_set1_pd(-0.0);
98 Self(_mm256_xor_pd(self.0, sign_mask))
99 }
100 }
101
102 #[inline]
105 pub fn shuffle_within_lanes<const MASK: i32>(self) -> Self {
106 unsafe { Self(_mm256_shuffle_pd(self.0, self.0, MASK)) }
107 }
108
109 #[inline]
111 pub fn swap_lanes(self) -> Self {
112 unsafe { Self(_mm256_permute2f128_pd(self.0, self.0, 0x01)) }
113 }
114
115 #[inline]
117 pub fn unpack_lo(self, other: Self) -> Self {
118 unsafe { Self(_mm256_unpacklo_pd(self.0, other.0)) }
119 }
120
121 #[inline]
123 pub fn unpack_hi(self, other: Self) -> Self {
124 unsafe { Self(_mm256_unpackhi_pd(self.0, other.0)) }
125 }
126
127 #[inline]
129 pub fn blend<const MASK: i32>(self, other: Self) -> Self {
130 unsafe { Self(_mm256_blend_pd(self.0, other.0, MASK)) }
131 }
132
133 #[inline]
135 pub fn low_128(self) -> super::sse2::Sse2F64 {
136 unsafe { super::sse2::Sse2F64(_mm256_castpd256_pd128(self.0)) }
137 }
138
139 #[inline]
141 pub fn high_128(self) -> super::sse2::Sse2F64 {
142 unsafe { super::sse2::Sse2F64(_mm256_extractf128_pd(self.0, 1)) }
143 }
144}
145
146impl SimdComplex for AvxF64 {
147 #[inline]
152 fn cmul(self, other: Self) -> Self {
153 unsafe {
154 let a_re = _mm256_unpacklo_pd(self.0, self.0);
156 let a_im = _mm256_unpackhi_pd(self.0, self.0);
158
159 let b_swap = _mm256_shuffle_pd(other.0, other.0, 0b0101);
161
162 let prod1 = _mm256_mul_pd(a_re, other.0);
164 let prod2 = _mm256_mul_pd(a_im, b_swap);
166
167 Self(_mm256_addsub_pd(prod1, prod2))
169 }
170 }
171
172 #[inline]
174 fn cmul_conj(self, other: Self) -> Self {
175 unsafe {
176 let a_re = _mm256_unpacklo_pd(self.0, self.0);
177 let a_im = _mm256_unpackhi_pd(self.0, self.0);
178 let b_swap = _mm256_shuffle_pd(other.0, other.0, 0b0101);
179
180 let prod1 = _mm256_mul_pd(a_re, other.0);
181 let prod2 = _mm256_mul_pd(a_im, b_swap);
182
183 let sign = _mm256_set_pd(-0.0, 0.0, -0.0, 0.0);
186 let prod1_signed = _mm256_xor_pd(prod1, sign);
187 Self(_mm256_add_pd(prod1_signed, prod2))
188 }
189 }
190}
191
192impl SimdVector for AvxF32 {
193 type Scalar = f32;
194 const LANES: usize = 8;
195
196 #[inline]
197 fn splat(value: f32) -> Self {
198 unsafe { Self(_mm256_set1_ps(value)) }
199 }
200
201 #[inline]
202 unsafe fn load_aligned(ptr: *const f32) -> Self {
203 unsafe { Self(_mm256_load_ps(ptr)) }
204 }
205
206 #[inline]
207 unsafe fn load_unaligned(ptr: *const f32) -> Self {
208 unsafe { Self(_mm256_loadu_ps(ptr)) }
209 }
210
211 #[inline]
212 unsafe fn store_aligned(self, ptr: *mut f32) {
213 unsafe { _mm256_store_ps(ptr, self.0) }
214 }
215
216 #[inline]
217 unsafe fn store_unaligned(self, ptr: *mut f32) {
218 unsafe { _mm256_storeu_ps(ptr, self.0) }
219 }
220
221 #[inline]
222 fn add(self, other: Self) -> Self {
223 unsafe { Self(_mm256_add_ps(self.0, other.0)) }
224 }
225
226 #[inline]
227 fn sub(self, other: Self) -> Self {
228 unsafe { Self(_mm256_sub_ps(self.0, other.0)) }
229 }
230
231 #[inline]
232 fn mul(self, other: Self) -> Self {
233 unsafe { Self(_mm256_mul_ps(self.0, other.0)) }
234 }
235
236 #[inline]
237 fn div(self, other: Self) -> Self {
238 unsafe { Self(_mm256_div_ps(self.0, other.0)) }
239 }
240}
241
242#[allow(dead_code)]
243impl AvxF32 {
244 #[inline]
246 pub fn new(a: f32, b: f32, c: f32, d: f32, e: f32, f: f32, g: f32, h: f32) -> Self {
247 unsafe { Self(_mm256_set_ps(h, g, f, e, d, c, b, a)) }
248 }
249
250 #[inline]
252 pub fn negate(self) -> Self {
253 unsafe {
254 let sign_mask = _mm256_set1_ps(-0.0);
255 Self(_mm256_xor_ps(self.0, sign_mask))
256 }
257 }
258
259 #[inline]
261 pub fn unpack_lo(self, other: Self) -> Self {
262 unsafe { Self(_mm256_unpacklo_ps(self.0, other.0)) }
263 }
264
265 #[inline]
267 pub fn unpack_hi(self, other: Self) -> Self {
268 unsafe { Self(_mm256_unpackhi_ps(self.0, other.0)) }
269 }
270
271 #[inline]
273 pub fn swap_lanes(self) -> Self {
274 unsafe { Self(_mm256_permute2f128_ps(self.0, self.0, 0x01)) }
275 }
276
277 #[inline]
279 pub fn low_128(self) -> super::sse2::Sse2F32 {
280 unsafe { super::sse2::Sse2F32(_mm256_castps256_ps128(self.0)) }
281 }
282
283 #[inline]
285 pub fn high_128(self) -> super::sse2::Sse2F32 {
286 unsafe { super::sse2::Sse2F32(_mm256_extractf128_ps(self.0, 1)) }
287 }
288}
289
290impl SimdComplex for AvxF32 {
291 #[inline]
295 fn cmul(self, other: Self) -> Self {
296 unsafe {
297 let a_re = _mm256_shuffle_ps(self.0, self.0, 0b1010_0000);
299 let a_im = _mm256_shuffle_ps(self.0, self.0, 0b1111_0101);
301
302 let b_swap = _mm256_shuffle_ps(other.0, other.0, 0b1011_0001);
304
305 let prod1 = _mm256_mul_ps(a_re, other.0);
306 let prod2 = _mm256_mul_ps(a_im, b_swap);
307
308 Self(_mm256_addsub_ps(prod1, prod2))
310 }
311 }
312
313 #[inline]
315 fn cmul_conj(self, other: Self) -> Self {
316 unsafe {
317 let a_re = _mm256_shuffle_ps(self.0, self.0, 0b1010_0000);
318 let a_im = _mm256_shuffle_ps(self.0, self.0, 0b1111_0101);
319 let b_swap = _mm256_shuffle_ps(other.0, other.0, 0b1011_0001);
320
321 let prod1 = _mm256_mul_ps(a_re, other.0);
322 let prod2 = _mm256_mul_ps(a_im, b_swap);
323
324 let sign = _mm256_set_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
326 let prod1_signed = _mm256_xor_ps(prod1, sign);
327 Self(_mm256_add_ps(prod1_signed, prod2))
328 }
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 fn has_avx() -> bool {
337 is_x86_feature_detected!("avx")
338 }
339
340 #[test]
341 fn test_avx_f64_basic() {
342 if !has_avx() {
343 return;
344 }
345
346 let a = AvxF64::splat(2.0);
347 let b = AvxF64::splat(3.0);
348
349 let sum = a.add(b);
350 assert_eq!(sum.extract(0), 5.0);
351 assert_eq!(sum.extract(3), 5.0);
352
353 let diff = a.sub(b);
354 assert_eq!(diff.extract(0), -1.0);
355
356 let prod = a.mul(b);
357 assert_eq!(prod.extract(0), 6.0);
358 }
359
360 #[test]
361 fn test_avx_f64_new() {
362 if !has_avx() {
363 return;
364 }
365
366 let v = AvxF64::new(1.0, 2.0, 3.0, 4.0);
367 assert_eq!(v.extract(0), 1.0);
368 assert_eq!(v.extract(1), 2.0);
369 assert_eq!(v.extract(2), 3.0);
370 assert_eq!(v.extract(3), 4.0);
371 }
372
373 #[test]
374 fn test_avx_f64_cmul() {
375 if !has_avx() {
376 return;
377 }
378
379 let a = AvxF64::new(1.0, 2.0, 3.0, 4.0);
383 let b = AvxF64::new(5.0, 6.0, 7.0, 8.0);
384 let c = a.cmul(b);
385 assert!((c.extract(0) - (-7.0)).abs() < 1e-10);
386 assert!((c.extract(1) - 16.0).abs() < 1e-10);
387 assert!((c.extract(2) - (-11.0)).abs() < 1e-10);
388 assert!((c.extract(3) - 52.0).abs() < 1e-10);
389 }
390
391 #[test]
392 fn test_avx_f64_load_store() {
393 if !has_avx() {
394 return;
395 }
396
397 let data = [1.0_f64, 2.0, 3.0, 4.0];
398 unsafe {
399 let v = AvxF64::load_unaligned(data.as_ptr());
400 assert_eq!(v.extract(0), 1.0);
401 assert_eq!(v.extract(3), 4.0);
402
403 let mut out = [0.0_f64; 4];
404 v.store_unaligned(out.as_mut_ptr());
405 assert_eq!(out, [1.0, 2.0, 3.0, 4.0]);
406 }
407 }
408
409 #[test]
410 fn test_avx_f32_basic() {
411 if !has_avx() {
412 return;
413 }
414
415 let a = AvxF32::splat(2.0);
416 let b = AvxF32::splat(3.0);
417
418 let sum = a.add(b);
419 let mut out = [0.0_f32; 8];
420 unsafe { sum.store_unaligned(out.as_mut_ptr()) };
421 assert_eq!(out, [5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0]);
422 }
423
424 #[test]
425 fn test_avx_f32_cmul() {
426 if !has_avx() {
427 return;
428 }
429
430 let a = AvxF32::new(1.0, 2.0, 3.0, 4.0, 1.0, 0.0, 0.0, 1.0);
432 let b = AvxF32::new(5.0, 6.0, 7.0, 8.0, 1.0, 0.0, 0.0, 1.0);
433 let c = a.cmul(b);
434 let mut out = [0.0_f32; 8];
435 unsafe { c.store_unaligned(out.as_mut_ptr()) };
436 assert!((out[0] - (-7.0)).abs() < 1e-5);
438 assert!((out[1] - 16.0).abs() < 1e-5);
439 assert!((out[2] - (-11.0)).abs() < 1e-5);
441 assert!((out[3] - 52.0).abs() < 1e-5);
442 assert!((out[4] - 1.0).abs() < 1e-5);
444 assert!((out[5] - 0.0).abs() < 1e-5);
445 assert!((out[6] - (-1.0)).abs() < 1e-5);
447 assert!((out[7] - 0.0).abs() < 1e-5);
448 }
449}