#[inline]
#[must_use]
pub fn simd_popcount_xor(a: &[u8], b: &[u8]) -> u32 {
debug_assert_eq!(a.len(), b.len(), "slices must have equal length");
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { avx2_popcount_xor(a, b) };
}
if is_x86_feature_detected!("popcnt") {
return native_popcount_xor(a, b);
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { neon_popcount_xor(a, b) };
}
}
scalar_popcount_xor(a, b)
}
#[inline]
#[must_use]
pub fn scalar_popcount_xor(a: &[u8], b: &[u8]) -> u32 {
debug_assert_eq!(a.len(), b.len());
a.iter()
.zip(b.iter())
.map(|(x, y)| (x ^ y).count_ones())
.sum()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_ptr_alignment
)]
unsafe fn avx2_popcount_xor(a: &[u8], b: &[u8]) -> u32 {
use std::arch::x86_64::{__m256i, _mm256_extract_epi64, _mm256_loadu_si256, _mm256_xor_si256};
let mut total = 0u32;
let len = a.len();
let chunks = len / 32;
for i in 0..chunks {
let offset = i * 32;
let va = _mm256_loadu_si256(a.as_ptr().add(offset).cast::<__m256i>());
let vb = _mm256_loadu_si256(b.as_ptr().add(offset).cast::<__m256i>());
let xor = _mm256_xor_si256(va, vb);
let v0 = _mm256_extract_epi64(xor, 0) as u64;
let v1 = _mm256_extract_epi64(xor, 1) as u64;
let v2 = _mm256_extract_epi64(xor, 2) as u64;
let v3 = _mm256_extract_epi64(xor, 3) as u64;
total += v0.count_ones() + v1.count_ones() + v2.count_ones() + v3.count_ones();
}
let remainder_start = chunks * 32;
for i in remainder_start..len {
total += (a[i] ^ b[i]).count_ones();
}
total
}
#[cfg(target_arch = "x86_64")]
fn native_popcount_xor(a: &[u8], b: &[u8]) -> u32 {
let mut count = 0u32;
for (chunk_a, chunk_b) in a.chunks_exact(8).zip(b.chunks_exact(8)) {
let va = u64::from_le_bytes([
chunk_a[0], chunk_a[1], chunk_a[2], chunk_a[3], chunk_a[4], chunk_a[5], chunk_a[6],
chunk_a[7],
]);
let vb = u64::from_le_bytes([
chunk_b[0], chunk_b[1], chunk_b[2], chunk_b[3], chunk_b[4], chunk_b[5], chunk_b[6],
chunk_b[7],
]);
count += (va ^ vb).count_ones();
}
let remainder_start = (a.len() / 8) * 8;
for i in remainder_start..a.len() {
count += (a[i] ^ b[i]).count_ones();
}
count
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn neon_popcount_xor(a: &[u8], b: &[u8]) -> u32 {
use std::arch::aarch64::*;
let mut total = 0u32;
let len = a.len();
let chunks = len / 16;
for i in 0..chunks {
let offset = i * 16;
let va = vld1q_u8(a.as_ptr().add(offset));
let vb = vld1q_u8(b.as_ptr().add(offset));
let xor = veorq_u8(va, vb);
let cnt = vcntq_u8(xor);
let sum16 = vpaddlq_u8(cnt);
let sum32 = vpaddlq_u16(sum16);
let sum64 = vpaddlq_u32(sum32);
total += (vgetq_lane_u64(sum64, 0) + vgetq_lane_u64(sum64, 1)) as u32;
}
let remainder_start = chunks * 16;
for i in remainder_start..len {
total += (a[i] ^ b[i]).count_ones();
}
total
}
#[cfg(test)]
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
mod tests {
use super::*;
#[test]
fn test_scalar_popcount_identical() {
let a = vec![0xAA; 96];
let b = vec![0xAA; 96];
assert_eq!(scalar_popcount_xor(&a, &b), 0);
}
#[test]
fn test_scalar_popcount_opposite() {
let a = vec![0x00; 96];
let b = vec![0xFF; 96];
assert_eq!(scalar_popcount_xor(&a, &b), 768);
}
#[test]
fn test_scalar_popcount_half() {
let a = vec![0xF0; 96]; let b = vec![0x0F; 96]; assert_eq!(scalar_popcount_xor(&a, &b), 768); }
#[test]
fn test_simd_matches_scalar_96_bytes() {
let a: Vec<u8> = (0..96).map(|i| i as u8).collect();
let b: Vec<u8> = (0..96).map(|i| (i * 2) as u8).collect();
let simd_result = simd_popcount_xor(&a, &b);
let scalar_result = scalar_popcount_xor(&a, &b);
assert_eq!(simd_result, scalar_result);
}
#[test]
fn test_simd_matches_scalar_16_bytes() {
let a = vec![0xAA; 16];
let b = vec![0x55; 16];
let simd_result = simd_popcount_xor(&a, &b);
let scalar_result = scalar_popcount_xor(&a, &b);
assert_eq!(simd_result, scalar_result);
assert_eq!(simd_result, 128); }
#[test]
fn test_simd_matches_scalar_128_bytes() {
let a: Vec<u8> = (0..128).map(|i| (i * 3) as u8).collect();
let b: Vec<u8> = (0..128).map(|i| (i * 7) as u8).collect();
let simd_result = simd_popcount_xor(&a, &b);
let scalar_result = scalar_popcount_xor(&a, &b);
assert_eq!(simd_result, scalar_result);
}
#[test]
fn test_simd_matches_scalar_192_bytes() {
let a: Vec<u8> = (0..192).map(|i| (i * 5) as u8).collect();
let b: Vec<u8> = (0..192).map(|i| (i * 11) as u8).collect();
let simd_result = simd_popcount_xor(&a, &b);
let scalar_result = scalar_popcount_xor(&a, &b);
assert_eq!(simd_result, scalar_result);
}
#[test]
fn test_simd_odd_length() {
let a = vec![0xFF; 7];
let b = vec![0x00; 7];
let simd_result = simd_popcount_xor(&a, &b);
let scalar_result = scalar_popcount_xor(&a, &b);
assert_eq!(simd_result, scalar_result);
assert_eq!(simd_result, 56); }
#[test]
fn test_simd_empty_slices() {
let a: Vec<u8> = vec![];
let b: Vec<u8> = vec![];
assert_eq!(simd_popcount_xor(&a, &b), 0);
}
#[test]
fn test_simd_single_byte() {
let a = vec![0b1010_1010];
let b = vec![0b0101_0101];
assert_eq!(simd_popcount_xor(&a, &b), 8);
}
#[test]
fn test_simd_large_vectors() {
let a: Vec<u8> = (0..4096).map(|i| (i % 256) as u8).collect();
let b: Vec<u8> = (0..4096).map(|i| ((i + 128) % 256) as u8).collect();
let simd_result = simd_popcount_xor(&a, &b);
let scalar_result = scalar_popcount_xor(&a, &b);
assert_eq!(simd_result, scalar_result);
}
}