#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
#[cfg(target_arch = "aarch64")]
use core::mem::MaybeUninit;
#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
use core::sync::atomic::{AtomicUsize, Ordering};
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn prefetch_write(ptr: *const u8) {
unsafe {
core::arch::asm!(
"prfm pstl1keep, [{ptr}]",
ptr = in(reg) ptr,
options(readonly, nostack, preserves_flags),
);
}
}
#[cfg(any(target_arch = "aarch64", test))]
const SMALL_LEN: usize = 255;
#[cfg(any(target_arch = "aarch64", test))]
const MEDIUM_LEN: usize = 1024 * 1024;
#[cfg(any(target_arch = "aarch64", test))]
const TILE_SIZE: usize = 16;
#[cfg(target_arch = "aarch64")]
const RECURSIVE_LIMIT: usize = 128;
#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
const PARALLEL_THRESHOLD: usize = 4 * 1024 * 1024;
#[inline]
pub fn transpose<T: Copy + Send + Sync>(
input: &[T],
output: &mut [T],
width: usize,
height: usize,
) {
assert_eq!(
input.len(),
width * height,
"Input length {} doesn't match width*height = {}",
input.len(),
width * height
);
assert_eq!(
output.len(),
width * height,
"Output length {} doesn't match width*height = {}",
output.len(),
width * height
);
if width == 0 || height == 0 {
return;
}
#[cfg(target_arch = "aarch64")]
{
if core::mem::size_of::<T>() == 4 {
unsafe {
transpose_neon_4b(
input.as_ptr().cast::<u32>(),
output.as_mut_ptr().cast::<u32>(),
width,
height,
);
}
return;
}
if core::mem::size_of::<T>() == 8 {
unsafe {
transpose_neon_8b(
input.as_ptr().cast::<u64>(),
output.as_mut_ptr().cast::<u64>(),
width,
height,
);
}
return;
}
}
transpose::transpose(input, output, width, height);
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn transpose_neon_4b(input: *const u32, output: *mut u32, width: usize, height: usize) {
let len = width * height;
#[cfg(feature = "parallel")]
{
if len >= PARALLEL_THRESHOLD {
unsafe {
transpose_neon_4b_parallel(input, output, width, height);
}
return;
}
}
if len <= SMALL_LEN {
unsafe {
transpose_small_4b(input, output, width, height);
}
} else if len <= MEDIUM_LEN {
unsafe {
transpose_tiled_4b(input, output, width, height);
}
} else {
unsafe {
transpose_recursive_4b(input, output, 0, height, 0, width, width, height);
}
}
}
#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
#[inline]
unsafe fn transpose_neon_4b_parallel(
input: *const u32,
output: *mut u32,
width: usize,
height: usize,
) {
use rayon::prelude::*;
let num_threads = rayon::current_num_threads();
let rows_per_thread = height.div_ceil(num_threads);
let inp = AtomicUsize::new(input as usize);
let out = AtomicUsize::new(output as usize);
(0..num_threads).into_par_iter().for_each(|thread_idx| {
let row_start = thread_idx * rows_per_thread;
let row_end = (row_start + rows_per_thread).min(height);
if row_start < row_end {
let input_ptr = inp.load(Ordering::Relaxed) as *const u32;
let output_ptr = out.load(Ordering::Relaxed) as *mut u32;
unsafe {
transpose_region_tiled_4b(
input_ptr, output_ptr, row_start, row_end, 0, width, width, height,
);
}
}
});
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn transpose_small_4b(input: *const u32, output: *mut u32, width: usize, height: usize) {
for x in 0..width {
for y in 0..height {
let input_index = x + y * width;
let output_index = y + x * height;
unsafe {
*output.add(output_index) = *input.add(input_index);
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn transpose_tiled_4b(input: *const u32, output: *mut u32, width: usize, height: usize) {
let x_tile_count = width / TILE_SIZE;
let y_tile_count = height / TILE_SIZE;
let remainder_x = width - x_tile_count * TILE_SIZE;
let remainder_y = height - y_tile_count * TILE_SIZE;
for y_tile in 0..y_tile_count {
for x_tile in 0..x_tile_count {
let x_start = x_tile * TILE_SIZE;
let y_start = y_tile * TILE_SIZE;
unsafe {
transpose_tile_16x16_neon(input, output, width, height, x_start, y_start);
}
}
if remainder_x > 0 {
unsafe {
transpose_block_scalar(
input,
output,
width,
height,
x_tile_count * TILE_SIZE, y_tile * TILE_SIZE, remainder_x, TILE_SIZE, );
}
}
}
if remainder_y > 0 {
for x_tile in 0..x_tile_count {
unsafe {
transpose_block_scalar(
input,
output,
width,
height,
x_tile * TILE_SIZE, y_tile_count * TILE_SIZE, TILE_SIZE, remainder_y, );
}
}
if remainder_x > 0 {
unsafe {
transpose_block_scalar(
input,
output,
width,
height,
x_tile_count * TILE_SIZE, y_tile_count * TILE_SIZE, remainder_x, remainder_y, );
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[allow(clippy::too_many_arguments)]
unsafe fn transpose_recursive_4b(
input: *const u32,
output: *mut u32,
row_start: usize,
row_end: usize,
col_start: usize,
col_end: usize,
total_cols: usize,
total_rows: usize,
) {
let nbr_rows = row_end - row_start;
let nbr_cols = col_end - col_start;
if (nbr_rows <= RECURSIVE_LIMIT && nbr_cols <= RECURSIVE_LIMIT)
|| nbr_rows <= 2
|| nbr_cols <= 2
{
unsafe {
transpose_region_tiled_4b(
input, output, row_start, row_end, col_start, col_end, total_cols, total_rows,
);
}
return;
}
if nbr_rows >= nbr_cols {
let mid = row_start + (nbr_rows / 2);
unsafe {
transpose_recursive_4b(
input, output, row_start, mid, col_start, col_end, total_cols, total_rows,
);
}
unsafe {
transpose_recursive_4b(
input, output, mid, row_end, col_start, col_end, total_cols, total_rows,
);
}
} else {
let mid = col_start + (nbr_cols / 2);
unsafe {
transpose_recursive_4b(
input, output, row_start, row_end, col_start, mid, total_cols, total_rows,
);
}
unsafe {
transpose_recursive_4b(
input, output, row_start, row_end, mid, col_end, total_cols, total_rows,
);
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
#[allow(clippy::too_many_arguments)]
unsafe fn transpose_region_tiled_4b(
input: *const u32,
output: *mut u32,
row_start: usize,
row_end: usize,
col_start: usize,
col_end: usize,
total_cols: usize,
total_rows: usize,
) {
let nbr_cols = col_end - col_start;
let nbr_rows = row_end - row_start;
let x_tile_count = nbr_cols / TILE_SIZE;
let y_tile_count = nbr_rows / TILE_SIZE;
let remainder_x = nbr_cols - x_tile_count * TILE_SIZE;
let remainder_y = nbr_rows - y_tile_count * TILE_SIZE;
for y_tile in 0..y_tile_count {
for x_tile in 0..x_tile_count {
let col = col_start + x_tile * TILE_SIZE;
let row = row_start + y_tile * TILE_SIZE;
unsafe {
transpose_tile_16x16_neon_buffered(input, output, total_cols, total_rows, col, row);
}
}
if remainder_x > 0 {
unsafe {
transpose_block_scalar(
input,
output,
total_cols,
total_rows,
col_start + x_tile_count * TILE_SIZE, row_start + y_tile * TILE_SIZE, remainder_x, TILE_SIZE, );
}
}
}
if remainder_y > 0 {
for x_tile in 0..x_tile_count {
unsafe {
transpose_block_scalar(
input,
output,
total_cols,
total_rows,
col_start + x_tile * TILE_SIZE, row_start + y_tile_count * TILE_SIZE, TILE_SIZE, remainder_y, );
}
}
if remainder_x > 0 {
unsafe {
transpose_block_scalar(
input,
output,
total_cols,
total_rows,
col_start + x_tile_count * TILE_SIZE, row_start + y_tile_count * TILE_SIZE, remainder_x, remainder_y, );
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn transpose_tile_16x16_neon(
input: *const u32,
output: *mut u32,
width: usize,
height: usize,
x_start: usize,
y_start: usize,
) {
unsafe {
let inp = input.add(y_start * width + x_start);
let out = output.add(x_start * height + y_start);
transpose_4x4_neon(inp, out, width, height);
transpose_4x4_neon(inp.add(4), out.add(4 * height), width, height);
transpose_4x4_neon(inp.add(8), out.add(8 * height), width, height);
transpose_4x4_neon(inp.add(12), out.add(12 * height), width, height);
let inp = input.add((y_start + 4) * width + x_start);
let out = output.add(x_start * height + y_start + 4);
transpose_4x4_neon(inp, out, width, height);
transpose_4x4_neon(inp.add(4), out.add(4 * height), width, height);
transpose_4x4_neon(inp.add(8), out.add(8 * height), width, height);
transpose_4x4_neon(inp.add(12), out.add(12 * height), width, height);
let inp = input.add((y_start + 8) * width + x_start);
let out = output.add(x_start * height + y_start + 8);
transpose_4x4_neon(inp, out, width, height);
transpose_4x4_neon(inp.add(4), out.add(4 * height), width, height);
transpose_4x4_neon(inp.add(8), out.add(8 * height), width, height);
transpose_4x4_neon(inp.add(12), out.add(12 * height), width, height);
let inp = input.add((y_start + 12) * width + x_start);
let out = output.add(x_start * height + y_start + 12);
transpose_4x4_neon(inp, out, width, height);
transpose_4x4_neon(inp.add(4), out.add(4 * height), width, height);
transpose_4x4_neon(inp.add(8), out.add(8 * height), width, height);
transpose_4x4_neon(inp.add(12), out.add(12 * height), width, height);
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn transpose_tile_16x16_neon_buffered(
input: *const u32,
output: *mut u32,
width: usize,
height: usize,
x_start: usize,
y_start: usize,
) {
let mut buffer = MaybeUninit::<[u32; TILE_SIZE * TILE_SIZE]>::uninit();
let buf = buffer.as_mut_ptr().cast::<u32>();
unsafe {
let inp = input.add(y_start * width + x_start);
transpose_4x4_neon(inp, buf, width, TILE_SIZE);
transpose_4x4_neon(inp.add(4), buf.add(4 * TILE_SIZE), width, TILE_SIZE);
transpose_4x4_neon(inp.add(8), buf.add(8 * TILE_SIZE), width, TILE_SIZE);
transpose_4x4_neon(inp.add(12), buf.add(12 * TILE_SIZE), width, TILE_SIZE);
let inp = input.add((y_start + 4) * width + x_start);
transpose_4x4_neon(inp, buf.add(4), width, TILE_SIZE);
transpose_4x4_neon(inp.add(4), buf.add(4 * TILE_SIZE + 4), width, TILE_SIZE);
transpose_4x4_neon(inp.add(8), buf.add(8 * TILE_SIZE + 4), width, TILE_SIZE);
transpose_4x4_neon(inp.add(12), buf.add(12 * TILE_SIZE + 4), width, TILE_SIZE);
let inp = input.add((y_start + 8) * width + x_start);
transpose_4x4_neon(inp, buf.add(8), width, TILE_SIZE);
transpose_4x4_neon(inp.add(4), buf.add(4 * TILE_SIZE + 8), width, TILE_SIZE);
transpose_4x4_neon(inp.add(8), buf.add(8 * TILE_SIZE + 8), width, TILE_SIZE);
transpose_4x4_neon(inp.add(12), buf.add(12 * TILE_SIZE + 8), width, TILE_SIZE);
let inp = input.add((y_start + 12) * width + x_start);
transpose_4x4_neon(inp, buf.add(12), width, TILE_SIZE);
transpose_4x4_neon(inp.add(4), buf.add(4 * TILE_SIZE + 12), width, TILE_SIZE);
transpose_4x4_neon(inp.add(8), buf.add(8 * TILE_SIZE + 12), width, TILE_SIZE);
transpose_4x4_neon(inp.add(12), buf.add(12 * TILE_SIZE + 12), width, TILE_SIZE);
prefetch_write(output.add(x_start * height + y_start) as *const u8);
for c in 0..TILE_SIZE {
if c + 1 < TILE_SIZE {
prefetch_write(output.add((x_start + c + 1) * height + y_start) as *const u8);
}
core::ptr::copy_nonoverlapping(
buf.add(c * TILE_SIZE),
output.add((x_start + c) * height + y_start),
TILE_SIZE,
);
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
#[allow(clippy::too_many_arguments)]
unsafe fn transpose_block_scalar(
input: *const u32,
output: *mut u32,
width: usize,
height: usize,
x_start: usize,
y_start: usize,
block_width: usize,
block_height: usize,
) {
for inner_x in 0..block_width {
for inner_y in 0..block_height {
let x = x_start + inner_x;
let y = y_start + inner_y;
let input_index = x + y * width;
let output_index = y + x * height;
unsafe {
*output.add(output_index) = *input.add(input_index);
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn transpose_4x4_neon(src: *const u32, dst: *mut u32, src_stride: usize, dst_stride: usize) {
unsafe {
let q0 = vld1q_u32(src);
let q1 = vld1q_u32(src.add(src_stride));
let q2 = vld1q_u32(src.add(2 * src_stride));
let q3 = vld1q_u32(src.add(3 * src_stride));
let t0_0 = vtrn1q_u32(q0, q1); let t0_1 = vtrn2q_u32(q0, q1); let t0_2 = vtrn1q_u32(q2, q3); let t0_3 = vtrn2q_u32(q2, q3);
let r0 = vreinterpretq_u32_u64(vtrn1q_u64(
vreinterpretq_u64_u32(t0_0),
vreinterpretq_u64_u32(t0_2),
));
let r2 = vreinterpretq_u32_u64(vtrn2q_u64(
vreinterpretq_u64_u32(t0_0),
vreinterpretq_u64_u32(t0_2),
));
let r1 = vreinterpretq_u32_u64(vtrn1q_u64(
vreinterpretq_u64_u32(t0_1),
vreinterpretq_u64_u32(t0_3),
));
let r3 = vreinterpretq_u32_u64(vtrn2q_u64(
vreinterpretq_u64_u32(t0_1),
vreinterpretq_u64_u32(t0_3),
));
vst1q_u32(dst, r0);
vst1q_u32(dst.add(dst_stride), r1);
vst1q_u32(dst.add(2 * dst_stride), r2);
vst1q_u32(dst.add(3 * dst_stride), r3);
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn transpose_neon_8b(input: *const u64, output: *mut u64, width: usize, height: usize) {
let len = width * height;
#[cfg(feature = "parallel")]
{
if len >= PARALLEL_THRESHOLD {
unsafe {
transpose_neon_8b_parallel(input, output, width, height);
}
return;
}
}
if len <= SMALL_LEN {
unsafe {
transpose_small_8b(input, output, width, height);
}
} else if len <= MEDIUM_LEN {
unsafe {
transpose_tiled_8b(input, output, width, height);
}
} else {
unsafe {
transpose_recursive_8b(input, output, 0, height, 0, width, width, height);
}
}
}
#[cfg(all(target_arch = "aarch64", feature = "parallel"))]
#[inline]
unsafe fn transpose_neon_8b_parallel(
input: *const u64,
output: *mut u64,
width: usize,
height: usize,
) {
use rayon::prelude::*;
let num_threads = rayon::current_num_threads();
let rows_per_thread = height.div_ceil(num_threads);
let inp = AtomicUsize::new(input as usize);
let out = AtomicUsize::new(output as usize);
(0..num_threads).into_par_iter().for_each(|thread_idx| {
let row_start = thread_idx * rows_per_thread;
let row_end = (row_start + rows_per_thread).min(height);
if row_start < row_end {
let input_ptr = inp.load(Ordering::Relaxed) as *const u64;
let output_ptr = out.load(Ordering::Relaxed) as *mut u64;
unsafe {
transpose_region_tiled_8b(
input_ptr, output_ptr, row_start, row_end, 0, width, width, height,
);
}
}
});
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn transpose_small_8b(input: *const u64, output: *mut u64, width: usize, height: usize) {
for x in 0..width {
for y in 0..height {
let input_index = x + y * width;
let output_index = y + x * height;
unsafe {
*output.add(output_index) = *input.add(input_index);
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn transpose_tiled_8b(input: *const u64, output: *mut u64, width: usize, height: usize) {
let x_tile_count = width / TILE_SIZE;
let y_tile_count = height / TILE_SIZE;
let remainder_x = width - x_tile_count * TILE_SIZE;
let remainder_y = height - y_tile_count * TILE_SIZE;
for y_tile in 0..y_tile_count {
for x_tile in 0..x_tile_count {
let x_start = x_tile * TILE_SIZE;
let y_start = y_tile * TILE_SIZE;
unsafe {
transpose_tile_16x16_neon_8b(input, output, width, height, x_start, y_start);
}
}
if remainder_x > 0 {
unsafe {
transpose_block_scalar_8b(
input,
output,
width,
height,
x_tile_count * TILE_SIZE,
y_tile * TILE_SIZE,
remainder_x,
TILE_SIZE,
);
}
}
}
if remainder_y > 0 {
for x_tile in 0..x_tile_count {
unsafe {
transpose_block_scalar_8b(
input,
output,
width,
height,
x_tile * TILE_SIZE,
y_tile_count * TILE_SIZE,
TILE_SIZE,
remainder_y,
);
}
}
if remainder_x > 0 {
unsafe {
transpose_block_scalar_8b(
input,
output,
width,
height,
x_tile_count * TILE_SIZE,
y_tile_count * TILE_SIZE,
remainder_x,
remainder_y,
);
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[allow(clippy::too_many_arguments)]
unsafe fn transpose_recursive_8b(
input: *const u64,
output: *mut u64,
row_start: usize,
row_end: usize,
col_start: usize,
col_end: usize,
total_cols: usize,
total_rows: usize,
) {
let nbr_rows = row_end - row_start;
let nbr_cols = col_end - col_start;
if (nbr_rows <= RECURSIVE_LIMIT && nbr_cols <= RECURSIVE_LIMIT)
|| nbr_rows <= 2
|| nbr_cols <= 2
{
unsafe {
transpose_region_tiled_8b(
input, output, row_start, row_end, col_start, col_end, total_cols, total_rows,
);
}
return;
}
if nbr_rows >= nbr_cols {
let mid = row_start + (nbr_rows / 2);
unsafe {
transpose_recursive_8b(
input, output, row_start, mid, col_start, col_end, total_cols, total_rows,
);
}
unsafe {
transpose_recursive_8b(
input, output, mid, row_end, col_start, col_end, total_cols, total_rows,
);
}
} else {
let mid = col_start + (nbr_cols / 2);
unsafe {
transpose_recursive_8b(
input, output, row_start, row_end, col_start, mid, total_cols, total_rows,
);
}
unsafe {
transpose_recursive_8b(
input, output, row_start, row_end, mid, col_end, total_cols, total_rows,
);
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
#[allow(clippy::too_many_arguments)]
unsafe fn transpose_region_tiled_8b(
input: *const u64,
output: *mut u64,
row_start: usize,
row_end: usize,
col_start: usize,
col_end: usize,
total_cols: usize,
total_rows: usize,
) {
let nbr_cols = col_end - col_start;
let nbr_rows = row_end - row_start;
let x_tile_count = nbr_cols / TILE_SIZE;
let y_tile_count = nbr_rows / TILE_SIZE;
let remainder_x = nbr_cols - x_tile_count * TILE_SIZE;
let remainder_y = nbr_rows - y_tile_count * TILE_SIZE;
for y_tile in 0..y_tile_count {
for x_tile in 0..x_tile_count {
let col = col_start + x_tile * TILE_SIZE;
let row = row_start + y_tile * TILE_SIZE;
unsafe {
transpose_tile_16x16_neon_8b_buffered(
input, output, total_cols, total_rows, col, row,
);
}
}
if remainder_x > 0 {
unsafe {
transpose_block_scalar_8b(
input,
output,
total_cols,
total_rows,
col_start + x_tile_count * TILE_SIZE,
row_start + y_tile * TILE_SIZE,
remainder_x,
TILE_SIZE,
);
}
}
}
if remainder_y > 0 {
for x_tile in 0..x_tile_count {
unsafe {
transpose_block_scalar_8b(
input,
output,
total_cols,
total_rows,
col_start + x_tile * TILE_SIZE,
row_start + y_tile_count * TILE_SIZE,
TILE_SIZE,
remainder_y,
);
}
}
if remainder_x > 0 {
unsafe {
transpose_block_scalar_8b(
input,
output,
total_cols,
total_rows,
col_start + x_tile_count * TILE_SIZE,
row_start + y_tile_count * TILE_SIZE,
remainder_x,
remainder_y,
);
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn transpose_tile_16x16_neon_8b(
input: *const u64,
output: *mut u64,
width: usize,
height: usize,
x_start: usize,
y_start: usize,
) {
unsafe {
let inp = input.add(y_start * width + x_start);
let out = output.add(x_start * height + y_start);
transpose_4x4_neon_8b(inp, out, width, height);
transpose_4x4_neon_8b(inp.add(4), out.add(4 * height), width, height);
transpose_4x4_neon_8b(inp.add(8), out.add(8 * height), width, height);
transpose_4x4_neon_8b(inp.add(12), out.add(12 * height), width, height);
let inp = input.add((y_start + 4) * width + x_start);
let out = output.add(x_start * height + y_start + 4);
transpose_4x4_neon_8b(inp, out, width, height);
transpose_4x4_neon_8b(inp.add(4), out.add(4 * height), width, height);
transpose_4x4_neon_8b(inp.add(8), out.add(8 * height), width, height);
transpose_4x4_neon_8b(inp.add(12), out.add(12 * height), width, height);
let inp = input.add((y_start + 8) * width + x_start);
let out = output.add(x_start * height + y_start + 8);
transpose_4x4_neon_8b(inp, out, width, height);
transpose_4x4_neon_8b(inp.add(4), out.add(4 * height), width, height);
transpose_4x4_neon_8b(inp.add(8), out.add(8 * height), width, height);
transpose_4x4_neon_8b(inp.add(12), out.add(12 * height), width, height);
let inp = input.add((y_start + 12) * width + x_start);
let out = output.add(x_start * height + y_start + 12);
transpose_4x4_neon_8b(inp, out, width, height);
transpose_4x4_neon_8b(inp.add(4), out.add(4 * height), width, height);
transpose_4x4_neon_8b(inp.add(8), out.add(8 * height), width, height);
transpose_4x4_neon_8b(inp.add(12), out.add(12 * height), width, height);
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn transpose_tile_16x16_neon_8b_buffered(
input: *const u64,
output: *mut u64,
width: usize,
height: usize,
x_start: usize,
y_start: usize,
) {
let mut buffer = MaybeUninit::<[u64; TILE_SIZE * TILE_SIZE]>::uninit();
let buf = buffer.as_mut_ptr().cast::<u64>();
unsafe {
let inp = input.add(y_start * width + x_start);
transpose_4x4_neon_8b(inp, buf, width, TILE_SIZE);
transpose_4x4_neon_8b(inp.add(4), buf.add(4 * TILE_SIZE), width, TILE_SIZE);
transpose_4x4_neon_8b(inp.add(8), buf.add(8 * TILE_SIZE), width, TILE_SIZE);
transpose_4x4_neon_8b(inp.add(12), buf.add(12 * TILE_SIZE), width, TILE_SIZE);
let inp = input.add((y_start + 4) * width + x_start);
transpose_4x4_neon_8b(inp, buf.add(4), width, TILE_SIZE);
transpose_4x4_neon_8b(inp.add(4), buf.add(4 * TILE_SIZE + 4), width, TILE_SIZE);
transpose_4x4_neon_8b(inp.add(8), buf.add(8 * TILE_SIZE + 4), width, TILE_SIZE);
transpose_4x4_neon_8b(inp.add(12), buf.add(12 * TILE_SIZE + 4), width, TILE_SIZE);
let inp = input.add((y_start + 8) * width + x_start);
transpose_4x4_neon_8b(inp, buf.add(8), width, TILE_SIZE);
transpose_4x4_neon_8b(inp.add(4), buf.add(4 * TILE_SIZE + 8), width, TILE_SIZE);
transpose_4x4_neon_8b(inp.add(8), buf.add(8 * TILE_SIZE + 8), width, TILE_SIZE);
transpose_4x4_neon_8b(inp.add(12), buf.add(12 * TILE_SIZE + 8), width, TILE_SIZE);
let inp = input.add((y_start + 12) * width + x_start);
transpose_4x4_neon_8b(inp, buf.add(12), width, TILE_SIZE);
transpose_4x4_neon_8b(inp.add(4), buf.add(4 * TILE_SIZE + 12), width, TILE_SIZE);
transpose_4x4_neon_8b(inp.add(8), buf.add(8 * TILE_SIZE + 12), width, TILE_SIZE);
transpose_4x4_neon_8b(inp.add(12), buf.add(12 * TILE_SIZE + 12), width, TILE_SIZE);
prefetch_write(output.add(x_start * height + y_start) as *const u8);
for c in 0..TILE_SIZE {
if c + 1 < TILE_SIZE {
prefetch_write(output.add((x_start + c + 1) * height + y_start) as *const u8);
}
core::ptr::copy_nonoverlapping(
buf.add(c * TILE_SIZE),
output.add((x_start + c) * height + y_start),
TILE_SIZE,
);
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
#[allow(clippy::too_many_arguments)]
unsafe fn transpose_block_scalar_8b(
input: *const u64,
output: *mut u64,
width: usize,
height: usize,
x_start: usize,
y_start: usize,
block_width: usize,
block_height: usize,
) {
for inner_x in 0..block_width {
for inner_y in 0..block_height {
let x = x_start + inner_x;
let y = y_start + inner_y;
let input_index = x + y * width;
let output_index = y + x * height;
unsafe {
*output.add(output_index) = *input.add(input_index);
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn transpose_4x4_neon_8b(
src: *const u64,
dst: *mut u64,
src_stride: usize,
dst_stride: usize,
) {
unsafe {
let q0_lo = vld1q_u64(src);
let q0_hi = vld1q_u64(src.add(2));
let q1_lo = vld1q_u64(src.add(src_stride));
let q1_hi = vld1q_u64(src.add(src_stride + 2));
let q2_lo = vld1q_u64(src.add(2 * src_stride));
let q2_hi = vld1q_u64(src.add(2 * src_stride + 2));
let q3_lo = vld1q_u64(src.add(3 * src_stride));
let q3_hi = vld1q_u64(src.add(3 * src_stride + 2));
let r0_lo = vtrn1q_u64(q0_lo, q1_lo); let r1_lo = vtrn2q_u64(q0_lo, q1_lo); let r2_lo = vtrn1q_u64(q0_hi, q1_hi); let r3_lo = vtrn2q_u64(q0_hi, q1_hi); let r0_hi = vtrn1q_u64(q2_lo, q3_lo); let r1_hi = vtrn2q_u64(q2_lo, q3_lo); let r2_hi = vtrn1q_u64(q2_hi, q3_hi); let r3_hi = vtrn2q_u64(q2_hi, q3_hi);
vst1q_u64(dst, r0_lo);
vst1q_u64(dst.add(2), r0_hi);
vst1q_u64(dst.add(dst_stride), r1_lo);
vst1q_u64(dst.add(dst_stride + 2), r1_hi);
vst1q_u64(dst.add(2 * dst_stride), r2_lo);
vst1q_u64(dst.add(2 * dst_stride + 2), r2_hi);
vst1q_u64(dst.add(3 * dst_stride), r3_lo);
vst1q_u64(dst.add(3 * dst_stride + 2), r3_hi);
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use alloc::vec::Vec;
use p3_baby_bear::BabyBear;
use p3_field::PrimeCharacteristicRing;
use p3_goldilocks::Goldilocks;
use proptest::prelude::*;
use super::*;
fn transpose_reference<T: Copy + Default>(input: &[T], width: usize, height: usize) -> Vec<T> {
let mut output = vec![T::default(); width * height];
for y in 0..height {
for x in 0..width {
output[x * height + y] = input[y * width + x];
}
}
output
}
fn dimension_strategy() -> impl Strategy<Value = (usize, usize)> {
let small_side = (SMALL_LEN as f64).sqrt() as usize;
let medium_side = (MEDIUM_LEN as f64).sqrt() as usize;
let large_side = medium_side + 1;
prop_oneof![
Just((0, 0)),
(1..=100_usize).prop_map(|w| (w, 1)),
(1..=100_usize).prop_map(|h| (1, h)),
(1..=small_side, 1..=small_side),
Just((4, 4)),
Just((TILE_SIZE, TILE_SIZE)),
Just((TILE_SIZE * 2, TILE_SIZE * 2)),
Just((TILE_SIZE * 4, TILE_SIZE * 4)),
(
(TILE_SIZE + 1)..=(TILE_SIZE * 4 - 1),
(TILE_SIZE + 1)..=(TILE_SIZE * 4 - 1)
),
(50..=200_usize, 10..=50_usize),
(10..=50_usize, 50..=200_usize),
Just((large_side, large_side)),
Just((large_side + 100, large_side + 100)),
Just((large_side * 2, large_side / 2)),
Just((large_side / 2, large_side * 2)),
Just((large_side + 50, large_side + 75)),
]
}
proptest! {
#[test]
fn proptest_transpose_babybear((width, height) in dimension_strategy()) {
if width == 0 || height == 0 {
let input: [BabyBear; 0] = [];
let mut output: [BabyBear; 0] = [];
transpose(&input, &mut output, width, height);
return Ok(());
}
let input: Vec<BabyBear> = (0..width * height)
.map(|i| BabyBear::from_u64(i as u64))
.collect();
let mut output = vec![BabyBear::ZERO; width * height];
transpose(&input, &mut output, width, height);
let expected = transpose_reference(&input, width, height);
prop_assert_eq!(
output,
expected,
"Transpose mismatch for {}×{} matrix",
width,
height
);
}
#[test]
fn proptest_transpose_u64((width, height) in dimension_strategy()) {
if width == 0 || height == 0 || width * height > 100_000 {
return Ok(());
}
let input: Vec<u64> = (0..width * height).map(|i| i as u64).collect();
let mut output = vec![0u64; width * height];
transpose(&input, &mut output, width, height);
let expected = transpose_reference(&input, width, height);
prop_assert_eq!(output, expected);
}
#[test]
fn proptest_transpose_u8((width, height) in dimension_strategy()) {
if width == 0 || height == 0 || width * height > 100_000 {
return Ok(());
}
let input: Vec<u8> = (0..width * height).map(|i| i as u8).collect();
let mut output = vec![0u8; width * height];
transpose(&input, &mut output, width, height);
let expected = transpose_reference(&input, width, height);
prop_assert_eq!(output, expected);
}
#[test]
fn proptest_transpose_goldilocks((width, height) in dimension_strategy()) {
if width == 0 || height == 0 {
let input: [Goldilocks; 0] = [];
let mut output: [Goldilocks; 0] = [];
transpose(&input, &mut output, width, height);
return Ok(());
}
let input: Vec<Goldilocks> = (0..width * height)
.map(|i| Goldilocks::from_u64(i as u64))
.collect();
let mut output = vec![Goldilocks::ZERO; width * height];
transpose(&input, &mut output, width, height);
let expected = transpose_reference(&input, width, height);
prop_assert_eq!(
output,
expected,
"Transpose mismatch for {}×{} matrix",
width,
height
);
}
}
}