use std::sync::OnceLock;
type ScanFn = fn(
packed: &[u64],
n_words: usize,
n: usize,
q_packed: &[u64],
mask: u64,
out_agree: &mut [u32],
);
static SCAN_IMPL: OnceLock<ScanFn> = OnceLock::new();
#[inline]
pub fn scan_scalar(
packed: &[u64],
n_words: usize,
n: usize,
q_packed: &[u64],
mask: u64,
out_agree: &mut [u32],
) {
debug_assert_eq!(packed.len(), n * n_words);
debug_assert_eq!(q_packed.len(), n_words);
debug_assert_eq!(out_agree.len(), n);
let aligned = mask == !0u64;
if aligned && n_words == 2 {
let q0 = q_packed[0];
let q1 = q_packed[1];
for i in 0..n {
let b = i * 2;
let w0 = unsafe { *packed.get_unchecked(b) };
let w1 = unsafe { *packed.get_unchecked(b + 1) };
let a = (!(w0 ^ q0)).count_ones() + (!(w1 ^ q1)).count_ones();
unsafe { *out_agree.get_unchecked_mut(i) = a };
}
} else if aligned {
for i in 0..n {
let base = i * n_words;
let mut a: u32 = 0;
for w in 0..n_words {
let wi = unsafe { *packed.get_unchecked(base + w) };
let qi = unsafe { *q_packed.get_unchecked(w) };
a += (!(wi ^ qi)).count_ones();
}
unsafe { *out_agree.get_unchecked_mut(i) = a };
}
} else {
let last = n_words - 1;
for i in 0..n {
let base = i * n_words;
let mut a: u32 = 0;
for w in 0..last {
let wi = unsafe { *packed.get_unchecked(base + w) };
let qi = unsafe { *q_packed.get_unchecked(w) };
a += (!(wi ^ qi)).count_ones();
}
let wi = unsafe { *packed.get_unchecked(base + last) };
let qi = unsafe { *q_packed.get_unchecked(last) };
a += (!(wi ^ qi) & mask).count_ones();
unsafe { *out_agree.get_unchecked_mut(i) = a };
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,popcnt")]
unsafe fn scan_avx2(
packed: &[u64],
n_words: usize,
n: usize,
q_packed: &[u64],
mask: u64,
out_agree: &mut [u32],
) {
use core::arch::x86_64::_popcnt64;
debug_assert_eq!(packed.len(), n * n_words);
debug_assert_eq!(q_packed.len(), n_words);
debug_assert_eq!(out_agree.len(), n);
let aligned = mask == !0u64;
let p = packed.as_ptr();
let o = out_agree.as_mut_ptr();
if aligned && n_words == 2 {
let q0 = q_packed[0] as i64;
let q1 = q_packed[1] as i64;
let n4 = n & !3usize;
let mut i = 0usize;
while i < n4 {
let b = p.add(i * 2);
let w0 = *b as i64;
let w1 = *b.add(1) as i64;
let w2 = *b.add(2) as i64;
let w3 = *b.add(3) as i64;
let w4 = *b.add(4) as i64;
let w5 = *b.add(5) as i64;
let w6 = *b.add(6) as i64;
let w7 = *b.add(7) as i64;
let a0: i32 = _popcnt64(!(w0 ^ q0)) + _popcnt64(!(w1 ^ q1));
let a1: i32 = _popcnt64(!(w2 ^ q0)) + _popcnt64(!(w3 ^ q1));
let a2: i32 = _popcnt64(!(w4 ^ q0)) + _popcnt64(!(w5 ^ q1));
let a3: i32 = _popcnt64(!(w6 ^ q0)) + _popcnt64(!(w7 ^ q1));
*o.add(i) = a0 as u32;
*o.add(i + 1) = a1 as u32;
*o.add(i + 2) = a2 as u32;
*o.add(i + 3) = a3 as u32;
i += 4;
}
while i < n {
let b = p.add(i * 2);
let a: i32 = _popcnt64(!((*b as i64) ^ q0)) + _popcnt64(!((*b.add(1) as i64) ^ q1));
*o.add(i) = a as u32;
i += 1;
}
return;
}
let n4 = n & !3usize;
let mut i = 0usize;
if aligned {
while i < n4 {
let mut a0: i32 = 0;
let mut a1: i32 = 0;
let mut a2: i32 = 0;
let mut a3: i32 = 0;
for w in 0..n_words {
let qi = *q_packed.get_unchecked(w) as i64;
a0 += _popcnt64(!((*p.add(i * n_words + w) as i64) ^ qi));
a1 += _popcnt64(!((*p.add((i + 1) * n_words + w) as i64) ^ qi));
a2 += _popcnt64(!((*p.add((i + 2) * n_words + w) as i64) ^ qi));
a3 += _popcnt64(!((*p.add((i + 3) * n_words + w) as i64) ^ qi));
}
*o.add(i) = a0 as u32;
*o.add(i + 1) = a1 as u32;
*o.add(i + 2) = a2 as u32;
*o.add(i + 3) = a3 as u32;
i += 4;
}
while i < n {
let mut a: i32 = 0;
for w in 0..n_words {
let qi = *q_packed.get_unchecked(w) as i64;
a += _popcnt64(!((*p.add(i * n_words + w) as i64) ^ qi));
}
*o.add(i) = a as u32;
i += 1;
}
} else {
let last = n_words - 1;
let m = mask as i64;
while i < n4 {
let mut a0: i32 = 0;
let mut a1: i32 = 0;
let mut a2: i32 = 0;
let mut a3: i32 = 0;
for w in 0..last {
let qi = *q_packed.get_unchecked(w) as i64;
a0 += _popcnt64(!((*p.add(i * n_words + w) as i64) ^ qi));
a1 += _popcnt64(!((*p.add((i + 1) * n_words + w) as i64) ^ qi));
a2 += _popcnt64(!((*p.add((i + 2) * n_words + w) as i64) ^ qi));
a3 += _popcnt64(!((*p.add((i + 3) * n_words + w) as i64) ^ qi));
}
let qi = *q_packed.get_unchecked(last) as i64;
a0 += _popcnt64(!((*p.add(i * n_words + last) as i64) ^ qi) & m);
a1 += _popcnt64(!((*p.add((i + 1) * n_words + last) as i64) ^ qi) & m);
a2 += _popcnt64(!((*p.add((i + 2) * n_words + last) as i64) ^ qi) & m);
a3 += _popcnt64(!((*p.add((i + 3) * n_words + last) as i64) ^ qi) & m);
*o.add(i) = a0 as u32;
*o.add(i + 1) = a1 as u32;
*o.add(i + 2) = a2 as u32;
*o.add(i + 3) = a3 as u32;
i += 4;
}
while i < n {
let mut a: i32 = 0;
for w in 0..last {
let qi = *q_packed.get_unchecked(w) as i64;
a += _popcnt64(!((*p.add(i * n_words + w) as i64) ^ qi));
}
let qi = *q_packed.get_unchecked(last) as i64;
a += _popcnt64(!((*p.add(i * n_words + last) as i64) ^ qi) & m);
*o.add(i) = a as u32;
i += 1;
}
}
}
#[cfg(target_arch = "x86_64")]
fn scan_avx2_dispatch(
packed: &[u64],
n_words: usize,
n: usize,
q_packed: &[u64],
mask: u64,
out_agree: &mut [u32],
) {
unsafe { scan_avx2(packed, n_words, n, q_packed, mask, out_agree) };
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,avx512f,avx512bw,avx512vpopcntdq")]
#[allow(clippy::incompatible_msrv)]
unsafe fn scan_avx512(
packed: &[u64],
n_words: usize,
n: usize,
q_packed: &[u64],
mask: u64,
out_agree: &mut [u32],
) {
use core::arch::x86_64::{
__m512i, _mm512_loadu_si512, _mm512_popcnt_epi64, _mm512_reduce_add_epi64,
_mm512_set1_epi64, _mm512_set_epi64, _mm512_xor_si512,
};
debug_assert_eq!(packed.len(), n * n_words);
debug_assert_eq!(q_packed.len(), n_words);
debug_assert_eq!(out_agree.len(), n);
let aligned = mask == !0u64;
let p = packed.as_ptr();
let o = out_agree.as_mut_ptr();
let ones = _mm512_set1_epi64(-1i64);
if aligned && n_words == 2 {
let q0 = q_packed[0] as i64;
let q1 = q_packed[1] as i64;
let qvec = _mm512_set_epi64(q1, q0, q1, q0, q1, q0, q1, q0);
let n4 = n & !3usize;
let mut i = 0usize;
while i < n4 {
let v = _mm512_loadu_si512(p.add(i * 2) as *const __m512i);
let x = _mm512_xor_si512(_mm512_xor_si512(v, qvec), ones); let pc = _mm512_popcnt_epi64(x);
let mut tmp = [0i64; 8];
core::arch::x86_64::_mm512_storeu_si512(tmp.as_mut_ptr() as *mut __m512i, pc);
*o.add(i) = (tmp[0] + tmp[1]) as u32;
*o.add(i + 1) = (tmp[2] + tmp[3]) as u32;
*o.add(i + 2) = (tmp[4] + tmp[5]) as u32;
*o.add(i + 3) = (tmp[6] + tmp[7]) as u32;
i += 4;
}
while i < n {
let b = p.add(i * 2);
let w0 = *b as i64;
let w1 = *b.add(1) as i64;
let a = (!(w0 ^ q0)).count_ones() + (!(w1 ^ q1)).count_ones();
*o.add(i) = a;
i += 1;
}
return;
}
let last_idx = if aligned { n_words } else { n_words - 1 };
let chunks = last_idx / 8;
let rem_start = chunks * 8;
let m = mask;
for i in 0..n {
let base = i * n_words;
let mut a: u32 = 0;
for c in 0..chunks {
let w_off = c * 8;
let vv = _mm512_loadu_si512(p.add(base + w_off) as *const __m512i);
let qv = _mm512_loadu_si512(q_packed.as_ptr().add(w_off) as *const __m512i);
let x = _mm512_xor_si512(_mm512_xor_si512(vv, qv), ones);
let pc = _mm512_popcnt_epi64(x);
a += _mm512_reduce_add_epi64(pc) as u32;
}
for w in rem_start..last_idx {
let wi = *p.add(base + w);
let qi = *q_packed.get_unchecked(w);
a += (!(wi ^ qi)).count_ones();
}
if !aligned {
let wi = *p.add(base + n_words - 1);
let qi = *q_packed.get_unchecked(n_words - 1);
a += (!(wi ^ qi) & m).count_ones();
}
*o.add(i) = a;
}
}
#[cfg(target_arch = "x86_64")]
fn scan_avx512_dispatch(
packed: &[u64],
n_words: usize,
n: usize,
q_packed: &[u64],
mask: u64,
out_agree: &mut [u32],
) {
unsafe { scan_avx512(packed, n_words, n, q_packed, mask, out_agree) };
}
#[inline]
pub fn scan(
packed: &[u64],
n_words: usize,
n: usize,
q_packed: &[u64],
mask: u64,
out_agree: &mut [u32],
) {
let f = SCAN_IMPL.get_or_init(select_impl);
f(packed, n_words, n, q_packed, mask, out_agree);
}
fn select_impl() -> ScanFn {
#[cfg(target_arch = "x86_64")]
{
if std::is_x86_feature_detected!("avx512f")
&& std::is_x86_feature_detected!("avx512vpopcntdq")
{
return scan_avx512_dispatch;
}
if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("popcnt") {
return scan_avx2_dispatch;
}
}
scan_scalar
}
#[cfg(test)]
mod tests {
use super::*;
fn random_packed(dim: usize, n: usize, seed: u64) -> (Vec<u64>, Vec<u64>, u64) {
let mut s = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15) | 1;
let mut step = || -> u64 {
s ^= s << 13;
s ^= s >> 7;
s ^= s << 17;
s
};
let n_words = (dim + 63) / 64;
let mut packed = vec![0u64; n * n_words];
for w in &mut packed {
*w = step();
}
let mut q = vec![0u64; n_words];
for w in &mut q {
*w = step();
}
let valid_bits = dim - 64 * (n_words - 1);
let mask = if valid_bits == 64 {
!0u64
} else {
!0u64 << (64 - valid_bits)
};
for i in 0..n {
let last = i * n_words + n_words - 1;
packed[last] &= mask;
}
q[n_words - 1] &= mask;
(packed, q, mask)
}
fn run_both(dim: usize, n: usize, seed: u64) {
let (packed, q, mask) = random_packed(dim, n, seed);
let n_words = (dim + 63) / 64;
let mut out_scalar = vec![0u32; n];
scan_scalar(&packed, n_words, n, &q, mask, &mut out_scalar);
let mut out_dispatch = vec![0u32; n];
scan(&packed, n_words, n, &q, mask, &mut out_dispatch);
assert_eq!(
out_scalar, out_dispatch,
"dispatcher output diverged from scalar at dim={dim} n={n}"
);
#[cfg(target_arch = "x86_64")]
if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("popcnt") {
let mut out_avx2 = vec![0u32; n];
unsafe {
scan_avx2(&packed, n_words, n, &q, mask, &mut out_avx2);
}
assert_eq!(
out_scalar, out_avx2,
"AVX2 output diverged from scalar at dim={dim} n={n}"
);
}
#[cfg(target_arch = "x86_64")]
if std::is_x86_feature_detected!("avx512f")
&& std::is_x86_feature_detected!("avx512vpopcntdq")
{
let mut out_avx512 = vec![0u32; n];
unsafe {
scan_avx512(&packed, n_words, n, &q, mask, &mut out_avx512);
}
assert_eq!(
out_scalar, out_avx512,
"AVX-512 output diverged from scalar at dim={dim} n={n}"
);
}
}
#[test]
fn scan_agree_matches_scalar_at_d128() {
run_both(128, 1000, 0xA5A5_5A5A_1234_CAFE);
}
#[test]
fn scan_avx512_matches_scalar() {
#[cfg(target_arch = "x86_64")]
{
if !std::is_x86_feature_detected!("avx512f")
|| !std::is_x86_feature_detected!("avx512vpopcntdq")
{
eprintln!(
"scan_avx512_matches_scalar: host lacks avx512f+avx512vpopcntdq, skipping"
);
return;
}
let dim = 128usize;
let n = 1000usize;
let (packed, q, mask) = random_packed(dim, n, 0xC001_FACE_D00D_BEEF);
let n_words = dim.div_ceil(64);
let mut out_scalar = vec![0u32; n];
scan_scalar(&packed, n_words, n, &q, mask, &mut out_scalar);
let mut out_avx512 = vec![0u32; n];
unsafe {
scan_avx512(&packed, n_words, n, &q, mask, &mut out_avx512);
}
assert_eq!(
out_scalar, out_avx512,
"scan_avx512 diverged from scan_scalar at D=128, n=1000"
);
}
#[cfg(not(target_arch = "x86_64"))]
{
eprintln!("scan_avx512_matches_scalar: non-x86_64 host, skipping");
}
}
#[test]
fn scan_agree_matches_scalar_at_d64_and_d192() {
run_both(64, 777, 0x0123_4567_89AB_CDEF);
run_both(192, 513, 0xFEDC_BA98_7654_3210);
run_both(100, 641, 0xDEAD_BEEF_CAFE_F00D);
run_both(200, 333, 0x1357_9BDF_2468_ACE0);
run_both(128, 1023, 0x4242_4242_4242_4242);
run_both(128, 7, 0x9999_AAAA_BBBB_CCCC);
}
}