#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[inline]
#[cfg(target_arch = "x86_64")]
pub unsafe fn transpose_8x2(i0: __m256, i1: __m256, o0: &mut __m256, o1: &mut __m256) {
let r0 = _mm256_permute2f128_ps(i0, i1, 0b0010_0000);
let r1 = _mm256_permute2f128_ps(i0, i1, 0b0011_0001);
*o0 = _mm256_shuffle_ps(r0, r1, 0b10_00_10_00);
*o1 = _mm256_shuffle_ps(r0, r1, 0b11_01_11_01);
}
#[inline]
#[cfg(target_arch = "x86_64")]
pub unsafe fn transpose_8x4(i0: __m256, i1: __m256, i2: __m256, i3: __m256) -> [__m256; 4] {
let r0 = _mm256_permute2f128_ps(i0, i2, _MM_SHUFFLE(0, 2, 0, 0));
let r1 = _mm256_permute2f128_ps(i1, i3, _MM_SHUFFLE(0, 2, 0, 0));
let r2 = _mm256_permute2f128_ps(i0, i2, _MM_SHUFFLE(0, 3, 0, 1));
let r3 = _mm256_permute2f128_ps(i1, i3, _MM_SHUFFLE(0, 3, 0, 1));
let t0 = _mm256_shuffle_ps(r0, r2, _MM_SHUFFLE(2, 0, 2, 0));
let t1 = _mm256_shuffle_ps(r0, r2, _MM_SHUFFLE(3, 1, 3, 1));
let t2 = _mm256_shuffle_ps(r1, r3, _MM_SHUFFLE(2, 0, 2, 0));
let t3 = _mm256_shuffle_ps(r1, r3, _MM_SHUFFLE(3, 1, 3, 1));
let o0 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(2, 0, 2, 0));
let o1 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(2, 0, 2, 0));
let o2 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(3, 1, 3, 1));
let o3 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(3, 1, 3, 1));
[o0, o1, o2, o3]
}
#[inline]
#[cfg(target_arch = "x86_64")]
pub unsafe fn transpose_8x8(
i0: __m256,
i1: __m256,
i2: __m256,
i3: __m256,
i4: __m256,
i5: __m256,
i6: __m256,
i7: __m256,
) -> [__m256; 8] {
let r0 = _mm256_unpacklo_ps(i0, i1);
let r1 = _mm256_unpackhi_ps(i0, i1);
let r2 = _mm256_unpacklo_ps(i2, i3);
let r3 = _mm256_unpackhi_ps(i2, i3);
let r4 = _mm256_unpacklo_ps(i4, i5);
let r5 = _mm256_unpackhi_ps(i4, i5);
let r6 = _mm256_unpacklo_ps(i6, i7);
let r7 = _mm256_unpackhi_ps(i6, i7);
let rr0 = _mm256_shuffle_ps(r0, r2, _MM_SHUFFLE(1, 0, 1, 0));
let rr1 = _mm256_shuffle_ps(r0, r2, _MM_SHUFFLE(3, 2, 3, 2));
let rr2 = _mm256_shuffle_ps(r1, r3, _MM_SHUFFLE(1, 0, 1, 0));
let rr3 = _mm256_shuffle_ps(r1, r3, _MM_SHUFFLE(3, 2, 3, 2));
let rr4 = _mm256_shuffle_ps(r4, r6, _MM_SHUFFLE(1, 0, 1, 0));
let rr5 = _mm256_shuffle_ps(r4, r6, _MM_SHUFFLE(3, 2, 3, 2));
let rr6 = _mm256_shuffle_ps(r5, r7, _MM_SHUFFLE(1, 0, 1, 0));
let rr7 = _mm256_shuffle_ps(r5, r7, _MM_SHUFFLE(3, 2, 3, 2));
let o0 = _mm256_permute2f128_ps(rr0, rr4, 0x20);
let o1 = _mm256_permute2f128_ps(rr1, rr5, 0x20);
let o2 = _mm256_permute2f128_ps(rr2, rr6, 0x20);
let o3 = _mm256_permute2f128_ps(rr3, rr7, 0x20);
let o4 = _mm256_permute2f128_ps(rr0, rr4, 0x31);
let o5 = _mm256_permute2f128_ps(rr1, rr5, 0x31);
let o6 = _mm256_permute2f128_ps(rr2, rr6, 0x31);
let o7 = _mm256_permute2f128_ps(rr3, rr7, 0x31);
[o0, o1, o2, o3, o4, o5, o6, o7]
}
#[cfg(test)]
mod tests {
use super::*;
use std::mem::transmute;
#[test]
#[cfg(target_arch = "x86_64")]
fn test_transpose_8x4() {
unsafe {
let i0 = _mm256_set_ps(13.0, 12.0, 11.0, 10.0, 3.0, 2.0, 1.0, 0.0);
let i1 = _mm256_set_ps(33.0, 32.0, 31.0, 30.0, 23.0, 22.0, 21.0, 20.0);
let i2 = _mm256_set_ps(53.0, 52.0, 51.0, 50.0, 43.0, 42.0, 41.0, 40.0);
let i3 = _mm256_set_ps(73.0, 72.0, 71.0, 70.0, 63.0, 62.0, 61.0, 60.0);
let transposed = transpose_8x4(i0, i1, i2, i3);
let expected = [
_mm256_set_ps(70.0, 60.0, 50.0, 40.0, 30.0, 20.0, 10.0, 0.0),
_mm256_set_ps(71.0, 61.0, 51.0, 41.0, 31.0, 21.0, 11.0, 1.0),
_mm256_set_ps(72.0, 62.0, 52.0, 42.0, 32.0, 22.0, 12.0, 2.0),
_mm256_set_ps(73.0, 63.0, 53.0, 43.0, 33.0, 23.0, 13.0, 3.0),
];
for (i, &val) in transposed.iter().enumerate() {
let transposed_row: [f32; 8] = transmute(val);
let expected_row: [f32; 8] = transmute(expected[i]);
assert_eq!(transposed_row, expected_row, "Row {} mismatch", i);
}
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_transpose_8x8() {
unsafe {
let i0 = _mm256_set_ps(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0);
let i1 = _mm256_set_ps(17.0, 16.0, 15.0, 14.0, 13.0, 12.0, 11.0, 10.0);
let i2 = _mm256_set_ps(27.0, 26.0, 25.0, 24.0, 23.0, 22.0, 21.0, 20.0);
let i3 = _mm256_set_ps(37.0, 36.0, 35.0, 34.0, 33.0, 32.0, 31.0, 30.0);
let i4 = _mm256_set_ps(47.0, 46.0, 45.0, 44.0, 43.0, 42.0, 41.0, 40.0);
let i5 = _mm256_set_ps(57.0, 56.0, 55.0, 54.0, 53.0, 52.0, 51.0, 50.0);
let i6 = _mm256_set_ps(67.0, 66.0, 65.0, 64.0, 63.0, 62.0, 61.0, 60.0);
let i7 = _mm256_set_ps(77.0, 76.0, 75.0, 74.0, 73.0, 72.0, 71.0, 70.0);
let transposed = transpose_8x8(i0, i1, i2, i3, i4, i5, i6, i7);
let expected = [
_mm256_set_ps(70.0, 60.0, 50.0, 40.0, 30.0, 20.0, 10.0, 0.0),
_mm256_set_ps(71.0, 61.0, 51.0, 41.0, 31.0, 21.0, 11.0, 1.0),
_mm256_set_ps(72.0, 62.0, 52.0, 42.0, 32.0, 22.0, 12.0, 2.0),
_mm256_set_ps(73.0, 63.0, 53.0, 43.0, 33.0, 23.0, 13.0, 3.0),
_mm256_set_ps(74.0, 64.0, 54.0, 44.0, 34.0, 24.0, 14.0, 4.0),
_mm256_set_ps(75.0, 65.0, 55.0, 45.0, 35.0, 25.0, 15.0, 5.0),
_mm256_set_ps(76.0, 66.0, 56.0, 46.0, 36.0, 26.0, 16.0, 6.0),
_mm256_set_ps(77.0, 67.0, 57.0, 47.0, 37.0, 27.0, 17.0, 7.0),
];
for (i, &val) in transposed.iter().enumerate() {
let transposed_row: [f32; 8] = transmute(val);
let expected_row: [f32; 8] = transmute(expected[i]);
assert_eq!(transposed_row, expected_row, "Row {} mismatch", i);
}
}
}
}