Skip to main content

voice_g2p/
lib.rs

1pub mod espeak;
2pub mod lexicon;
3pub mod number;
4pub mod stress;
5pub mod tagger;
6pub mod token;
7pub mod tokenizer;
8
9use std::collections::HashMap;
10use std::sync::OnceLock;
11
12use espeak::EspeakFallback;
13use lexicon::Lexicon;
14use stress::{apply_stress, CONSONANTS, NON_QUOTE_PUNCTS, PRIMARY_STRESS, SUBTOKEN_JUNKS, VOWELS};
15use token::{merge_tokens, MToken, TokenContext};
16use tokenizer::TokenOrGroup;
17
18#[derive(Debug, thiserror::Error)]
19pub enum G2pError {
20    #[error("espeak-ng not found. Install with: brew install espeak-ng")]
21    EspeakNotFound,
22    #[error("espeak-ng failed: {0}")]
23    EspeakFailed(String),
24    #[error("I/O error: {0}")]
25    Io(#[from] std::io::Error),
26}
27
28/// Configuration for external tool paths used by the G2P pipeline.
29#[derive(Debug, Clone)]
30pub struct G2PConfig {
31    /// Path to the `espeak-ng` binary for fallback pronunciation.
32    /// Defaults to `"espeak-ng"` (PATH lookup).
33    pub espeak_path: String,
34}
35
36impl Default for G2PConfig {
37    fn default() -> Self {
38        Self {
39            espeak_path: "espeak-ng".to_string(),
40        }
41    }
42}
43
44/// The main G2P pipeline, ported from misaki's `en.G2P.__call__()`.
45pub struct G2P {
46    lexicon: Lexicon,
47    fallback: EspeakFallback,
48    unk: String,
49    overrides: HashMap<String, String>,
50}
51
52fn global_g2p() -> &'static G2P {
53    static INSTANCE: OnceLock<G2P> = OnceLock::new();
54    INSTANCE.get_or_init(G2P::new)
55}
56
57impl G2P {
58    pub fn new() -> Self {
59        Self::with_config(G2PConfig::default())
60    }
61
62    pub fn with_config(config: G2PConfig) -> Self {
63        Self {
64            lexicon: Lexicon::new(),
65            fallback: EspeakFallback::with_path(config.espeak_path),
66            unk: String::new(),
67            overrides: HashMap::new(),
68        }
69    }
70
71    /// Set custom word-to-phoneme overrides (builder pattern).
72    ///
73    /// Overrides map lowercase words to phoneme strings, checked before
74    /// the lexicon and espeak fallback.
75    pub fn with_overrides(mut self, overrides: HashMap<String, String>) -> Self {
76        self.overrides = overrides;
77        self
78    }
79
80    /// Full pipeline: text -> phoneme string.
81    ///
82    /// Mirrors misaki `G2P.__call__()` from en.py:679-738.
83    pub fn convert(&self, text: &str) -> Result<String, G2pError> {
84        // 1. Tokenize and POS-tag (embedded perceptron tagger)
85        let tokens = tokenizer::tokenize(text);
86
87        // 2. fold_left: merge non-head tokens
88        let tokens = tokenizer::fold_left(tokens);
89
90        // 3. retokenize: subtokenize, handle punctuation/currency
91        let mut items = tokenizer::retokenize(tokens);
92
93        // 4. Right-to-left resolution with TokenContext
94        let mut ctx = TokenContext::default();
95
96        for item in items.iter_mut().rev() {
97            match item {
98                TokenOrGroup::Single(ref mut w) => {
99                    self.resolve_single_token(w, &ctx);
100                    ctx = Self::token_context(&ctx, w.phonemes.as_deref(), w);
101                }
102                TokenOrGroup::Group(ref mut group) => {
103                    self.resolve_group(group, &ctx);
104                    if let Some(first) = group.first() {
105                        ctx = Self::token_context(&ctx, first.phonemes.as_deref(), first);
106                    }
107                }
108            }
109        }
110
111        // 5. Merge groups into single tokens
112        let tokens: Vec<MToken> = items
113            .into_iter()
114            .map(|item| match item {
115                TokenOrGroup::Single(tok) => tok,
116                TokenOrGroup::Group(group) => merge_tokens(&group, Some(&self.unk)),
117            })
118            .collect();
119
120        // 6. Legacy conversion: ɾ->T, ʔ->t
121        let result: String = tokens
122            .iter()
123            .map(|tk| {
124                let ps = match &tk.phonemes {
125                    Some(p) => p.replace('ɾ', "T").replace('ʔ', "t"),
126                    None => self.unk.clone(),
127                };
128                format!("{}{}", ps, tk.whitespace)
129            })
130            .collect();
131
132        Ok(result)
133    }
134
135    /// Resolve a single (non-grouped) token.
136    fn resolve_single_token(&self, w: &mut MToken, ctx: &TokenContext) {
137        if w.phonemes.is_some() {
138            return;
139        }
140
141        // Check custom overrides before lexicon/espeak fallback
142        let lookup_key = w.text.to_lowercase();
143        if let Some(ps) = self.overrides.get(&lookup_key) {
144            w.phonemes = Some(ps.clone());
145            w.underscore.rating = Some(5); // highest priority
146            return;
147        }
148        let (ps, rating) = self.lexicon.call(
149            &w.text,
150            w.underscore.alias.as_deref(),
151            &w.tag,
152            w.underscore.stress,
153            w.underscore.currency,
154            w.underscore.is_head,
155            &w.underscore.num_flags,
156            ctx,
157        );
158        if let Some(ps) = ps {
159            w.phonemes = Some(ps);
160            w.underscore.rating = rating;
161            return;
162        }
163
164        if let Some((ps, rating)) = self.fallback.convert_word(&w.text) {
165            w.phonemes = Some(ps);
166            w.underscore.rating = Some(rating);
167        }
168    }
169
170    /// Resolve a group of subtokens using the left-expand/right-shrink algorithm.
171    ///
172    /// Ported from en.py:694-731.
173    fn resolve_group(&self, group: &mut [MToken], ctx: &TokenContext) {
174        let n = group.len();
175        let mut left = 0;
176        let mut right = n;
177        let mut should_fallback = false;
178
179        while left < right {
180            let has_existing = group[left..right]
181                .iter()
182                .any(|tk| tk.underscore.alias.is_some() || tk.phonemes.is_some());
183
184            let (ps, rating) = if has_existing {
185                (None, None)
186            } else {
187                let merged = merge_tokens(&group[left..right], None);
188                self.lexicon.call(
189                    &merged.text,
190                    merged.underscore.alias.as_deref(),
191                    &merged.tag,
192                    merged.underscore.stress,
193                    merged.underscore.currency,
194                    merged.underscore.is_head,
195                    &merged.underscore.num_flags,
196                    ctx,
197                )
198            };
199
200            if let Some(ps) = ps {
201                group[left].phonemes = Some(ps);
202                group[left].underscore.rating = rating;
203                for x in &mut group[left + 1..right] {
204                    x.phonemes = Some(String::new());
205                    x.underscore.rating = rating;
206                }
207                right = left;
208                left = 0;
209            } else if left + 1 < right {
210                left += 1;
211            } else {
212                right -= 1;
213                let tk = &mut group[right];
214                if tk.phonemes.is_none() {
215                    if tk.text.chars().all(|c| SUBTOKEN_JUNKS.contains(c)) {
216                        tk.phonemes = Some(String::new());
217                        tk.underscore.rating = Some(3);
218                    } else {
219                        should_fallback = true;
220                        break;
221                    }
222                }
223                left = 0;
224            }
225        }
226
227        if should_fallback {
228            let merged = merge_tokens(group, None);
229            if let Some((ps, rating)) = self.fallback.convert_word(&merged.text) {
230                group[0].phonemes = Some(ps);
231                group[0].underscore.rating = Some(rating);
232                for j in 1..group.len() {
233                    group[j].phonemes = Some(String::new());
234                    group[j].underscore.rating = group[0].underscore.rating;
235                }
236            }
237        } else {
238            Self::resolve_tokens(group);
239        }
240    }
241
242    /// Update TokenContext based on resolved phonemes and token.
243    ///
244    /// Ported from en.py:646-650.
245    fn token_context(ctx: &TokenContext, ps: Option<&str>, token: &MToken) -> TokenContext {
246        let mut vowel = ctx.future_vowel;
247
248        if let Some(ps) = ps {
249            for c in ps.chars() {
250                let is_vowel = VOWELS.contains(c);
251                let is_consonant = CONSONANTS.contains(c);
252                let is_punct = NON_QUOTE_PUNCTS.contains(c);
253
254                if is_vowel || is_consonant || is_punct {
255                    vowel = if is_punct { None } else { Some(is_vowel) };
256                    break;
257                }
258            }
259        }
260
261        let future_to = matches!(token.text.as_str(), "to" | "To")
262            || (token.text == "TO" && matches!(token.tag.as_str(), "TO" | "IN"));
263
264        TokenContext {
265            future_vowel: vowel,
266            future_to,
267        }
268    }
269
270    /// Normalize stress across a group of resolved subtokens.
271    ///
272    /// Ported from en.py:652-677.
273    fn resolve_tokens(tokens: &mut [MToken]) {
274        if tokens.is_empty() {
275            return;
276        }
277
278        let text: String = tokens
279            .iter()
280            .enumerate()
281            .map(|(i, tk)| {
282                if i < tokens.len() - 1 {
283                    format!("{}{}", tk.text, tk.whitespace)
284                } else {
285                    tk.text.clone()
286                }
287            })
288            .collect();
289
290        let has_space = text.contains(' ') || text.contains('/');
291        let char_classes: std::collections::HashSet<u8> = text
292            .chars()
293            .filter(|c| !SUBTOKEN_JUNKS.contains(*c))
294            .map(|c| {
295                if c.is_alphabetic() {
296                    0
297                } else if c.is_ascii_digit() {
298                    1
299                } else {
300                    2
301                }
302            })
303            .collect();
304        let prespace = has_space || char_classes.len() > 1;
305
306        let n = tokens.len();
307        for (i, tk) in tokens.iter_mut().enumerate() {
308            if tk.phonemes.is_none() {
309                let last = i == n - 1;
310                if last
311                    && tk.text.len() == 1
312                    && NON_QUOTE_PUNCTS.contains(tk.text.chars().next().unwrap_or(' '))
313                {
314                    tk.phonemes = Some(tk.text.clone());
315                    tk.underscore.rating = Some(3);
316                } else if tk.text.chars().all(|c| SUBTOKEN_JUNKS.contains(c)) {
317                    tk.phonemes = Some(String::new());
318                    tk.underscore.rating = Some(3);
319                }
320            } else if i > 0 {
321                tk.underscore.prespace = prespace;
322            }
323        }
324
325        if prespace {
326            return;
327        }
328
329        let indices: Vec<(bool, usize, usize)> = tokens
330            .iter()
331            .enumerate()
332            .filter_map(|(i, tk)| {
333                tk.phonemes.as_ref().filter(|p| !p.is_empty()).map(|p| {
334                    let has_primary = p.contains(PRIMARY_STRESS);
335                    let weight = token::stress_weight(Some(p));
336                    (has_primary, weight, i)
337                })
338            })
339            .collect();
340
341        if indices.len() == 2 && tokens[indices[0].2].text.len() == 1 {
342            let i = indices[1].2;
343            if let Some(ref ps) = tokens[i].phonemes {
344                tokens[i].phonemes = Some(apply_stress(ps, Some(-0.5)));
345            }
346            return;
347        }
348
349        if indices.len() < 2 {
350            return;
351        }
352        let primary_count: usize = indices.iter().filter(|(b, _, _)| *b).count();
353        if primary_count <= indices.len().div_ceil(2) {
354            return;
355        }
356
357        let mut sorted = indices.clone();
358        sorted.sort();
359        let half = sorted.len() / 2;
360        for &(_, _, i) in &sorted[..half] {
361            if let Some(ref ps) = tokens[i].phonemes {
362                tokens[i].phonemes = Some(apply_stress(ps, Some(-0.5)));
363            }
364        }
365    }
366}
367
368impl Default for G2P {
369    fn default() -> Self {
370        Self::new()
371    }
372}
373
374// ---------------------------------------------------------------------------
375// Public API (backward-compatible)
376// ---------------------------------------------------------------------------
377
378/// Convert English text to a Kokoro-compatible phoneme string.
379///
380/// Uses misaki-style dictionary lookup with espeak-ng fallback for unknown words.
381pub fn english_to_phonemes(text: &str) -> Result<String, G2pError> {
382    global_g2p().convert(text)
383}
384
385/// Convert English text to phonemes with custom word overrides.
386///
387/// Overrides map lowercase words to phoneme strings, checked before
388/// the lexicon and espeak fallback.
389pub fn english_to_phonemes_with_overrides(
390    text: &str,
391    overrides: &HashMap<String, String>,
392) -> Result<String, G2pError> {
393    let g2p = G2P::new().with_overrides(overrides.clone());
394    g2p.convert(text)
395}
396
397/// Post-process espeak-ng IPA output into Kokoro phoneme format.
398///
399/// Kept for backward compatibility. New code should use `english_to_phonemes()`.
400pub fn espeak_ipa_to_kokoro(ipa: &str) -> String {
401    let mut s = ipa.to_string();
402
403    s = s.replace("dʒ", "ʤ");
404    s = s.replace("tʃ", "ʧ");
405    s = s.replace("ɜːɹ", "ɜɹ");
406    s = s.replace("ɜː", "ɜɹ");
407    s = s.replace("aɪ", "I");
408    s = s.replace("aʊ", "W");
409    s = s.replace("eɪ", "A");
410    s = s.replace("oʊ", "O");
411    s = s.replace("ɔɪ", "Y");
412    s = s.replace('ː', "");
413    s = s.replace('ɾ', "T");
414
415    s
416}
417
418/// Split text into chunks whose phoneme representations fit within the model's
419/// 510-character context limit.
420pub fn text_to_phoneme_chunks(text: &str) -> Result<Vec<String>, G2pError> {
421    const MAX_PHONEME_LEN: usize = 500;
422
423    let mut chunks = Vec::new();
424
425    for paragraph in text.split('\n') {
426        let paragraph = paragraph.trim();
427        if paragraph.is_empty() {
428            continue;
429        }
430
431        let phonemes = english_to_phonemes(paragraph)?;
432        if phonemes.len() <= MAX_PHONEME_LEN {
433            chunks.push(phonemes);
434            continue;
435        }
436
437        let sentences = split_sentences(paragraph);
438        let mut current_phonemes = String::new();
439
440        for sentence in &sentences {
441            let sentence = sentence.trim();
442            if sentence.is_empty() {
443                continue;
444            }
445            let sent_phonemes = english_to_phonemes(sentence)?;
446
447            if current_phonemes.is_empty() {
448                current_phonemes = sent_phonemes;
449            } else if current_phonemes.len() + 1 + sent_phonemes.len() <= MAX_PHONEME_LEN {
450                current_phonemes.push(' ');
451                current_phonemes.push_str(&sent_phonemes);
452            } else {
453                chunks.push(current_phonemes);
454                current_phonemes = sent_phonemes;
455            }
456        }
457
458        if !current_phonemes.is_empty() {
459            chunks.push(current_phonemes);
460        }
461    }
462
463    if chunks.is_empty() {
464        chunks.push(String::new());
465    }
466
467    Ok(chunks)
468}
469
470/// Split text into chunks whose phoneme representations fit within the model's
471/// 510-character context limit, with custom word-to-phoneme overrides.
472///
473/// Overrides map lowercase words to phoneme strings, checked before
474/// the lexicon and espeak fallback.
475pub fn text_to_phoneme_chunks_with_overrides(
476    text: &str,
477    overrides: &HashMap<String, String>,
478) -> Result<Vec<String>, G2pError> {
479    const MAX_PHONEME_LEN: usize = 500;
480
481    let mut chunks = Vec::new();
482
483    for paragraph in text.split('\n') {
484        let paragraph = paragraph.trim();
485        if paragraph.is_empty() {
486            continue;
487        }
488
489        let phonemes = english_to_phonemes_with_overrides(paragraph, overrides)?;
490        if phonemes.len() <= MAX_PHONEME_LEN {
491            chunks.push(phonemes);
492            continue;
493        }
494
495        let sentences = split_sentences(paragraph);
496        let mut current_phonemes = String::new();
497
498        for sentence in &sentences {
499            let sentence = sentence.trim();
500            if sentence.is_empty() {
501                continue;
502            }
503            let sent_phonemes = english_to_phonemes_with_overrides(sentence, overrides)?;
504
505            if current_phonemes.is_empty() {
506                current_phonemes = sent_phonemes;
507            } else if current_phonemes.len() + 1 + sent_phonemes.len() <= MAX_PHONEME_LEN {
508                current_phonemes.push(' ');
509                current_phonemes.push_str(&sent_phonemes);
510            } else {
511                chunks.push(current_phonemes);
512                current_phonemes = sent_phonemes;
513            }
514        }
515
516        if !current_phonemes.is_empty() {
517            chunks.push(current_phonemes);
518        }
519    }
520
521    if chunks.is_empty() {
522        chunks.push(String::new());
523    }
524
525    Ok(chunks)
526}
527
528fn split_sentences(text: &str) -> Vec<String> {
529    let mut sentences = Vec::new();
530    let mut current = String::new();
531
532    for ch in text.chars() {
533        current.push(ch);
534        if matches!(ch, '.' | '!' | '?') {
535            sentences.push(current.clone());
536            current.clear();
537        }
538    }
539
540    if !current.trim().is_empty() {
541        sentences.push(current);
542    }
543
544    sentences
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    #[test]
552    fn test_affricate_conversion() {
553        assert_eq!(espeak_ipa_to_kokoro("dʒʌmp"), "ʤʌmp");
554        assert_eq!(espeak_ipa_to_kokoro("tʃɪp"), "ʧɪp");
555    }
556
557    #[test]
558    fn test_diphthong_collapse() {
559        assert_eq!(espeak_ipa_to_kokoro("haɪ"), "hI");
560        assert_eq!(espeak_ipa_to_kokoro("naʊ"), "nW");
561        assert_eq!(espeak_ipa_to_kokoro("deɪ"), "dA");
562        assert_eq!(espeak_ipa_to_kokoro("goʊ"), "gO");
563        assert_eq!(espeak_ipa_to_kokoro("bɔɪ"), "bY");
564    }
565
566    #[test]
567    fn test_nurse_vowel() {
568        assert_eq!(espeak_ipa_to_kokoro("wɜːɹld"), "wɜɹld");
569        assert_eq!(espeak_ipa_to_kokoro("bɜːd"), "bɜɹd");
570    }
571
572    #[test]
573    fn test_length_mark_removal() {
574        assert_eq!(espeak_ipa_to_kokoro("siː"), "si");
575        assert_eq!(espeak_ipa_to_kokoro("fuːd"), "fud");
576    }
577
578    #[test]
579    fn test_flap_to_t() {
580        assert_eq!(espeak_ipa_to_kokoro("wɑɾɚ"), "wɑTɚ");
581    }
582
583    #[test]
584    fn test_full_espeak_output() {
585        let input = "həlˈoʊ wˈɜːld";
586        let expected = "həlˈO wˈɜɹld";
587        assert_eq!(espeak_ipa_to_kokoro(input), expected);
588    }
589
590    #[test]
591    fn test_split_sentences() {
592        let sentences = split_sentences("Hello world. How are you? I'm fine!");
593        assert_eq!(
594            sentences,
595            vec!["Hello world.", " How are you?", " I'm fine!"]
596        );
597    }
598
599    #[test]
600    fn test_g2p_convert_hello() {
601        let g2p = G2P::new();
602        let result = g2p.convert("hello").unwrap();
603        assert!(!result.is_empty());
604        assert!(
605            result.contains('O') || result.contains('o'),
606            "Expected phonemes for 'hello', got: {}",
607            result
608        );
609    }
610
611    #[test]
612    fn test_g2p_convert_sentence() {
613        let g2p = G2P::new();
614        let result = g2p.convert("Hello world").unwrap();
615        assert!(!result.is_empty());
616        assert!(
617            result.contains(' '),
618            "Expected space between words in: {}",
619            result
620        );
621    }
622
623    #[test]
624    fn test_g2p_convert_the_context() {
625        let g2p = G2P::new();
626        let result = g2p.convert("the apple").unwrap();
627        assert!(
628            result.contains("ði"),
629            "Expected 'ði' (the before vowel) in: {}",
630            result
631        );
632    }
633
634    #[test]
635    fn test_g2p_convert_number() {
636        let g2p = G2P::new();
637        let result = g2p.convert("42").unwrap();
638        assert!(!result.is_empty(), "Should produce phonemes for numbers");
639    }
640
641    #[test]
642    fn test_english_to_phonemes_api() {
643        let result = english_to_phonemes("hello world");
644        assert!(result.is_ok());
645        let phonemes = result.unwrap();
646        assert!(!phonemes.is_empty());
647    }
648
649    // -- Punctuation preservation tests --------------------------------------
650
651    #[test]
652    fn test_period_preserved() {
653        let result = english_to_phonemes("Hello.").unwrap();
654        assert!(
655            result.contains('.'),
656            "Period should appear in phonemes: {result}"
657        );
658    }
659
660    #[test]
661    fn test_comma_preserved() {
662        let result = english_to_phonemes("Hello, world.").unwrap();
663        assert!(
664            result.contains(','),
665            "Comma should appear in phonemes: {result}"
666        );
667        assert!(
668            result.contains('.'),
669            "Period should appear in phonemes: {result}"
670        );
671    }
672
673    #[test]
674    fn test_question_mark_preserved() {
675        let result = english_to_phonemes("Hello?").unwrap();
676        assert!(
677            result.contains('?'),
678            "Question mark should appear in phonemes: {result}"
679        );
680    }
681
682    #[test]
683    fn test_exclamation_preserved() {
684        let result = english_to_phonemes("Hello!").unwrap();
685        assert!(
686            result.contains('!'),
687            "Exclamation mark should appear in phonemes: {result}"
688        );
689    }
690
691    #[test]
692    fn test_two_sentences_have_period_between() {
693        let result = english_to_phonemes("Hello. World.").unwrap();
694        // Should have at least one period (ideally two) in the phoneme output
695        let period_count = result.chars().filter(|c| *c == '.').count();
696        assert!(
697            period_count >= 1,
698            "Expected period(s) between sentences, got: {result}"
699        );
700    }
701
702    #[test]
703    fn test_mixed_punctuation() {
704        let result = english_to_phonemes("Wait! What? Really.").unwrap();
705        assert!(
706            result.contains('!'),
707            "Exclamation should appear in phonemes: {result}"
708        );
709        assert!(
710            result.contains('?'),
711            "Question mark should appear in phonemes: {result}"
712        );
713        assert!(
714            result.contains('.'),
715            "Period should appear in phonemes: {result}"
716        );
717    }
718
719    #[test]
720    fn test_semicolon_preserved() {
721        let result = english_to_phonemes("Hello; world.").unwrap();
722        assert!(
723            result.contains(';'),
724            "Semicolon should appear in phonemes: {result}"
725        );
726    }
727}