use std::sync::OnceLock;
use fancy_regex::Regex;
use crate::stress::{punct_tag_phoneme, PUNCTS, PUNCT_TAGS, SUBTOKEN_JUNKS};
use crate::tagger;
use crate::token::MToken;
const TRAILING_PUNCT: &str = ".!?…,;:)—\"\u{201D}\u{201E}\u{2019}";
const LEADING_PUNCT: &str = "($£€¥\"\u{201C}\u{2018}";
fn is_all_puncts(s: &str) -> bool {
!s.is_empty() && s.chars().all(|c| PUNCTS.contains(c))
}
fn is_number_like(s: &str) -> bool {
if s.is_empty() {
return false;
}
if !s
.chars()
.all(|c| c.is_ascii_digit() || c == ',' || c == '.')
{
return false;
}
if !s.chars().any(|c| c.is_ascii_digit()) {
return false;
}
!s.ends_with('.') && !s.ends_with(',')
}
fn is_currency(s: &str) -> bool {
!s.is_empty() && s.chars().all(|c| matches!(c, '$' | '£' | '€'))
}
fn simple_tag(text: &str) -> &'static str {
if is_currency(text) {
"$"
} else if text == "(" {
"-LRB-"
} else if text == ")" {
"-RRB-"
} else if is_all_puncts(text) {
let first = text.chars().next().unwrap();
if matches!(first, '.' | '!' | '?' | '…') {
"."
} else if first == ',' {
","
} else if matches!(first, ':' | ';' | '—' | '-') {
":"
} else {
"."
}
} else if is_number_like(text) {
"CD"
} else {
"DEFAULT"
}
}
fn leading_punct_tag(text: &str) -> &'static str {
match text {
"(" => "-LRB-",
"\"" | "\u{201C}" | "\u{2018}" => "``",
_ => simple_tag(text),
}
}
fn trailing_punct_tag(text: &str) -> &'static str {
match text {
")" => "-RRB-",
"\"" | "\u{201D}" | "\u{201E}" | "\u{2019}" => "''",
_ => simple_tag(text),
}
}
fn split_punct(word: &str, whitespace: &str) -> Vec<MToken> {
if word.is_empty() || is_number_like(word) || is_currency(word) {
let tag = simple_tag(word);
let mut tok = MToken::new(word, tag, whitespace);
tok.underscore.is_head = true;
return vec![tok];
}
if is_all_puncts(word) {
let chars: Vec<char> = word.chars().collect();
if chars.len() == 1 {
let tag = simple_tag(word);
let mut tok = MToken::new(word, tag, whitespace);
tok.underscore.is_head = true;
return vec![tok];
}
let mut result = Vec::new();
for (i, ch) in chars.iter().enumerate() {
let s: String = std::iter::once(*ch).collect();
let is_last = i == chars.len() - 1;
let tok_ws = if is_last { whitespace } else { "" };
let tag = simple_tag(&s);
let mut tok = MToken::new(&s, tag, tok_ws);
tok.underscore.is_head = true;
result.push(tok);
}
return result;
}
let chars: Vec<char> = word.chars().collect();
let len = chars.len();
let leading = chars
.iter()
.take_while(|c| LEADING_PUNCT.contains(**c))
.count();
let trailing = chars
.iter()
.rev()
.take_while(|c| TRAILING_PUNCT.contains(**c))
.count();
if leading + trailing >= len {
let mut result = Vec::new();
for (i, ch) in chars.iter().enumerate() {
let s: String = std::iter::once(*ch).collect();
let is_last = i == chars.len() - 1;
let tok_ws = if is_last { whitespace } else { "" };
let tag = if LEADING_PUNCT.contains(*ch) {
leading_punct_tag(&s)
} else {
trailing_punct_tag(&s)
};
let mut tok = MToken::new(&s, tag, tok_ws);
tok.underscore.is_head = true;
result.push(tok);
}
return result;
}
let mut result = Vec::new();
for i in 0..leading {
let ch: String = chars[i..=i].iter().collect();
let tag = leading_punct_tag(&ch);
let mut tok = MToken::new(&ch, tag, "");
tok.underscore.is_head = true;
result.push(tok);
}
let core: String = chars[leading..len - trailing].iter().collect();
let core_ws = if trailing > 0 { "" } else { whitespace };
let tag = simple_tag(&core);
let mut tok = MToken::new(&core, tag, core_ws);
tok.underscore.is_head = true;
result.push(tok);
for i in 0..trailing {
let ch: String = chars[len - trailing + i..=len - trailing + i]
.iter()
.collect();
let is_last = i == trailing - 1;
let tok_ws = if is_last { whitespace } else { "" };
let tag = trailing_punct_tag(&ch);
let mut tok = MToken::new(&ch, tag, tok_ws);
tok.underscore.is_head = true;
result.push(tok);
}
result
}
pub fn tokenize_simple(text: &str) -> Vec<MToken> {
let mut raw_tokens: Vec<(&str, &str)> = Vec::new(); let mut chars = text.char_indices().peekable();
let mut current_word_start: Option<usize> = None;
while let Some(&(i, c)) = chars.peek() {
if c.is_whitespace() {
if let Some(start) = current_word_start.take() {
let word = &text[start..i];
let ws_start = i;
while let Some(&(_, wc)) = chars.peek() {
if wc.is_whitespace() {
chars.next();
} else {
break;
}
}
let ws_end = chars.peek().map(|&(idx, _)| idx).unwrap_or(text.len());
let ws = &text[ws_start..ws_end];
raw_tokens.push((word, ws));
} else {
chars.next();
}
} else {
if current_word_start.is_none() {
current_word_start = Some(i);
}
chars.next();
}
}
if let Some(start) = current_word_start {
let word = &text[start..];
raw_tokens.push((word, ""));
}
let mut tokens = Vec::new();
for (word, ws) in raw_tokens {
tokens.extend(split_punct(word, ws));
}
tokens
}
pub fn tokenize(text: &str) -> Vec<MToken> {
let mut tokens = tokenize_simple(text);
let words_owned: Vec<String> = tokens.iter().map(|t| t.text.clone()).collect();
let words: Vec<&str> = words_owned.iter().map(|s| s.as_str()).collect();
let tags = tagger::global_tagger().tag(&words);
for (tok, tag) in tokens.iter_mut().zip(tags.iter()) {
if tok.tag == "DEFAULT" {
tok.tag = tag.tag.clone();
}
}
tokens
}
fn subtokenize_regex() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
Regex::new(
r"(?x)
^[''']+ |
\p{Lu}(?=\p{Lu}\p{Ll}) |
(?:^-)?(?:\d?[,.]?\d)+ |
[-_]+ |
[''']{2,} |
\p{L}*?(?:[''']\p{L})*?\p{Ll}(?=\p{Lu}) |
\p{L}+(?:[''']\p{L})* |
[^-_\p{L}'''\d] |
[''']+$
",
)
.expect("subtokenize regex should compile")
})
}
pub fn subtokenize(word: &str) -> Vec<String> {
let re = subtokenize_regex();
let mut results = Vec::new();
let mut pos = 0;
while pos < word.len() {
if let Ok(Some(m)) = re.find(&word[pos..]) {
let start = m.start();
let end = m.end();
if start == end {
pos += word[pos..].chars().next().map_or(1, |c| c.len_utf8());
continue;
}
results.push(word[pos + start..pos + end].to_string());
pos += end;
} else {
break;
}
}
if results.is_empty() {
vec![word.to_string()]
} else {
results
}
}
pub fn fold_left(tokens: Vec<MToken>) -> Vec<MToken> {
let mut result: Vec<MToken> = Vec::new();
for tok in tokens {
if !tok.underscore.is_head && !result.is_empty() {
let prev = result.last_mut().unwrap();
prev.text.push_str(&prev.whitespace);
prev.text.push_str(&tok.text);
prev.whitespace = tok.whitespace;
if let Some(ref tp) = tok.phonemes {
if let Some(ref mut pp) = prev.phonemes {
pp.push_str(tp);
} else {
prev.phonemes = tok.phonemes.clone();
}
}
} else {
result.push(tok);
}
}
result
}
#[derive(Clone, Debug)]
pub enum TokenOrGroup {
Single(MToken),
Group(Vec<MToken>),
}
pub fn retokenize(tokens: Vec<MToken>) -> Vec<TokenOrGroup> {
let mut output: Vec<TokenOrGroup> = Vec::new();
let mut pending_currency: Option<char> = None;
for tok in tokens {
if tok.underscore.alias.is_some() || tok.phonemes.is_some() {
output.push(TokenOrGroup::Single(tok));
continue;
}
let subtokens = subtokenize(&tok.text);
if subtokens.len() <= 1 {
let mut tok = tok;
if tok.tag == "$" && is_currency(&tok.text) {
tok.phonemes = Some(String::new());
tok.underscore.rating = Some(4);
pending_currency = tok.text.chars().next();
output.push(TokenOrGroup::Single(tok));
} else if tok.tag == ":" && tok.text.chars().all(|c| c == '-' || c == '—') {
tok.phonemes = Some("\u{2014}".to_string());
tok.underscore.rating = Some(3);
output.push(TokenOrGroup::Single(tok));
} else if PUNCT_TAGS.contains(&tok.tag.as_str())
&& !tok.text.chars().all(|c| c.is_alphabetic())
{
if let Some(ph) = punct_tag_phoneme(&tok.tag) {
tok.phonemes = Some(ph.to_string());
} else {
let filtered: String =
tok.text.chars().filter(|c| PUNCTS.contains(*c)).collect();
if !filtered.is_empty() {
tok.phonemes = Some(filtered);
} else {
tok.phonemes = Some(String::new());
}
}
tok.underscore.rating = Some(3);
output.push(TokenOrGroup::Single(tok));
} else {
if let Some(cur) = pending_currency {
if tok.tag == "CD" {
tok.underscore.currency = Some(cur);
} else {
pending_currency = None;
}
}
output.push(TokenOrGroup::Single(tok));
}
continue;
}
let mut group: Vec<MToken> = Vec::new();
for (i, sub_text) in subtokens.iter().enumerate() {
let is_last = i == subtokens.len() - 1;
let is_junk = sub_text.chars().all(|c| SUBTOKEN_JUNKS.contains(c));
let mut sub_tok = MToken::new(
sub_text.as_str(),
if is_junk {
":".to_string()
} else {
tok.tag.clone()
},
if is_last {
tok.whitespace.clone()
} else {
String::new()
},
);
sub_tok.underscore.is_head = i == 0;
if is_junk {
sub_tok.phonemes = Some(String::new());
sub_tok.underscore.rating = Some(3);
}
if let Some(cur) = pending_currency {
if sub_tok.tag == "CD" {
sub_tok.underscore.currency = Some(cur);
}
}
if !group.is_empty() && !is_last {
group.push(sub_tok);
} else if !group.is_empty() && is_last {
group.push(sub_tok);
output.push(TokenOrGroup::Group(std::mem::take(&mut group)));
} else if !is_last {
group.push(sub_tok);
} else {
output.push(TokenOrGroup::Single(sub_tok));
}
}
if !group.is_empty() {
output.push(TokenOrGroup::Group(group));
}
if tok.tag != "$" && tok.tag != "CD" {
pending_currency = None;
}
}
output
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simple_basic_sentence() {
let tokens = tokenize_simple("Hello world");
assert_eq!(tokens.len(), 2);
assert_eq!(tokens[0].text, "Hello");
assert_eq!(tokens[0].tag, "DEFAULT");
assert_eq!(tokens[0].whitespace, " ");
assert_eq!(tokens[1].text, "world");
assert_eq!(tokens[1].tag, "DEFAULT");
assert_eq!(tokens[1].whitespace, "");
}
#[test]
fn simple_punctuation() {
let tokens = tokenize_simple("Hello, world!");
assert_eq!(tokens.len(), 4);
assert_eq!(tokens[0].text, "Hello");
assert_eq!(tokens[0].whitespace, "");
assert_eq!(tokens[1].text, ",");
assert_eq!(tokens[1].tag, ",");
assert_eq!(tokens[1].whitespace, " ");
assert_eq!(tokens[2].text, "world");
assert_eq!(tokens[2].whitespace, "");
assert_eq!(tokens[3].text, "!");
assert_eq!(tokens[3].tag, ".");
assert_eq!(tokens[3].whitespace, "");
}
#[test]
fn simple_standalone_punct() {
let tokens = tokenize_simple("a . b");
assert_eq!(tokens[1].text, ".");
assert_eq!(tokens[1].tag, ".");
}
#[test]
fn split_period_from_word() {
let tokens = tokenize_simple("Hello.");
assert_eq!(tokens.len(), 2);
assert_eq!(tokens[0].text, "Hello");
assert_eq!(tokens[1].text, ".");
assert_eq!(tokens[1].tag, ".");
}
#[test]
fn split_question_mark() {
let tokens = tokenize_simple("Really?");
assert_eq!(tokens.len(), 2);
assert_eq!(tokens[0].text, "Really");
assert_eq!(tokens[1].text, "?");
assert_eq!(tokens[1].tag, ".");
}
#[test]
fn split_leading_paren() {
let tokens = tokenize_simple("(Hello)");
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0].text, "(");
assert_eq!(tokens[1].text, "Hello");
assert_eq!(tokens[2].text, ")");
}
#[test]
fn split_leading_currency() {
let tokens = tokenize_simple("$100");
assert_eq!(tokens.len(), 2);
assert_eq!(tokens[0].text, "$");
assert_eq!(tokens[0].tag, "$");
assert_eq!(tokens[1].text, "100");
assert_eq!(tokens[1].tag, "CD");
}
#[test]
fn no_split_contraction() {
let tokens = tokenize_simple("don't");
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].text, "don't");
}
#[test]
fn no_split_decimal() {
let tokens = tokenize_simple("3.14");
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].text, "3.14");
assert_eq!(tokens[0].tag, "CD");
}
#[test]
fn split_multiple_trailing() {
let tokens = tokenize_simple("What!?");
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0].text, "What");
assert_eq!(tokens[1].text, "!");
assert_eq!(tokens[2].text, "?");
}
#[test]
fn split_sentence_periods() {
let tokens = tokenize_simple("Hello. World.");
assert_eq!(tokens.len(), 4);
assert_eq!(tokens[0].text, "Hello");
assert_eq!(tokens[1].text, ".");
assert_eq!(tokens[1].tag, ".");
assert_eq!(tokens[1].whitespace, " ");
assert_eq!(tokens[2].text, "World");
assert_eq!(tokens[3].text, ".");
assert_eq!(tokens[3].tag, ".");
}
#[test]
fn pure_punct_split_into_chars() {
let tokens = tokenize_simple("...");
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0].text, ".");
assert_eq!(tokens[1].text, ".");
assert_eq!(tokens[2].text, ".");
}
#[test]
fn pure_mixed_punct_split() {
let tokens = tokenize_simple("!?!?");
assert_eq!(tokens.len(), 4);
assert_eq!(tokens[0].text, "!");
assert_eq!(tokens[1].text, "?");
assert_eq!(tokens[2].text, "!");
assert_eq!(tokens[3].text, "?");
}
#[test]
fn single_punct_not_split_further() {
let tokens = tokenize_simple(".");
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].text, ".");
assert_eq!(tokens[0].tag, ".");
}
#[test]
fn split_parens_with_tags() {
let tokens = tokenize_simple("(Hello)");
assert_eq!(tokens[0].text, "(");
assert_eq!(tokens[0].tag, "-LRB-");
assert_eq!(tokens[2].text, ")");
assert_eq!(tokens[2].tag, "-RRB-");
}
#[test]
fn split_ascii_quotes() {
let tokens = tokenize_simple("\"Hello\"");
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0].text, "\"");
assert_eq!(tokens[0].tag, "``"); assert_eq!(tokens[2].text, "\"");
assert_eq!(tokens[2].tag, "''"); }
#[test]
fn split_quoted_sentence_with_period() {
let tokens = tokenize_simple("She said, \"hello.\"");
let texts: Vec<&str> = tokens.iter().map(|t| t.text.as_str()).collect();
assert!(texts.contains(&","), "comma should be split: {texts:?}");
assert!(texts.contains(&"."), "period should be split: {texts:?}");
assert!(
texts.iter().filter(|t| **t == "\"").count() == 2,
"two quotes should be split: {texts:?}"
);
}
#[test]
fn split_number_trailing_period() {
let tokens = tokenize_simple("I have 3.");
let texts: Vec<&str> = tokens.iter().map(|t| t.text.as_str()).collect();
assert!(texts.contains(&"3"), "number should be separate: {texts:?}");
assert!(texts.contains(&"."), "period should be split: {texts:?}");
}
#[test]
fn no_split_decimal_number() {
let tokens = tokenize_simple("He scored 3.14 points.");
let texts: Vec<&str> = tokens.iter().map(|t| t.text.as_str()).collect();
assert!(
texts.contains(&"3.14"),
"decimal should stay together: {texts:?}"
);
}
#[test]
fn split_number_trailing_comma() {
let tokens = tokenize_simple("Buy 3, get 1.");
let texts: Vec<&str> = tokens.iter().map(|t| t.text.as_str()).collect();
assert!(texts.contains(&"3"), "number should be separate: {texts:?}");
assert!(texts.contains(&","), "comma should be split: {texts:?}");
}
#[test]
fn simple_currency() {
let tokens = tokenize_simple("$ 100");
assert_eq!(tokens.len(), 2);
assert_eq!(tokens[0].text, "$");
assert_eq!(tokens[0].tag, "$");
assert_eq!(tokens[1].text, "100");
assert_eq!(tokens[1].tag, "CD");
}
#[test]
fn simple_number() {
let tokens = tokenize_simple("42 1,000 3.14");
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0].tag, "CD");
assert_eq!(tokens[1].tag, "CD");
assert_eq!(tokens[2].tag, "CD");
}
#[test]
fn simple_dash_tag() {
let tokens = tokenize_simple("a — b");
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[1].text, "—");
assert_eq!(tokens[1].tag, ":");
}
#[test]
fn simple_all_heads() {
let tokens = tokenize_simple("one two three");
for tok in &tokens {
assert!(tok.underscore.is_head);
}
}
#[test]
fn simple_empty_input() {
let tokens = tokenize_simple("");
assert!(tokens.is_empty());
}
#[test]
fn simple_whitespace_only() {
let tokens = tokenize_simple(" ");
assert!(tokens.is_empty());
}
#[test]
fn simple_multiple_spaces() {
let tokens = tokenize_simple("hello world");
assert_eq!(tokens.len(), 2);
assert_eq!(tokens[0].whitespace, " ");
}
#[test]
fn subtokenize_simple_word() {
let result = subtokenize("hello");
assert_eq!(result, vec!["hello"]);
}
#[test]
fn subtokenize_camel_case() {
let result = subtokenize("iPhone");
assert!(
result.len() >= 2,
"Expected split for camelCase: {:?}",
result
);
}
#[test]
fn subtokenize_number_word() {
let result = subtokenize("test123");
assert!(
result.len() >= 2,
"Expected split for word+number: {:?}",
result
);
}
#[test]
fn subtokenize_hyphenated() {
let result = subtokenize("well-known");
assert!(result.len() >= 3, "Expected split on hyphen: {:?}", result);
assert!(result.contains(&"-".to_string()));
}
#[test]
fn subtokenize_apostrophe() {
let result = subtokenize("don't");
assert!(result.len() >= 1, "Should handle apostrophe: {:?}", result);
}
#[test]
fn subtokenize_all_caps() {
let result = subtokenize("NASA");
assert!(!result.is_empty());
}
#[test]
fn subtokenize_leading_quotes() {
let result = subtokenize("'hello");
assert_eq!(result[0], "'");
}
#[test]
fn subtokenize_trailing_quotes() {
let result = subtokenize("hello'");
assert!(result.last().unwrap().contains('\''));
}
#[test]
fn subtokenize_digits_with_comma() {
let result = subtokenize("1,000");
assert!(
result.contains(&"1,000".to_string()),
"Expected number to stay together: {:?}",
result
);
}
#[test]
fn fold_left_merges_non_heads() {
let tok1 = {
let mut t = MToken::new("can", "NN", "");
t.underscore.is_head = true;
t.phonemes = Some("k".into());
t
};
let tok2 = {
let mut t = MToken::new("not", "RB", " ");
t.underscore.is_head = false;
t.phonemes = Some("n".into());
t
};
let result = fold_left(vec![tok1, tok2]);
assert_eq!(result.len(), 1);
assert_eq!(result[0].text, "cannot");
assert_eq!(result[0].phonemes, Some("kn".into()));
assert_eq!(result[0].whitespace, " ");
}
#[test]
fn fold_left_preserves_heads() {
let tok1 = {
let mut t = MToken::new("a", "DT", " ");
t.underscore.is_head = true;
t
};
let tok2 = {
let mut t = MToken::new("b", "NN", "");
t.underscore.is_head = true;
t
};
let result = fold_left(vec![tok1, tok2]);
assert_eq!(result.len(), 2);
}
}