use crate::{pack, supported_bit_width, unpack, BitPacking, FastLanes};
use paste::paste;
pub trait FoR: BitPacking {
fn for_pack<const W: usize, const B: usize>(
input: &[Self; 1024],
reference: Self,
output: &mut [Self; B],
);
fn unfor_pack<const W: usize, const B: usize>(
input: &[Self; B],
reference: Self,
output: &mut [Self; 1024],
);
}
macro_rules! impl_for {
($T:ty) => {
impl FoR for $T {
fn for_pack<const W: usize, const B: usize>(
input: &[Self; 1024],
reference: Self,
output: &mut [Self; B],
) {
const {
assert!(supported_bit_width(W, 8 * core::mem::size_of::<$T>()));
assert!(B == 1024 * W / Self::T);
}
for lane in 0..Self::LANES {
pack!($T, W, output, lane, |$idx| {
input[$idx].wrapping_sub(reference)
});
}
}
fn unfor_pack<const W: usize, const B: usize>(
input: &[Self; B],
reference: Self,
output: &mut [Self; 1024],
) {
const {
assert!(supported_bit_width(W, 8 * core::mem::size_of::<$T>()));
assert!(B == 1024 * W / Self::T);
}
for lane in 0..Self::LANES {
unpack!($T, W, input, lane, |$idx, $elem| {
output[$idx] = $elem.wrapping_add(reference)
});
}
}
}
};
}
impl_for!(u8);
impl_for!(u16);
impl_for!(u32);
impl_for!(u64);
#[cfg(test)]
mod test {
use super::*;
use core::mem::size_of;
#[test]
fn test_ffor() {
const W: usize = 15;
const B: usize = 1024 * W / u16::T;
let mut values: [u16; 1024] = [0; 1024];
for i in 0..1024 {
values[i] = (i % (1 << W)) as u16;
}
let mut packed = [0; 128 * W / size_of::<u16>()];
FoR::for_pack::<W, B>(&values, 10, &mut packed);
let mut unpacked = [0; 1024];
BitPacking::unpack::<W, B>(&packed, &mut unpacked);
for (i, (a, b)) in values.iter().zip(unpacked.iter()).enumerate() {
assert_eq!(
a.wrapping_sub(10) & ((1 << W) - 1),
*b,
"Mismatch at index {i}"
);
}
}
}