pub mod scalar;
mod common;
#[cfg(target_arch = "x86_64")]
pub mod avx2;
#[cfg(target_arch = "x86_64")]
pub mod sse42;
#[cfg(target_arch = "aarch64")]
pub mod neon;
use std::fmt;
pub type SearchResult = Option<usize>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdPlatform {
Avx2,
Sse42,
Neon,
Scalar,
}
impl fmt::Display for SimdPlatform {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SimdPlatform::Avx2 => write!(f, "AVX2"),
SimdPlatform::Sse42 => write!(f, "SSE4.2"),
SimdPlatform::Neon => write!(f, "NEON"),
SimdPlatform::Scalar => write!(f, "Scalar"),
}
}
}
#[must_use]
pub fn detect_platform() -> SimdPlatform {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
log::debug!("SIMD platform: AVX2");
return SimdPlatform::Avx2;
}
if is_x86_feature_detected!("sse4.2") {
log::debug!("SIMD platform: SSE4.2");
return SimdPlatform::Sse42;
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
log::debug!("SIMD platform: NEON");
return SimdPlatform::Neon;
}
}
log::debug!("SIMD platform: Scalar (no SIMD available)");
SimdPlatform::Scalar
}
#[must_use]
pub fn search(haystack: &[u8], needle: &[u8]) -> SearchResult {
if needle.is_empty() {
return Some(0);
}
if haystack.len() < needle.len() {
return None;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { avx2::search(haystack, needle) };
}
if is_x86_feature_detected!("sse4.2") {
return unsafe { sse42::search(haystack, needle) };
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { neon::search(haystack, needle) };
}
}
scalar::search(haystack, needle)
}
#[must_use]
pub fn extract_trigrams(text: &str) -> Vec<String> {
if text.len() < 3 {
return vec![text.to_string()];
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { avx2::extract_trigrams(text) };
}
if is_x86_feature_detected!("sse4.2") {
return unsafe { sse42::extract_trigrams(text) };
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { neon::extract_trigrams(text) };
}
}
scalar::extract_trigrams(text)
}
#[must_use]
pub fn to_lowercase_ascii(text: &str) -> String {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { avx2::to_lowercase_ascii(text) };
}
if is_x86_feature_detected!("sse4.2") {
return unsafe { sse42::to_lowercase_ascii(text) };
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") {
return unsafe { neon::to_lowercase_ascii(text) };
}
}
scalar::to_lowercase_ascii(text)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_platform() {
let platform = detect_platform();
assert!(matches!(
platform,
SimdPlatform::Avx2 | SimdPlatform::Sse42 | SimdPlatform::Neon | SimdPlatform::Scalar
));
}
#[test]
fn test_platform_display() {
assert_eq!(SimdPlatform::Avx2.to_string(), "AVX2");
assert_eq!(SimdPlatform::Sse42.to_string(), "SSE4.2");
assert_eq!(SimdPlatform::Neon.to_string(), "NEON");
assert_eq!(SimdPlatform::Scalar.to_string(), "Scalar");
}
#[test]
fn test_search_empty_needle() {
let haystack = b"hello";
let needle = b"";
assert_eq!(search(haystack, needle), Some(0));
}
#[test]
fn test_search_needle_too_long() {
let haystack = b"hi";
let needle = b"hello";
assert_eq!(search(haystack, needle), None);
}
#[test]
fn test_extract_trigrams_short_string() {
assert_eq!(extract_trigrams("ab"), vec!["ab"]);
assert_eq!(extract_trigrams(""), vec![""]);
}
#[test]
fn test_to_lowercase_ascii_empty() {
assert_eq!(to_lowercase_ascii(""), "");
}
#[test]
fn test_extract_trigrams_ascii_matches_scalar() {
let inputs = [
"hello",
"abc",
"abcdefghijklmnopqrstuvwxyz0123456789",
"createCompilerHost",
"aaaa",
"HELLO_WORLD",
];
for input in &inputs {
let mut dispatched = extract_trigrams(input);
let mut scalar_result = scalar::extract_trigrams(input);
dispatched.sort();
scalar_result.sort();
assert_eq!(
dispatched, scalar_result,
"SIMD ≡ scalar mismatch for ASCII input: {input}"
);
}
}
#[test]
fn test_extract_trigrams_non_ascii_matches_scalar() {
let inputs = ["héllo", "日本語", "café", "naïve", "über"];
for input in &inputs {
let mut dispatched = extract_trigrams(input);
let mut scalar_result = scalar::extract_trigrams(input);
dispatched.sort();
scalar_result.sort();
assert_eq!(
dispatched, scalar_result,
"SIMD ≡ scalar mismatch for non-ASCII input: {input}"
);
}
}
}