use rand::Rng;
pub mod avx2;
pub mod avx512;
pub mod neon;
pub mod rvv;
pub mod sve;
pub fn pack_dispatch(bits: &[u8]) -> crate::bitstream::BitStreamTensor {
let length = bits.len();
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512bw") {
let data = unsafe { avx512::pack_avx512(bits) };
return crate::bitstream::BitStreamTensor { data, length };
}
if is_x86_feature_detected!("avx2") {
let data = unsafe { avx2::pack_avx2(bits) };
return crate::bitstream::BitStreamTensor { data, length };
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "sve"))]
{
let data = unsafe { sve::pack_sve(bits) };
return crate::bitstream::BitStreamTensor { data, length };
}
crate::bitstream::pack_fast(bits)
}
pub fn popcount_dispatch(data: &[u64]) -> u64 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512vpopcntdq") {
return unsafe { avx512::popcount_avx512(data) };
}
if is_x86_feature_detected!("avx2") {
return unsafe { avx2::popcount_avx2(data) };
}
}
#[cfg(target_arch = "aarch64")]
{
#[cfg(target_feature = "sve")]
{
return unsafe { sve::popcount_sve(data) };
}
#[cfg(not(target_feature = "sve"))]
{
return unsafe { neon::popcount_neon(data) };
}
}
#[cfg(all(target_arch = "riscv64", target_feature = "v"))]
{
return unsafe { rvv::popcount_rvv(data) };
}
crate::bitstream::popcount_words_portable(data)
}
pub fn fused_and_popcount_dispatch(a: &[u64], b: &[u64]) -> u64 {
let len = a.len().min(b.len());
let a = &a[..len];
let b = &b[..len];
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512vpopcntdq") {
return unsafe { avx512::fused_and_popcount_avx512(a, b) };
}
if is_x86_feature_detected!("avx2") {
return unsafe { avx2::fused_and_popcount_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
#[cfg(target_feature = "sve")]
{
return unsafe { sve::fused_and_popcount_sve(a, b) };
}
}
#[cfg(all(target_arch = "riscv64", target_feature = "v"))]
{
return unsafe { rvv::fused_and_popcount_rvv(a, b) };
}
a.iter()
.zip(b.iter())
.map(|(&wa, &wb)| (wa & wb).count_ones() as u64)
.sum()
}
pub fn fused_xor_popcount_dispatch(a: &[u64], b: &[u64]) -> u64 {
let len = a.len().min(b.len());
let a = &a[..len];
let b = &b[..len];
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512vpopcntdq") {
return unsafe { avx512::fused_xor_popcount_avx512(a, b) };
}
if is_x86_feature_detected!("avx2") {
return unsafe { avx2::fused_xor_popcount_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
#[cfg(target_feature = "sve")]
{
return unsafe { sve::fused_xor_popcount_sve(a, b) };
}
}
#[cfg(all(target_arch = "riscv64", target_feature = "v"))]
{
return unsafe { rvv::fused_xor_popcount_rvv(a, b) };
}
a.iter()
.zip(b.iter())
.map(|(&wa, &wb)| (wa ^ wb).count_ones() as u64)
.sum()
}
pub fn dot_f64_dispatch(a: &[f64], b: &[f64]) -> f64 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
return unsafe { avx512::dot_f64_avx512(a, b) };
}
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { avx2::dot_f64_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
return unsafe { neon::dot_f64_neon(a, b) };
}
#[allow(unreachable_code)]
{
let len = a.len().min(b.len());
a[..len].iter().zip(&b[..len]).map(|(&x, &y)| x * y).sum()
}
}
pub fn max_f64_dispatch(a: &[f64]) -> f64 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
return unsafe { avx512::max_f64_avx512(a) };
}
if is_x86_feature_detected!("avx2") {
return unsafe { avx2::max_f64_avx2(a) };
}
}
#[cfg(target_arch = "aarch64")]
{
return unsafe { neon::max_f64_neon(a) };
}
#[allow(unreachable_code)]
a.iter().copied().fold(f64::NEG_INFINITY, f64::max)
}
pub fn sum_f64_dispatch(a: &[f64]) -> f64 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
return unsafe { avx512::sum_f64_avx512(a) };
}
if is_x86_feature_detected!("avx2") {
return unsafe { avx2::sum_f64_avx2(a) };
}
}
#[cfg(target_arch = "aarch64")]
{
return unsafe { neon::sum_f64_neon(a) };
}
#[allow(unreachable_code)]
a.iter().sum()
}
pub fn scale_f64_dispatch(alpha: f64, y: &mut [f64]) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
unsafe { avx512::scale_f64_avx512(alpha, y) };
return;
}
if is_x86_feature_detected!("avx2") {
unsafe { avx2::scale_f64_avx2(alpha, y) };
return;
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { neon::scale_f64_neon(alpha, y) };
return;
}
#[allow(unreachable_code)]
for v in y.iter_mut() {
*v *= alpha;
}
}
pub fn hamming_distance_dispatch(a: &[u64], b: &[u64]) -> u64 {
fused_xor_popcount_dispatch(a, b)
}
pub fn softmax_inplace_f64_dispatch(scores: &mut [f64]) {
if scores.is_empty() {
return;
}
let max_val = max_f64_dispatch(scores);
for s in scores.iter_mut() {
*s = (*s - max_val).exp();
}
let exp_sum = sum_f64_dispatch(scores);
if exp_sum > 0.0 {
scale_f64_dispatch(1.0 / exp_sum, scores);
}
}
pub fn encode_and_popcount_dispatch<R: Rng + ?Sized>(
weight_words: &[u64],
prob: f64,
length: usize,
rng: &mut R,
) -> u64 {
crate::bitstream::encode_and_popcount(weight_words, prob, length, rng)
}