use crate::error::TruenoError;
#[inline(always)]
fn transpose_region(
a: &[f32],
b: &mut [f32],
rows: std::ops::Range<usize>,
cols: std::ops::Range<usize>,
src_cols: usize,
dst_rows: usize,
) {
for r in rows {
let src_base = r * src_cols;
for c in cols.clone() {
b[c * dst_rows + r] = a[src_base + c];
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn transpose_8x8_avx2(src: *const f32, src_stride: usize, dst: *mut f32, dst_stride: usize) {
unsafe {
use std::arch::x86_64::*;
let r0 = _mm256_loadu_ps(src);
let r1 = _mm256_loadu_ps(src.add(src_stride));
let r2 = _mm256_loadu_ps(src.add(src_stride * 2));
let r3 = _mm256_loadu_ps(src.add(src_stride * 3));
let r4 = _mm256_loadu_ps(src.add(src_stride * 4));
let r5 = _mm256_loadu_ps(src.add(src_stride * 5));
let r6 = _mm256_loadu_ps(src.add(src_stride * 6));
let r7 = _mm256_loadu_ps(src.add(src_stride * 7));
let t0 = _mm256_unpacklo_ps(r0, r1);
let t1 = _mm256_unpackhi_ps(r0, r1);
let t2 = _mm256_unpacklo_ps(r2, r3);
let t3 = _mm256_unpackhi_ps(r2, r3);
let t4 = _mm256_unpacklo_ps(r4, r5);
let t5 = _mm256_unpackhi_ps(r4, r5);
let t6 = _mm256_unpacklo_ps(r6, r7);
let t7 = _mm256_unpackhi_ps(r6, r7);
let u0 = _mm256_shuffle_ps(t0, t2, 0x44);
let u1 = _mm256_shuffle_ps(t0, t2, 0xEE);
let u2 = _mm256_shuffle_ps(t1, t3, 0x44);
let u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
let u4 = _mm256_shuffle_ps(t4, t6, 0x44);
let u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
let u6 = _mm256_shuffle_ps(t5, t7, 0x44);
let u7 = _mm256_shuffle_ps(t5, t7, 0xEE);
let v0 = _mm256_permute2f128_ps(u0, u4, 0x20);
let v1 = _mm256_permute2f128_ps(u1, u5, 0x20);
let v2 = _mm256_permute2f128_ps(u2, u6, 0x20);
let v3 = _mm256_permute2f128_ps(u3, u7, 0x20);
let v4 = _mm256_permute2f128_ps(u0, u4, 0x31);
let v5 = _mm256_permute2f128_ps(u1, u5, 0x31);
let v6 = _mm256_permute2f128_ps(u2, u6, 0x31);
let v7 = _mm256_permute2f128_ps(u3, u7, 0x31);
_mm256_storeu_ps(dst, v0);
_mm256_storeu_ps(dst.add(dst_stride), v1);
_mm256_storeu_ps(dst.add(dst_stride * 2), v2);
_mm256_storeu_ps(dst.add(dst_stride * 3), v3);
_mm256_storeu_ps(dst.add(dst_stride * 4), v4);
_mm256_storeu_ps(dst.add(dst_stride * 5), v5);
_mm256_storeu_ps(dst.add(dst_stride * 6), v6);
_mm256_storeu_ps(dst.add(dst_stride * 7), v7);
}
}
pub fn transpose(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) -> Result<(), TruenoError> {
debug_assert!(!a.is_empty(), "Contract transpose: input is empty");
debug_assert!(rows > 0 && cols > 0, "Contract transpose: zero dimensions");
let expected = rows * cols;
if a.len() != expected || b.len() != expected {
return Err(TruenoError::InvalidInput(format!(
"transpose size mismatch: a[{}], b[{}], expected {}",
a.len(),
b.len(),
expected
)));
}
if expected < 64 {
transpose_region(a, b, 0..rows, 0..cols, cols, rows);
return Ok(());
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe {
return transpose_avx2_impl(rows, cols, a, b);
}
}
}
transpose_scalar_impl(rows, cols, a, b);
Ok(())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn transpose_avx2_impl(
rows: usize,
cols: usize,
a: &[f32],
b: &mut [f32],
) -> Result<(), TruenoError> {
use std::arch::x86_64::*;
const TILE: usize = 64; const BLOCK: usize = 8;
let rb_end = rows / BLOCK * BLOCK;
let cb_end = cols / BLOCK * BLOCK;
let tall_skinny = rows >= 4 * cols;
unsafe {
for rt in (0..rb_end).step_by(TILE) {
let rt_end = (rt + TILE).min(rb_end);
for ct in (0..cb_end).step_by(TILE) {
let ct_end = (ct + TILE).min(cb_end);
if tall_skinny {
for c0 in (ct..ct_end).step_by(BLOCK) {
for r0 in (rt..rt_end).step_by(BLOCK) {
if r0 + BLOCK < rt_end {
let pf_dst = b.as_ptr().add(c0 * rows + r0 + BLOCK);
_mm_prefetch(pf_dst as *const i8, _MM_HINT_T0);
_mm_prefetch(pf_dst.add(rows) as *const i8, _MM_HINT_T0);
}
let src = a.as_ptr().add(r0 * cols + c0);
let dst = b.as_mut_ptr().add(c0 * rows + r0);
transpose_8x8_avx2(src, cols, dst, rows);
}
}
} else {
for r0 in (rt..rt_end).step_by(BLOCK) {
for c0 in (ct..ct_end).step_by(BLOCK) {
let src = a.as_ptr().add(r0 * cols + c0);
let dst = b.as_mut_ptr().add(c0 * rows + r0);
transpose_8x8_avx2(src, cols, dst, rows);
}
}
}
}
}
}
if cb_end < cols {
transpose_region(a, b, 0..rb_end, cb_end..cols, cols, rows);
}
if rb_end < rows {
transpose_region(a, b, rb_end..rows, 0..cols, cols, rows);
}
Ok(())
}
fn transpose_scalar_impl(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) {
const BLOCK: usize = 8;
let row_blocks = rows / BLOCK;
let col_blocks = cols / BLOCK;
for rb in 0..row_blocks {
for cb in 0..col_blocks {
let rs = rb * BLOCK;
let cs = cb * BLOCK;
transpose_region(a, b, rs..rs + BLOCK, cs..cs + BLOCK, cols, rows);
}
}
let col_rem = col_blocks * BLOCK;
if col_rem < cols {
transpose_region(a, b, 0..row_blocks * BLOCK, col_rem..cols, cols, rows);
}
let row_rem = row_blocks * BLOCK;
if row_rem < rows {
transpose_region(a, b, row_rem..rows, 0..cols, cols, rows);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn transpose_naive(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) {
for i in 0..rows {
for j in 0..cols {
b[j * rows + i] = a[i * cols + j];
}
}
}
#[test]
fn test_element_correctness() {
for (rows, cols) in [(4, 5), (8, 8), (16, 32), (31, 17), (64, 64)] {
let a: Vec<f32> = (0..rows * cols).map(|i| i as f32).collect();
let mut b = vec![0.0f32; rows * cols];
transpose(rows, cols, &a, &mut b).unwrap();
for i in 0..rows {
for j in 0..cols {
assert_eq!(b[j * rows + i], a[i * cols + j], "({i},{j}) {rows}×{cols}");
}
}
}
}
#[test]
fn test_involution() {
for (rows, cols) in [(7, 13), (16, 16), (33, 17), (64, 128)] {
let a: Vec<f32> = (0..rows * cols).map(|i| (i as f32) * 0.1 + 0.37).collect();
let mut b = vec![0.0f32; rows * cols];
let mut c = vec![0.0f32; rows * cols];
transpose(rows, cols, &a, &mut b).unwrap();
transpose(cols, rows, &b, &mut c).unwrap();
assert_eq!(a, c, "Involution failed for {rows}×{cols}");
}
}
#[test]
fn test_non_aligned() {
for (rows, cols) in [(7, 13), (17, 3), (1, 32), (32, 1), (1, 1), (3, 3)] {
let a: Vec<f32> = (0..rows * cols).map(|i| i as f32).collect();
let mut b_test = vec![0.0f32; rows * cols];
let mut b_ref = vec![0.0f32; rows * cols];
transpose(rows, cols, &a, &mut b_test).unwrap();
transpose_naive(rows, cols, &a, &mut b_ref);
assert_eq!(b_test, b_ref, "Mismatch for {rows}×{cols}");
}
}
#[test]
fn test_avx2_scalar_parity() {
let rows = 2048;
let cols = 128;
let a: Vec<f32> = (0..rows * cols).map(|i| (i as f32) * 0.001).collect();
let mut b_scalar = vec![0.0f32; rows * cols];
let mut b_dispatch = vec![0.0f32; rows * cols];
transpose_scalar_impl(rows, cols, &a, &mut b_scalar);
transpose(rows, cols, &a, &mut b_dispatch).unwrap();
assert_eq!(b_scalar, b_dispatch, "AVX2 vs scalar mismatch at 2048×128");
}
#[test]
fn test_identity() {
for n in [4, 8, 16, 32] {
let mut a = vec![0.0f32; n * n];
for i in 0..n {
a[i * n + i] = 1.0;
}
let mut b = vec![0.0f32; n * n];
transpose(n, n, &a, &mut b).unwrap();
assert_eq!(a, b, "Identity not preserved for {n}×{n}");
}
}
#[test]
fn test_attention_shape() {
let rows = 2048;
let cols = 128;
let a: Vec<f32> =
(0..rows * cols).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
let mut b_test = vec![0.0f32; rows * cols];
let mut b_ref = vec![0.0f32; rows * cols];
transpose(rows, cols, &a, &mut b_test).unwrap();
transpose_naive(rows, cols, &a, &mut b_ref);
assert_eq!(b_test, b_ref, "Attention shape 2048×128 mismatch");
}
#[test]
fn test_dimension_mismatch() {
let a = vec![1.0f32; 12];
let mut b = vec![0.0f32; 10]; assert!(transpose(3, 4, &a, &mut b).is_err());
}
#[test]
fn test_small_matrix() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut b = vec![0.0f32; 6];
transpose(2, 3, &a, &mut b).unwrap();
assert_eq!(b, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
}