use std::collections::HashSet;
use std::sync::OnceLock;
static HAS_AVX2: OnceLock<bool> = OnceLock::new();
#[inline]
fn has_avx2() -> bool {
*HAS_AVX2.get_or_init(|| {
#[cfg(target_arch = "x86_64")]
{
is_x86_feature_detected!("avx2")
}
#[cfg(not(target_arch = "x86_64"))]
{
false
}
})
}
pub fn filter_allowed_scalar(ids: &[u64], allowed: &[u64]) -> Vec<u64> {
let allowed_set: HashSet<u64> = allowed.iter().copied().collect();
ids.iter()
.filter(|id| allowed_set.contains(id))
.copied()
.collect()
}
pub fn filter_denied_scalar(ids: &[u64], denied: &[u64]) -> Vec<u64> {
let denied_set: HashSet<u64> = denied.iter().copied().collect();
ids.iter()
.filter(|id| !denied_set.contains(id))
.copied()
.collect()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn filter_batch_avx2(ids: &[u64], filter_set: &[u64], include: bool) -> Vec<u64> {
unsafe {
use std::arch::x86_64::*;
let filter_set_hash: HashSet<u64> = filter_set.iter().copied().collect();
let mut result = Vec::with_capacity(ids.len());
let chunks = ids.chunks_exact(4);
let remainder = chunks.remainder();
for chunk in chunks {
let _id_vec = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
let id_array = [chunk[0], chunk[1], chunk[2], chunk[3]];
for &id in &id_array {
let in_set = filter_set_hash.contains(&id);
if (include && in_set) || (!include && !in_set) {
result.push(id);
}
}
}
for &id in remainder {
let in_set = filter_set_hash.contains(&id);
if (include && in_set) || (!include && !in_set) {
result.push(id);
}
}
result
}
}
pub fn filter_batch(ids: &[u64], filter_set: &[u64], include: bool) -> Vec<u64> {
#[cfg(target_arch = "x86_64")]
{
if has_avx2() && ids.len() >= 32 {
unsafe { filter_batch_avx2(ids, filter_set, include) }
} else {
if include {
filter_allowed_scalar(ids, filter_set)
} else {
filter_denied_scalar(ids, filter_set)
}
}
}
#[cfg(not(target_arch = "x86_64"))]
{
if include {
filter_allowed_scalar(ids, filter_set)
} else {
filter_denied_scalar(ids, filter_set)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_filter_allowed_basic() {
let ids = vec![1, 2, 3, 4, 5];
let allowed = vec![2, 3, 4];
let filtered = filter_allowed_scalar(&ids, &allowed);
assert_eq!(filtered, vec![2, 3, 4]);
}
#[test]
fn test_filter_denied_basic() {
let ids = vec![1, 2, 3, 4, 5];
let denied = vec![2, 4];
let filtered = filter_denied_scalar(&ids, &denied);
assert_eq!(filtered, vec![1, 3, 5]);
}
#[test]
fn test_filter_empty_ids() {
let ids: Vec<u64> = vec![];
let allowed = vec![1, 2, 3];
let filtered = filter_allowed_scalar(&ids, &allowed);
assert!(filtered.is_empty());
let filtered = filter_denied_scalar(&ids, &allowed);
assert!(filtered.is_empty());
}
#[test]
fn test_filter_empty_filter_set() {
let ids = vec![1, 2, 3, 4, 5];
let allowed: Vec<u64> = vec![];
let filtered = filter_allowed_scalar(&ids, &allowed);
assert!(filtered.is_empty());
let filtered = filter_denied_scalar(&ids, &allowed);
assert_eq!(filtered, vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_filter_large_batch() {
let ids: Vec<u64> = (1..=1000).collect();
let allowed: Vec<u64> = (1..=500).filter(|x| x % 2 == 0).collect();
let filtered = filter_allowed_scalar(&ids, &allowed);
let allowed_set: HashSet<u64> = allowed.iter().copied().collect();
for &id in &filtered {
assert!(
allowed_set.contains(&id),
"ID {} should be in allowed set",
id
);
}
assert_eq!(filtered.len(), 250);
}
#[test]
fn test_filter_batch_include() {
let ids = vec![1, 2, 3, 4, 5];
let allowed = vec![2, 3, 4];
let filtered = filter_batch(&ids, &allowed, true);
assert_eq!(filtered, vec![2, 3, 4]);
}
#[test]
fn test_filter_batch_exclude() {
let ids = vec![1, 2, 3, 4, 5];
let denied = vec![2, 4];
let filtered = filter_batch(&ids, &denied, false);
assert_eq!(filtered, vec![1, 3, 5]);
}
#[test]
fn test_filter_batch_small_set() {
let ids = vec![1, 2, 3];
let allowed = vec![2];
let filtered = filter_batch(&ids, &allowed, true);
assert_eq!(filtered, vec![2]);
}
#[test]
fn test_filter_all_allowed() {
let ids = vec![1, 2, 3, 4, 5];
let allowed = vec![1, 2, 3, 4, 5];
let filtered = filter_allowed_scalar(&ids, &allowed);
assert_eq!(filtered, vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_filter_all_denied() {
let ids = vec![1, 2, 3, 4, 5];
let denied = vec![1, 2, 3, 4, 5];
let filtered = filter_denied_scalar(&ids, &denied);
assert!(filtered.is_empty());
}
#[test]
fn test_filter_no_match() {
let ids = vec![1, 2, 3];
let allowed = vec![4, 5, 6];
let filtered = filter_allowed_scalar(&ids, &allowed);
assert!(filtered.is_empty());
}
#[test]
fn test_avx2_availability() {
let _has_it = has_avx2();
let ids = vec![1, 2, 3, 4, 5];
let allowed = vec![2, 3, 4];
let filtered = filter_batch(&ids, &allowed, true);
assert_eq!(filtered, vec![2, 3, 4]);
}
}