use std::borrow::Cow;
use std::mem::transmute;
use std::slice::ChunksExact;
use super::{SizeMismatch, UnalignedVector, UnalignedVectorCodec};
type BitPackedWord = u64;
const PACKED_WORD_BITS: usize = BitPackedWord::BITS as usize;
const PACKED_WORD_BYTES: usize = std::mem::size_of::<BitPackedWord>();
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Binary {}
impl UnalignedVectorCodec for Binary {
fn from_bytes(bytes: &[u8]) -> Result<Cow<'_, UnalignedVector<Self>>, SizeMismatch> {
let rem = bytes.len() % PACKED_WORD_BYTES;
if rem == 0 {
Ok(Cow::Borrowed(unsafe { transmute::<&[u8], &UnalignedVector<Self>>(bytes) }))
} else {
Err(SizeMismatch { vector_codec: "binary", rem })
}
}
fn from_slice(slice: &[f32]) -> Cow<'static, UnalignedVector<Self>> {
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return Cow::Owned(unsafe { from_slice_neon(slice) });
}
}
Cow::Owned(from_slice_non_optimized(slice))
}
fn from_vec(vec: Vec<f32>) -> Cow<'static, UnalignedVector<Self>> {
Cow::Owned(Self::from_slice(&vec).into_owned())
}
fn to_vec(vec: &UnalignedVector<Self>) -> Vec<f32> {
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { to_vec_neon(vec) };
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse") {
return unsafe { to_vec_sse(vec) };
}
}
to_vec_non_optimized(vec)
}
fn iter(vec: &UnalignedVector<Self>) -> impl ExactSizeIterator<Item = f32> + '_ {
BinaryIterator {
current_element: 0,
current_iteration: PACKED_WORD_BITS,
iter: vec.vector.chunks_exact(PACKED_WORD_BYTES),
}
}
fn len(vec: &UnalignedVector<Self>) -> usize {
(vec.vector.len() / PACKED_WORD_BYTES) * PACKED_WORD_BITS
}
fn is_zero(vec: &UnalignedVector<Self>) -> bool {
vec.as_bytes().iter().all(|b| *b == 0)
}
fn word_size() -> usize {
PACKED_WORD_BITS
}
}
pub(super) fn from_slice_non_optimized(slice: &[f32]) -> Vec<u8> {
let mut output = Vec::with_capacity(slice.len().div_ceil(PACKED_WORD_BITS));
for chunk in slice.chunks(PACKED_WORD_BITS) {
let mut word: BitPackedWord = 0;
let mut bits: u32;
for scalar in chunk.iter().rev() {
word <<= 1;
bits = scalar.to_bits();
word += (bits < 0x8000_0000 && bits > 0x0000_0000) as BitPackedWord;
}
output.extend_from_slice(&word.to_ne_bytes());
}
output
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
unsafe fn from_slice_neon(slice: &[f32]) -> Vec<u8> {
use core::arch::aarch64::*;
let iterations = slice.len() / PACKED_WORD_BYTES;
let remaining = slice.len() % PACKED_WORD_BYTES;
let mut len = iterations;
if !len.is_multiple_of(PACKED_WORD_BYTES) {
len += PACKED_WORD_BYTES - len % PACKED_WORD_BYTES;
} else if remaining != 0 {
len += PACKED_WORD_BYTES;
}
let mut ret = vec![0; len];
let ptr = slice.as_ptr();
let low: [u32; 4] = [
0b_00000000_00000000_00000000_00000001,
0b_00000000_00000000_00000000_00000010,
0b_00000000_00000000_00000000_00000100,
0b_00000000_00000000_00000000_00001000,
];
let high: [u32; 4] = [
0b_00000000_00000000_00000000_00010000,
0b_00000000_00000000_00000000_00100000,
0b_00000000_00000000_00000000_01000000,
0b_00000000_00000000_00000000_10000000,
];
for i in 0..iterations {
unsafe {
let mut byte = 0;
for (idx, mask) in [low, high].iter().enumerate() {
let lane = vld1q_f32(ptr.add(i * 8 + 4 * idx));
let lane = vcgtzq_f32(lane);
let mask = vld1q_u32(mask.as_ptr());
let lane = vandq_u32(lane, mask);
byte |= vaddvq_u32(lane) as u8;
}
*ret.get_unchecked_mut(i) = byte;
}
}
if remaining != 0 {
let mut rem: u8 = 0;
let mut bits: u32;
for r in slice[slice.len() - remaining..].iter().rev() {
rem <<= 1;
bits = r.to_bits();
let r = bits < 0x8000_0000 && bits > 0x0000_0000;
rem |= r as u8;
}
ret[iterations] = rem;
}
ret
}
pub(super) fn to_vec_non_optimized(vec: &UnalignedVector<Binary>) -> Vec<f32> {
vec.iter().collect()
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
unsafe fn to_vec_neon(vec: &UnalignedVector<Binary>) -> Vec<f32> {
use core::arch::aarch64::*;
let mut output: Vec<f32> = vec![0.0; vec.len()];
let output_ptr = output.as_mut_ptr();
let bytes = vec.as_bytes();
let low_mask = [0b_0000_0001, 0b_0000_0010, 0b_0000_0100, 0b_0000_1000];
let high_mask = [0b_0001_0000, 0b_0010_0000, 0b_0100_0000, 0b_1000_0000];
let ones = unsafe { vld1q_dup_f32(&1.0) };
let zeros = unsafe { vld1q_dup_f32(&0.0) };
for (current_byte, base) in bytes.iter().enumerate() {
unsafe {
let base = *base as u32;
let base = vld1q_dup_u32(&base);
for (i, mask) in [low_mask, high_mask].iter().enumerate() {
let mask = vld1q_u32(mask.as_ptr());
let mask = vandq_u32(base, mask);
let mask = vcgtq_u32(mask, vdupq_n_u32(0));
let lane = vbslq_f32(mask, ones, zeros);
let offset = output_ptr.add(current_byte * 8 + i * 4);
vst1q_f32(offset, lane);
}
}
}
output
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe fn to_vec_sse(vec: &UnalignedVector<Binary>) -> Vec<f32> {
use core::arch::x86_64::*;
let mut output: Vec<f32> = vec![0.0; vec.len()];
let output_ptr = output.as_mut_ptr();
let bytes = vec.as_bytes();
let low_mask = [0b_0000_0001, 0b_0000_0010, 0b_0000_0100, 0b_0000_1000];
let high_mask = [0b_0001_0000, 0b_0010_0000, 0b_0100_0000, 0b_1000_0000];
let ones = unsafe { _mm_set1_ps(1.0) };
let zeros = unsafe { _mm_setzero_ps() };
for (current_byte, base) in bytes.iter().enumerate() {
unsafe {
let base = _mm_set1_epi32(*base as i32);
for (i, mask) in [low_mask, high_mask].iter().enumerate() {
let mask = _mm_set_epi32(mask[3], mask[2], mask[1], mask[0]);
let mask = _mm_and_si128(base, mask);
let mask = _mm_cmpeq_epi32(mask, _mm_setzero_si128());
let lane = _mm_blendv_ps(ones, zeros, _mm_castsi128_ps(mask));
let offset = output_ptr.add(current_byte * 8 + i * 4);
_mm_store_ps(offset, lane);
}
}
}
output
}
pub struct BinaryIterator<'a> {
current_element: BitPackedWord,
current_iteration: usize,
iter: ChunksExact<'a, u8>,
}
impl Iterator for BinaryIterator<'_> {
type Item = f32;
fn next(&mut self) -> Option<Self::Item> {
if self.current_iteration >= PACKED_WORD_BITS {
let bytes = self.iter.next()?;
self.current_element = BitPackedWord::from_ne_bytes(bytes.try_into().unwrap());
self.current_iteration = 0;
}
let bit = self.current_element & 1;
self.current_element >>= 1;
self.current_iteration += 1;
Some(bit as f32)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (low, high) = self.iter.size_hint();
let rem = PACKED_WORD_BITS - self.current_iteration;
(low * PACKED_WORD_BITS + rem, high.map(|h| h * PACKED_WORD_BITS + rem))
}
}
impl ExactSizeIterator for BinaryIterator<'_> {
fn len(&self) -> usize {
let (lower, upper) = self.size_hint();
debug_assert_eq!(upper, Some(lower));
lower
}
}