use core::ptr::{swap, swap_nonoverlapping};
#[cfg(feature = "parallel")]
use core::sync::atomic::{AtomicPtr, Ordering};
const BASE_CASE_LOG: usize = 3;
const BASE_CASE_ELEMENT_THRESHOLD: usize = 1 << (2 * BASE_CASE_LOG);
#[cfg(feature = "parallel")]
const PARALLEL_RECURSION_THRESHOLD: usize = 1 << 10;
unsafe fn transpose_in_place_square_small<T>(
arr: &mut [T],
log_stride: usize,
log_size: usize,
x: usize,
) {
unsafe {
for i in (x + 1)..(x + (1 << log_size)) {
for j in x..i {
swap(
arr.get_unchecked_mut(i + (j << log_stride)),
arr.get_unchecked_mut((i << log_stride) + j),
);
}
}
}
}
pub(super) unsafe fn transpose_swap<T: Copy>(
a: *mut T,
b: *mut T,
width_outer_mat: usize,
(rows, cols): (usize, usize),
) {
let size = rows * cols;
if size < BASE_CASE_ELEMENT_THRESHOLD {
for i in 0..rows {
for j in 0..cols {
let ai = i * width_outer_mat + j;
let bi = j * width_outer_mat + i;
unsafe {
swap_nonoverlapping(a.add(ai), b.add(bi), 1);
}
}
}
return;
}
#[cfg(feature = "parallel")]
{
if size > PARALLEL_RECURSION_THRESHOLD {
let a = AtomicPtr::new(a);
let b = AtomicPtr::new(b);
if rows > cols {
let top = rows / 2;
let bottom = rows - top;
rayon::join(
|| {
let a = a.load(Ordering::Relaxed);
let b = b.load(Ordering::Relaxed);
unsafe {
transpose_swap(a, b, width_outer_mat, (top, cols));
}
},
|| {
let a = a.load(Ordering::Relaxed);
let b = b.load(Ordering::Relaxed);
unsafe {
transpose_swap(
a.add(top * width_outer_mat),
b.add(top),
width_outer_mat,
(bottom, cols),
);
}
},
);
} else {
let left = cols / 2;
let right = cols - left;
rayon::join(
|| {
let a = a.load(Ordering::Relaxed);
let b = b.load(Ordering::Relaxed);
unsafe {
transpose_swap(a, b, width_outer_mat, (rows, left));
}
},
|| {
let a = a.load(Ordering::Relaxed);
let b = b.load(Ordering::Relaxed);
unsafe {
transpose_swap(
a.add(left),
b.add(left * width_outer_mat),
width_outer_mat,
(rows, right),
);
}
},
);
}
return;
}
}
if rows > cols {
let top = rows / 2;
let bottom = rows - top;
unsafe {
transpose_swap(a, b, width_outer_mat, (top, cols));
transpose_swap(
a.add(top * width_outer_mat),
b.add(top),
width_outer_mat,
(bottom, cols),
);
}
} else {
let left = cols / 2;
let right = cols - left;
unsafe {
transpose_swap(a, b, width_outer_mat, (rows, left));
transpose_swap(
a.add(left),
b.add(left * width_outer_mat),
width_outer_mat,
(rows, right),
);
}
}
}
pub(crate) unsafe fn transpose_in_place_square<T>(
arr: &mut [T],
log_stride: usize,
log_size: usize,
x: usize,
) where
T: Copy + Send + Sync,
{
if log_size <= BASE_CASE_LOG {
unsafe {
transpose_in_place_square_small(arr, log_stride, log_size, x);
}
return;
}
let log_half_size = log_size - 1;
let half = 1 << log_half_size;
let stride = 1 << log_stride;
#[cfg(feature = "parallel")]
{
let elements = 1 << (2 * log_size);
if elements >= PARALLEL_RECURSION_THRESHOLD {
let base = AtomicPtr::new(arr.as_mut_ptr());
let len = arr.len();
rayon::join(
|| unsafe {
transpose_in_place_square(
core::slice::from_raw_parts_mut(base.load(Ordering::Relaxed), len),
log_stride,
log_half_size,
x,
);
},
|| {
rayon::join(
|| unsafe {
let ptr = base.load(Ordering::Relaxed);
transpose_swap(
ptr.add((x << log_stride) + (x + half)),
ptr.add(((x + half) << log_stride) + x),
stride,
(half, half),
);
},
|| unsafe {
transpose_in_place_square(
core::slice::from_raw_parts_mut(base.load(Ordering::Relaxed), len),
log_stride,
log_half_size,
x + half,
);
},
)
},
);
return;
}
}
let ptr = arr.as_mut_ptr();
unsafe {
transpose_in_place_square(arr, log_stride, log_half_size, x);
transpose_swap(
ptr.add((x << log_stride) + (x + half)),
ptr.add(((x + half) << log_stride) + x),
stride,
(half, half),
);
transpose_in_place_square(arr, log_stride, log_half_size, x + half);
}
}
#[cfg(test)]
mod tests {
extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
use super::*;
fn generate_matrix(log_size: usize) -> Vec<u32> {
let size = 1 << log_size;
(0..size * size).collect()
}
fn transpose_reference(input: &[u32], log_size: usize) -> Vec<u32> {
let size = 1 << log_size;
let mut transposed = vec![0; size * size];
for i in 0..size {
for j in 0..size {
transposed[j * size + i] = input[i * size + j];
}
}
transposed
}
#[test]
fn transpose_square() {
for log_size in 1..=10 {
let size = 1 << log_size;
let mut mat = generate_matrix(log_size);
let expected = transpose_reference(&mat, log_size);
unsafe {
transpose_in_place_square(&mut mat, log_size, log_size, 0);
}
assert_eq!(mat, expected, "Transpose failed for {size}x{size} matrix");
}
}
}