use rand::Rng;
pub mod avx2;
pub mod avx512;
pub mod neon;
pub mod rvv;
pub mod sve;
pub fn pack_dispatch(bits: &[u8]) -> crate::bitstream::BitStreamTensor {
let length = bits.len();
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512bw") {
let data = unsafe { avx512::pack_avx512(bits) };
return crate::bitstream::BitStreamTensor { data, length };
}
if is_x86_feature_detected!("avx2") {
let data = unsafe { avx2::pack_avx2(bits) };
return crate::bitstream::BitStreamTensor { data, length };
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "sve"))]
{
let data = unsafe { sve::pack_sve(bits) };
return crate::bitstream::BitStreamTensor { data, length };
}
crate::bitstream::pack_fast(bits)
}
pub fn popcount_dispatch(data: &[u64]) -> u64 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512vpopcntdq") {
return unsafe { avx512::popcount_avx512(data) };
}
if is_x86_feature_detected!("avx2") {
return unsafe { avx2::popcount_avx2(data) };
}
}
#[cfg(target_arch = "aarch64")]
{
#[cfg(target_feature = "sve")]
{
return unsafe { sve::popcount_sve(data) };
}
#[cfg(not(target_feature = "sve"))]
{
return unsafe { neon::popcount_neon(data) };
}
}
#[cfg(all(target_arch = "riscv64", target_feature = "v"))]
{
return unsafe { rvv::popcount_rvv(data) };
}
crate::bitstream::popcount_words_portable(data)
}
pub fn fused_and_popcount_dispatch(a: &[u64], b: &[u64]) -> u64 {
let len = a.len().min(b.len());
let a = &a[..len];
let b = &b[..len];
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512vpopcntdq") {
return unsafe { avx512::fused_and_popcount_avx512(a, b) };
}
if is_x86_feature_detected!("avx2") {
return unsafe { avx2::fused_and_popcount_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
#[cfg(target_feature = "sve")]
{
return unsafe { sve::fused_and_popcount_sve(a, b) };
}
}
#[cfg(all(target_arch = "riscv64", target_feature = "v"))]
{
return unsafe { rvv::fused_and_popcount_rvv(a, b) };
}
a.iter()
.zip(b.iter())
.map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
.sum()
}
pub fn fused_xor_popcount_dispatch(a: &[u64], b: &[u64]) -> u64 {
let len = a.len().min(b.len());
let a = &a[..len];
let b = &b[..len];
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512vpopcntdq") {
return unsafe { avx512::fused_xor_popcount_avx512(a, b) };
}
if is_x86_feature_detected!("avx2") {
return unsafe { avx2::fused_xor_popcount_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
#[cfg(target_feature = "sve")]
{
return unsafe { sve::fused_xor_popcount_sve(a, b) };
}
}
#[cfg(all(target_arch = "riscv64", target_feature = "v"))]
{
return unsafe { rvv::fused_xor_popcount_rvv(a, b) };
}
a.iter()
.zip(b.iter())
.map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
.sum()
}
pub fn encode_and_popcount_dispatch<R: Rng + ?Sized>(
weight_words: &[u64],
prob: f64,
length: usize,
rng: &mut R,
) -> u64 {
crate::bitstream::encode_and_popcount(weight_words, prob, length, rng)
}