Skip to main content

oximedia_codec/simd/x86/
avx2.rs

1//! AVX2 SIMD implementation for x86_64.
2//!
3//! This module provides optimized implementations of SIMD operations
4//! using AVX2 instructions, available on Intel Haswell (2013) and later,
5//! and AMD Excavator (2015) and later processors.
6
7#![allow(unsafe_code)]
8
9use crate::simd::traits::{SimdOps, SimdOpsExt};
10use crate::simd::types::{I16x16, I16x8, I32x4, I32x8, U8x16, U8x32};
11
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14
15/// AVX2 SIMD implementation.
16#[derive(Clone, Copy, Debug)]
17pub struct Avx2Simd;
18
19impl Avx2Simd {
20    /// Create a new AVX2 SIMD instance.
21    ///
22    /// # Safety
23    ///
24    /// The caller must ensure that AVX2 is available on the current CPU.
25    /// Use `is_available()` to check before calling SIMD operations.
26    #[inline]
27    #[must_use]
28    pub const fn new() -> Self {
29        Self
30    }
31
32    /// Check if AVX2 is available at runtime.
33    #[inline]
34    #[must_use]
35    pub fn is_available() -> bool {
36        #[cfg(target_arch = "x86_64")]
37        {
38            is_x86_feature_detected!("avx2")
39        }
40        #[cfg(not(target_arch = "x86_64"))]
41        {
42            false
43        }
44    }
45}
46
47impl SimdOps for Avx2Simd {
48    #[inline]
49    fn name(&self) -> &'static str {
50        "avx2"
51    }
52
53    #[inline]
54    fn is_available(&self) -> bool {
55        Self::is_available()
56    }
57
58    // ========================================================================
59    // Vector Arithmetic
60    // ========================================================================
61
62    #[inline]
63    #[cfg(target_arch = "x86_64")]
64    fn add_i16x8(&self, a: I16x8, b: I16x8) -> I16x8 {
65        // SAFETY: AVX2 is checked at runtime before calling this
66        unsafe {
67            let a_vec = _mm_loadu_si128(a.as_ptr().cast());
68            let b_vec = _mm_loadu_si128(b.as_ptr().cast());
69            let result = _mm_add_epi16(a_vec, b_vec);
70            let mut out = I16x8::zero();
71            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
72            out
73        }
74    }
75
76    #[inline]
77    #[cfg(not(target_arch = "x86_64"))]
78    fn add_i16x8(&self, a: I16x8, b: I16x8) -> I16x8 {
79        let mut result = I16x8::zero();
80        for i in 0..8 {
81            result[i] = a[i].wrapping_add(b[i]);
82        }
83        result
84    }
85
86    #[inline]
87    #[cfg(target_arch = "x86_64")]
88    fn sub_i16x8(&self, a: I16x8, b: I16x8) -> I16x8 {
89        // SAFETY: AVX2 is checked at runtime
90        unsafe {
91            let a_vec = _mm_loadu_si128(a.as_ptr().cast());
92            let b_vec = _mm_loadu_si128(b.as_ptr().cast());
93            let result = _mm_sub_epi16(a_vec, b_vec);
94            let mut out = I16x8::zero();
95            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
96            out
97        }
98    }
99
100    #[inline]
101    #[cfg(not(target_arch = "x86_64"))]
102    fn sub_i16x8(&self, a: I16x8, b: I16x8) -> I16x8 {
103        let mut result = I16x8::zero();
104        for i in 0..8 {
105            result[i] = a[i].wrapping_sub(b[i]);
106        }
107        result
108    }
109
110    #[inline]
111    #[cfg(target_arch = "x86_64")]
112    fn mul_i16x8(&self, a: I16x8, b: I16x8) -> I16x8 {
113        // SAFETY: AVX2 is checked at runtime
114        unsafe {
115            let a_vec = _mm_loadu_si128(a.as_ptr().cast());
116            let b_vec = _mm_loadu_si128(b.as_ptr().cast());
117            let result = _mm_mullo_epi16(a_vec, b_vec);
118            let mut out = I16x8::zero();
119            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
120            out
121        }
122    }
123
124    #[inline]
125    #[cfg(not(target_arch = "x86_64"))]
126    fn mul_i16x8(&self, a: I16x8, b: I16x8) -> I16x8 {
127        let mut result = I16x8::zero();
128        for i in 0..8 {
129            result[i] = a[i].wrapping_mul(b[i]);
130        }
131        result
132    }
133
134    #[inline]
135    #[cfg(target_arch = "x86_64")]
136    fn add_i32x4(&self, a: I32x4, b: I32x4) -> I32x4 {
137        // SAFETY: AVX2 is checked at runtime
138        unsafe {
139            let a_vec = _mm_loadu_si128(a.as_ptr().cast());
140            let b_vec = _mm_loadu_si128(b.as_ptr().cast());
141            let result = _mm_add_epi32(a_vec, b_vec);
142            let mut out = I32x4::zero();
143            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
144            out
145        }
146    }
147
148    #[inline]
149    #[cfg(not(target_arch = "x86_64"))]
150    fn add_i32x4(&self, a: I32x4, b: I32x4) -> I32x4 {
151        let mut result = I32x4::zero();
152        for i in 0..4 {
153            result[i] = a[i].wrapping_add(b[i]);
154        }
155        result
156    }
157
158    #[inline]
159    #[cfg(target_arch = "x86_64")]
160    fn sub_i32x4(&self, a: I32x4, b: I32x4) -> I32x4 {
161        // SAFETY: AVX2 is checked at runtime
162        unsafe {
163            let a_vec = _mm_loadu_si128(a.as_ptr().cast());
164            let b_vec = _mm_loadu_si128(b.as_ptr().cast());
165            let result = _mm_sub_epi32(a_vec, b_vec);
166            let mut out = I32x4::zero();
167            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
168            out
169        }
170    }
171
172    #[inline]
173    #[cfg(not(target_arch = "x86_64"))]
174    fn sub_i32x4(&self, a: I32x4, b: I32x4) -> I32x4 {
175        let mut result = I32x4::zero();
176        for i in 0..4 {
177            result[i] = a[i].wrapping_sub(b[i]);
178        }
179        result
180    }
181
182    // ========================================================================
183    // Min/Max/Clamp
184    // ========================================================================
185
186    #[inline]
187    #[cfg(target_arch = "x86_64")]
188    fn min_i16x8(&self, a: I16x8, b: I16x8) -> I16x8 {
189        // SAFETY: AVX2 is checked at runtime
190        unsafe {
191            let a_vec = _mm_loadu_si128(a.as_ptr().cast());
192            let b_vec = _mm_loadu_si128(b.as_ptr().cast());
193            let result = _mm_min_epi16(a_vec, b_vec);
194            let mut out = I16x8::zero();
195            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
196            out
197        }
198    }
199
200    #[inline]
201    #[cfg(not(target_arch = "x86_64"))]
202    fn min_i16x8(&self, a: I16x8, b: I16x8) -> I16x8 {
203        let mut result = I16x8::zero();
204        for i in 0..8 {
205            result[i] = a[i].min(b[i]);
206        }
207        result
208    }
209
210    #[inline]
211    #[cfg(target_arch = "x86_64")]
212    fn max_i16x8(&self, a: I16x8, b: I16x8) -> I16x8 {
213        // SAFETY: AVX2 is checked at runtime
214        unsafe {
215            let a_vec = _mm_loadu_si128(a.as_ptr().cast());
216            let b_vec = _mm_loadu_si128(b.as_ptr().cast());
217            let result = _mm_max_epi16(a_vec, b_vec);
218            let mut out = I16x8::zero();
219            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
220            out
221        }
222    }
223
224    #[inline]
225    #[cfg(not(target_arch = "x86_64"))]
226    fn max_i16x8(&self, a: I16x8, b: I16x8) -> I16x8 {
227        let mut result = I16x8::zero();
228        for i in 0..8 {
229            result[i] = a[i].max(b[i]);
230        }
231        result
232    }
233
234    #[inline]
235    fn clamp_i16x8(&self, v: I16x8, min: i16, max: i16) -> I16x8 {
236        let min_vec = I16x8::splat(min);
237        let max_vec = I16x8::splat(max);
238        let clamped_min = self.max_i16x8(v, min_vec);
239        self.min_i16x8(clamped_min, max_vec)
240    }
241
242    #[inline]
243    #[cfg(target_arch = "x86_64")]
244    fn min_u8x16(&self, a: U8x16, b: U8x16) -> U8x16 {
245        // SAFETY: AVX2 is checked at runtime
246        unsafe {
247            let a_vec = _mm_loadu_si128(a.as_ptr().cast());
248            let b_vec = _mm_loadu_si128(b.as_ptr().cast());
249            let result = _mm_min_epu8(a_vec, b_vec);
250            let mut out = U8x16::zero();
251            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
252            out
253        }
254    }
255
256    #[inline]
257    #[cfg(not(target_arch = "x86_64"))]
258    fn min_u8x16(&self, a: U8x16, b: U8x16) -> U8x16 {
259        let mut result = U8x16::zero();
260        for i in 0..16 {
261            result[i] = a[i].min(b[i]);
262        }
263        result
264    }
265
266    #[inline]
267    #[cfg(target_arch = "x86_64")]
268    fn max_u8x16(&self, a: U8x16, b: U8x16) -> U8x16 {
269        // SAFETY: AVX2 is checked at runtime
270        unsafe {
271            let a_vec = _mm_loadu_si128(a.as_ptr().cast());
272            let b_vec = _mm_loadu_si128(b.as_ptr().cast());
273            let result = _mm_max_epu8(a_vec, b_vec);
274            let mut out = U8x16::zero();
275            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
276            out
277        }
278    }
279
280    #[inline]
281    #[cfg(not(target_arch = "x86_64"))]
282    fn max_u8x16(&self, a: U8x16, b: U8x16) -> U8x16 {
283        let mut result = U8x16::zero();
284        for i in 0..16 {
285            result[i] = a[i].max(b[i]);
286        }
287        result
288    }
289
290    #[inline]
291    fn clamp_u8x16(&self, v: U8x16, min: u8, max: u8) -> U8x16 {
292        let min_vec = U8x16::splat(min);
293        let max_vec = U8x16::splat(max);
294        let clamped_min = self.max_u8x16(v, min_vec);
295        self.min_u8x16(clamped_min, max_vec)
296    }
297
298    // ========================================================================
299    // Horizontal Operations
300    // ========================================================================
301
302    #[inline]
303    #[cfg(target_arch = "x86_64")]
304    fn horizontal_sum_i16x8(&self, v: I16x8) -> i32 {
305        // SAFETY: AVX2 is checked at runtime
306        unsafe {
307            let vec = _mm_loadu_si128(v.as_ptr().cast());
308            // Horizontal add to get pairs
309            let sum1 = _mm_hadd_epi16(vec, vec);
310            let sum2 = _mm_hadd_epi16(sum1, sum1);
311            let sum3 = _mm_hadd_epi16(sum2, sum2);
312            _mm_extract_epi16(sum3, 0) as i16 as i32
313        }
314    }
315
316    #[inline]
317    #[cfg(not(target_arch = "x86_64"))]
318    fn horizontal_sum_i16x8(&self, v: I16x8) -> i32 {
319        v.iter().map(|&x| i32::from(x)).sum()
320    }
321
322    #[inline]
323    #[cfg(target_arch = "x86_64")]
324    fn horizontal_sum_i32x4(&self, v: I32x4) -> i32 {
325        // SAFETY: AVX2 is checked at runtime
326        unsafe {
327            let vec = _mm_loadu_si128(v.as_ptr().cast());
328            let sum1 = _mm_hadd_epi32(vec, vec);
329            let sum2 = _mm_hadd_epi32(sum1, sum1);
330            _mm_extract_epi32(sum2, 0)
331        }
332    }
333
334    #[inline]
335    #[cfg(not(target_arch = "x86_64"))]
336    fn horizontal_sum_i32x4(&self, v: I32x4) -> i32 {
337        v.iter().sum()
338    }
339
340    // ========================================================================
341    // SAD (Sum of Absolute Differences)
342    // ========================================================================
343
344    #[inline]
345    #[cfg(target_arch = "x86_64")]
346    fn sad_u8x16(&self, a: U8x16, b: U8x16) -> u32 {
347        // SAFETY: AVX2 is checked at runtime
348        unsafe {
349            let a_vec = _mm_loadu_si128(a.as_ptr().cast());
350            let b_vec = _mm_loadu_si128(b.as_ptr().cast());
351            let sad = _mm_sad_epu8(a_vec, b_vec);
352            let low = _mm_extract_epi64(sad, 0) as u32;
353            let high = _mm_extract_epi64(sad, 1) as u32;
354            low + high
355        }
356    }
357
358    #[inline]
359    #[cfg(not(target_arch = "x86_64"))]
360    fn sad_u8x16(&self, a: U8x16, b: U8x16) -> u32 {
361        a.iter()
362            .zip(b.iter())
363            .map(|(&x, &y)| u32::from(x.abs_diff(y)))
364            .sum()
365    }
366
367    #[inline]
368    fn sad_8(&self, a: &[u8], b: &[u8]) -> u32 {
369        assert!(a.len() >= 8 && b.len() >= 8);
370        a[..8]
371            .iter()
372            .zip(b[..8].iter())
373            .map(|(&x, &y)| u32::from(x.abs_diff(y)))
374            .sum()
375    }
376
377    #[inline]
378    fn sad_16(&self, a: &[u8], b: &[u8]) -> u32 {
379        assert!(a.len() >= 16 && b.len() >= 16);
380        let mut a_vec = U8x16::zero();
381        let mut b_vec = U8x16::zero();
382        a_vec.copy_from_slice(&a[..16]);
383        b_vec.copy_from_slice(&b[..16]);
384        self.sad_u8x16(a_vec, b_vec)
385    }
386
387    // ========================================================================
388    // Widening/Narrowing
389    // ========================================================================
390
391    #[inline]
392    #[cfg(target_arch = "x86_64")]
393    fn widen_low_u8_to_i16(&self, v: U8x16) -> I16x8 {
394        // SAFETY: AVX2 is checked at runtime
395        unsafe {
396            let vec = _mm_loadu_si128(v.as_ptr().cast());
397            let zero = _mm_setzero_si128();
398            let result = _mm_unpacklo_epi8(vec, zero);
399            let mut out = I16x8::zero();
400            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
401            out
402        }
403    }
404
405    #[inline]
406    #[cfg(not(target_arch = "x86_64"))]
407    fn widen_low_u8_to_i16(&self, v: U8x16) -> I16x8 {
408        let mut result = I16x8::zero();
409        for i in 0..8 {
410            result[i] = i16::from(v[i]);
411        }
412        result
413    }
414
415    #[inline]
416    #[cfg(target_arch = "x86_64")]
417    fn widen_high_u8_to_i16(&self, v: U8x16) -> I16x8 {
418        // SAFETY: AVX2 is checked at runtime
419        unsafe {
420            let vec = _mm_loadu_si128(v.as_ptr().cast());
421            let zero = _mm_setzero_si128();
422            let result = _mm_unpackhi_epi8(vec, zero);
423            let mut out = I16x8::zero();
424            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
425            out
426        }
427    }
428
429    #[inline]
430    #[cfg(not(target_arch = "x86_64"))]
431    fn widen_high_u8_to_i16(&self, v: U8x16) -> I16x8 {
432        let mut result = I16x8::zero();
433        for i in 0..8 {
434            result[i] = i16::from(v[i + 8]);
435        }
436        result
437    }
438
439    #[inline]
440    #[cfg(target_arch = "x86_64")]
441    fn narrow_i32x4_to_i16x8(&self, low: I32x4, high: I32x4) -> I16x8 {
442        // SAFETY: AVX2 is checked at runtime
443        unsafe {
444            let low_vec = _mm_loadu_si128(low.as_ptr().cast());
445            let high_vec = _mm_loadu_si128(high.as_ptr().cast());
446            let result = _mm_packs_epi32(low_vec, high_vec);
447            let mut out = I16x8::zero();
448            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
449            out
450        }
451    }
452
453    #[inline]
454    #[cfg(not(target_arch = "x86_64"))]
455    fn narrow_i32x4_to_i16x8(&self, low: I32x4, high: I32x4) -> I16x8 {
456        let mut result = I16x8::zero();
457        for i in 0..4 {
458            result[i] = low[i].clamp(i32::from(i16::MIN), i32::from(i16::MAX)) as i16;
459            result[i + 4] = high[i].clamp(i32::from(i16::MIN), i32::from(i16::MAX)) as i16;
460        }
461        result
462    }
463
464    // ========================================================================
465    // Multiply-Add
466    // ========================================================================
467
468    #[inline]
469    fn madd_i16x8(&self, a: I16x8, b: I16x8, c: I16x8) -> I16x8 {
470        let prod = self.mul_i16x8(a, b);
471        self.add_i16x8(prod, c)
472    }
473
474    #[inline]
475    #[cfg(target_arch = "x86_64")]
476    fn pmaddwd(&self, a: I16x8, b: I16x8) -> I32x4 {
477        // SAFETY: AVX2 is checked at runtime
478        unsafe {
479            let a_vec = _mm_loadu_si128(a.as_ptr().cast());
480            let b_vec = _mm_loadu_si128(b.as_ptr().cast());
481            let result = _mm_madd_epi16(a_vec, b_vec);
482            let mut out = I32x4::zero();
483            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
484            out
485        }
486    }
487
488    #[inline]
489    #[cfg(not(target_arch = "x86_64"))]
490    fn pmaddwd(&self, a: I16x8, b: I16x8) -> I32x4 {
491        let mut result = I32x4::zero();
492        for i in 0..4 {
493            result[i] = i32::from(a[i * 2]) * i32::from(b[i * 2])
494                + i32::from(a[i * 2 + 1]) * i32::from(b[i * 2 + 1]);
495        }
496        result
497    }
498
499    // ========================================================================
500    // Shift Operations
501    // ========================================================================
502
503    #[inline]
504    #[cfg(target_arch = "x86_64")]
505    fn shr_i16x8(&self, v: I16x8, shift: u32) -> I16x8 {
506        // SAFETY: AVX2 is checked at runtime
507        unsafe {
508            let vec = _mm_loadu_si128(v.as_ptr().cast());
509            let shift_vec = _mm_cvtsi32_si128(shift as i32);
510            let result = _mm_sra_epi16(vec, shift_vec);
511            let mut out = I16x8::zero();
512            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
513            out
514        }
515    }
516
517    #[inline]
518    #[cfg(not(target_arch = "x86_64"))]
519    fn shr_i16x8(&self, v: I16x8, shift: u32) -> I16x8 {
520        let mut result = I16x8::zero();
521        for i in 0..8 {
522            result[i] = v[i] >> shift;
523        }
524        result
525    }
526
527    #[inline]
528    #[cfg(target_arch = "x86_64")]
529    fn shl_i16x8(&self, v: I16x8, shift: u32) -> I16x8 {
530        // SAFETY: AVX2 is checked at runtime
531        unsafe {
532            let vec = _mm_loadu_si128(v.as_ptr().cast());
533            let shift_vec = _mm_cvtsi32_si128(shift as i32);
534            let result = _mm_sll_epi16(vec, shift_vec);
535            let mut out = I16x8::zero();
536            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
537            out
538        }
539    }
540
541    #[inline]
542    #[cfg(not(target_arch = "x86_64"))]
543    fn shl_i16x8(&self, v: I16x8, shift: u32) -> I16x8 {
544        let mut result = I16x8::zero();
545        for i in 0..8 {
546            result[i] = v[i] << shift;
547        }
548        result
549    }
550
551    #[inline]
552    #[cfg(target_arch = "x86_64")]
553    fn shr_i32x4(&self, v: I32x4, shift: u32) -> I32x4 {
554        // SAFETY: AVX2 is checked at runtime
555        unsafe {
556            let vec = _mm_loadu_si128(v.as_ptr().cast());
557            let shift_vec = _mm_cvtsi32_si128(shift as i32);
558            let result = _mm_sra_epi32(vec, shift_vec);
559            let mut out = I32x4::zero();
560            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
561            out
562        }
563    }
564
565    #[inline]
566    #[cfg(not(target_arch = "x86_64"))]
567    fn shr_i32x4(&self, v: I32x4, shift: u32) -> I32x4 {
568        let mut result = I32x4::zero();
569        for i in 0..4 {
570            result[i] = v[i] >> shift;
571        }
572        result
573    }
574
575    #[inline]
576    #[cfg(target_arch = "x86_64")]
577    fn shl_i32x4(&self, v: I32x4, shift: u32) -> I32x4 {
578        // SAFETY: AVX2 is checked at runtime
579        unsafe {
580            let vec = _mm_loadu_si128(v.as_ptr().cast());
581            let shift_vec = _mm_cvtsi32_si128(shift as i32);
582            let result = _mm_sll_epi32(vec, shift_vec);
583            let mut out = I32x4::zero();
584            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
585            out
586        }
587    }
588
589    #[inline]
590    #[cfg(not(target_arch = "x86_64"))]
591    fn shl_i32x4(&self, v: I32x4, shift: u32) -> I32x4 {
592        let mut result = I32x4::zero();
593        for i in 0..4 {
594            result[i] = v[i] << shift;
595        }
596        result
597    }
598
599    // ========================================================================
600    // Averaging
601    // ========================================================================
602
603    #[inline]
604    #[cfg(target_arch = "x86_64")]
605    fn avg_u8x16(&self, a: U8x16, b: U8x16) -> U8x16 {
606        // SAFETY: AVX2 is checked at runtime
607        unsafe {
608            let a_vec = _mm_loadu_si128(a.as_ptr().cast());
609            let b_vec = _mm_loadu_si128(b.as_ptr().cast());
610            let result = _mm_avg_epu8(a_vec, b_vec);
611            let mut out = U8x16::zero();
612            _mm_storeu_si128(out.as_mut_ptr().cast(), result);
613            out
614        }
615    }
616
617    #[inline]
618    #[cfg(not(target_arch = "x86_64"))]
619    fn avg_u8x16(&self, a: U8x16, b: U8x16) -> U8x16 {
620        let mut result = U8x16::zero();
621        for i in 0..16 {
622            result[i] = ((u16::from(a[i]) + u16::from(b[i]) + 1) / 2) as u8;
623        }
624        result
625    }
626}
627
628impl SimdOpsExt for Avx2Simd {
629    #[inline]
630    fn load4_u8_to_i16x8(&self, src: &[u8]) -> I16x8 {
631        assert!(src.len() >= 4);
632        let mut result = I16x8::zero();
633        for i in 0..4 {
634            result[i] = i16::from(src[i]);
635        }
636        result
637    }
638
639    #[inline]
640    fn load8_u8_to_i16x8(&self, src: &[u8]) -> I16x8 {
641        assert!(src.len() >= 8);
642        let mut result = I16x8::zero();
643        for i in 0..8 {
644            result[i] = i16::from(src[i]);
645        }
646        result
647    }
648
649    #[inline]
650    fn store4_i16x8_as_u8(&self, v: I16x8, dst: &mut [u8]) {
651        assert!(dst.len() >= 4);
652        for i in 0..4 {
653            dst[i] = v[i].clamp(0, 255) as u8;
654        }
655    }
656
657    #[inline]
658    fn store8_i16x8_as_u8(&self, v: I16x8, dst: &mut [u8]) {
659        assert!(dst.len() >= 8);
660        for i in 0..8 {
661            dst[i] = v[i].clamp(0, 255) as u8;
662        }
663    }
664
665    #[inline]
666    fn transpose_4x4_i16(&self, rows: &[I16x8; 4]) -> [I16x8; 4] {
667        #[cfg(target_arch = "x86_64")]
668        {
669            // SAFETY: AVX2 is checked at runtime
670            unsafe {
671                // Load 4 rows
672                let r0 = _mm_loadl_epi64(rows[0].as_ptr().cast());
673                let r1 = _mm_loadl_epi64(rows[1].as_ptr().cast());
674                let r2 = _mm_loadl_epi64(rows[2].as_ptr().cast());
675                let r3 = _mm_loadl_epi64(rows[3].as_ptr().cast());
676
677                // Interleave pairs
678                let t0 = _mm_unpacklo_epi16(r0, r1);
679                let t1 = _mm_unpacklo_epi16(r2, r3);
680
681                // Final interleave
682                let o0 = _mm_unpacklo_epi32(t0, t1);
683                let o1 = _mm_unpackhi_epi32(t0, t1);
684                let o2 = _mm_unpacklo_epi32(_mm_unpackhi_epi16(r0, r1), _mm_unpackhi_epi16(r2, r3));
685                let o3 = _mm_unpackhi_epi32(_mm_unpackhi_epi16(r0, r1), _mm_unpackhi_epi16(r2, r3));
686
687                let mut out = [I16x8::zero(); 4];
688                _mm_storeu_si128(out[0].as_mut_ptr().cast(), o0);
689                _mm_storeu_si128(out[1].as_mut_ptr().cast(), o1);
690                _mm_storeu_si128(out[2].as_mut_ptr().cast(), o2);
691                _mm_storeu_si128(out[3].as_mut_ptr().cast(), o3);
692                out
693            }
694        }
695        #[cfg(not(target_arch = "x86_64"))]
696        {
697            let mut out = [I16x8::zero(); 4];
698            for i in 0..4 {
699                for j in 0..4 {
700                    out[i][j] = rows[j][i];
701                }
702            }
703            out
704        }
705    }
706
707    #[inline]
708    fn transpose_8x8_i16(&self, rows: &[I16x8; 8]) -> [I16x8; 8] {
709        #[cfg(target_arch = "x86_64")]
710        {
711            // SAFETY: AVX2 is checked at runtime
712            unsafe {
713                // Load all 8 rows
714                let r0 = _mm_loadu_si128(rows[0].as_ptr().cast());
715                let r1 = _mm_loadu_si128(rows[1].as_ptr().cast());
716                let r2 = _mm_loadu_si128(rows[2].as_ptr().cast());
717                let r3 = _mm_loadu_si128(rows[3].as_ptr().cast());
718                let r4 = _mm_loadu_si128(rows[4].as_ptr().cast());
719                let r5 = _mm_loadu_si128(rows[5].as_ptr().cast());
720                let r6 = _mm_loadu_si128(rows[6].as_ptr().cast());
721                let r7 = _mm_loadu_si128(rows[7].as_ptr().cast());
722
723                // First level of interleaving
724                let t0 = _mm_unpacklo_epi16(r0, r1);
725                let t1 = _mm_unpackhi_epi16(r0, r1);
726                let t2 = _mm_unpacklo_epi16(r2, r3);
727                let t3 = _mm_unpackhi_epi16(r2, r3);
728                let t4 = _mm_unpacklo_epi16(r4, r5);
729                let t5 = _mm_unpackhi_epi16(r4, r5);
730                let t6 = _mm_unpacklo_epi16(r6, r7);
731                let t7 = _mm_unpackhi_epi16(r6, r7);
732
733                // Second level
734                let u0 = _mm_unpacklo_epi32(t0, t2);
735                let u1 = _mm_unpackhi_epi32(t0, t2);
736                let u2 = _mm_unpacklo_epi32(t1, t3);
737                let u3 = _mm_unpackhi_epi32(t1, t3);
738                let u4 = _mm_unpacklo_epi32(t4, t6);
739                let u5 = _mm_unpackhi_epi32(t4, t6);
740                let u6 = _mm_unpacklo_epi32(t5, t7);
741                let u7 = _mm_unpackhi_epi32(t5, t7);
742
743                // Third level
744                let o0 = _mm_unpacklo_epi64(u0, u4);
745                let o1 = _mm_unpackhi_epi64(u0, u4);
746                let o2 = _mm_unpacklo_epi64(u1, u5);
747                let o3 = _mm_unpackhi_epi64(u1, u5);
748                let o4 = _mm_unpacklo_epi64(u2, u6);
749                let o5 = _mm_unpackhi_epi64(u2, u6);
750                let o6 = _mm_unpacklo_epi64(u3, u7);
751                let o7 = _mm_unpackhi_epi64(u3, u7);
752
753                let mut out = [I16x8::zero(); 8];
754                _mm_storeu_si128(out[0].as_mut_ptr().cast(), o0);
755                _mm_storeu_si128(out[1].as_mut_ptr().cast(), o1);
756                _mm_storeu_si128(out[2].as_mut_ptr().cast(), o2);
757                _mm_storeu_si128(out[3].as_mut_ptr().cast(), o3);
758                _mm_storeu_si128(out[4].as_mut_ptr().cast(), o4);
759                _mm_storeu_si128(out[5].as_mut_ptr().cast(), o5);
760                _mm_storeu_si128(out[6].as_mut_ptr().cast(), o6);
761                _mm_storeu_si128(out[7].as_mut_ptr().cast(), o7);
762                out
763            }
764        }
765        #[cfg(not(target_arch = "x86_64"))]
766        {
767            let mut out = [I16x8::zero(); 8];
768            for i in 0..8 {
769                for j in 0..8 {
770                    out[i][j] = rows[j][i];
771                }
772            }
773            out
774        }
775    }
776
777    #[inline]
778    fn butterfly_i16x8(&self, a: I16x8, b: I16x8) -> (I16x8, I16x8) {
779        let sum = self.add_i16x8(a, b);
780        let diff = self.sub_i16x8(a, b);
781        (sum, diff)
782    }
783
784    #[inline]
785    fn butterfly_i32x4(&self, a: I32x4, b: I32x4) -> (I32x4, I32x4) {
786        let sum = self.add_i32x4(a, b);
787        let diff = self.sub_i32x4(a, b);
788        (sum, diff)
789    }
790}