use crate::tokenizer::{NormalizedString, Normalizer, Result};
use serde::{Deserialize, Serialize};
use unicode_categories::UnicodeCategories;
fn is_whitespace(c: char) -> bool {
match c {
'\t' | '\n' | '\r' => true,
_ => c.is_whitespace(),
}
}
fn is_control(c: char) -> bool {
match c {
'\t' | '\n' | '\r' => false,
_ => c.is_other(),
}
}
fn is_chinese_char(c: char) -> bool {
matches!(
c as usize,
0x4E00..=0x9FFF |
0x3400..=0x4DBF |
0x20000..=0x2A6DF |
0x2A700..=0x2B73F |
0x2B740..=0x2B81F |
0x2B920..=0x2CEAF |
0xF900..=0xFAFF |
0x2F800..=0x2FA1F
)
}
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
#[serde(tag = "type")]
#[non_exhaustive]
pub struct BertNormalizer {
pub clean_text: bool,
pub handle_chinese_chars: bool,
pub strip_accents: Option<bool>,
pub lowercase: bool,
}
impl Default for BertNormalizer {
fn default() -> Self {
Self {
clean_text: true,
handle_chinese_chars: true,
strip_accents: None,
lowercase: true,
}
}
}
impl BertNormalizer {
pub fn new(
clean_text: bool,
handle_chinese_chars: bool,
strip_accents: Option<bool>,
lowercase: bool,
) -> Self {
Self {
clean_text,
handle_chinese_chars,
strip_accents,
lowercase,
}
}
fn do_clean_text(&self, normalized: &mut NormalizedString) {
normalized
.filter(|c| !(c as usize == 0 || c as usize == 0xfffd || is_control(c)))
.map(|c| if is_whitespace(c) { ' ' } else { c });
}
fn do_handle_chinese_chars(&self, normalized: &mut NormalizedString) {
let mut new_chars: Vec<(char, isize)> = vec![];
normalized.for_each(|c| {
if is_chinese_char(c) {
new_chars.extend([(' ', 0), (c, 1), (' ', 1)]);
} else {
new_chars.push((c, 0));
}
});
normalized.transform(new_chars.into_iter(), 0);
}
fn do_strip_accents(&self, normalized: &mut NormalizedString) {
normalized.nfd().filter(|c| !c.is_mark_nonspacing());
}
fn do_lowercase(&self, normalized: &mut NormalizedString) {
normalized.lowercase();
}
}
impl Normalizer for BertNormalizer {
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
if self.clean_text {
self.do_clean_text(normalized);
}
if self.handle_chinese_chars {
self.do_handle_chinese_chars(normalized);
}
let strip_accents = self.strip_accents.unwrap_or(self.lowercase);
if strip_accents {
self.do_strip_accents(normalized);
}
if self.lowercase {
self.do_lowercase(normalized);
}
Ok(())
}
}