use std::simd::{cmp::SimdPartialOrd, num::SimdUint, Simd};
use multiversion::multiversion;
pub fn string_to_bitmask(s: &[u8]) -> u64 {
let mut mask: u64 = 0;
for c in s {
let c = c.to_ascii_uppercase();
if (33..=90).contains(&c) {
mask |= 1 << (c - 33);
}
}
mask
}
const LANES: usize = 8;
#[multiversion(targets(
// x86-64-v4 without lahfsahf
"x86_64+avx512f+avx512bw+avx512cd+avx512dq+avx512vl+avx+avx2+bmi1+bmi2+cmpxchg16b+f16c+fma+fxsr+lzcnt+movbe+popcnt+sse+sse2+sse3+sse4.1+sse4.2+ssse3+xsave",
// x86-64-v3 without lahfsahf
"x86_64+avx+avx2+bmi1+bmi2+cmpxchg16b+f16c+fma+fxsr+lzcnt+movbe+popcnt+sse+sse2+sse3+sse4.1+sse4.2+ssse3+xsave",
// x86-64-v2 without lahfsahf
"x86_64+cmpxchg16b+fxsr+popcnt+sse+sse2+sse3+sse4.1+sse4.2+ssse3",
))]
pub fn string_to_bitmask_simd(s: &[u8]) -> u64 {
let mut mask: u64 = 0;
let zero = Simd::splat(0);
let zero_wide = Simd::splat(0);
let one = Simd::splat(1);
let to_upperacse = Simd::splat(0x20);
let mut i = 0;
while i + LANES <= s.len() {
let simd_chunk = Simd::<u8, LANES>::from_slice(&s[i..i + LANES]);
let is_lower =
simd_chunk.simd_ge(Simd::splat(b'a')) & simd_chunk.simd_le(Simd::splat(b'z'));
let simd_upper = simd_chunk - is_lower.select(to_upperacse, zero);
let in_range =
simd_upper.simd_ge(Simd::splat(33u8)) & simd_upper.simd_le(Simd::splat(90u8));
let indices = in_range.cast::<i64>().select(
one << (simd_upper - Simd::splat(33u8)).cast::<u64>(),
zero_wide,
);
mask |= indices.reduce_or();
i += LANES;
}
for &c in s[i..s.len()].iter() {
let c = c.to_ascii_uppercase();
if (33..=90).contains(&c) {
mask |= 1 << (c - 33);
}
}
mask
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_case_insensitive() {
assert_eq!(
string_to_bitmask("ABC".as_bytes()),
string_to_bitmask("abc".as_bytes())
);
}
#[test]
fn test_letters() {
assert_eq!(
string_to_bitmask("abC".as_bytes()),
0b0000000000000000000000000000011100000000000000000000000000000000
);
}
#[test]
fn test_numbers() {
assert_eq!(
string_to_bitmask("123".as_bytes()),
0b00000000000000000000000000000000000000000001110000000000000000
);
}
#[test]
fn test_symbols() {
assert_eq!(
string_to_bitmask("!\"#$%&'()*+,-./".as_bytes()),
0b00000000000000000000000000000000000000000000000111111111111111
);
}
#[test]
fn test_simd() {
assert_eq!(
string_to_bitmask_simd("abc".as_bytes()),
string_to_bitmask("abc".as_bytes())
);
}
}