use super::EspeakG2P;
use std::collections::HashMap;
use std::error::Error;
use std::time::Instant;
pub struct EspeakIpaTokenizer {
vocab: HashMap<String, i64>,
bos_id: i64,
eos_id: i64,
model_max_length: usize,
g2p: EspeakG2P,
max_token_chars: usize,
}
impl EspeakIpaTokenizer {
pub fn new(vocab: HashMap<String, i64>) -> Result<Self, Box<dyn std::error::Error>> {
let bos_id = *vocab.get("$").ok_or("BOS token '$' not found")?;
let eos_id = bos_id;
let g2p = EspeakG2P::new()?;
let max_token_chars = Self::max_token_chars(&vocab);
Ok(Self {
vocab,
bos_id,
eos_id,
model_max_length: 512,
g2p,
max_token_chars,
})
}
pub fn with_model_max_length(mut self, max_length: usize) -> Self {
self.model_max_length = max_length;
self
}
fn espeak_ipa_to_misaki(&self, ipa: &str) -> String {
let mut result = ipa.replace('\u{0361}', "^");
let from_espeaks = vec![
("ʔˌn\u{0329}", "tᵊn"), ("a^ɪ", "I"), ("a^ʊ", "W"), ("d^ʒ", "ʤ"), ("e^ɪ", "A"), ("t^ʃ", "ʧ"), ("ɔ^ɪ", "Y"), ("ə^l", "ᵊl"), ("ʔn", "tᵊn"), ("ɚ", "əɹ"), ("ʲO", "jO"), ("ʲQ", "jQ"), ("\u{0303}", ""), ("e", "A"), ("r", "ɹ"), ("x", "k"), ("ç", "k"), ("ɐ", "ə"), ("ɬ", "l"), ("ʔ", "t"), ("ʲ", ""), ];
for (old, new) in from_espeaks {
result = result.replace(old, new);
}
let mut chars: Vec<char> = result.chars().collect();
let mut i = 0;
while i < chars.len() {
if i + 1 < chars.len() && chars[i + 1] == '\u{0329}' {
let consonant = chars[i];
chars[i] = 'ᵊ';
chars[i + 1] = consonant;
i += 2;
} else {
i += 1;
}
}
result = chars.into_iter().collect();
result = result.replace('\u{0329}', "");
result = result.replace("o^ʊ", "O");
result = result.replace("ɜːɹ", "ɜɹ");
result = result.replace("ɜː", "ɜɹ");
result = result.replace("ɪə", "iə");
result = result.replace("ː", "");
result = result.replace("^", "");
result
}
fn text_to_ipa(&self, text: &str) -> Result<String, Box<dyn Error>> {
let ipa = self.g2p.text_to_ipa(text)?;
let misaki_phonemes = self.espeak_ipa_to_misaki(&ipa);
if std::env::var("DEBUG_PHONEMES").is_ok() {
println!("Input text: '{}'", text);
println!("Espeak IPA: '{}'", ipa);
println!("Misaki phonemes: '{}'", misaki_phonemes);
}
Ok(misaki_phonemes)
}
fn max_token_chars(vocab: &HashMap<String, i64>) -> usize {
vocab.keys().map(|k| k.chars().count()).max().unwrap_or(1)
}
pub fn tokenize_longest(&self, ipa: &str) -> Vec<i64> {
let mut ids = Vec::with_capacity(ipa.len());
let chars: Vec<char> = ipa.chars().collect();
let mut i = 0;
let max_len = self.max_token_chars;
while i < chars.len() {
let mut matched = false;
let limit = max_len.min(chars.len() - i);
for l in (1..=limit).rev() {
let cand: String = chars[i..i + l].iter().collect();
if let Some(&id) = self.vocab.get(&cand) {
ids.push(id);
i += l;
matched = true;
break;
}
}
if !matched {
if !chars[i].is_whitespace() {
eprintln!("Warning: unknown token {:?}", chars[i]);
}
i += 1;
}
}
ids
}
pub fn encode_phonemes(
&self,
phonemes: &str,
max_length: Option<usize>,
) -> Result<Vec<i64>, Box<dyn std::error::Error>> {
let start_time = Instant::now();
let max_len = max_length.unwrap_or(self.model_max_length);
let mut tokens = Vec::with_capacity(phonemes.len() + 2);
tokens.push(self.bos_id);
let mut inner = self.tokenize_longest(phonemes);
tokens.append(&mut inner);
tokens.push(self.eos_id);
if tokens.len() > max_len {
let keep_inner = max_len.saturating_sub(2);
let mut truncated = Vec::with_capacity(max_len);
truncated.push(self.bos_id);
truncated.extend_from_slice(&tokens[1..1 + keep_inner]);
truncated.push(self.eos_id);
if std::env::var("DEBUG_TIMING").is_ok() {
println!(
"Direct phoneme tokenization time: {:?}",
start_time.elapsed()
);
}
return Ok(truncated);
}
if std::env::var("DEBUG_TIMING").is_ok() {
println!(
"Direct phoneme tokenization time: {:?}",
start_time.elapsed()
);
}
if std::env::var("DEBUG_TOKENS").is_ok() {
println!("tokens = {:?}", tokens);
}
Ok(tokens)
}
pub fn encode(
&self,
text: &str,
max_length: Option<usize>,
) -> Result<Vec<i64>, Box<dyn std::error::Error>> {
let start_time = Instant::now();
let max_len = max_length.unwrap_or(self.model_max_length);
let mut tokens = Vec::with_capacity(text.len() + 2);
tokens.push(self.bos_id);
let ipa_start = Instant::now();
let ipa_text = self.text_to_ipa(text)?;
if std::env::var("DEBUG_TIMING").is_ok() {
println!(
"Phoneme tokenization (espeak IPA conversion) took: {:?}",
ipa_start.elapsed()
);
}
let mut inner = self.tokenize_longest(&ipa_text);
tokens.append(&mut inner);
tokens.push(self.eos_id);
if tokens.len() > max_len {
let keep_inner = max_len.saturating_sub(2);
let mut truncated = Vec::with_capacity(max_len);
truncated.push(self.bos_id);
truncated.extend_from_slice(&tokens[1..1 + keep_inner]);
truncated.push(self.eos_id);
if std::env::var("DEBUG_TIMING").is_ok() {
println!("Total tokenization time: {:?}", start_time.elapsed());
}
return Ok(truncated);
}
if std::env::var("DEBUG_TIMING").is_ok() {
println!("Total tokenization time: {:?}", start_time.elapsed());
}
if std::env::var("DEBUG_TOKENS").is_ok() {
println!("tokens = {:?}", tokens);
}
Ok(tokens)
}
}