use std::{arch::x86_64::*, cmp};
use bytemuck::{must_cast_slice, must_cast_slice_mut};
use seq_macro::seq;
#[inline]
#[target_feature(enable = "avx2")]
fn transpose_2x2_matrices(x: &mut __m256i, y: &mut __m256i) {
let u = _mm256_permute2x128_si256(*x, *y, 0x20);
let v = _mm256_permute2x128_si256(*x, *y, 0x31);
let mut diff = _mm256_xor_si256(u, _mm256_slli_epi16(v, 1));
diff = _mm256_and_si256(diff, _mm256_set1_epi16(0b1010101010101010_u16 as i16));
let u = _mm256_xor_si256(u, diff);
let v = _mm256_xor_si256(v, _mm256_srli_epi16(diff, 1));
*x = _mm256_permute2x128_si256(u, v, 0x20);
*y = _mm256_permute2x128_si256(u, v, 0x31);
}
#[inline]
#[target_feature(enable = "avx2")]
fn partial_swap_sub_matrices<const SHIFT_AMOUNT: i32, const MASK: u64>(
x: &mut __m256i,
y: &mut __m256i,
) {
let mut diff = _mm256_xor_si256(*x, _mm256_slli_epi64::<SHIFT_AMOUNT>(*y));
diff = _mm256_and_si256(diff, _mm256_set1_epi64x(MASK as i64));
*x = _mm256_xor_si256(*x, diff);
*y = _mm256_xor_si256(*y, _mm256_srli_epi64::<SHIFT_AMOUNT>(diff));
}
#[inline]
#[target_feature(enable = "avx2")]
fn partial_swap_64x64_matrices(x: &mut __m256i, y: &mut __m256i) {
let out_x = _mm256_unpacklo_epi64(*x, *y);
let out_y = _mm256_unpackhi_epi64(*x, *y);
*x = out_x;
*y = out_y;
}
#[target_feature(enable = "avx2")]
pub fn avx_transpose128x128(in_out: &mut [__m256i; 64]) {
for chunk in in_out.chunks_exact_mut(2) {
if let [x, y] = chunk {
transpose_2x2_matrices(x, y);
} else {
unreachable!("chunk size is 2")
}
}
seq!(N in 1..=5 {
const SHIFT_~N: i32 = 1 << N;
const MASK_~N: u64 = match N {
1 => mask(0b1100, 4),
2 => mask(0b11110000, 8),
3 => mask(0b1111111100000000, 16),
4 => mask(0b11111111111111110000000000000000, 32),
5 => 0xffffffff00000000,
_ => unreachable!(),
};
#[allow(clippy::eq_op)] const OFFSET~N: usize = 1 << (N - 1);
for chunk in in_out.chunks_exact_mut(2 * OFFSET~N) {
let (x_chunk, y_chunk) = chunk.split_at_mut(OFFSET~N);
for (x, y) in x_chunk.iter_mut().zip(y_chunk.iter_mut()) {
partial_swap_sub_matrices::<SHIFT_~N, MASK_~N>(x, y);
}
}
});
const SHIFT_6: usize = 6;
const OFFSET_6: usize = 1 << (SHIFT_6 - 1);
for chunk in in_out.chunks_exact_mut(2 * OFFSET_6) {
let (x_chunk, y_chunk) = chunk.split_at_mut(OFFSET_6);
for (x, y) in x_chunk.iter_mut().zip(y_chunk.iter_mut()) {
partial_swap_64x64_matrices(x, y);
}
}
}
const fn mask(pattern: u64, pattern_len: u32) -> u64 {
let mut mask = pattern;
let mut current_block_len = pattern_len;
while current_block_len < 64 {
mask = (mask << current_block_len) | mask;
current_block_len *= 2;
}
mask
}
#[target_feature(enable = "avx2")]
pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) {
assert_eq!(input.len(), output.len());
assert!(rows >= 128, "Number of rows must be >= 128.");
assert_eq!(
0,
input.len() % rows,
"input.len(), must be divisble by rows"
);
assert_eq!(0, rows % 128, "Number of rows must be a multiple of 128.");
let cols = input.len() * 8 / rows;
assert_eq!(0, cols % 8, "Number of columns must be a multiple of 8.");
let mut buf = [_mm256_setzero_si256(); 64 * 4];
let in_stride = cols / 8; let out_stride = rows / 8;
let r_main = rows / 128;
let c_main = cols / 128;
let c_rest = cols % 128;
for i in 0..r_main {
let mut j = 0;
while j < c_main {
let input_offset = i * 128 * in_stride + j * 16;
let curr_addr = input[input_offset..].as_ptr().addr();
let next_cache_line_addr = (curr_addr + 1).next_multiple_of(64); let blocks_in_cache_line = (next_cache_line_addr - curr_addr) / 16;
let remaining_blocks_in_cache_line = if blocks_in_cache_line == 0 {
4
} else {
blocks_in_cache_line
};
let remaining_blocks_in_cache_line =
cmp::min(remaining_blocks_in_cache_line, c_main - j);
let buf_as_bytes: &mut [u8] = must_cast_slice_mut(&mut buf);
macro_rules! loading_loop {
($remaining_blocks_in_cache_line:expr) => {
for k in 0..128 {
let src_slice = &input[input_offset + k * in_stride
..input_offset + k * in_stride + 16 * remaining_blocks_in_cache_line];
for block in 0..remaining_blocks_in_cache_line {
buf_as_bytes[block * 2048 + k * 16..block * 2048 + (k + 1) * 16]
.copy_from_slice(&src_slice[block * 16..(block + 1) * 16]);
}
}
};
}
match remaining_blocks_in_cache_line {
4 => loading_loop!(4),
#[allow(unused_variables)] other => loading_loop!(other),
}
for block in 0..remaining_blocks_in_cache_line {
avx_transpose128x128(
(&mut buf[block * 64..(block + 1) * 64])
.try_into()
.expect("slice has length 64"),
);
}
let mut output_offset = j * 128 * out_stride + i * 16;
let buf_as_bytes: &[u8] = must_cast_slice(&buf);
if out_stride == 16 {
let dst_slice = &mut output
[output_offset..output_offset + 16 * 128 * remaining_blocks_in_cache_line];
dst_slice.copy_from_slice(&buf_as_bytes[..remaining_blocks_in_cache_line * 2048]);
} else {
for block in 0..remaining_blocks_in_cache_line {
for k in 0..128 {
let src_slice =
&buf_as_bytes[block * 2048 + k * 16..block * 2048 + (k + 1) * 16];
let dst_slice = &mut output
[output_offset + k * out_stride..output_offset + k * out_stride + 16];
dst_slice.copy_from_slice(src_slice);
}
output_offset += 128 * out_stride;
}
}
j += remaining_blocks_in_cache_line;
}
if c_rest > 0 {
handle_rest_cols(input, output, &mut buf, in_stride, out_stride, c_rest, i, j);
}
}
}
#[inline(never)]
#[target_feature(enable = "avx2")]
#[allow(clippy::too_many_arguments)]
fn handle_rest_cols(
input: &[u8],
output: &mut [u8],
buf: &mut [__m256i; 256],
in_stride: usize,
out_stride: usize,
c_rest: usize,
i: usize,
j: usize,
) {
let input_offset = i * 128 * in_stride + j * 16;
let remaining_cols_bytes = c_rest / 8;
buf[0..64].fill(_mm256_setzero_si256());
let buf_as_bytes: &mut [u8] = must_cast_slice_mut(buf);
for k in 0..128 {
let src_row_offset = input_offset + k * in_stride;
let src_slice = &input[src_row_offset..src_row_offset + remaining_cols_bytes];
let buf_offset = k * 16;
buf_as_bytes[buf_offset..buf_offset + remaining_cols_bytes].copy_from_slice(src_slice);
}
avx_transpose128x128((&mut buf[..64]).try_into().expect("slice has length 64"));
let output_offset = j * 128 * out_stride + i * 16;
let buf_as_bytes: &[u8] = must_cast_slice(&*buf);
for k in 0..c_rest {
let src_slice = &buf_as_bytes[k * 16..(k + 1) * 16];
let dst_slice =
&mut output[output_offset + k * out_stride..output_offset + k * out_stride + 16];
dst_slice.copy_from_slice(src_slice);
}
}
#[cfg(all(test, target_feature = "avx2"))]
mod tests {
use std::arch::x86_64::_mm256_setzero_si256;
use rand::{Rng, SeedableRng, rngs::StdRng};
use super::{avx_transpose128x128, transpose_bitmatrix};
#[test]
fn test_avx_transpose128() {
unsafe {
let mut v = [_mm256_setzero_si256(); 64];
StdRng::seed_from_u64(42).fill_bytes(bytemuck::cast_slice_mut(&mut v));
let orig = v;
avx_transpose128x128(&mut v);
avx_transpose128x128(&mut v);
let mut failed = false;
for (i, (o, t)) in orig.into_iter().zip(v).enumerate() {
let o = bytemuck::cast::<_, [u128; 2]>(o);
let t = bytemuck::cast::<_, [u128; 2]>(t);
if o != t {
eprintln!("difference in block {i}");
eprintln!("orig: {o:?}");
eprintln!("tran: {t:?}");
failed = true;
}
}
if failed {
panic!("double transposed is different than original")
}
}
}
#[test]
fn test_avx_transpose() {
let rows = 128 * 2;
let cols = 128 * 2;
let mut v = vec![0_u8; rows * cols / 8];
StdRng::seed_from_u64(42).fill_bytes(&mut v);
let mut avx_transposed = v.clone();
let mut sse_transposed = v.clone();
unsafe {
transpose_bitmatrix(&v, &mut avx_transposed, rows);
}
crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows);
assert_eq!(sse_transposed, avx_transposed);
}
#[test]
fn test_avx_transpose_unaligned_data() {
let rows = 128 * 2;
let cols = 128 * 2;
let mut v = vec![0_u8; rows * (cols + 128) / 8];
StdRng::seed_from_u64(42).fill_bytes(&mut v);
let v = {
let addr = v.as_ptr().addr();
let offset = addr.next_multiple_of(3) - addr;
&v[offset..offset + rows * cols / 8]
};
assert_eq!(0, v.as_ptr().addr() % 3);
let mut avx_transposed = v.to_owned();
let mut sse_transposed = v.to_owned();
unsafe {
transpose_bitmatrix(&v, &mut avx_transposed, rows);
}
crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows);
assert_eq!(sse_transposed, avx_transposed);
}
#[test]
fn test_avx_transpose_larger_cols_divisible_by_4_times_128() {
let rows = 128;
let cols = 128 * 8;
let mut v = vec![0_u8; rows * cols / 8];
StdRng::seed_from_u64(42).fill_bytes(&mut v);
let mut avx_transposed = v.clone();
let mut sse_transposed = v.clone();
unsafe {
transpose_bitmatrix(&v, &mut avx_transposed, rows);
}
crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows);
assert_eq!(sse_transposed, avx_transposed);
}
#[test]
fn test_avx_transpose_larger_cols_divisible_by_8() {
let rows = 128;
let cols = 128 + 32;
let mut v = vec![0_u8; rows * cols / 8];
StdRng::seed_from_u64(42).fill_bytes(&mut v);
let mut avx_transposed = v.clone();
let mut sse_transposed = v.clone();
unsafe {
transpose_bitmatrix(&v, &mut avx_transposed, rows);
}
crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows);
assert_eq!(sse_transposed, avx_transposed);
}
}