pub(crate) fn nchw_to_nhwc<T: Copy>(src: &[T], h: usize, w: usize, c: usize, dst: &mut [T]) {
debug_assert_eq!(src.len(), h * w * c, "src length mismatch");
debug_assert_eq!(dst.len(), h * w * c, "dst length mismatch");
#[cfg(target_arch = "aarch64")]
if std::mem::size_of::<T>() == 1 && c >= 16 && (h * w) >= 16 {
unsafe {
let src_bytes = core::slice::from_raw_parts(src.as_ptr() as *const u8, src.len());
let dst_bytes = core::slice::from_raw_parts_mut(dst.as_mut_ptr() as *mut u8, dst.len());
nchw_to_nhwc_u8_neon(src_bytes, h, w, c, dst_bytes);
}
return;
}
let hw = h * w;
let mut dst_idx = 0;
for hi in 0..h {
for wi in 0..w {
let src_base = hi * w + wi;
for ci in 0..c {
dst[dst_idx] = src[ci * hw + src_base];
dst_idx += 1;
}
}
}
}
#[cfg(target_arch = "aarch64")]
pub(crate) unsafe fn nchw_to_nhwc_u8_neon(
src: &[u8],
h: usize,
w: usize,
c: usize,
dst: &mut [u8],
) {
use core::arch::aarch64::*;
let hw = h * w;
debug_assert_eq!(src.len(), hw * c);
debug_assert_eq!(dst.len(), hw * c);
debug_assert!(hw >= 16 && c >= 16);
let src_ptr = src.as_ptr();
let dst_ptr = dst.as_mut_ptr();
let n_tiles_hw = hw / 16;
let n_tiles_c = c / 16;
for tile_c in 0..n_tiles_c {
let c_base = tile_c * 16;
for tile_hw in 0..n_tiles_hw {
let hw_base = tile_hw * 16;
let r0 = vld1q_u8(src_ptr.add((c_base) * hw + hw_base));
let r1 = vld1q_u8(src_ptr.add((c_base + 1) * hw + hw_base));
let r2 = vld1q_u8(src_ptr.add((c_base + 2) * hw + hw_base));
let r3 = vld1q_u8(src_ptr.add((c_base + 3) * hw + hw_base));
let r4 = vld1q_u8(src_ptr.add((c_base + 4) * hw + hw_base));
let r5 = vld1q_u8(src_ptr.add((c_base + 5) * hw + hw_base));
let r6 = vld1q_u8(src_ptr.add((c_base + 6) * hw + hw_base));
let r7 = vld1q_u8(src_ptr.add((c_base + 7) * hw + hw_base));
let r8 = vld1q_u8(src_ptr.add((c_base + 8) * hw + hw_base));
let r9 = vld1q_u8(src_ptr.add((c_base + 9) * hw + hw_base));
let r10 = vld1q_u8(src_ptr.add((c_base + 10) * hw + hw_base));
let r11 = vld1q_u8(src_ptr.add((c_base + 11) * hw + hw_base));
let r12 = vld1q_u8(src_ptr.add((c_base + 12) * hw + hw_base));
let r13 = vld1q_u8(src_ptr.add((c_base + 13) * hw + hw_base));
let r14 = vld1q_u8(src_ptr.add((c_base + 14) * hw + hw_base));
let r15 = vld1q_u8(src_ptr.add((c_base + 15) * hw + hw_base));
let a0 = vtrn1q_u8(r0, r1);
let a1 = vtrn2q_u8(r0, r1);
let a2 = vtrn1q_u8(r2, r3);
let a3 = vtrn2q_u8(r2, r3);
let a4 = vtrn1q_u8(r4, r5);
let a5 = vtrn2q_u8(r4, r5);
let a6 = vtrn1q_u8(r6, r7);
let a7 = vtrn2q_u8(r6, r7);
let a8 = vtrn1q_u8(r8, r9);
let a9 = vtrn2q_u8(r8, r9);
let a10 = vtrn1q_u8(r10, r11);
let a11 = vtrn2q_u8(r10, r11);
let a12 = vtrn1q_u8(r12, r13);
let a13 = vtrn2q_u8(r12, r13);
let a14 = vtrn1q_u8(r14, r15);
let a15 = vtrn2q_u8(r14, r15);
macro_rules! trn_h {
($lo:expr, $hi:expr, $kind:ident) => {
vreinterpretq_u8_u16($kind(
vreinterpretq_u16_u8($lo),
vreinterpretq_u16_u8($hi),
))
};
}
let b0 = trn_h!(a0, a2, vtrn1q_u16);
let b1 = trn_h!(a1, a3, vtrn1q_u16);
let b2 = trn_h!(a0, a2, vtrn2q_u16);
let b3 = trn_h!(a1, a3, vtrn2q_u16);
let b4 = trn_h!(a4, a6, vtrn1q_u16);
let b5 = trn_h!(a5, a7, vtrn1q_u16);
let b6 = trn_h!(a4, a6, vtrn2q_u16);
let b7 = trn_h!(a5, a7, vtrn2q_u16);
let b8 = trn_h!(a8, a10, vtrn1q_u16);
let b9 = trn_h!(a9, a11, vtrn1q_u16);
let b10 = trn_h!(a8, a10, vtrn2q_u16);
let b11 = trn_h!(a9, a11, vtrn2q_u16);
let b12 = trn_h!(a12, a14, vtrn1q_u16);
let b13 = trn_h!(a13, a15, vtrn1q_u16);
let b14 = trn_h!(a12, a14, vtrn2q_u16);
let b15 = trn_h!(a13, a15, vtrn2q_u16);
macro_rules! trn_s {
($lo:expr, $hi:expr, $kind:ident) => {
vreinterpretq_u8_u32($kind(
vreinterpretq_u32_u8($lo),
vreinterpretq_u32_u8($hi),
))
};
}
let d0 = trn_s!(b0, b4, vtrn1q_u32);
let d1 = trn_s!(b1, b5, vtrn1q_u32);
let d2 = trn_s!(b2, b6, vtrn1q_u32);
let d3 = trn_s!(b3, b7, vtrn1q_u32);
let d4 = trn_s!(b0, b4, vtrn2q_u32);
let d5 = trn_s!(b1, b5, vtrn2q_u32);
let d6 = trn_s!(b2, b6, vtrn2q_u32);
let d7 = trn_s!(b3, b7, vtrn2q_u32);
let d8 = trn_s!(b8, b12, vtrn1q_u32);
let d9 = trn_s!(b9, b13, vtrn1q_u32);
let d10 = trn_s!(b10, b14, vtrn1q_u32);
let d11 = trn_s!(b11, b15, vtrn1q_u32);
let d12 = trn_s!(b8, b12, vtrn2q_u32);
let d13 = trn_s!(b9, b13, vtrn2q_u32);
let d14 = trn_s!(b10, b14, vtrn2q_u32);
let d15 = trn_s!(b11, b15, vtrn2q_u32);
macro_rules! trn_d {
($lo:expr, $hi:expr, $kind:ident) => {
vreinterpretq_u8_u64($kind(
vreinterpretq_u64_u8($lo),
vreinterpretq_u64_u8($hi),
))
};
}
let t0 = trn_d!(d0, d8, vtrn1q_u64);
let t1 = trn_d!(d1, d9, vtrn1q_u64);
let t2 = trn_d!(d2, d10, vtrn1q_u64);
let t3 = trn_d!(d3, d11, vtrn1q_u64);
let t4 = trn_d!(d4, d12, vtrn1q_u64);
let t5 = trn_d!(d5, d13, vtrn1q_u64);
let t6 = trn_d!(d6, d14, vtrn1q_u64);
let t7 = trn_d!(d7, d15, vtrn1q_u64);
let t8 = trn_d!(d0, d8, vtrn2q_u64);
let t9 = trn_d!(d1, d9, vtrn2q_u64);
let t10 = trn_d!(d2, d10, vtrn2q_u64);
let t11 = trn_d!(d3, d11, vtrn2q_u64);
let t12 = trn_d!(d4, d12, vtrn2q_u64);
let t13 = trn_d!(d5, d13, vtrn2q_u64);
let t14 = trn_d!(d6, d14, vtrn2q_u64);
let t15 = trn_d!(d7, d15, vtrn2q_u64);
vst1q_u8(dst_ptr.add((hw_base) * c + c_base), t0);
vst1q_u8(dst_ptr.add((hw_base + 1) * c + c_base), t1);
vst1q_u8(dst_ptr.add((hw_base + 2) * c + c_base), t2);
vst1q_u8(dst_ptr.add((hw_base + 3) * c + c_base), t3);
vst1q_u8(dst_ptr.add((hw_base + 4) * c + c_base), t4);
vst1q_u8(dst_ptr.add((hw_base + 5) * c + c_base), t5);
vst1q_u8(dst_ptr.add((hw_base + 6) * c + c_base), t6);
vst1q_u8(dst_ptr.add((hw_base + 7) * c + c_base), t7);
vst1q_u8(dst_ptr.add((hw_base + 8) * c + c_base), t8);
vst1q_u8(dst_ptr.add((hw_base + 9) * c + c_base), t9);
vst1q_u8(dst_ptr.add((hw_base + 10) * c + c_base), t10);
vst1q_u8(dst_ptr.add((hw_base + 11) * c + c_base), t11);
vst1q_u8(dst_ptr.add((hw_base + 12) * c + c_base), t12);
vst1q_u8(dst_ptr.add((hw_base + 13) * c + c_base), t13);
vst1q_u8(dst_ptr.add((hw_base + 14) * c + c_base), t14);
vst1q_u8(dst_ptr.add((hw_base + 15) * c + c_base), t15);
}
let hw_tail_start = n_tiles_hw * 16;
for hw_idx in hw_tail_start..hw {
for ci in 0..16 {
let src_idx = (c_base + ci) * hw + hw_idx;
let dst_idx = hw_idx * c + (c_base + ci);
*dst.get_unchecked_mut(dst_idx) = *src.get_unchecked(src_idx);
}
}
}
let c_tail_start = n_tiles_c * 16;
for hw_idx in 0..hw {
for ci in c_tail_start..c {
let src_idx = ci * hw + hw_idx;
let dst_idx = hw_idx * c + ci;
*dst.get_unchecked_mut(dst_idx) = *src.get_unchecked(src_idx);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn nchw_to_nhwc_small_hand_check() {
let src: [i32; 12] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
let expected: [i32; 12] = [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11];
let mut dst = [0i32; 12];
nchw_to_nhwc(&src, 2, 3, 2, &mut dst);
assert_eq!(dst, expected);
}
#[test]
fn nchw_to_nhwc_roundtrip_random() {
let h = 5;
let w = 7;
let c = 4;
let n = h * w * c;
let src: Vec<i16> = (0..n as i16).collect();
let mut nhwc = vec![0i16; n];
nchw_to_nhwc(&src, h, w, c, &mut nhwc);
let mut back = vec![0i16; n];
let hw = h * w;
for hi in 0..h {
for wi in 0..w {
for ci in 0..c {
let nhwc_idx = (hi * w + wi) * c + ci;
let nchw_idx = ci * hw + hi * w + wi;
back[nchw_idx] = nhwc[nhwc_idx];
}
}
}
assert_eq!(back, src);
}
#[test]
fn nchw_to_nhwc_single_channel_passthrough() {
let h = 4;
let w = 5;
let c = 1;
let src: Vec<u8> = (0..(h * w) as u8).collect();
let mut dst = vec![0u8; h * w];
nchw_to_nhwc(&src, h, w, c, &mut dst);
assert_eq!(dst, src);
}
#[test]
fn nchw_to_nhwc_single_pixel_passthrough() {
let src: [f32; 8] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut dst = [0f32; 8];
nchw_to_nhwc(&src, 1, 1, 8, &mut dst);
assert_eq!(dst, src);
}
#[cfg(target_arch = "aarch64")]
#[test]
fn nchw_to_nhwc_u8_neon_matches_scalar_aligned_16_16() {
let h = 4;
let w = 4;
let c = 16;
let n = h * w * c;
let src: Vec<u8> = (0..n).map(|i| (i % 251) as u8).collect();
let mut dst_neon = vec![0u8; n];
let mut dst_scalar = vec![0u8; n];
super::nchw_to_nhwc(&src, h, w, c, &mut dst_neon);
let hw = h * w;
let mut idx = 0;
for hi in 0..h {
for wi in 0..w {
let base = hi * w + wi;
for ci in 0..c {
dst_scalar[idx] = src[ci * hw + base];
idx += 1;
}
}
}
assert_eq!(dst_neon, dst_scalar, "NEON 16x16 tile mismatched scalar");
}
#[cfg(target_arch = "aarch64")]
#[test]
fn nchw_to_nhwc_u8_neon_matches_scalar_with_both_tails() {
let h = 5;
let w = 4;
let c = 20;
let n = h * w * c;
let src: Vec<u8> = (0..n).map(|i| ((i * 37) % 251) as u8).collect();
let mut dst_neon = vec![0u8; n];
let mut dst_scalar = vec![0u8; n];
super::nchw_to_nhwc(&src, h, w, c, &mut dst_neon);
let hw = h * w;
let mut idx = 0;
for hi in 0..h {
for wi in 0..w {
let base = hi * w + wi;
for ci in 0..c {
dst_scalar[idx] = src[ci * hw + base];
idx += 1;
}
}
}
assert_eq!(dst_neon, dst_scalar);
}
#[cfg(target_arch = "aarch64")]
#[test]
fn nchw_to_nhwc_u8_neon_realistic_score_shape() {
let h = 20;
let w = 20;
let c = 80;
let n = h * w * c;
let src: Vec<u8> = (0..n).map(|i| ((i ^ (i >> 7)) % 251) as u8).collect();
let mut dst_neon = vec![0u8; n];
let mut dst_scalar = vec![0u8; n];
super::nchw_to_nhwc(&src, h, w, c, &mut dst_neon);
let hw = h * w;
let mut idx = 0;
for hi in 0..h {
for wi in 0..w {
let base = hi * w + wi;
for ci in 0..c {
dst_scalar[idx] = src[ci * hw + base];
idx += 1;
}
}
}
assert_eq!(dst_neon, dst_scalar);
}
#[cfg(target_arch = "aarch64")]
#[test]
fn nchw_to_nhwc_i8_neon_matches_scalar() {
let h = 4;
let w = 4;
let c = 32;
let n = h * w * c;
let src: Vec<i8> = (0..n).map(|i| (i as i32 - 100) as i8).collect();
let mut dst_neon = vec![0i8; n];
let mut dst_scalar = vec![0i8; n];
super::nchw_to_nhwc(&src, h, w, c, &mut dst_neon);
let hw = h * w;
let mut idx = 0;
for hi in 0..h {
for wi in 0..w {
let base = hi * w + wi;
for ci in 0..c {
dst_scalar[idx] = src[ci * hw + base];
idx += 1;
}
}
}
assert_eq!(dst_neon, dst_scalar);
}
}