#![allow(clippy::manual_swap)]
use fearless_simd::prelude::*;
use fearless_simd::{f32x4, f32x8, f64x4, f64x8, Simd};
const TILE_SIDE_F32: usize = 64; const TILE_SIDE_F64: usize = 32;
const MIN_TILES: usize = 16;
#[inline(always)]
fn stage_in<T: Copy, const TILE_SIDE: usize>(
data: &[[T; TILE_SIDE]],
buf: &mut [[T; TILE_SIDE]; TILE_SIDE],
tile: usize,
) {
let strip_stride = data.len() / TILE_SIDE;
for u in 0..TILE_SIDE {
buf[u] = data[u * strip_stride + tile];
}
}
#[inline(always)]
fn stage_out<T: Copy, const TILE_SIDE: usize>(
buf: &[[T; TILE_SIDE]; TILE_SIDE],
data: &mut [[T; TILE_SIDE]],
tile: usize,
) {
let strip_stride = data.len() / TILE_SIDE;
for u in 0..TILE_SIDE {
data[u * strip_stride + tile] = buf[u];
}
}
#[inline(always)]
fn stage_swap<T: Copy, const TILE_SIDE: usize>(
data: &mut [[T; TILE_SIDE]],
buf: &mut [[T; TILE_SIDE]; TILE_SIDE],
tile_rev: usize,
) {
let strip_stride = data.len() / TILE_SIDE;
#[allow(clippy::needless_range_loop)] for u in 0..TILE_SIDE {
let data_idx = u * strip_stride + tile_rev;
let tmp = buf[u];
buf[u] = data[data_idx];
data[data_idx] = tmp;
}
}
macro_rules! impl_bit_rev_bravo {
($fn_name:ident, $buf_fn_name:ident, $cobravo_fn_name:ident, $elem_ty:ty, $vec_ty:ty, $lanes:expr, $tile_side:expr) => {
#[inline(always)]
fn $buf_fn_name<S: Simd>(simd: S, data: &mut [$elem_ty], n: usize) {
type Chunk<S> = $vec_ty;
const LANES: usize = $lanes;
assert!(<Chunk<S>>::N == LANES);
let big_n = 1usize << n; assert_eq!(data.len(), big_n, "Data length must be 2^n");
const LOG_W: usize = LANES.ilog2() as usize;
let num_classes = big_n / (LANES * LANES);
let class_bits = n - 2 * LOG_W;
let (data_chunks, _) = data.as_chunks_mut::<LANES>();
let mut chunks_a: [Chunk<S>; LANES] = [Chunk::splat(simd, Default::default()); LANES];
let mut chunks_b: [Chunk<S>; LANES] = [Chunk::splat(simd, Default::default()); LANES];
for class_idx in 0..num_classes {
let class_idx_rev = if class_bits > 0 {
class_idx.reverse_bits() >> (usize::BITS - class_bits as u32)
} else {
0
};
if class_idx > class_idx_rev {
continue;
}
for j in 0..LANES {
chunks_a[j] =
Chunk::from_slice(simd, &data_chunks[class_idx + j * num_classes]);
}
for round in 0..LOG_W {
let stride = 1 << round;
let mut i = 0;
while i < LANES {
for offset in 0..stride {
let idx0 = i + offset;
let idx1 = i + offset + stride;
let vec0 = chunks_a[idx0];
let vec1 = chunks_a[idx1];
chunks_a[idx0] = vec0.zip_low(vec1);
chunks_a[idx1] = vec0.zip_high(vec1);
}
i += stride * 2;
}
}
if class_idx == class_idx_rev {
for j in 0..LANES {
chunks_a[j].store_slice(&mut data_chunks[class_idx + j * num_classes]);
}
} else {
for j in 0..LANES {
chunks_b[j] =
Chunk::from_slice(simd, &data_chunks[class_idx_rev + j * num_classes]);
}
for round in 0..LOG_W {
let stride = 1 << round;
let mut i = 0;
while i < LANES {
for offset in 0..stride {
let idx0 = i + offset;
let idx1 = i + offset + stride;
let vec0 = chunks_b[idx0];
let vec1 = chunks_b[idx1];
chunks_b[idx0] = vec0.zip_low(vec1);
chunks_b[idx1] = vec0.zip_high(vec1);
}
i += stride * 2;
}
}
for j in 0..LANES {
chunks_a[j].store_slice(&mut data_chunks[class_idx_rev + j * num_classes]);
chunks_b[j].store_slice(&mut data_chunks[class_idx + j * num_classes]);
}
}
}
}
#[inline(always)]
fn $cobravo_fn_name<S: Simd>(simd: S, data: &mut [$elem_ty], n: usize) {
const TILE_SIDE: usize = $tile_side;
const N_BUF: usize = 2 * TILE_SIDE.ilog2() as usize; let tile_bits = n - N_BUF;
let num_tiles = 1usize << tile_bits;
let (data_tiles, _) = data.as_chunks_mut::<TILE_SIDE>();
let mut buf = [[<$elem_ty>::default(); TILE_SIDE]; TILE_SIDE];
for tile in 0..num_tiles {
let tile_rev = reverse_bits_scalar(tile, tile_bits as u32);
if tile > tile_rev {
continue;
}
stage_in(data_tiles, &mut buf, tile);
$buf_fn_name(simd, buf.as_flattened_mut(), N_BUF);
if tile == tile_rev {
stage_out(&buf, data_tiles, tile);
} else {
stage_swap(data_tiles, &mut buf, tile_rev);
$buf_fn_name(simd, buf.as_flattened_mut(), N_BUF);
stage_out(&buf, data_tiles, tile);
}
}
}
#[inline(always)] fn $fn_name<S: Simd>(simd: S, data: &mut [$elem_ty], n: usize) {
const LANES: usize = $lanes;
let big_n = 1usize << n;
assert_eq!(data.len(), big_n, "Data length must be 2^n");
if big_n < LANES * LANES {
scalar_bit_reversal(data, n);
return;
}
const TILE_SIDE: usize = $tile_side;
if big_n <= TILE_SIDE * TILE_SIDE * MIN_TILES {
simd.vectorize(
#[inline(always)]
|| $buf_fn_name(simd, data, n),
);
return;
}
simd.vectorize(
#[inline(always)]
|| $cobravo_fn_name(simd, data, n),
);
}
};
}
impl_bit_rev_bravo!(
bit_rev_bravo_chunk_4_f32,
bravo_on_buf_chunk_4_f32,
cobravo_chunk_4_f32,
f32,
f32x4<S>,
4,
TILE_SIDE_F32
);
impl_bit_rev_bravo!(
bit_rev_bravo_chunk_8_f32,
bravo_on_buf_chunk_8_f32,
cobravo_chunk_8_f32,
f32,
f32x8<S>,
8,
TILE_SIDE_F32
);
impl_bit_rev_bravo!(
bit_rev_bravo_chunk_4_f64,
bravo_on_buf_chunk_4_f64,
cobravo_chunk_4_f64,
f64,
f64x4<S>,
4,
TILE_SIDE_F64
);
impl_bit_rev_bravo!(
bit_rev_bravo_chunk_8_f64,
bravo_on_buf_chunk_8_f64,
cobravo_chunk_8_f64,
f64,
f64x8<S>,
8,
TILE_SIDE_F64
);
#[inline(always)] pub fn bit_rev_bravo_f32<S: Simd>(simd: S, data: &mut [f32], n: usize) {
match <S::f32s>::N {
4 => bit_rev_bravo_chunk_4_f32(simd, data, n), _ => bit_rev_bravo_chunk_8_f32(simd, data, n),
}
}
#[inline(always)] pub fn bit_rev_bravo_f64<S: Simd>(simd: S, data: &mut [f64], n: usize) {
match <S::f64s>::N {
2 => bit_rev_bravo_chunk_4_f64(simd, data, n), _ => bit_rev_bravo_chunk_8_f64(simd, data, n),
}
}
fn scalar_bit_reversal<T: Default + Copy + Clone>(data: &mut [T], n: usize) {
let big_n = data.len();
for i in 0..big_n {
let j = reverse_bits_scalar(i, n as u32);
if i < j {
data.swap(i, j);
}
}
}
const fn reverse_bits_scalar(x: usize, bits: u32) -> usize {
if bits == 0 {
return 0;
}
x.reverse_bits() >> (usize::BITS - bits)
}
#[cfg(test)]
mod tests {
use fearless_simd::{dispatch, Level};
use super::*;
fn top_down_bit_reverse_permutation<T: Copy + Clone>(x: &[T]) -> Vec<T> {
if x.len() == 1 {
return x.to_vec();
}
let mut y = Vec::with_capacity(x.len());
let mut evens = Vec::with_capacity(x.len() >> 1);
let mut odds = Vec::with_capacity(x.len() >> 1);
let mut i = 1;
while i < x.len() {
evens.push(x[i - 1]);
odds.push(x[i]);
i += 2;
}
y.extend_from_slice(&top_down_bit_reverse_permutation(&evens));
y.extend_from_slice(&top_down_bit_reverse_permutation(&odds));
y
}
#[test]
fn test_bravo_bit_reversal_f64() {
for n in 2..24 {
let big_n = 1 << n; let mut actual_re: Vec<f64> = (0..big_n).map(f64::from).collect();
let mut actual_im: Vec<f64> = (0..big_n).map(f64::from).collect();
let simd_level = Level::new();
dispatch!(simd_level, simd => bit_rev_bravo_f64(simd, &mut actual_re, n));
dispatch!(simd_level, simd => bit_rev_bravo_f64(simd, &mut actual_im, n));
let input_re: Vec<f64> = (0..big_n).map(f64::from).collect();
let expected_re = top_down_bit_reverse_permutation(&input_re);
assert_eq!(actual_re, expected_re);
let input_im: Vec<f64> = (0..big_n).map(f64::from).collect();
let expected_im = top_down_bit_reverse_permutation(&input_im);
assert_eq!(actual_im, expected_im);
}
}
#[test]
fn test_bravo_bit_reversal_f32() {
for n in 2..24 {
let big_n = 1 << n; let mut actual_re: Vec<f32> = (0..big_n).map(|i| i as f32).collect();
let mut actual_im: Vec<f32> = (0..big_n).map(|i| i as f32).collect();
let simd_level = Level::new();
dispatch!(simd_level, simd => bit_rev_bravo_f32(simd, &mut actual_re, n));
dispatch!(simd_level, simd => bit_rev_bravo_f32(simd, &mut actual_im, n));
let input_re: Vec<f32> = (0..big_n).map(|i| i as f32).collect();
let expected_re = top_down_bit_reverse_permutation(&input_re);
assert_eq!(actual_re, expected_re);
let input_im: Vec<f32> = (0..big_n).map(|i| i as f32).collect();
let expected_im = top_down_bit_reverse_permutation(&input_im);
assert_eq!(actual_im, expected_im);
}
}
}