#[inline]
pub fn to_ascii_lowercase(s: &str) -> String {
let mut bytes = s.as_bytes().to_vec();
to_ascii_lowercase_inplace(&mut bytes);
unsafe { String::from_utf8_unchecked(bytes) }
}
#[inline]
pub fn to_ascii_lowercase_inplace(bytes: &mut [u8]) {
if bytes.is_empty() {
return;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512bw") && bytes.len() >= 64 {
unsafe { lowercase_avx512(bytes) };
return;
}
if is_x86_feature_detected!("avx2") && bytes.len() >= 32 {
unsafe { lowercase_avx2(bytes) };
return;
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") && bytes.len() >= 16 {
unsafe { lowercase_neon(bytes) };
return;
}
}
lowercase_scalar(bytes);
}
#[inline]
fn lowercase_scalar(bytes: &mut [u8]) {
for b in bytes.iter_mut() {
if *b >= b'A' && *b <= b'Z' {
*b += 32;
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn lowercase_avx2(bytes: &mut [u8]) {
use std::arch::x86_64::*;
let upper_a = _mm256_set1_epi8(b'A' as i8);
let upper_z = _mm256_set1_epi8(b'Z' as i8);
let case_diff = _mm256_set1_epi8(32);
let chunks = bytes.len() / 32;
let ptr = bytes.as_mut_ptr();
for i in 0..chunks {
let chunk_ptr = ptr.add(i * 32) as *mut __m256i;
let data = _mm256_loadu_si256(chunk_ptr as *const __m256i);
let ge_a = _mm256_cmpgt_epi8(data, _mm256_sub_epi8(upper_a, _mm256_set1_epi8(1)));
let le_z = _mm256_cmpgt_epi8(_mm256_add_epi8(upper_z, _mm256_set1_epi8(1)), data);
let is_upper = _mm256_and_si256(ge_a, le_z);
let to_add = _mm256_and_si256(is_upper, case_diff);
let result = _mm256_add_epi8(data, to_add);
_mm256_storeu_si256(chunk_ptr, result);
}
lowercase_scalar(&mut bytes[chunks * 32..]);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "avx512bw")]
unsafe fn lowercase_avx512(bytes: &mut [u8]) {
use std::arch::x86_64::*;
let upper_a = _mm512_set1_epi8(b'A' as i8);
let upper_z = _mm512_set1_epi8(b'Z' as i8);
let case_diff = _mm512_set1_epi8(32);
let chunks = bytes.len() / 64;
let ptr = bytes.as_mut_ptr();
for i in 0..chunks {
let chunk_ptr = ptr.add(i * 64) as *mut i32;
let data = _mm512_loadu_si512(chunk_ptr as *const i32);
let ge_a = _mm512_cmpge_epi8_mask(data, upper_a);
let le_z = _mm512_cmple_epi8_mask(data, upper_z);
let is_upper = ge_a & le_z;
let result = _mm512_mask_add_epi8(data, is_upper, data, case_diff);
_mm512_storeu_si512(chunk_ptr, result);
}
let remainder = &mut bytes[chunks * 64..];
if remainder.len() >= 32 && is_x86_feature_detected!("avx2") {
lowercase_avx2(remainder);
} else {
lowercase_scalar(remainder);
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn lowercase_neon(bytes: &mut [u8]) {
use std::arch::aarch64::*;
let upper_a = vdupq_n_u8(b'A');
let upper_z = vdupq_n_u8(b'Z');
let case_diff = vdupq_n_u8(32);
let chunks = bytes.len() / 16;
let ptr = bytes.as_mut_ptr();
for i in 0..chunks {
let chunk_ptr = ptr.add(i * 16);
let data = vld1q_u8(chunk_ptr);
let is_upper = vandq_u8(vcgeq_u8(data, upper_a), vcleq_u8(data, upper_z));
let to_add = vandq_u8(is_upper, case_diff);
let result = vaddq_u8(data, to_add);
vst1q_u8(chunk_ptr, result);
}
lowercase_scalar(&mut bytes[chunks * 16..]);
}
#[inline]
pub fn to_ascii_uppercase(s: &str) -> String {
let mut bytes = s.as_bytes().to_vec();
to_ascii_uppercase_inplace(&mut bytes);
unsafe { String::from_utf8_unchecked(bytes) }
}
#[inline]
pub fn to_ascii_uppercase_inplace(bytes: &mut [u8]) {
if bytes.is_empty() {
return;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512bw") && bytes.len() >= 64 {
unsafe { uppercase_avx512(bytes) };
return;
}
if is_x86_feature_detected!("avx2") && bytes.len() >= 32 {
unsafe { uppercase_avx2(bytes) };
return;
}
}
#[cfg(target_arch = "aarch64")]
{
if std::arch::is_aarch64_feature_detected!("neon") && bytes.len() >= 16 {
unsafe { uppercase_neon(bytes) };
return;
}
}
uppercase_scalar(bytes);
}
#[inline]
fn uppercase_scalar(bytes: &mut [u8]) {
for b in bytes.iter_mut() {
if *b >= b'a' && *b <= b'z' {
*b -= 32;
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn uppercase_avx2(bytes: &mut [u8]) {
use std::arch::x86_64::*;
let lower_a = _mm256_set1_epi8(b'a' as i8);
let lower_z = _mm256_set1_epi8(b'z' as i8);
let case_diff = _mm256_set1_epi8(32);
let chunks = bytes.len() / 32;
let ptr = bytes.as_mut_ptr();
for i in 0..chunks {
let chunk_ptr = ptr.add(i * 32) as *mut __m256i;
let data = _mm256_loadu_si256(chunk_ptr as *const __m256i);
let ge_a = _mm256_cmpgt_epi8(data, _mm256_sub_epi8(lower_a, _mm256_set1_epi8(1)));
let le_z = _mm256_cmpgt_epi8(_mm256_add_epi8(lower_z, _mm256_set1_epi8(1)), data);
let is_lower = _mm256_and_si256(ge_a, le_z);
let to_sub = _mm256_and_si256(is_lower, case_diff);
let result = _mm256_sub_epi8(data, to_sub);
_mm256_storeu_si256(chunk_ptr, result);
}
uppercase_scalar(&mut bytes[chunks * 32..]);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "avx512bw")]
unsafe fn uppercase_avx512(bytes: &mut [u8]) {
use std::arch::x86_64::*;
let lower_a = _mm512_set1_epi8(b'a' as i8);
let lower_z = _mm512_set1_epi8(b'z' as i8);
let case_diff = _mm512_set1_epi8(32);
let chunks = bytes.len() / 64;
let ptr = bytes.as_mut_ptr();
for i in 0..chunks {
let chunk_ptr = ptr.add(i * 64) as *mut i32;
let data = _mm512_loadu_si512(chunk_ptr as *const i32);
let ge_a = _mm512_cmpge_epi8_mask(data, lower_a);
let le_z = _mm512_cmple_epi8_mask(data, lower_z);
let is_lower = ge_a & le_z;
let result = _mm512_mask_sub_epi8(data, is_lower, data, case_diff);
_mm512_storeu_si512(chunk_ptr, result);
}
let remainder = &mut bytes[chunks * 64..];
if remainder.len() >= 32 && is_x86_feature_detected!("avx2") {
uppercase_avx2(remainder);
} else {
uppercase_scalar(remainder);
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn uppercase_neon(bytes: &mut [u8]) {
use std::arch::aarch64::*;
let lower_a = vdupq_n_u8(b'a');
let lower_z = vdupq_n_u8(b'z');
let case_diff = vdupq_n_u8(32);
let chunks = bytes.len() / 16;
let ptr = bytes.as_mut_ptr();
for i in 0..chunks {
let chunk_ptr = ptr.add(i * 16);
let data = vld1q_u8(chunk_ptr);
let is_lower = vandq_u8(vcgeq_u8(data, lower_a), vcleq_u8(data, lower_z));
let to_sub = vandq_u8(is_lower, case_diff);
let result = vsubq_u8(data, to_sub);
vst1q_u8(chunk_ptr, result);
}
uppercase_scalar(&mut bytes[chunks * 16..]);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_to_ascii_lowercase_empty() {
assert_eq!(to_ascii_lowercase(""), "");
}
#[test]
fn test_to_ascii_lowercase_basic() {
assert_eq!(to_ascii_lowercase("HELLO"), "hello");
assert_eq!(to_ascii_lowercase("Hello World"), "hello world");
assert_eq!(to_ascii_lowercase("already lowercase"), "already lowercase");
}
#[test]
fn test_to_ascii_lowercase_mixed() {
assert_eq!(to_ascii_lowercase("HeLLo WoRLD 123!"), "hello world 123!");
}
#[test]
fn test_to_ascii_lowercase_long() {
let input = "ABCDEFGHIJKLMNOPQRSTUVWXYZ".repeat(100);
let expected = "abcdefghijklmnopqrstuvwxyz".repeat(100);
assert_eq!(to_ascii_lowercase(&input), expected);
}
#[test]
fn test_to_ascii_lowercase_non_ascii() {
let input = "Héllo Wörld";
let result = to_ascii_lowercase(input);
assert!(result.contains("h"));
assert!(result.contains("w"));
}
#[test]
fn test_to_ascii_uppercase_empty() {
assert_eq!(to_ascii_uppercase(""), "");
}
#[test]
fn test_to_ascii_uppercase_basic() {
assert_eq!(to_ascii_uppercase("hello"), "HELLO");
assert_eq!(to_ascii_uppercase("Hello World"), "HELLO WORLD");
assert_eq!(to_ascii_uppercase("ALREADY UPPERCASE"), "ALREADY UPPERCASE");
}
#[test]
fn test_to_ascii_uppercase_long() {
let input = "abcdefghijklmnopqrstuvwxyz".repeat(100);
let expected = "ABCDEFGHIJKLMNOPQRSTUVWXYZ".repeat(100);
assert_eq!(to_ascii_uppercase(&input), expected);
}
#[test]
fn test_inplace_lowercase() {
let mut bytes = b"HELLO WORLD".to_vec();
to_ascii_lowercase_inplace(&mut bytes);
assert_eq!(&bytes, b"hello world");
}
#[test]
fn test_inplace_uppercase() {
let mut bytes = b"hello world".to_vec();
to_ascii_uppercase_inplace(&mut bytes);
assert_eq!(&bytes, b"HELLO WORLD");
}
#[test]
fn test_simd_scalar_equivalence() {
for len in [0, 1, 15, 16, 31, 32, 63, 64, 100, 1000] {
let input: String = (0..len).map(|i| if i % 2 == 0 { 'A' } else { 'z' }).collect();
let mut scalar = input.as_bytes().to_vec();
lowercase_scalar(&mut scalar);
let simd = to_ascii_lowercase(&input);
assert_eq!(
scalar,
simd.as_bytes(),
"Lowercase mismatch at length {}",
len
);
}
for len in [0, 1, 15, 16, 31, 32, 63, 64, 100, 1000] {
let input: String = (0..len).map(|i| if i % 2 == 0 { 'a' } else { 'Z' }).collect();
let mut scalar = input.as_bytes().to_vec();
uppercase_scalar(&mut scalar);
let simd = to_ascii_uppercase(&input);
assert_eq!(
scalar,
simd.as_bytes(),
"Uppercase mismatch at length {}",
len
);
}
}
#[test]
fn test_preserves_non_letters() {
let input = "ABC123!@#xyz";
assert_eq!(to_ascii_lowercase(input), "abc123!@#xyz");
assert_eq!(to_ascii_uppercase(input), "ABC123!@#XYZ");
}
}