use std::borrow::Cow;
use unicode_general_category::{get_general_category, GeneralCategory};
static NEEDS_CLEANING: [bool; 256] = {
let mut table = [false; 256];
let mut i = 0u16;
while i < 256 {
let b = i as u8;
table[i as usize] = b >= 0x80
|| (b < 0x20 && b != 0x09 && b != 0x0A && b != 0x0D)
|| b == 0x7F;
i += 1;
}
table
};
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum Normalizer {
#[default]
None,
BertUncased,
BertCased,
Nfc,
Metaspace,
SentencePiece,
SentencePieceLowercase,
MetaspaceReplace,
}
impl Normalizer {
#[inline]
pub fn normalize<'a>(&self, text: &'a str) -> Cow<'a, str> {
match self {
Normalizer::None => Cow::Borrowed(text),
Normalizer::BertCased => clean_text(text),
Normalizer::BertUncased => bert_uncased_normalize(text),
Normalizer::Nfc => normalize_nfc(text),
Normalizer::Metaspace => metaspace_normalize(text),
Normalizer::SentencePiece => sentencepiece_normalize(text),
Normalizer::SentencePieceLowercase => sentencepiece_lowercase_normalize(text),
Normalizer::MetaspaceReplace => metaspace_replace_normalize(text),
}
}
#[inline]
pub fn is_identity(&self) -> bool {
matches!(self, Normalizer::None)
}
}
#[inline]
fn normalize_nfc<'a>(text: &'a str) -> Cow<'a, str> {
let nfc = icu_normalizer::ComposingNormalizer::new_nfc();
if nfc.is_normalized(text) {
return Cow::Borrowed(text);
}
Cow::Owned(nfc.normalize(text))
}
#[inline]
pub fn metaspace_normalize(text: &str) -> Cow<'_, str> {
let starts_with_space = text.starts_with(' ') || text.starts_with('\t');
let replaced = fnr(text, " ", "▁");
if starts_with_space {
match replaced {
Cow::Borrowed(_) => {
replaced
}
Cow::Owned(s) => Cow::Owned(s),
}
} else {
let mut result = String::with_capacity(replaced.len() + 3);
result.push('▁');
result.push_str(&replaced);
Cow::Owned(result)
}
}
#[inline]
pub fn metaspace_replace_normalize(text: &str) -> Cow<'_, str> {
fnr(text, " ", "▁")
}
#[inline]
pub fn sentencepiece_normalize(text: &str) -> Cow<'_, str> {
if text.is_empty() {
return Cow::Borrowed(text);
}
let nfkc = icu_normalizer::ComposingNormalizer::new_nfkc();
let normalized = if nfkc.is_normalized(text) {
Cow::Borrowed(text)
} else {
Cow::Owned(nfkc.normalize(text))
};
let collapsed = collapse_strip_whitespace_and_controls(&normalized);
let mut result = String::with_capacity(collapsed.len() + 3);
result.push('▁');
for c in collapsed.chars() {
if c == ' ' {
result.push('▁');
} else {
result.push(c);
}
}
Cow::Owned(result)
}
#[inline]
pub fn sentencepiece_lowercase_normalize(text: &str) -> Cow<'_, str> {
if text.is_empty() {
return Cow::Borrowed(text);
}
let nfkd = icu_normalizer::DecomposingNormalizer::new_nfkd();
let normalized = nfkd.normalize(text);
let stripped: String = normalized
.chars()
.filter(|&c| !is_combining_mark(c) && !is_control(c) && c != '\0' && c != '\u{FFFD}')
.collect();
let collapsed = collapse_and_strip_whitespace(&stripped);
let mut result = String::with_capacity(collapsed.len() + 3);
result.push('▁');
for c in collapsed.chars() {
if c == ' ' {
result.push('▁');
} else {
for lc in c.to_lowercase() {
result.push(lc);
}
}
}
Cow::Owned(result)
}
fn collapse_and_strip_whitespace(text: &str) -> String {
let mut result = String::with_capacity(text.len());
let mut prev_was_space = true;
for c in text.chars() {
if c.is_whitespace() {
if !prev_was_space {
result.push(' ');
prev_was_space = true;
}
} else {
result.push(c);
prev_was_space = false;
}
}
if result.ends_with(' ') {
result.pop();
}
result
}
fn collapse_strip_whitespace_and_controls(text: &str) -> String {
let mut result = String::with_capacity(text.len());
let mut prev_was_space = true;
for c in text.chars() {
if c == '\0' {
continue;
}
let cat = get_general_category(c);
match cat {
GeneralCategory::Control => {
if c == '\t' || c == '\n' || c == '\r' || c == '\x0C' {
if !prev_was_space {
result.push(' ');
prev_was_space = true;
}
}
}
GeneralCategory::Format => {
if !prev_was_space {
result.push(' ');
prev_was_space = true;
}
}
_ => {
if c == '\u{FFFD}' {
if !prev_was_space {
result.push(' ');
prev_was_space = true;
}
} else if c.is_whitespace() {
if !prev_was_space {
result.push(' ');
prev_was_space = true;
}
} else {
result.push(c);
prev_was_space = false;
}
}
}
}
if result.ends_with(' ') {
result.pop();
}
result
}
#[inline]
fn is_control(c: char) -> bool {
match c {
'\t' | '\n' | '\r' => false,
_ => matches!(
get_general_category(c),
GeneralCategory::Control
| GeneralCategory::Format
| GeneralCategory::Unassigned
| GeneralCategory::PrivateUse
),
}
}
#[inline]
fn is_bert_whitespace(c: char) -> bool {
matches!(c, ' ' | '\t' | '\n' | '\r')
|| get_general_category(c) == GeneralCategory::SpaceSeparator
}
pub fn clean_text<'a>(text: &'a str) -> Cow<'a, str> {
let bytes = text.as_bytes();
let first_problem = bytes.iter().position(|&b| NEEDS_CLEANING[b as usize]);
let first_pos = match first_problem {
None => return Cow::Borrowed(text), Some(pos) => pos,
};
if bytes[first_pos] >= 0x80 {
return clean_text_unicode(text, first_pos);
}
if bytes[first_pos..].iter().any(|&b| b >= 0x80) {
return clean_text_unicode(text, first_pos);
}
let mut result = Vec::with_capacity(text.len());
result.extend_from_slice(&bytes[..first_pos]);
for &b in &bytes[first_pos..] {
if b < 0x20 {
if b == b'\t' || b == b'\n' || b == b'\r' {
result.push(b' '); }
} else if b == 0x7F {
} else {
result.push(b);
}
}
Cow::Owned(unsafe { String::from_utf8_unchecked(result) })
}
fn clean_text_unicode<'a>(text: &'a str, first_problem: usize) -> Cow<'a, str> {
let prefix_bytes = &text.as_bytes()[..first_problem];
let prefix_needs_cleaning = prefix_bytes.iter().any(|&b| NEEDS_CLEANING[b as usize]);
let suffix = &text[first_problem..];
let suffix_needs_cleaning = suffix.chars().any(|c| {
c == '\0' || c == '\u{FFFD}' || is_control(c) || (is_bert_whitespace(c) && c != ' ')
});
if !prefix_needs_cleaning && !suffix_needs_cleaning {
return Cow::Borrowed(text);
}
let mut result = String::with_capacity(text.len());
for &b in prefix_bytes {
if b < 0x20 {
if b == b'\t' || b == b'\n' || b == b'\r' {
result.push(' ');
}
} else if b != 0x7F {
result.push(b as char);
}
}
for c in suffix.chars() {
if c == '\0' || c == '\u{FFFD}' || is_control(c) {
continue;
}
if is_bert_whitespace(c) {
result.push(' ');
} else {
result.push(c);
}
}
Cow::Owned(result)
}
pub fn strip_accents<'a>(text: &'a str) -> Cow<'a, str> {
if text.bytes().all(|b| b < 0x80) {
return Cow::Borrowed(text);
}
let nfd = icu_normalizer::DecomposingNormalizer::new_nfd();
let normalized = nfd.normalize(text);
let result: String = normalized
.chars()
.filter(|&c| !is_combining_mark(c))
.collect();
Cow::Owned(result)
}
#[inline]
fn is_combining_mark(c: char) -> bool {
let cp = c as u32;
if cp < 0x0300 {
return false;
}
if cp <= 0x036F {
return true; }
get_general_category(c) == GeneralCategory::NonspacingMark
}
static NEEDS_BERT_NORMALIZE: [bool; 256] = {
let mut table = [false; 256];
let mut i = 0u16;
while i < 256 {
let b = i as u8;
table[i as usize] = b >= 0x80
|| (b >= b'A' && b <= b'Z')
|| (b < 0x20 && b != 0x09 && b != 0x0A && b != 0x0D)
|| b == 0;
i += 1;
}
table
};
pub fn bert_uncased_normalize(text: &str) -> Cow<'_, str> {
let bytes = text.as_bytes();
let first_problem = bytes.iter().position(|&b| NEEDS_BERT_NORMALIZE[b as usize]);
let first_pos = match first_problem {
None => return Cow::Borrowed(text),
Some(pos) => pos,
};
let has_non_ascii = bytes[first_pos] >= 0x80
|| bytes[first_pos..].iter().any(|&b| b >= 0x80);
let mut result = String::with_capacity(text.len());
if has_non_ascii {
let nfd = icu_normalizer::DecomposingNormalizer::new_nfd();
let normalized = nfd.normalize(text);
for c in normalized.chars() {
if is_combining_mark(c) {
continue;
}
if is_control(c) || c == '\0' || c == '\u{FFFD}' {
continue;
}
if is_bert_whitespace(c) {
result.push(' ');
} else if c.is_ascii() {
if c >= 'A' && c <= 'Z' {
result.push((c as u8 | 0x20) as char);
} else {
result.push(c);
}
} else {
for lc in c.to_lowercase() {
result.push(lc);
}
}
}
} else {
result.push_str(&text[..first_pos]);
for &b in &bytes[first_pos..] {
if b < 0x20 && b != b'\t' && b != b'\n' && b != b'\r' {
continue;
}
if b == 0 {
continue;
}
if b == b'\t' || b == b'\n' || b == b'\r' {
result.push(' ');
} else if b >= b'A' && b <= b'Z' {
result.push((b + 32) as char);
} else {
result.push(b as char);
}
}
}
Cow::Owned(result)
}
#[inline]
pub fn fnr<'a>(text: &'a str, needle: &str, replacement: &str) -> Cow<'a, str> {
use memchr::memmem;
if needle.is_empty() || text.is_empty() {
return Cow::Borrowed(text);
}
let finder = memmem::Finder::new(needle.as_bytes());
let text_bytes = text.as_bytes();
if finder.find(text_bytes).is_none() {
return Cow::Borrowed(text);
}
let size_diff = replacement.len() as isize - needle.len() as isize;
let estimated_cap = if size_diff > 0 {
text.len() + (size_diff as usize * 2)
} else {
text.len()
};
let mut result = String::with_capacity(estimated_cap);
let mut last_end = 0;
for pos in finder.find_iter(text_bytes) {
result.push_str(&text[last_end..pos]);
result.push_str(replacement);
last_end = pos + needle.len();
}
result.push_str(&text[last_end..]);
Cow::Owned(result)
}
pub struct FnrFinder<'n> {
needle: &'n str,
finder: memchr::memmem::Finder<'n>,
}
impl<'n> FnrFinder<'n> {
#[inline]
pub fn new(needle: &'n str) -> Self {
Self {
needle,
finder: memchr::memmem::Finder::new(needle.as_bytes()),
}
}
#[inline]
pub fn replace<'a>(&self, text: &'a str, replacement: &str) -> Cow<'a, str> {
if self.needle.is_empty() || text.is_empty() {
return Cow::Borrowed(text);
}
let text_bytes = text.as_bytes();
if self.finder.find(text_bytes).is_none() {
return Cow::Borrowed(text);
}
let size_diff = replacement.len() as isize - self.needle.len() as isize;
let estimated_cap = if size_diff > 0 {
text.len() + (size_diff as usize * 2)
} else {
text.len()
};
let mut result = String::with_capacity(estimated_cap);
let mut last_end = 0;
for pos in self.finder.find_iter(text_bytes) {
result.push_str(&text[last_end..pos]);
result.push_str(replacement);
last_end = pos + self.needle.len();
}
result.push_str(&text[last_end..]);
Cow::Owned(result)
}
#[inline]
pub fn contains(&self, text: &str) -> bool {
self.finder.find(text.as_bytes()).is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fnr_basic() {
let result = fnr("hello world", "world", "rust");
assert_eq!(result, "hello rust");
assert!(matches!(result, Cow::Owned(_)));
}
#[test]
fn test_fnr_multiple() {
let result = fnr("foo bar foo baz foo", "foo", "x");
assert_eq!(result, "x bar x baz x");
}
#[test]
fn test_fnr_no_match() {
let result = fnr("hello world", "xyz", "abc");
assert!(matches!(result, Cow::Borrowed(_)));
assert_eq!(result, "hello world");
}
#[test]
fn test_fnr_empty_needle() {
let result = fnr("hello", "", "x");
assert!(matches!(result, Cow::Borrowed(_)));
}
#[test]
fn test_fnr_empty_text() {
let result = fnr("", "foo", "bar");
assert!(matches!(result, Cow::Borrowed(_)));
}
#[test]
fn test_fnr_empty_replacement() {
let result = fnr("hello world", " ", "");
assert_eq!(result, "helloworld");
}
#[test]
fn test_fnr_longer_replacement() {
let result = fnr("a b c", " ", "---");
assert_eq!(result, "a---b---c");
}
#[test]
fn test_fnr_at_boundaries() {
let result = fnr("foo bar", "foo", "baz");
assert_eq!(result, "baz bar");
let result = fnr("bar foo", "foo", "baz");
assert_eq!(result, "bar baz");
let result = fnr("foo", "foo", "bar");
assert_eq!(result, "bar");
}
#[test]
fn test_fnr_unicode() {
let result = fnr("héllo wörld", "ö", "o");
assert_eq!(result, "héllo world");
}
#[test]
fn test_fnr_finder_reuse() {
let finder = FnrFinder::new("foo");
let r1 = finder.replace("foo bar foo", "baz");
assert_eq!(r1, "baz bar baz");
let r2 = finder.replace("no match", "baz");
assert!(matches!(r2, Cow::Borrowed(_)));
assert_eq!(r2, "no match");
assert!(finder.contains("has foo"));
assert!(!finder.contains("no match"));
}
#[test]
fn test_none_normalizer() {
let norm = Normalizer::None;
let text = "Hello World!";
let result = norm.normalize(text);
assert!(matches!(result, Cow::Borrowed(_)));
assert_eq!(result, "Hello World!");
}
#[test]
fn test_bert_uncased_lowercase() {
let norm = Normalizer::BertUncased;
let result = norm.normalize("Hello World!");
assert!(matches!(result, Cow::Owned(_)));
assert_eq!(result, "hello world!");
let result = norm.normalize("hello world!");
assert!(matches!(result, Cow::Borrowed(_)));
assert_eq!(result, "hello world!");
}
#[test]
fn test_bert_uncased_unicode() {
let norm = Normalizer::BertUncased;
let result = norm.normalize("HÉLLO");
assert_eq!(result, "hello");
let result = norm.normalize("straße");
assert_eq!(result, "straße");
let result = norm.normalize("café résumé naïve");
assert_eq!(result, "cafe resume naive");
}
#[test]
fn test_nfc_normalizer() {
let norm = Normalizer::Nfc;
let result = norm.normalize("hello");
assert!(matches!(result, Cow::Borrowed(_)));
let nfd = "e\u{0301}"; let result = norm.normalize(nfd);
assert!(matches!(result, Cow::Owned(_)));
assert_eq!(result, "é");
}
#[test]
fn test_bert_cased() {
let norm = Normalizer::BertCased;
let text = "Hello World!";
let result = norm.normalize(text);
assert!(matches!(result, Cow::Borrowed(_)));
assert_eq!(result, "Hello World!"); }
}