#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f,avx512bw")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn calculate_shannon_entropy(chunk: &[u8]) -> f64 {
let len = chunk.len();
if len == 0 {
return 0.0;
}
let (counts, active_len) = crate::entropy_fast::histogram_8way(chunk);
if active_len == 0 {
return 0.0;
}
if active_len <= 255 {
let table = crate::entropy_fast::get_log2_table();
let mut sum = 0.0;
for &count in &counts {
if count > 0 {
sum += table[count as usize];
}
}
return (active_len as f64).log2() - sum / (active_len as f64);
}
let mut sum_v = _mm512_setzero_pd();
let len_v = _mm512_set1_pd(active_len as f64);
for k in (0..256).step_by(8) {
let counts_v = _mm256_loadu_si256(counts[k..].as_ptr() as *const __m256i);
let counts_f = _mm512_cvtepi32_pd(counts_v);
let mask = _mm512_cmp_pd_mask(counts_f, _mm512_setzero_pd(), 30);
if mask == 0 {
continue;
}
let p = _mm512_maskz_div_pd(mask, counts_f, len_v);
let log2p = approx_log2_pd(p);
let term = _mm512_mul_pd(p, log2p);
sum_v = _mm512_mask_sub_pd(sum_v, mask, sum_v, term);
}
let mut sums = [0.0f64; 8];
_mm512_storeu_pd(sums.as_mut_ptr(), sum_v);
sums.iter().sum()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn approx_log2_pd(x: __m512d) -> __m512d {
let bits = _mm512_castpd_si512(x);
let e = _mm512_sub_epi64(
_mm512_and_si512(_mm512_srli_epi64(bits, 52), _mm512_set1_epi64(0x7FF)),
_mm512_set1_epi64(1023),
);
let e_f = _mm512_cvtepi64_pd(e);
let m_bits = _mm512_or_si512(
_mm512_and_si512(bits, _mm512_set1_epi64(0xFFFFFFFFFFFFF)),
_mm512_set1_epi64(0x3FF0000000000000), );
let m = _mm512_castsi512_pd(m_bits);
let z = _mm512_sub_pd(m, _mm512_set1_pd(1.0));
let a1 = _mm512_set1_pd(1.442689882843058);
let a2 = _mm512_set1_pd(-0.721344529025066);
let a3 = _mm512_set1_pd(0.480884024344551);
let a4 = _mm512_set1_pd(-0.359880922880757);
let a5 = _mm512_set1_pd(0.246417534433544);
let mut poly = a5;
poly = _mm512_fmadd_pd(poly, z, a4);
poly = _mm512_fmadd_pd(poly, z, a3);
poly = _mm512_fmadd_pd(poly, z, a2);
poly = _mm512_fmadd_pd(poly, z, a1);
let log2m = _mm512_mul_pd(poly, z);
_mm512_add_pd(e_f, log2m)
}