use std::collections::HashMap;
use once_cell::sync::Lazy;
use regex::Regex;
const PAD: char = '$';
const PUNCTUATION: &str = ";:,.!?¡¿—…\u{201C}«»\u{201D}\" ";
const LETTERS: &str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const IPA_LETTERS: &str =
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘\u{2019}\u{0329}\u{2018}ᵻ";
static VOCAB: Lazy<HashMap<char, i64>> = Lazy::new(|| {
let symbols: Vec<char> = std::iter::once(PAD)
.chain(PUNCTUATION.chars())
.chain(LETTERS.chars())
.chain(IPA_LETTERS.chars())
.collect();
symbols.into_iter().enumerate().map(|(i, c)| (c, i as i64)).collect()
});
static RE_TOKENIZE: Lazy<Regex> = Lazy::new(|| Regex::new(r"\w+|[^\w\s]").unwrap());
pub fn basic_english_tokenize(text: &str) -> String {
RE_TOKENIZE
.find_iter(text)
.map(|m| m.as_str())
.collect::<Vec<_>>()
.join(" ")
}
pub fn char_to_id(c: char) -> Option<i64> {
VOCAB.get(&c).copied()
}
pub fn text_to_ids(tokenized: &str) -> Vec<i64> {
let mut ids = vec![0i64]; for ch in tokenized.chars() {
if let Some(id) = char_to_id(ch) {
ids.push(id);
}
}
ids.push(0i64); ids
}
pub fn ipa_to_ids(ipa: &str) -> Vec<i64> {
let tokenized = basic_english_tokenize(ipa);
text_to_ids(&tokenized)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vocab_not_empty() {
assert!(!VOCAB.is_empty());
}
#[test]
fn test_pad_is_zero() {
assert_eq!(char_to_id('$'), Some(0));
}
#[test]
fn test_known_chars() {
for ch in ";:,.!?".chars() {
assert!(char_to_id(ch).is_some(), "char {} not in vocab", ch);
}
for ch in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz".chars() {
assert!(char_to_id(ch).is_some(), "char {} not in vocab", ch);
}
}
#[test]
fn test_unknown_char_returns_none() {
assert_eq!(char_to_id('\u{0000}'), None);
assert_eq!(char_to_id('中'), None);
}
#[test]
fn test_ids_have_pads() {
let ids = ipa_to_ids("hɛloʊ");
assert_eq!(ids[0], 0, "should start with pad token 0");
assert_eq!(*ids.last().unwrap(), 0, "should end with pad token 0");
assert!(ids.len() > 2, "should have content between pads");
}
#[test]
fn test_basic_english_tokenize() {
let out = basic_english_tokenize("hɛloʊ wɜːld!");
assert!(out.contains("hɛloʊ"), "got: {}", out);
assert!(out.contains("wɜːld"), "got: {}", out);
assert!(out.contains('!'), "got: {}", out);
}
#[test]
fn test_vocab_uniqueness() {
let mut seen_indices = std::collections::HashSet::new();
for &idx in VOCAB.values() {
assert!(seen_indices.insert(idx), "duplicate index {}", idx);
}
}
}