use crate::levenshtein::levenshtein_ratio;
const MATCH_THRESHOLD: f32 = 0.65;
static STOP_WORDS: &[&str] = &[
"the", "of", "on", "or", "in", "at", "to", "a", "an", "and", "as", "is", "it", "be", "do",
"so", "up", "by", "if", "no", "my", "we", "he", "me", "us", "am", "are", "was", "not", "but",
"day", "date", "year", "month", "time", "age",
];
fn is_stop_word(word: &str) -> bool {
STOP_WORDS.contains(&word)
}
static UNITS: &[(&str, i32)] = &[
("one", 1),
("two", 2),
("three", 3),
("four", 4),
("five", 5),
("six", 6),
("seven", 7),
("eight", 8),
("nine", 9),
("first", 1),
("second", 2),
("third", 3),
("fourth", 4),
("fifth", 5),
("sixth", 6),
("seventh", 7),
("eighth", 8),
("ninth", 9),
];
static TEENS: &[(&str, i32)] = &[
("ten", 10),
("eleven", 11),
("twelve", 12),
("thirteen", 13),
("fourteen", 14),
("fifteen", 15),
("sixteen", 16),
("seventeen", 17),
("eighteen", 18),
("nineteen", 19),
("tenth", 10),
("eleventh", 11),
("twelfth", 12),
("thirteenth", 13),
("fourteenth", 14),
("fifteenth", 15),
("sixteenth", 16),
("seventeenth", 17),
("eighteenth", 18),
("nineteenth", 19),
];
static TENS: &[(&str, i32)] = &[
("twenty", 20),
("thirty", 30),
("forty", 40),
("fifty", 50),
("sixty", 60),
("seventy", 70),
("eighty", 80),
("ninety", 90),
("twentieth", 20),
("thirtieth", 30),
("fortieth", 40),
("fiftieth", 50),
("sixtieth", 60),
("seventieth", 70),
("eightieth", 80),
("ninetieth", 90),
];
static HUNDREDS: &[&str] = &["hundred", "hundredth"];
static THOUSANDS: &[&str] = &["thousand", "thousandth"];
fn best_match(word: &str, table: &[(&str, i32)]) -> Option<i32> {
if is_stop_word(word) {
return None;
}
table
.iter()
.map(|&(canonical, value)| (levenshtein_ratio(word, canonical), value))
.filter(|&(ratio, _)| ratio >= MATCH_THRESHOLD)
.max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal))
.map(|(_, value)| value)
}
fn match_unit(word: &str) -> Option<i32> {
best_match(word, UNITS)
}
fn match_teen(word: &str) -> Option<i32> {
best_match(word, TEENS)
}
fn match_tens(word: &str) -> Option<i32> {
best_match(word, TENS)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum NumberCategory {
Unit,
Teen,
Tens,
}
fn best_ratio(word: &str, table: &[(&str, i32)]) -> f32 {
if is_stop_word(word) {
return 0.0;
}
table
.iter()
.map(|&(canonical, _)| levenshtein_ratio(word, canonical))
.filter(|&ratio| ratio >= MATCH_THRESHOLD)
.fold(0.0_f32, f32::max)
}
fn match_best_single_word(word: &str) -> Option<(i32, NumberCategory)> {
let unit_ratio = best_ratio(word, UNITS);
let teen_ratio = best_ratio(word, TEENS);
let tens_ratio = best_ratio(word, TENS);
if unit_ratio == 0.0 && teen_ratio == 0.0 && tens_ratio == 0.0 {
return None;
}
if unit_ratio > teen_ratio && unit_ratio > tens_ratio {
match_unit(word).map(|value| (value, NumberCategory::Unit))
} else if teen_ratio > tens_ratio {
match_teen(word).map(|value| (value, NumberCategory::Teen))
} else {
match_tens(word).map(|value| (value, NumberCategory::Tens))
}
}
fn match_hundred(word: &str) -> bool {
if is_stop_word(word) {
return false;
}
HUNDREDS
.iter()
.any(|&canonical| levenshtein_ratio(word, canonical) >= MATCH_THRESHOLD)
}
fn match_thousand(word: &str) -> bool {
if is_stop_word(word) {
return false;
}
THOUSANDS
.iter()
.any(|&canonical| levenshtein_ratio(word, canonical) >= MATCH_THRESHOLD)
}
fn word_tokens(utterance: &str) -> Vec<(usize, &str)> {
let mut tokens: Vec<(usize, &str)> = Vec::new();
let mut start: Option<usize> = None;
for (byte_offset, character) in utterance.char_indices() {
let is_separator = character == ' '
|| character == '-'
|| character == '\t'
|| character == '\n'
|| character == '\r';
if is_separator {
if let Some(word_start) = start.take() {
tokens.push((word_start, &utterance[word_start..byte_offset]));
}
} else if start.is_none() {
start = Some(byte_offset);
}
}
if let Some(word_start) = start {
tokens.push((word_start, &utterance[word_start..]));
}
tokens
}
fn try_parse_number(tokens: &[(usize, &str)], cursor: usize) -> Option<(i32, usize)> {
let lower_word = |index: usize| -> Option<String> {
tokens.get(index).map(|(_, word)| word.to_ascii_lowercase())
};
let mut position = cursor;
let mut total: i32 = 0;
if let Some(unit_word) = lower_word(position)
&& let Some(unit_value) = match_unit(&unit_word)
&& let Some(thousand_word) = lower_word(position + 1)
&& match_thousand(&thousand_word)
{
total += unit_value * 1000;
position += 2;
}
if let Some(unit_word) = lower_word(position)
&& let Some(unit_value) = match_unit(&unit_word)
&& let Some(hundred_word) = lower_word(position + 1)
&& match_hundred(&hundred_word)
{
total += unit_value * 100;
position += 2;
}
if let Some(word) = lower_word(position) {
match match_best_single_word(&word) {
Some((value, NumberCategory::Tens)) => {
total += value;
position += 1;
if let Some(unit_word) = lower_word(position)
&& let Some((unit_value, _)) = match_best_single_word(&unit_word)
&& matches!(
match_best_single_word(&unit_word),
Some((_, NumberCategory::Unit | NumberCategory::Teen))
)
{
total += unit_value;
position += 1;
}
}
Some((value, NumberCategory::Teen | NumberCategory::Unit)) => {
total += value;
position += 1;
}
None => {
}
}
}
let words_consumed = position - cursor;
if words_consumed == 0 || total <= 0 {
return None;
}
Some((total, words_consumed))
}
pub fn replace_word_numbers(utterance: &str) -> String {
let tokens = word_tokens(utterance);
if tokens.is_empty() {
return utterance.to_string();
}
let mut result = String::with_capacity(utterance.len());
let mut output_up_to: usize = 0;
let mut token_cursor: usize = 0;
while token_cursor < tokens.len() {
match try_parse_number(&tokens, token_cursor) {
Some((value, words_consumed)) => {
let span_start = tokens[token_cursor].0;
if span_start > output_up_to {
result.push_str(&utterance[output_up_to..span_start]);
}
result.push_str(&value.to_string());
let last_consumed_index = token_cursor + words_consumed - 1;
let (last_word_start, last_word) = tokens[last_consumed_index];
output_up_to = last_word_start + last_word.len();
token_cursor += words_consumed;
}
None => {
token_cursor += 1;
}
}
}
if output_up_to < utterance.len() {
result.push_str(&utterance[output_up_to..]);
}
result
}