use alloc::borrow::Cow;
use alloc::string::String;
use crate::tables;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CaseFoldMode {
Standard,
Turkish,
}
#[inline]
pub fn casefold_char(c: char, mode: CaseFoldMode) -> char {
if mode == CaseFoldMode::Turkish
&& let Some(folded) = tables::turkish_casefold(c)
{
return folded;
}
tables::lookup_casefold(c).unwrap_or(c)
}
pub fn casefold<'a>(input: &'a str, mode: CaseFoldMode) -> Cow<'a, str> {
if input.is_empty() {
return Cow::Borrowed(input);
}
if mode == CaseFoldMode::Standard {
casefold_ascii_fastpath(input)
} else {
casefold_scalar(input, mode)
}
}
fn casefold_scalar<'a>(input: &'a str, mode: CaseFoldMode) -> Cow<'a, str> {
let mut scan_iter = input.char_indices();
let first_change = loop {
match scan_iter.next() {
None => return Cow::Borrowed(input),
Some((idx, ch)) => {
let folded = casefold_char(ch, mode);
if folded != ch {
break idx;
}
},
}
};
let mut out = String::with_capacity(input.len());
out.push_str(&input[..first_change]);
for ch in input[first_change..].chars() {
out.push(casefold_char(ch, mode));
}
Cow::Owned(out)
}
fn casefold_ascii_fastpath<'a>(input: &'a str) -> Cow<'a, str> {
let bytes = input.as_bytes();
let len = bytes.len();
let ptr = bytes.as_ptr();
let mut pos = 0usize;
let mut first_change: Option<usize> = None;
while pos + 64 <= len {
let nonascii = unsafe { crate::simd::scan_chunk(ptr.add(pos), 0x80) };
if nonascii != 0 {
return casefold_scalar(input, CaseFoldMode::Standard);
}
let upper_or_more = unsafe { crate::simd::scan_chunk(ptr.add(pos), b'A') };
if upper_or_more != 0 {
let mut mask = upper_or_more;
while mask != 0 {
let bit = mask.trailing_zeros() as usize;
mask &= mask.wrapping_sub(1);
let b = bytes[pos + bit];
if b.is_ascii_uppercase() {
first_change = Some(pos + bit);
break;
}
}
if first_change.is_some() {
break;
}
}
pos += 64;
}
if first_change.is_none() {
let mut tail = pos;
while tail < len {
let b = bytes[tail];
if b >= 0x80 {
return casefold_scalar(input, CaseFoldMode::Standard);
}
if b.is_ascii_uppercase() {
first_change = Some(tail);
break;
}
tail += 1;
}
}
let Some(start) = first_change else {
return Cow::Borrowed(input);
};
let mut out = String::with_capacity(len);
out.push_str(unsafe { core::str::from_utf8_unchecked(&bytes[..start]) });
let mut i = start;
while i < len {
let b = bytes[i];
if b >= 0x80 {
let rest = unsafe { core::str::from_utf8_unchecked(&bytes[i..]) };
for ch in rest.chars() {
out.push(casefold_char(ch, CaseFoldMode::Standard));
}
return Cow::Owned(out);
}
if b.is_ascii_uppercase() {
out.push((b | 0x20) as char);
} else {
out.push(b as char);
}
i += 1;
}
Cow::Owned(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fold_ascii_uppercase() {
assert_eq!(casefold_char('A', CaseFoldMode::Standard), 'a');
assert_eq!(casefold_char('Z', CaseFoldMode::Standard), 'z');
}
#[test]
fn fold_ascii_lowercase_unchanged() {
assert_eq!(casefold_char('a', CaseFoldMode::Standard), 'a');
assert_eq!(casefold_char('z', CaseFoldMode::Standard), 'z');
}
#[test]
fn fold_digit_unchanged() {
assert_eq!(casefold_char('0', CaseFoldMode::Standard), '0');
assert_eq!(casefold_char('9', CaseFoldMode::Standard), '9');
}
#[test]
fn fold_latin_extended() {
assert_eq!(
casefold_char('\u{00C0}', CaseFoldMode::Standard),
'\u{00E0}'
);
assert_eq!(
casefold_char('\u{00D6}', CaseFoldMode::Standard),
'\u{00F6}'
);
}
#[test]
fn fold_greek() {
assert_eq!(
casefold_char('\u{0391}', CaseFoldMode::Standard),
'\u{03B1}'
);
assert_eq!(
casefold_char('\u{03A3}', CaseFoldMode::Standard),
'\u{03C3}'
);
}
#[test]
fn fold_cyrillic() {
assert_eq!(
casefold_char('\u{0410}', CaseFoldMode::Standard),
'\u{0430}'
);
}
#[test]
fn fold_micro_sign() {
assert_eq!(
casefold_char('\u{00B5}', CaseFoldMode::Standard),
'\u{03BC}'
);
}
#[test]
fn fold_sharp_s() {
assert_eq!(
casefold_char('\u{1E9E}', CaseFoldMode::Standard),
'\u{00DF}'
);
}
#[test]
fn fold_turkish_dotless_i() {
assert_eq!(casefold_char('I', CaseFoldMode::Standard), 'i');
assert_eq!(casefold_char('I', CaseFoldMode::Turkish), '\u{0131}');
}
#[test]
fn fold_turkish_dotted_capital_i() {
assert_eq!(casefold_char('\u{0130}', CaseFoldMode::Turkish), 'i');
}
#[test]
fn fold_turkish_other_chars_unchanged() {
assert_eq!(casefold_char('A', CaseFoldMode::Turkish), 'a');
assert_eq!(casefold_char('a', CaseFoldMode::Turkish), 'a');
}
#[test]
fn fold_string_ascii() {
let result = casefold("Hello World", CaseFoldMode::Standard);
assert_eq!(&*result, "hello world");
}
#[test]
fn fold_string_already_folded() {
let result = casefold("hello world", CaseFoldMode::Standard);
assert!(matches!(result, Cow::Borrowed(_)));
assert_eq!(&*result, "hello world");
}
#[test]
fn fold_string_empty() {
let result = casefold("", CaseFoldMode::Standard);
assert!(matches!(result, Cow::Borrowed(_)));
}
#[test]
fn fold_string_mixed() {
let result = casefold("Ströme", CaseFoldMode::Standard);
assert_eq!(&*result, "ströme");
}
#[test]
fn fold_string_turkish() {
let result = casefold("Istanbul", CaseFoldMode::Turkish);
assert_eq!(&*result, "\u{0131}stanbul");
}
#[test]
fn fold_string_all_ascii_lowercase() {
let result = casefold(
"abcdefghijklmnopqrstuvwxyz0123456789",
CaseFoldMode::Standard,
);
assert!(matches!(result, Cow::Borrowed(_)));
}
}