use unicode_normalization::UnicodeNormalization;
use unicode_segmentation::UnicodeSegmentation;
use crate::error::Result;
use crate::tokenize::Tokenizer;
fn is_cjk_char(c: char) -> bool {
matches!(c as u32,
0x4E00..=0x9FFF | 0x3400..=0x4DBF | 0x20000..=0x2A6DF | 0xF900..=0xFAFF | 0x2F800..=0x2FA1F | 0x3000..=0x303F | 0x3040..=0x309F | 0x30A0..=0x30FF | 0x31F0..=0x31FF | 0xAC00..=0xD7AF )
}
fn is_whitespace_segment(s: &str) -> bool {
s.chars().all(|c| c.is_whitespace())
}
fn is_pure_punctuation(s: &str) -> bool {
!s.is_empty()
&& s.chars()
.all(|c| !c.is_alphanumeric() && !c.is_whitespace())
}
#[derive(Debug, Clone)]
pub struct LanguageAgnosticTokenizer {
pub normalize: bool,
pub lowercase: bool,
pub split_cjk_by_char: bool,
pub preserve_punctuation: bool,
pub max_token_len: Option<usize>,
}
impl Default for LanguageAgnosticTokenizer {
fn default() -> Self {
Self {
normalize: true,
lowercase: false,
split_cjk_by_char: true,
preserve_punctuation: true,
max_token_len: None,
}
}
}
impl LanguageAgnosticTokenizer {
pub fn new() -> Self {
Self::default()
}
pub fn tokenize_str(&self, text: &str) -> Vec<String> {
let normalized: String = if self.normalize {
text.nfc().collect()
} else {
text.to_owned()
};
let processed: String = if self.lowercase {
normalized.to_lowercase()
} else {
normalized
};
let mut tokens: Vec<String> = Vec::new();
let mut cjk_run: String = String::new();
let flush_cjk_run = |run: &mut String, tokens: &mut Vec<String>, max: Option<usize>| {
if !run.is_empty() {
let s = std::mem::take(run);
match max {
Some(max_len) if s.chars().count() > max_len => {
let truncated: String = s.chars().take(max_len).collect();
tokens.push(truncated);
}
_ => tokens.push(s),
}
}
};
for segment in processed.split_word_bounds() {
if is_whitespace_segment(segment) {
flush_cjk_run(&mut cjk_run, &mut tokens, self.max_token_len);
continue;
}
if is_pure_punctuation(segment) && !self.preserve_punctuation {
flush_cjk_run(&mut cjk_run, &mut tokens, self.max_token_len);
continue;
}
let has_cjk = segment.chars().any(is_cjk_char);
if has_cjk {
if self.split_cjk_by_char {
flush_cjk_run(&mut cjk_run, &mut tokens, self.max_token_len);
for ch in segment.chars() {
let ch_str = ch.to_string();
if !ch_str.trim().is_empty() {
self.push_token(&mut tokens, ch_str);
}
}
} else {
cjk_run.push_str(segment);
}
} else {
flush_cjk_run(&mut cjk_run, &mut tokens, self.max_token_len);
self.push_token(&mut tokens, segment.to_owned());
}
}
flush_cjk_run(&mut cjk_run, &mut tokens, self.max_token_len);
tokens
}
fn push_token(&self, tokens: &mut Vec<String>, token: String) {
match self.max_token_len {
Some(max_len) if token.chars().count() > max_len => {
let truncated: String = token.chars().take(max_len).collect();
tokens.push(truncated);
}
_ => tokens.push(token),
}
}
}
impl Tokenizer for LanguageAgnosticTokenizer {
fn tokenize(&self, text: &str) -> Result<Vec<String>> {
Ok(self.tokenize_str(text))
}
fn clone_box(&self) -> Box<dyn Tokenizer + Send + Sync> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn unit_mixed_latin_cjk() {
let t = LanguageAgnosticTokenizer::new();
let tokens = t.tokenize_str("Hello 你好 World");
assert!(
tokens.iter().any(|s| s == "Hello"),
"missing 'Hello': {tokens:?}"
);
assert!(
tokens.iter().any(|s| s == "你" || s == "你好"),
"missing CJK token: {tokens:?}"
);
assert!(
tokens.iter().any(|s| s == "World"),
"missing 'World': {tokens:?}"
);
}
#[test]
fn unit_cjk_split_by_char() {
let t = LanguageAgnosticTokenizer {
split_cjk_by_char: true,
..Default::default()
};
let tokens = t.tokenize_str("日本語");
assert!(
tokens.len() >= 3,
"expected 3 individual CJK chars, got: {tokens:?}"
);
assert!(tokens.contains(&"日".to_string()), "tokens: {tokens:?}");
assert!(tokens.contains(&"本".to_string()), "tokens: {tokens:?}");
assert!(tokens.contains(&"語".to_string()), "tokens: {tokens:?}");
}
#[test]
fn unit_cjk_no_split() {
let t = LanguageAgnosticTokenizer {
split_cjk_by_char: false,
..Default::default()
};
let tokens = t.tokenize_str("日本語");
assert_eq!(
tokens.len(),
1,
"expected single CJK token, got: {tokens:?}"
);
assert_eq!(tokens[0], "日本語");
}
#[test]
fn unit_empty_string() {
let t = LanguageAgnosticTokenizer::new();
assert_eq!(t.tokenize_str(""), Vec::<String>::new());
}
#[test]
fn unit_whitespace_only() {
let t = LanguageAgnosticTokenizer::new();
assert_eq!(t.tokenize_str(" \t\n "), Vec::<String>::new());
}
#[test]
fn unit_lowercase() {
let t = LanguageAgnosticTokenizer {
lowercase: true,
..Default::default()
};
let tokens = t.tokenize_str("Hello World");
assert!(
tokens.iter().any(|s| s == "hello"),
"expected lowercase: {tokens:?}"
);
assert!(
tokens.iter().any(|s| s == "world"),
"expected lowercase: {tokens:?}"
);
}
#[test]
fn unit_preserve_punctuation_true() {
let t = LanguageAgnosticTokenizer {
preserve_punctuation: true,
..Default::default()
};
let tokens = t.tokenize_str("hello, world!");
let has_comma = tokens.iter().any(|s| s.contains(','));
let has_excl = tokens.iter().any(|s| s.contains('!'));
assert!(
has_comma || has_excl,
"expected punctuation preserved: {tokens:?}"
);
}
#[test]
fn unit_preserve_punctuation_false() {
let t = LanguageAgnosticTokenizer {
preserve_punctuation: false,
..Default::default()
};
let tokens = t.tokenize_str("hello, world!");
let has_punc = tokens.iter().any(|s| is_pure_punctuation(s));
assert!(!has_punc, "unexpected punctuation token: {tokens:?}");
assert!(
tokens.iter().any(|s| s == "hello"),
"missing 'hello': {tokens:?}"
);
}
#[test]
fn unit_max_token_len() {
let t = LanguageAgnosticTokenizer {
max_token_len: Some(3),
..Default::default()
};
let tokens = t.tokenize_str("superlongword");
assert!(
tokens.iter().all(|s| s.chars().count() <= 3),
"token exceeded max_len=3: {tokens:?}"
);
}
#[test]
fn unit_nfc_normalization_idempotent() {
let t = LanguageAgnosticTokenizer {
normalize: true,
..Default::default()
};
let nfd_cafe = "cafe\u{0301}"; let nfc_cafe = "caf\u{00E9}"; let t1 = t.tokenize_str(nfd_cafe);
let t2 = t.tokenize_str(nfc_cafe);
assert_eq!(t1, t2, "NFC normalization not idempotent: {t1:?} vs {t2:?}");
}
#[test]
fn unit_trait_tokenize_result() {
let t = LanguageAgnosticTokenizer::new();
let result = <LanguageAgnosticTokenizer as Tokenizer>::tokenize(&t, "hello world");
assert!(result.is_ok());
let tokens = result.unwrap_or_default();
assert!(tokens.iter().any(|s| s == "hello"), "tokens: {tokens:?}");
}
}